Author:
        
        PARK, JunHo, junho@ccnets.org

        
        KIM, JeongYoong, jeongyoong@ccnets.org
        
    COPYRIGHT (c) 2024. CCNets. All Rights reserved.

In [1]:
import sys
path_append = "../"
sys.path.append(path_append)  # Go up one directory from where you are.

import torch
import pandas as pd
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler

In [2]:
from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from nn.utils.init import set_random_seed
set_random_seed(0)

from trainer_hub import TrainerHub

In [3]:
dataroot = path_append + "../data/eeg/sub-01_ses-01.csv"
df = pd.read_csv(dataroot)
df

Unnamed: 0,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,...,D24,D25,D26,D27,D28,D29,D30,D31,D32,event
0,3549.790315,4533.538497,3619.665186,3077.291188,-1380.325575,6120.066816,-4072.820600,-2256.511456,1820.012261,-2815.635423,...,-7240.845997,7034.252627,8458.062496,5905.223463,6147.660515,2458.073582,-7465.876831,-3604.133966,-5445.224315,5
1,3551.227812,4534.850995,3622.540181,3077.322438,-1377.575581,6123.066810,-4069.851856,-2252.167714,1825.168502,-2803.072947,...,-7227.283522,7039.627617,8463.874985,5911.598451,6153.504254,2463.354822,-7461.033090,-3594.258985,-5435.693082,5
2,3556.727802,4539.850986,3629.040169,3081.978679,-1370.419344,6130.348047,-4063.508118,-2249.292720,1828.074746,-2804.041695,...,-7227.158522,7048.502600,8473.562467,5921.348433,6163.004236,2469.854810,-7460.470591,-3591.540240,-5433.568086,5
3,3557.915300,4541.225983,3628.540169,3083.197427,-1372.263090,6130.410547,-4062.070620,-2251.667715,1825.856000,-2803.572946,...,-7224.189777,7042.346362,8464.593734,5917.660940,6160.972990,2467.011066,-7458.158095,-3597.008980,-5437.474329,5
4,3553.352808,4535.757243,3622.477681,3079.572434,-1377.763080,6125.598056,-4066.570612,-2255.136459,1821.981008,-2808.041687,...,-7219.971035,7044.658857,8466.843729,5914.848445,6156.785498,2466.948566,-7457.501846,-3585.821500,-5428.630595,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
234436,2357.917517,4456.663639,4075.414344,1823.043506,-766.732959,4533.569747,-2612.042050,184.640283,2023.293136,-1143.232264,...,-6307.816471,6492.316128,4752.756842,4841.569178,5098.693703,3744.352455,-12741.898332,-2260.261450,-4432.538686,11
234437,2352.636277,4451.976148,4071.008102,1817.887265,-769.482954,4530.913502,-2614.854544,183.546535,2021.199389,-1144.044762,...,-6303.441479,6490.691131,4749.506848,4839.506682,5096.131208,3742.602459,-12744.398327,-2258.417703,-4432.788685,11
234438,2359.573764,4459.976133,4077.320590,1823.699754,-761.951718,4538.132239,-2605.792061,191.702770,2028.980625,-1134.451030,...,-6292.753999,6499.097365,4758.819331,4846.725419,5105.256191,3752.164941,-12733.742097,-2251.948965,-4424.444951,11
234439,2365.261254,4466.038622,4083.258079,1831.824739,-754.389232,4542.850980,-2599.667072,198.015258,2036.261862,-1126.576044,...,-6286.066511,6509.972345,4768.381813,4852.506658,5111.506179,3758.696179,-12725.992112,-2243.886480,-4415.101218,11


In [4]:
mm = StandardScaler()
df.iloc[:,:-1] = mm.fit_transform(df.iloc[:,:-1])

tmp = df[df.event==5]
df = pd.concat([df, tmp, tmp])
df.reset_index(drop=True, inplace=True)
df.event = df.event.apply(lambda x: x if x<5 else x-1)

In [5]:
from torch.utils.data import Dataset, DataLoader
from random import shuffle

class EEG_Dataset(Dataset):
    def __init__(self, df, indices, **kwargs):
        self.df = df
        self.indices = indices

    def __len__(self):
        return len(self.indices)
        # return len(self.df)-self.window_size+1
    
    def __getitem__(self, idx):
        seq = self.df.loc[self.indices[idx]]
        X,y = seq.values[:,:-1], seq.values[:,-1]
        X = torch.from_numpy(X).float()
        y = torch.from_numpy(y).long()
        y = torch.nn.functional.one_hot(y, num_classes=13)
        
        return X, y

In [6]:
window_size = 128
indices = [list(df.index[i:i + window_size]) for i in range(0, len(df), window_size//2) if len(df.index[i:i + window_size]) == window_size and len(df['event'][i:i+window_size].unique()) == 1]
shuffle(indices)
train_indices, test_indices = train_test_split(indices, test_size=0.1, shuffle=False)

trainset = EEG_Dataset(df, train_indices)
testset = EEG_Dataset(df, test_indices)

In [7]:
data_config = DataConfig(dataset_name = 'eeg-sub-01', task_type='multi_class_classification', obs_shape=[128], label_size=13)

#  Set training configuration from the AlgorithmConfig class, returning them as a Namespace object.
ml_params = MLParameters()

In [8]:
ml_params.core_model_name = 'gpt' 
ml_params.encoder_model_name = 'none'
ml_params.training.max_epoch = 10
ml_params.training.batch_size = 64

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize the TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags
trainer_hub = TrainerHub(ml_params, data_config, device, use_print=True, use_wandb=False) 

In [9]:
trainer_hub.train(trainset, testset)

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Iterations:   0%|          | 0/48 [00:00<?, ?it/s]

In [None]:
trainer_hub.test(testset)