In [1]:
###-------------------------------------------------------------------------------------------------------------------
#         imports
###-------------------------------------------------------------------------------------------------------------------

from training import balanced_data_shuffle, training_loop, evaluate
from utils import get_dict_raw_data
from models import MRIVisionTransformers
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import OneHotEncoder, LabelEncoder


In [2]:
###-------------------------------------------------------------------------------------------------------------------
#         hyperparameters
###-------------------------------------------------------------------------------------------------------------------

config = {
    # general
    "epochs": 100,
    "batch_size": 4,
    "lr": 1e-3,

    # model
    "d_model_output": 512,
    "d_model_input": 400,
    "dropout" : 0.1,
    "attention_dropout" : 0.1,
    "num_heads": 4,
    "num_layers": 0, # TBA?

    # optimizer
    "lambda_si": 0.6,
    "lambda_td": 0.4
}

In [3]:
###-------------------------------------------------------------------------------------------------------------------
#         subject ID list
###-------------------------------------------------------------------------------------------------------------------

IDs = [100307,  117122,  131722,  153025,  211720,
100408,  118528,  133019,  154734,  212318,      
101107,  118730,  133928,  156637,  214423,        
101309,  118932,  135225,  159340,  221319,       
101915,  120111,  135932,  160123,  239944 ,      
103111,  122317,  136833,  161731,  245333,        
103414,  122620,  138534,  162733,  280739,        
103818, 123117,  139637,  163129,  298051,        
105014,  123925,  140925,  176542,  366446,        
105115,  124422,  144832,  178950,  397760,        
106016,  125525,  146432,  188347,  414229,        
108828,  126325,  147737,  189450,  499566,
110411,  127630,  148335,  190031,  654754,
111312,  127933,  148840,  192540,  672756,
111716,  128127,  149337,  196750,  751348,
113619,  128632,  149539,  198451,  756055,
113922,  129028,  149741,  199655,  792564,
114419,  130013,  151223,  201111,  856766,
115320,  130316,  151526,  208226,  857263]

In [4]:
###-------------------------------------------------------------------------------------------------------------------
#         joining train and test dataframes from all subjects
###-------------------------------------------------------------------------------------------------------------------

# "C:/Users/emy8/OneDrive/Documents/EPFL/Master/MA3/DeepLbiomed/Project/DATA/", "1003" 

data_dict_train, data_dict_test = get_dict_raw_data("/Users/eddyvonmatt/Desktop/MIPLab-TeamCEE-DeepLearningforBiomed-main/DATA/", IDs[0:3])

In [5]:
print(data_dict_train)

   label_id     task_id                                                mat
0    100307    GAMBLING  [[0.2356709594115424, 0.03883497545044236, 0.1...
1    100307       REST1  [[0.21854491103466994, 0.07509374392964863, 0....
2    100307       MOTOR  [[0.2141270371266362, 0.040754342863046, 0.084...
3    100307    LANGUAGE  [[0.2317390561241142, 0.06537822245634475, 0.0...
4    100307      SOCIAL  [[0.27075755129825896, 0.07942572217389814, 0....
5    100307       REST2  [[0.2509722712619662, 0.06429771271159306, 0.1...
6    100307          WM  [[0.28122430896568573, 0.12358947079320645, 0....
7    100307     EMOTION  [[0.27626702525883573, 0.03827488524289221, 0....
8    100307  RELATIONAL  [[0.2709434948110919, 0.08915439190003989, 0.1...
9    117122       REST2  [[0.38164660255395016, 0.24027807540608934, 0....
10   117122     EMOTION  [[0.2922083137463598, 0.04759277700431855, 0.1...
11   117122    GAMBLING  [[0.25858517393823743, 0.05327319230138714, 0....
12   117122      SOCIAL  

In [6]:
###-------------------------------------------------------------------------------------------------------------------
#         label encoding
###-------------------------------------------------------------------------------------------------------------------

# one hot encoding
if False:
    enc_labels = OneHotEncoder(handle_unknown='ignore')
    enc_tasks = OneHotEncoder(handle_unknown='ignore')

    enc_labels.fit(data_dict_train["label_id"].to_numpy().reshape(-1, 1))
    enc_tasks.fit(data_dict_train["task_id"].to_numpy().reshape(-1, 1))

    enc_train_label_encodings = enc_labels.transform(data_dict_train["label_id"].to_numpy().reshape(-1, 1)).toarray()
    enc_train_task_encodings = enc_tasks.transform(data_dict_train["task_id"].to_numpy().reshape(-1, 1)).toarray()

    enc_test_label_encodings = enc_labels.transform(data_dict_test["label_id"].to_numpy().reshape(-1, 1)).toarray()
    enc_test_task_encodings = enc_tasks.transform(data_dict_test["task_id"].to_numpy().reshape(-1, 1)).toarray()

    data_dict_train["enc_label_id"] = enc_train_label_encodings.tolist()
    data_dict_train["enc_task_id"] = enc_train_task_encodings.tolist()

    data_dict_test["enc_label_id"] = enc_test_label_encodings.tolist()
    data_dict_test["enc_task_id"] = enc_test_task_encodings.tolist()

# label encoding
enc_labels = LabelEncoder()
enc_tasks = LabelEncoder()

enc_labels.fit(data_dict_train["label_id"].tolist())
enc_tasks.fit(data_dict_train["task_id"].tolist())

enc_train_label_encodings = enc_labels.transform(data_dict_train["label_id"].tolist())
enc_train_task_encodings = enc_tasks.transform(data_dict_train["task_id"].tolist())

enc_test_label_encodings = enc_labels.transform(data_dict_test["label_id"].tolist())
enc_test_task_encodings = enc_tasks.transform(data_dict_test["task_id"].tolist())

data_dict_train["enc_label_id"] = enc_train_label_encodings
data_dict_train["enc_task_id"] = enc_train_task_encodings
data_dict_test["enc_label_id"] = enc_test_label_encodings
data_dict_test["enc_task_id"] = enc_test_task_encodings

#enc.inverse_transform() to reverse

In [7]:
###-------------------------------------------------------------------------------------------------------------------
#         initializing dataloader objects
###-------------------------------------------------------------------------------------------------------------------

train_dataset = TensorDataset(torch.tensor(data_dict_train["mat"][:]).float(), torch.tensor(data_dict_train["enc_label_id"][:]), torch.tensor(data_dict_train["enc_task_id"][:]))
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

test_dataset = TensorDataset(torch.tensor(data_dict_test["mat"][:]).float(), torch.tensor(data_dict_test["enc_label_id"][:]), torch.tensor(data_dict_test["enc_task_id"][:]))
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)

  train_dataset = TensorDataset(torch.tensor(data_dict_train["mat"][:]).float(), torch.tensor(data_dict_train["enc_label_id"][:]), torch.tensor(data_dict_train["enc_task_id"][:]))


In [8]:
###-------------------------------------------------------------------------------------------------------------------
#         initializing model
###-------------------------------------------------------------------------------------------------------------------

model = MRIVisionTransformers(
        output_size = config["d_model_output"],
        input_size = config["d_model_input"],
        num_heads = config["num_heads"],
        dropout = config["dropout"],
        attention_dropout = config["attention_dropout"]
)

x = torch.randn(1, 400, 400)
y = model(x)

# x_si, x_td, attn_weights
print(y[0].size())
print(y[1].size())
print(y[2].size())

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 400, 400])


In [9]:
###-------------------------------------------------------------------------------------------------------------------
#         training
###-------------------------------------------------------------------------------------------------------------------

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

# change to cuda
device = "cpu"
training_loop(config["epochs"], model, train_loader, test_loader, criterion, optimizer, device, config)

Epoch: 1/100 - loss_total: 4.2294 - acc: 0.0000 - val-loss_total: 3.4689 - val-acc: 0.0000 (17.57s/epoch)
Epoch: 2/100 - loss_total: 2.2377 - acc: 0.0000 - val-loss_total: 2.9964 - val-acc: 0.0000 (13.21s/epoch)
Epoch: 3/100 - loss_total: 2.1823 - acc: 0.0000 - val-loss_total: 1.3685 - val-acc: 0.0000 (12.67s/epoch)
Epoch: 4/100 - loss_total: 1.1520 - acc: 0.0000 - val-loss_total: 1.7359 - val-acc: 0.0000 (11.70s/epoch)
Epoch: 5/100 - loss_total: 0.9093 - acc: 0.0000 - val-loss_total: 2.2340 - val-acc: 0.0000 (14.06s/epoch)
Epoch: 6/100 - loss_total: 0.9250 - acc: 0.0000 - val-loss_total: 2.1779 - val-acc: 0.0000 (13.16s/epoch)
Epoch: 7/100 - loss_total: 0.6000 - acc: 0.0000 - val-loss_total: 1.8633 - val-acc: 0.0000 (12.39s/epoch)


KeyboardInterrupt: 