In [1]:

import os, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
import hydra
from omegaconf import DictConfig
import wandb
from termcolor import cprint
from tqdm import tqdm

from src.datasets import ThingsMEGDataset
from src.densenet import DenseNetClassifier
from src.resnet2d import resnet50_2d
from src.resnet1d import resnet50_1d

import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms

from src.utils import set_seed
from scipy.signal import butter, sosfiltfilt
from datetime import datetime as dt

In [9]:
# ハイパーパラメータの設定
args = DictConfig({
    'seed': 1234,
    'data_dir': '/mnt/mp_nas_mks/labmember/d.fukunaga/data',
    'batch_size': 128,
    'num_workers': 8,
    'lr': 0.001,
    'epochs': 30,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'use_wandb': False
})

set_seed(args.seed)
logdir = 'outputs'

if args.use_wandb:
    wandb.init(mode="online", dir=logdir, project="MEG-classification")


In [3]:
import torchaudio.transforms as T
from torchvision.transforms import Compose

class ResampleTransform:
    def __init__(self, orig_freq, new_freq):
        self.resample = T.Resample(orig_freq, new_freq)

    def __call__(self, x):
        return self.resample(x)

class BandpassFilterTransform:
    def __init__(self, low_cutoff, high_cutoff, fs, order=5):
        self.sos = butter(order, [low_cutoff / (0.5 * fs), high_cutoff / (0.5 * fs)], btype='band', output='sos')

    def __call__(self, x):
        x_np = np.ascontiguousarray(x.detach().cpu().numpy())  # NumPy配列を連続配列としてコピー
        filtered = sosfiltfilt(self.sos, x_np, axis=-1)  # axisを指定してフィルタを適用
        return torch.tensor(filtered, dtype=torch.float32, device=x.device)  # 元のデバイスに戻す

class NormalizeTransform:
    def __call__(self, x):
        return (x - x.mean(dim=-1, keepdim=True)) / x.std(dim=-1, keepdim=True)

# 複合transform
wave_transforms = Compose([
    ResampleTransform(orig_freq=1000, new_freq=128),
    # BandpassFilterTransform(low_cutoff=1, high_cutoff=40, fs=128),
    NormalizeTransform()
])


In [4]:
# ------------------
#    Dataloader
# ------------------
loader_args = {"batch_size": args.batch_size, "num_workers": args.num_workers}

In [5]:
train_set = ThingsMEGDataset("train", args.data_dir, transforms=wave_transforms)
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)
val_set = ThingsMEGDataset("val", args.data_dir, transforms=wave_transforms)
val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)
test_set = ThingsMEGDataset("test", args.data_dir, transforms=wave_transforms)
test_loader = torch.utils.data.DataLoader(
    test_set,
    shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
)

In [6]:
test_set[4000]

(tensor([[-0.4431, -2.0549, -0.5833,  ...,  1.7070,  0.8035,  0.0636],
         [-0.6210, -2.2190, -0.6947,  ...,  1.4975,  0.3503,  0.7291],
         [-0.7366, -2.4049, -0.9725,  ...,  1.0039, -0.1360,  0.9885],
         ...,
         [-1.1306,  1.4246,  0.3441,  ..., -0.6665,  0.9779,  0.3473],
         [ 0.9576,  0.8476,  1.3057,  ..., -1.2348,  0.7442, -0.0424],
         [-1.8650, -0.2206, -0.7674,  ...,  0.8498,  0.2762,  0.7946]]),
 tensor(0))

In [7]:
# ------------------
#       Model
# ------------------
# model = DenseNetClassifier(
#     train_set.num_classes, train_set.seq_len, train_set.num_channels
# ).to(args.device)

model = resnet50_1d(
    num_classes=train_set.num_classes, in_channels=train_set.num_channels
).to(args.device)

# ------------------
#     Optimizer
# ------------------
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# スケジューラ―
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# ------------------
#   Start training
# ------------------
max_val_acc = 0
accuracy = Accuracy(
    task="multiclass", num_classes=train_set.num_classes, top_k=10
).to(args.device)

In [10]:
import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm

# # CUDAの設定を確認・調整
cudnn.enabled = False
torch.cuda.empty_cache()

# トレーニングループ
max_val_acc = 0
accuracy = Accuracy(task="multiclass", num_classes=train_set.num_classes, top_k=10).to(args.device)

for epoch in range(args.epochs):
    print(f"Epoch {epoch+1}/{args.epochs}")

    train_loss, train_acc, val_loss, val_acc = [], [], [], []

    model.train()
    for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
        X, y, subject_idxs = X.to(args.device), y.to(args.device), subject_idxs.to(args.device)
        # print(f"Input shape: {X.shape}")
        # print(f"Subject indices shape: {subject_idxs.shape}")

        try:
            X = X.clone().detach()
            subject_idxs = subject_idxs.clone().detach()

            y_pred = model(X, subject_idxs)

            loss = F.cross_entropy(y_pred, y)
            train_loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = accuracy(y_pred, y)
            train_acc.append(acc.item())
        except RuntimeError as e:
            print(f"RuntimeError: {e}")
            torch.cuda.empty_cache()
            continue

    model.eval()
    for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
        X, y, subject_idxs = X.to(args.device), y.to(args.device), subject_idxs.to(args.device)

        with torch.no_grad():
            y_pred = model(X, subject_idxs)

        val_loss.append(F.cross_entropy(y_pred, y).item())
        val_acc.append(accuracy(y_pred, y).item())

    scheduler.step()

    print(f"Epoch {epoch+1}/{args.epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")
    torch.save(model.state_dict(), os.path.join(logdir, "model_last.pt"))
    if args.use_wandb:
        wandb.log({
            "train_loss": np.mean(train_loss),
            "train_acc": np.mean(train_acc),
            "val_loss": np.mean(val_loss),
            "val_acc": np.mean(val_acc),
        })

    if np.mean(val_acc) > max_val_acc:
        cprint("New best.", "cyan")
        torch.save(model.state_dict(), os.path.join(logdir, "model_best.pt"))
        max_val_acc = np.mean(val_acc)


Epoch 1/30


Train:   0%|          | 0/514 [00:00<?, ?it/s]

In [None]:
# ベストモデルでの評価
model.load_state_dict(torch.load(os.path.join(logdir, "model_best.pt"), map_location=args.device))

preds = []
model.eval()
for X, subject_idxs in tqdm(test_loader, desc="Validation"):
    with torch.no_grad():
        pred = model(X.to(args.device), subject_idxs.to(args.device))
    preds.append(pred.detach().cpu())

preds = torch.cat(preds, dim=0).numpy()
np.save(os.path.join(logdir, "submission"), preds)
cprint(f"Submission {preds.shape} saved at {logdir}", "cyan")


Validation: 100%|██████████| 129/129 [00:02<00:00, 48.24it/s]


[36mSubmission (16432, 1854) saved at outputs[0m
