In [2]:
!pip install torchmetrics
!pip install omegaconf
!pip install wandb
!pip install einops
import os, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from omegaconf import DictConfig
import wandb
from termcolor import cprint
from tqdm import tqdm
import os
import numpy as np
import torch
from typing import Tuple
from termcolor import cprint
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
import zipfile
import random
import numpy as np
import torch




In [4]:
class ThingsMEGDataset(torch.utils.data.Dataset):
    def __init__(self, split: str, data_dir: str = "data") -> None:
        super().__init__()

        assert split in ["train", "val", "test"], f"Invalid split: {split}"
        self.split = split
        self.num_classes = 1854

        self.X = torch.load(os.path.join(data_dir, f"{split}_X.pt"))
        self.subject_idxs = torch.load(os.path.join(data_dir, f"{split}_subject_idxs.pt"))

        if split in ["train", "val"]:
            self.y = torch.load(os.path.join(data_dir, f"{split}_y.pt"))
            assert len(torch.unique(self.y)) == self.num_classes, "Number of classes do not match."

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, i):
        if hasattr(self, "y"):
            return self.X[i], self.y[i], self.subject_idxs[i]
        else:
            return self.X[i], self.subject_idxs[i]

    @property
    def num_channels(self) -> int:
        return self.X.shape[1]

    @property
    def seq_len(self) -> int:
        return self.X.shape[2]

In [None]:
class BasicConvClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int,
        seq_len: int,
        in_channels: int,
        hid_dim: int = 128
    ) -> None:
        super().__init__()

        self.blocks = nn.Sequential(
            ConvBlock(in_channels, hid_dim),
            ConvBlock(hid_dim, hid_dim),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            Rearrange("b d 1 -> b d"),
            nn.Linear(hid_dim, num_classes),
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """_summary_
        Args:
            X ( b, c, t ): _description_
        Returns:
            X ( b, num_classes ): _description_
        """
        X = self.blocks(X)

        return self.head(X)


class ConvBlock(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        kernel_size: int = 3,
        p_drop: float = 0.1,
    ) -> None:
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.conv0 = nn.Conv1d(in_dim, out_dim, kernel_size, padding="same")
        self.conv1 = nn.Conv1d(out_dim, out_dim, kernel_size, padding="same")
        # self.conv2 = nn.Conv1d(out_dim, out_dim, kernel_size) # , padding="same")

        self.batchnorm0 = nn.BatchNorm1d(num_features=out_dim)
        self.batchnorm1 = nn.BatchNorm1d(num_features=out_dim)

        self.dropout = nn.Dropout(p_drop)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if self.in_dim == self.out_dim:
            X = self.conv0(X) + X  # skip connection
        else:
            X = self.conv0(X)

        X = F.gelu(self.batchnorm0(X))

        X = self.conv1(X) + X  # skip connection
        X = F.gelu(self.batchnorm1(X))

        # X = self.conv2(X)
        # X = F.glu(X, dim=-2)

        return self.dropout(X)

In [None]:
class BasicLSTMClassifier(nn.Module):
    def __init__(self, num_classes: int, seq_len: int, in_channels: int, hid_dim: int = 50, dropout: float = 0.5):
        super().__init__()
        self.batch_norm_input = nn.BatchNorm1d(seq_len)
        self.lstm = nn.LSTM(in_channels, hid_dim, 1, batch_first=True, bidirectional=True)
        self.batch_norm_lstm = nn.BatchNorm1d(hid_dim * 2)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(hid_dim * 2, num_classes)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        X = X.permute(0, 2, 1)  # Reorder dimensions to [batch, seq_len, in_channels]
        X = self.batch_norm_input(X)
        X, _ = self.lstm(X)  # Unpack the tuple returned by LSTM
        X = self.batch_norm_lstm(X[:, -1, :])  # Apply batch normalization and use the output from the last time step
        X = self.dropout(X)  # Apply dropout
        X = self.linear(X)
        return X

In [None]:
class SimpleGRUClassifier(nn.Module):
    def __init__(self, num_classes: int, seq_len: int, in_channels: int, hid_dim: int = 500, dropout: float = 0.5):
        super(SimpleGRUClassifier, self).__init__()
        self.batch_norm_input = nn.BatchNorm1d(seq_len)
        self.gru = nn.GRU(in_channels, hid_dim, batch_first=True, bidirectional=True)
        self.batch_norm_gru = nn.BatchNorm1d(hid_dim * 2)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(hid_dim * 2, num_classes)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        X = X.permute(0, 2, 1)  # Reorder dimensions to [batch, seq_len, in_channels]
        X = self.batch_norm_input(X)
        X, _ = self.gru(X)  # Unpack the tuple returned by GRU
        X = self.batch_norm_gru(X[:, -1, :])  # Apply batch normalization and use the output from the last time step
        X = self.dropout(X)  # Apply dropout
        X = self.linear(X)
        return X

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
zip_file_path = '/content/drive/MyDrive/Colab Notebooks/DL_basic_final/data.zip'

save_path = '/content/MEG_data/data'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(save_path)

In [6]:
def set_seed(seed: int = 0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [None]:
config = {
    "batch_size": 128,
    "epochs" : 80,
    "lr" : 0.001,
    "device" : "cuda:0",
    "num_workers" : 4,
    "seed" : 1234,
    "use_wandb" : True,
    "data_dir" : '/content/MEG_data/data',
}

In [7]:
def run(args: DictConfig):
    set_seed(args.seed)

    logdir = '/content/MEG_data/log'

    # ------------------
    #    Dataloader
    # ------------------
    loader_args = {"batch_size": args.batch_size, "num_workers": args.num_workers}

    train_set = ThingsMEGDataset("train", args.data_dir)
    train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)
    val_set = ThingsMEGDataset("val", args.data_dir)
    val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)
    test_set = ThingsMEGDataset("test", args.data_dir)
    test_loader = torch.utils.data.DataLoader(
        test_set, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers
    )

    # ------------------
    #       Model
    # ------------------
    model = SimpleGRUClassifier(
        train_set.num_classes, train_set.seq_len, train_set.num_channels, dropout=0.5
    ).to(args.device)

    # ------------------
    #     Optimizer
    # ------------------
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,weight_decay=1e-5)

    # ------------------
    # Learning Rate Scheduler
    # ------------------
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

    # ------------------
    #   Start training
    # ------------------
    max_val_acc = 0
    accuracy = Accuracy(
        task="multiclass", num_classes=train_set.num_classes, top_k=10
    ).to(args.device)

    for epoch in range(args.epochs):
        print(f"Epoch {epoch+1}/{args.epochs}")

        train_loss, train_acc, val_loss, val_acc = [], [], [], []

        model.train()
        for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
            X, y = X.to(args.device), y.to(args.device)
            y_pred = model(X)
            loss = F.cross_entropy(y_pred, y)
            train_loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = accuracy(y_pred, y)
            train_acc.append(acc.item())

        model.eval()
        for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
            X, y = X.to(args.device), y.to(args.device)

            with torch.no_grad():
                y_pred = model(X)

            val_loss.append(F.cross_entropy(y_pred, y).item())
            val_acc.append(accuracy(y_pred, y).item())

        print(f"Epoch {epoch+1}/{args.epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")
        torch.save(model.state_dict(), os.path.join(logdir, "model_last.pt"))

        if np.mean(val_acc) > max_val_acc:
            cprint("New best.", "cyan")
            torch.save(model.state_dict(), os.path.join(logdir, "model_best.pt"))
            max_val_acc = np.mean(val_acc)


    # ----------------------------------
    #  Start evaluation with best model
    # ----------------------------------
    model.load_state_dict(torch.load(os.path.join(logdir, "model_best.pt"), map_location=args.device))

    preds = []
    model.eval()
    for X, subject_idxs in tqdm(test_loader, desc="Validation"):
        preds.append(model(X.to(args.device)).detach().cpu())

    preds = torch.cat(preds, dim=0).numpy()
    np.save(os.path.join(logdir, "submission"), preds)
    cprint(f"Submission {preds.shape} saved at {logdir}", "cyan")


In [None]:
args = DictConfig(config)
run(args)