In [None]:
import os, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
import hydra
from omegaconf import DictConfig
import wandb
from termcolor import cprint
from tqdm import tqdm

from src.datasets import ThingsMEGDataset
from src.models import BasicConvClassifier
from src.utils import set_seed

torch.cuda.is_available()

In [None]:
batch_size= 128
epochs= 20
lr= 0.0625

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
num_workers= 4
seed= 1234
use_wandb= False

data_dir= "data"

set_seed(seed)
logdir = ""

loader_args = {"batch_size": batch_size, "num_workers": num_workers}

train_set = ThingsMEGDataset("train", data_dir)
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)
val_set = ThingsMEGDataset("val", data_dir)
val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)
test_set = ThingsMEGDataset("test", data_dir)
test_loader = torch.utils.data.DataLoader(
    test_set, shuffle=False, batch_size=batch_size, num_workers=num_workers
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
import numpy as np

class BasicConvClassifier1(nn.Module):
    def __init__(
        self,
        num_classes: int,
        seq_len: int,
        in_channels: int,
        hid_dim: int = 128
    ) -> None:
        super().__init__()

        self.blocks = nn.Sequential(
            ConvBlock(in_channels, hid_dim),
        )


    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """_summary_
        Args:
            X ( b, c, t ): _description_
        Returns:
            X ( b, num_classes ): _description_
        """
        X = self.blocks(X)

        return X
    
    

class ConvBlock(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        input_window_samples=None,
        n_filters_time=20,
        filter_time_length=25,
        n_filters_spat=40,
        pool_time_length=75,
        pool_time_stride=15,
        final_conv_length=13,
        pool_mode="mean",
        split_first_layer=True,
        batch_norm=True,
        batch_norm_alpha=0.1,
        drop_prob=0.5,
    ) -> None:
        super().__init__()
        self.input_window_samples = input_window_samples
        self.n_filters_time = n_filters_time
        self.filter_time_length = filter_time_length
        self.n_filters_spat = n_filters_spat
        self.pool_time_length = pool_time_length
        self.pool_time_stride = pool_time_stride
        self.final_conv_length = final_conv_length
        self.pool_mode = pool_mode
        self.split_first_layer = split_first_layer
        self.batch_norm = batch_norm
        self.batch_norm_alpha = batch_norm_alpha
        self.drop_prob = drop_prob
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_classes = 1854

        self.conv1 = nn.Conv2d(1, self.n_filters_time, (self.filter_time_length, 1), stride=1)
        self.conv2 = nn.Conv2d(self.n_filters_time,self.n_filters_spat,(1,self.in_dim), stride=1,bias=not self.batch_norm)
        self.norm3 = nn.BatchNorm2d(self.n_filters_spat, momentum=self.batch_norm_alpha, affine=True, track_running_stats=True)
        self.avgp4 = nn.AvgPool2d(kernel_size=(self.pool_time_length, 1), stride=(self.pool_time_stride, 1), padding=0)
        self.drop5 = nn.Dropout(p=self.drop_prob, inplace=False)
        self.conv6 = nn.Conv2d(self.n_filters_spat, self.n_classes, (self.final_conv_length, 1), stride=(1, 1), bias=True)
        self.soft7 = nn.LogSoftmax(dim=1)

    def ensure4d(self, x):
        while len(x.shape) < 4:
            x = x.unsqueeze(-1)
        return x
    
    def square(self,x):
        return x * x

    def safe_log(self,x, eps=1e-6):
        return torch.log(torch.clamp(x, min=eps))

    def transpose_time_to_spat(self,x):
        return x.permute(0, 3, 2, 1)

    def np_to_th(self,
        X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
    ):
        if not hasattr(X, "__len__"):
            X = [X]
        X = np.asarray(X)
        if dtype is not None:
            X = X.astype(dtype)
        X_tensor = torch.tensor(X, requires_grad=requires_grad, **tensor_kwargs)
        if pin_memory:
            X_tensor = X_tensor.pin_memory()
        return X_tensor

    def squeeze_final_output(self,x):
        assert x.size()[3] == 1
        x = x[:, :, :, 0]
        if x.size()[2] == 1:
            x = x[:, :, 0]
        return x
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ensure4d(x)
        x = x.permute(0, 3, 2, 1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.norm3(x)
        x = self.square(x)
        x = self.avgp4(x)
        x = self.safe_log(x, eps=1e-6)
        x = self.drop5(x)
        x = self.conv6(x)
        x = self.soft7(x)
        x = self.squeeze_final_output(x)
        return x

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
import numpy as np

class BasicConvClassifier2(nn.Module):
    def __init__(
        self,
        num_classes: int,
        seq_len: int,
        in_channels: int,
        hid_dim: int = 128
    ) -> None:
        super().__init__()

        self.blocks = nn.Sequential(
            ConvBlock(in_channels, hid_dim),
            ConvBlock(hid_dim, hid_dim),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            Rearrange("b d 1 -> b d"),
            nn.Linear(hid_dim, num_classes),
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """_summary_
        Args:
            X ( b, c, t ): _description_
        Returns:
            X ( b, num_classes ): _description_
        """
        X = self.blocks(X)

        return self.head(X)


class ConvBlock(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        kernel_size: int = 3,
        p_drop: float = 0.1,
    ) -> None:
        super().__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.conv0 = nn.Conv1d(in_dim, out_dim, kernel_size, padding="same")
        self.conv1 = nn.Conv1d(out_dim, out_dim, kernel_size, padding="same")
        # self.conv2 = nn.Conv1d(out_dim, out_dim, kernel_size) # , padding="same")
        
        self.batchnorm0 = nn.BatchNorm1d(num_features=out_dim)
        self.batchnorm1 = nn.BatchNorm1d(num_features=out_dim)

        self.dropout = nn.Dropout(p_drop)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if self.in_dim == self.out_dim:
            X = self.conv0(X) + X  # skip connection
        else:
            X = self.conv0(X)

        X = F.gelu(self.batchnorm0(X))

        X = self.conv1(X) + X  # skip connection
        X = F.gelu(self.batchnorm1(X))

        # X = self.conv2(X)
        # X = F.glu(X, dim=-2)

        return self.dropout(X)

In [24]:
batch_size= 128
epochs= 50
lr= 0.0625*0.1
print(device)
# ------------------
#       Model
# ------------------
model = BasicConvClassifier1(
    train_set.num_classes, train_set.seq_len, train_set.num_channels
).to(device)
# model = BasicConvClassifier2(
#     train_set.num_classes, train_set.seq_len, train_set.num_channels
# ).to(device)
# model = BasicConvClassifier1(
#     train_set.num_classes, train_set.seq_len, train_set.num_channels
# ).to(device)

# ------------------
#     Optimizer
# ------------------
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# ------------------
#   Start training
# ------------------  
max_val_acc = 0
accuracy = Accuracy(
    task="multiclass", num_classes=train_set.num_classes, top_k=10
).to(device)
    
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    
    train_loss, train_acc, val_loss, val_acc = [], [], [], []
    
    model.train()
    for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
        X, y = X.to(device), y.to(device)

        y_pred = model(X)
        
        loss = F.cross_entropy(y_pred, y)
        train_loss.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = accuracy(y_pred, y)
        train_acc.append(acc.item())

    model.eval()
    for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
        X, y = X.to(device), y.to(device)
        
        with torch.no_grad():
            y_pred = model(X)
        
        val_loss.append(F.cross_entropy(y_pred, y).item())
        val_acc.append(accuracy(y_pred, y).item())

    print(f"Epoch {epoch+1}/{epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")
    torch.save(model.state_dict(), os.path.join(logdir, "model_last.pt"))

    if np.mean(val_acc) > max_val_acc:
        cprint("New best.", "cyan")
        torch.save(model.state_dict(), os.path.join(logdir, "model_best.pt"))
        max_val_acc = np.mean(val_acc)
        

# ----------------------------------
#  Start evaluation with best model
# ----------------------------------
model.load_state_dict(torch.load(os.path.join(logdir, "model_best.pt"), map_location=device))

preds = [] 
model.eval()
for X, subject_idxs in tqdm(test_loader, desc="Validation"):        
    preds.append(model(X.to(device)).detach().cpu())
    
preds = torch.cat(preds, dim=0).numpy()
np.save(os.path.join(logdir, "submission"), preds)
cprint(f"Submission {preds.shape} saved at {logdir}", "cyan")

cuda
Epoch 1/50


Train: 100%|██████████| 514/514 [02:27<00:00,  3.49it/s]
Validation: 100%|██████████| 129/129 [00:04<00:00, 29.86it/s]


Epoch 1/50 | train loss: 7.890 | train acc: 0.007 | val loss: 7.516 | val acc: 0.011
[36mNew best.[0m
Epoch 2/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.45it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.25it/s]


Epoch 2/50 | train loss: 7.434 | train acc: 0.019 | val loss: 7.496 | val acc: 0.015
[36mNew best.[0m
Epoch 3/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.46it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 32.93it/s]


Epoch 3/50 | train loss: 7.262 | train acc: 0.033 | val loss: 7.483 | val acc: 0.022
[36mNew best.[0m
Epoch 4/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.46it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 32.98it/s]


Epoch 4/50 | train loss: 7.040 | train acc: 0.058 | val loss: 7.512 | val acc: 0.021
Epoch 5/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.46it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.02it/s]


Epoch 5/50 | train loss: 6.823 | train acc: 0.083 | val loss: 7.579 | val acc: 0.022
Epoch 6/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.47it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.03it/s]


Epoch 6/50 | train loss: 6.637 | train acc: 0.109 | val loss: 7.618 | val acc: 0.022
Epoch 7/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.45it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.19it/s]


Epoch 7/50 | train loss: 6.469 | train acc: 0.135 | val loss: 7.698 | val acc: 0.022
Epoch 8/50


Train: 100%|██████████| 514/514 [02:29<00:00,  3.44it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.11it/s]


Epoch 8/50 | train loss: 6.320 | train acc: 0.157 | val loss: 7.742 | val acc: 0.023
[36mNew best.[0m
Epoch 9/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.46it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.18it/s]


Epoch 9/50 | train loss: 6.200 | train acc: 0.175 | val loss: 7.757 | val acc: 0.021
Epoch 10/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.45it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.61it/s]


Epoch 10/50 | train loss: 6.099 | train acc: 0.191 | val loss: 7.836 | val acc: 0.021
Epoch 11/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.46it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.32it/s]


Epoch 11/50 | train loss: 6.001 | train acc: 0.208 | val loss: 7.858 | val acc: 0.022
Epoch 12/50


Train: 100%|██████████| 514/514 [02:29<00:00,  3.45it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.26it/s]


Epoch 12/50 | train loss: 5.918 | train acc: 0.222 | val loss: 7.914 | val acc: 0.023
[36mNew best.[0m
Epoch 13/50


Train: 100%|██████████| 514/514 [02:28<00:00,  3.46it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.10it/s]


Epoch 13/50 | train loss: 5.847 | train acc: 0.234 | val loss: 7.930 | val acc: 0.024
[36mNew best.[0m
Epoch 14/50


Train: 100%|██████████| 514/514 [02:29<00:00,  3.45it/s]
Validation: 100%|██████████| 129/129 [00:03<00:00, 33.17it/s]


Epoch 14/50 | train loss: 5.781 | train acc: 0.246 | val loss: 7.961 | val acc: 0.023
Epoch 15/50


Train:  71%|███████▏  | 367/514 [01:46<00:42,  3.43it/s]


KeyboardInterrupt: 

In [None]:
model_path = "model_best.pt"
set_seed(seed)
savedir = os.path.dirname(model_path)

# ------------------
#    Dataloader
# ------------------    
test_set = ThingsMEGDataset("test", data_dir)
test_loader = torch.utils.data.DataLoader(
    test_set, shuffle=False, batch_size=batch_size, num_workers=num_workers
)

# ------------------
#       Model
# ------------------
model = BasicConvClassifier(
    test_set.num_classes, test_set.seq_len, test_set.num_channels
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))

# ------------------
#  Start evaluation
# ------------------ 
preds = [] 
model.eval()
for X, subject_idxs in tqdm(test_loader, desc="Validation"):        
    preds.append(model(X.to(device)).detach().cpu())
    
preds = torch.cat(preds, dim=0).numpy()
np.save(os.path.join(savedir, "submission"), preds)
cprint(f"Submission {preds.shape} saved at {savedir}", "cyan")