In [1]:
import os, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
import hydra
from omegaconf import DictConfig
import wandb
from termcolor import cprint
from tqdm import tqdm

from src.datasets import ThingsMEGDataset
from src.densenet import DenseNetClassifier
from src.resnet2d import resnet50_2d
from src.resnet1d import resnet50_1d

import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms

from src.utils import set_seed
from scipy.signal import butter, sosfiltfilt
from datetime import datetime as dt

In [2]:
# ハイパーパラメータの設定
args = DictConfig({
    'seed': 1234,
    'data_dir': 'data',
    'batch_size': 128,
    'num_workers': 8,
    'lr': 0.001,
    'epochs': 30,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'use_wandb': False
})

set_seed(args.seed)
logdir = 'outputs'

if args.use_wandb:
    wandb.init(mode="online", dir=logdir, project="MEG-classification")


In [3]:
# ------------------
#    Dataloader
# ------------------
loader_args = {"batch_size": args.batch_size, "num_workers": args.num_workers}

In [4]:
import torch

class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        return (data - self.mean) / self.std

def calculate_mean_std(dataset):
    all_data = torch.cat([dataset[i][0] for i in range(len(dataset))], dim=1)
    mean = all_data.mean()
    std = all_data.std()
    return mean, std


In [5]:
train_set = ThingsMEGDataset("train", args.data_dir, transforms=None)
mean, std = calculate_mean_std(train_set)
normalize_transform = Normalize(mean, std)
train_set = ThingsMEGDataset("train", args.data_dir, transforms=normalize_transform)
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)
val_set = ThingsMEGDataset("val", args.data_dir, transforms=normalize_transform)
val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)
test_set = ThingsMEGDataset("test", args.data_dir, transforms=normalize_transform)
test_loader = torch.utils.data.DataLoader(
    test_set,
    shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
)


In [6]:
train_set[4]

(tensor([[-0.0133, -0.1427, -0.2028,  ...,  0.5933,  0.4664,  0.1367],
         [-0.2018, -0.2515, -0.2093,  ...,  0.3536,  0.1010, -0.3133],
         [-0.0569, -0.1831, -0.1927,  ...,  0.1002, -0.3236, -0.6687],
         ...,
         [ 0.0532,  0.5752,  0.5588,  ...,  0.0176, -0.1719, -0.0257],
         [ 0.6802,  0.8113,  0.6213,  ...,  0.1411, -0.1011, -0.2018],
         [-0.2484,  0.1821,  0.3607,  ...,  0.4149,  0.2179, -0.0297]]),
 tensor(1556),
 tensor(0))

In [7]:
# ------------------
#       Model
# ------------------
# model = DenseNetClassifier(
#     train_set.num_classes, train_set.seq_len, train_set.num_channels
# ).to(args.device)

model = resnet50_1d(
    num_classes=train_set.num_classes, in_channels=train_set.num_channels
).to(args.device)

# ------------------
#     Optimizer
# ------------------
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# スケジューラ―
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# ------------------
#   Start training
# ------------------
max_val_acc = 0
accuracy = Accuracy(
    task="multiclass", num_classes=train_set.num_classes, top_k=10
).to(args.device)

In [8]:
import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm

# # CUDAの設定を確認・調整
# cudnn.enabled = False
# torch.cuda.empty_cache()

# トレーニングループ
max_val_acc = 0
accuracy = Accuracy(task="multiclass", num_classes=train_set.num_classes, top_k=10).to(args.device)

for epoch in range(args.epochs):
    print(f"Epoch {epoch+1}/{args.epochs}")

    train_loss, train_acc, val_loss, val_acc = [], [], [], []

    model.train()
    for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
        X, y, subject_idxs = X.to(args.device), y.to(args.device), subject_idxs.to(args.device)
        # print(f"Input shape: {X.shape}")
        # print(f"Subject indices shape: {subject_idxs.shape}")

        try:
            X = X.clone().detach()
            subject_idxs = subject_idxs.clone().detach()

            y_pred = model(X, subject_idxs)

            loss = F.cross_entropy(y_pred, y)
            train_loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = accuracy(y_pred, y)
            train_acc.append(acc.item())
        except RuntimeError as e:
            print(f"RuntimeError: {e}")
            torch.cuda.empty_cache()
            continue

    model.eval()
    for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
        X, y, subject_idxs = X.to(args.device), y.to(args.device), subject_idxs.to(args.device)

        with torch.no_grad():
            y_pred = model(X, subject_idxs)

        val_loss.append(F.cross_entropy(y_pred, y).item())
        val_acc.append(accuracy(y_pred, y).item())

    scheduler.step()

    print(f"Epoch {epoch+1}/{args.epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")
    torch.save(model.state_dict(), os.path.join(logdir, "model_last.pt"))
    if args.use_wandb:
        wandb.log({
            "train_loss": np.mean(train_loss),
            "train_acc": np.mean(train_acc),
            "val_loss": np.mean(val_loss),
            "val_acc": np.mean(val_acc),
        })

    if np.mean(val_acc) > max_val_acc:
        cprint("New best.", "cyan")
        torch.save(model.state_dict(), os.path.join(logdir, "model_best.pt"))
        max_val_acc = np.mean(val_acc)


Epoch 1/30


Train: 100%|██████████| 514/514 [00:34<00:00, 15.11it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.26it/s]


Epoch 1/30 | train loss: 7.592 | train acc: 0.008 | val loss: 7.516 | val acc: 0.010
[36mNew best.[0m
Epoch 2/30


Train: 100%|██████████| 514/514 [00:34<00:00, 15.05it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.23it/s]


Epoch 2/30 | train loss: 7.511 | train acc: 0.008 | val loss: 7.511 | val acc: 0.010
[36mNew best.[0m
Epoch 3/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.22it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.46it/s]


Epoch 3/30 | train loss: 7.503 | train acc: 0.008 | val loss: 7.500 | val acc: 0.010
Epoch 4/30


Train: 100%|██████████| 514/514 [00:30<00:00, 16.82it/s]
Validation: 100%|██████████| 129/129 [00:04<00:00, 26.60it/s]


Epoch 4/30 | train loss: 7.503 | train acc: 0.008 | val loss: 24.368 | val acc: 0.008
Epoch 5/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.20it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.85it/s]


Epoch 5/30 | train loss: 7.498 | train acc: 0.008 | val loss: 7.507 | val acc: 0.010
[36mNew best.[0m
Epoch 6/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.24it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.20it/s]


Epoch 6/30 | train loss: 7.488 | train acc: 0.009 | val loss: 7.526 | val acc: 0.011
[36mNew best.[0m
Epoch 7/30


Train: 100%|██████████| 514/514 [00:32<00:00, 16.06it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.05it/s]


Epoch 7/30 | train loss: 7.481 | train acc: 0.009 | val loss: 7.507 | val acc: 0.010
Epoch 8/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.09it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.83it/s]


Epoch 8/30 | train loss: 7.475 | train acc: 0.010 | val loss: 7.527 | val acc: 0.010
Epoch 9/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.08it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.36it/s]


Epoch 9/30 | train loss: 7.468 | train acc: 0.010 | val loss: 7.601 | val acc: 0.011
[36mNew best.[0m
Epoch 10/30


Train: 100%|██████████| 514/514 [00:32<00:00, 16.03it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 35.70it/s]


Epoch 10/30 | train loss: 7.460 | train acc: 0.011 | val loss: 7.543 | val acc: 0.010
Epoch 11/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.06it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.86it/s]


Epoch 11/30 | train loss: 7.433 | train acc: 0.014 | val loss: 7.561 | val acc: 0.011
Epoch 12/30


Train: 100%|██████████| 514/514 [00:32<00:00, 15.96it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 32.90it/s]


Epoch 12/30 | train loss: 7.426 | train acc: 0.014 | val loss: 7.556 | val acc: 0.011
Epoch 13/30


Train: 100%|██████████| 514/514 [00:32<00:00, 16.01it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.34it/s]


Epoch 13/30 | train loss: 7.421 | train acc: 0.014 | val loss: 7.573 | val acc: 0.011
Epoch 14/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.09it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.29it/s]


Epoch 14/30 | train loss: 7.416 | train acc: 0.015 | val loss: 7.575 | val acc: 0.011
Epoch 15/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.12it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.18it/s]


Epoch 15/30 | train loss: 7.407 | train acc: 0.016 | val loss: 7.604 | val acc: 0.011
Epoch 16/30


Train: 100%|██████████| 514/514 [00:32<00:00, 16.06it/s]
Validation: 100%|██████████| 129/129 [00:04<00:00, 27.11it/s]


Epoch 16/30 | train loss: 7.400 | train acc: 0.016 | val loss: 7.604 | val acc: 0.011
Epoch 17/30


Train: 100%|██████████| 514/514 [00:32<00:00, 16.04it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.87it/s]


Epoch 17/30 | train loss: 7.391 | train acc: 0.017 | val loss: 7.670 | val acc: 0.011
Epoch 18/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.11it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 32.69it/s]


Epoch 18/30 | train loss: 7.376 | train acc: 0.018 | val loss: 7.664 | val acc: 0.011
Epoch 19/30


Train: 100%|██████████| 514/514 [00:32<00:00, 15.87it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.10it/s]


Epoch 19/30 | train loss: 7.361 | train acc: 0.019 | val loss: 7.657 | val acc: 0.010
Epoch 20/30


Train: 100%|██████████| 514/514 [00:32<00:00, 15.92it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.53it/s]


Epoch 20/30 | train loss: 7.343 | train acc: 0.019 | val loss: 7.702 | val acc: 0.011
Epoch 21/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.09it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.83it/s]


Epoch 21/30 | train loss: 7.308 | train acc: 0.023 | val loss: 7.713 | val acc: 0.011
Epoch 22/30


Train: 100%|██████████| 514/514 [00:32<00:00, 16.04it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.63it/s]


Epoch 22/30 | train loss: 7.300 | train acc: 0.024 | val loss: 7.747 | val acc: 0.012
[36mNew best.[0m
Epoch 23/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.15it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.51it/s]


Epoch 23/30 | train loss: 7.297 | train acc: 0.024 | val loss: 7.718 | val acc: 0.012
[36mNew best.[0m
Epoch 24/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.11it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.39it/s]


Epoch 24/30 | train loss: 7.292 | train acc: 0.024 | val loss: 7.731 | val acc: 0.011
Epoch 25/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.38it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 35.55it/s]


Epoch 25/30 | train loss: 7.290 | train acc: 0.024 | val loss: 7.718 | val acc: 0.011
Epoch 26/30


Train: 100%|██████████| 514/514 [00:30<00:00, 16.63it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.74it/s]


Epoch 26/30 | train loss: 7.287 | train acc: 0.025 | val loss: 7.749 | val acc: 0.011
Epoch 27/30


Train: 100%|██████████| 514/514 [00:30<00:00, 16.59it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 34.48it/s]


Epoch 27/30 | train loss: 7.284 | train acc: 0.025 | val loss: 7.789 | val acc: 0.011
Epoch 28/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.51it/s]
Validation: 100%|██████████| 129/129 [00:04<00:00, 27.12it/s]


Epoch 28/30 | train loss: 7.279 | train acc: 0.026 | val loss: 7.739 | val acc: 0.011
Epoch 29/30


Train: 100%|██████████| 514/514 [00:30<00:00, 17.10it/s]
Validation: 100%|██████████| 129/129 [00:04<00:00, 27.62it/s]


Epoch 29/30 | train loss: 7.277 | train acc: 0.026 | val loss: 7.777 | val acc: 0.012
Epoch 30/30


Train: 100%|██████████| 514/514 [00:31<00:00, 16.41it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.17it/s]


Epoch 30/30 | train loss: 7.274 | train acc: 0.026 | val loss: 7.747 | val acc: 0.011


In [9]:
# ベストモデルでの評価
model.load_state_dict(torch.load(os.path.join(logdir, "model_best.pt"), map_location=args.device))

preds = []
model.eval()
for X, subject_idxs in tqdm(test_loader, desc="Validation"):
    with torch.no_grad():
        pred = model(X.to(args.device), subject_idxs.to(args.device))
    preds.append(pred.detach().cpu())

preds = torch.cat(preds, dim=0).numpy()
np.save(os.path.join(logdir, "submission"), preds)
cprint(f"Submission {preds.shape} saved at {logdir}", "cyan")


Validation: 100%|██████████| 129/129 [00:03<00:00, 34.45it/s]


[36mSubmission (16432, 1854) saved at outputs[0m
