In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
from tqdm import tqdm
from thop import profile, clever_format


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

def set_seed(seed=42):
    np.random.seed(seed)  # Set NumPy random seed
    torch.manual_seed(seed)  # Set PyTorch random seed
    torch.cuda.manual_seed(seed)  # Set random seed for current GPU
    torch.cuda.manual_seed_all(seed)  # Set random seed for all GPUs
    torch.backends.cudnn.deterministic = True  # Ensure deterministic convolution operations
    torch.backends.cudnn.benchmark = False  # Disable non-deterministic optimization

set_seed(42)  # Call function to set fixed random seed


config = {
    'subjects_num': 12,
    'n_epochs': 30, 
    'batch_size': 60,
    'save_name': 'logs/DualFusion-{epoch:02d}-{val_acc:.2f}',
    'log_path1': 'logs/DualFusion_logs',  # Modified
    'num_class': 2 # Modified, binary classification: 0-awake, 1-fatigue
}

isIntraSub = False  # Modified


def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'



class EEG_IntraSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        sub_list = [i for i in range(config['subjects_num'])]
        data = []
        label = []

        
        for i in sub_list:
            data_sub = np.load(path + f'sub_{i}_eeg.npy')
            label_sub = np.load(path + f'sub_{i}_labels.npy')
            data.extend(data_sub)
            label.extend(label_sub)
            
        data = np.array(data)
        label = np.array(label).flatten()
        
        # Generate random indices for synchronized shuffling
        shuffle_idx = np.random.permutation(len(data))
        data = data[shuffle_idx]
        label = label[shuffle_idx]
    
        if mode == 'train':
            data = data[:int(len(data)*0.8)]
            label = label[:int(len(label)*0.8)]
       
        elif mode == 'val':
            data = data[int(len(data)*0.8):int(len(data)*0.9)]
            label = label[int(len(label)*0.8):int(len(label)*0.9)]
        
        elif mode == 'test':
            data = data[int(len(data)*0.9):]
            label = label[int(len(label)*0.9):]
        
        self.data = torch.FloatTensor(data)
        self.label = torch.LongTensor(label)

    def __len__(self):
        return len(self.data)  

    def __getitem__(self, index):
        return self.data[index], self.label[index]
        
class EEG_InterSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        self.test_sub = test_sub
        
        if mode == 'train' or mode == 'val':
            train_sub = [i for i in range(config['subjects_num'])]
            train_sub.remove(test_sub)
            data = []
            label = []
            for i in train_sub:
                data_sub = np.load(path + f'sub_{i}_eeg.npy')
                label_sub = np.load(path + f'sub_{i}_labels.npy')
                data.extend(data_sub)
                label.extend(label_sub)
                
            data = np.array(data)
            label = np.array(label).flatten()
            # Generate random indices for synchronized shuffling
            shuffle_idx = np.random.permutation(len(data))
            data = data[shuffle_idx]
            label = label[shuffle_idx]
    
            if mode == 'train':
                data = data[:int(len(data)*0.8)]
                label = label[:int(len(label)*0.8)]
                
            elif mode == 'val':
                data = data[int(len(data)*0.8):]
                label = label[int(len(label)*0.8):]
                   
        
        elif mode == 'test':
            
            data = np.load(path + f'sub_{test_sub}_eeg.npy')
            label = np.load(path + f'sub_{test_sub}_labels.npy')

        
        self.data = torch.FloatTensor(data)
        self.label = torch.LongTensor(label)      
    def __len__(self):
        return len(self.data)  # Return total number of data points

    def __getitem__(self, index):
        return self.data[index], self.label[index]


def prep_dataloader(path, mode, batch_size, test_sub, isIntraSub = False, njobs=1):
    if isIntraSub:
        print("IntraSub")
        dataset = EEG_IntraSub_Dataset(path, mode, test_sub)
    else:
        print("InterSub")
        dataset = EEG_InterSub_Dataset(path, mode, test_sub)
        
    dataloader = DataLoader(dataset, batch_size, shuffle=(mode == 'train'), drop_last=False, num_workers=njobs,
                            pin_memory=True)
    return dataloader


In [None]:
class channel_MLP(pl.LightningModule):
    
    def __init__(self, num_channels=17, input_dim=384, hidden_dim=300, output_dim=30, final_output_dim=2, dropout_prob=0.3, activation='gelu'):
    
        super(channel_MLP, self).__init__()
        
        self.activation = nn.SiLU()
        

        # Define the MLP for each channel
        self.mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim*2, hidden_dim),
                nn.BatchNorm1d(hidden_dim),  # Add BatchNorm after the first Linear layer
                self.activation,
                nn.Dropout(p=dropout_prob),  # Add dropout
                nn.Linear(hidden_dim, output_dim),
                nn.BatchNorm1d(output_dim),  # Add BatchNorm after the second Linear layer
                self.activation
            )
            for _ in range(num_channels)
        ])
        
        self.fc1 = nn.Linear(510, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        """
        :param x: Input tensor of shape [batch_size, num_channels, input_dim]
        """
        batch_size, num_channels, input_dim = x.shape
        
        x_fft = torch.fft.fft(x,dim=-1)
        x_mag = torch.abs(x_fft)
        fusion_output = torch.cat((x_mag, x), dim=2)

        # Apply MLP for each channel
        channel_outputs = []
        for i, mlp in enumerate(self.mlps):
            channel_output = mlp(fusion_output[:, i, :])  # Shape: [batch_size, output_dim]
            channel_outputs.append(channel_output)
        
        aggregated_output = torch.cat(channel_outputs, dim=1)  
        out = self.fc1(aggregated_output)
        out = self.activation(out)
        out = self.fc2(out)
        return out

    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
      
    def training_step(self, batch):
        x, y = batch
        preds = self(x)
        #loss = self.arc_loss(preds, y)
        loss = F.cross_entropy(preds, y)
        self.log('training_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        loss = {'loss': loss}
        return loss

    def validation_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('val_loss', loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)

    def test_step(self, batch):
        x, y = batch
        preds = self(x)
        y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
        acc = accuracy_score(y.cpu(), y_pre.cpu())
        pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
        recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
        f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

        self.log('test_acc', acc)
        self.log('test_pre', pre)
        self.log('test_recall', recall)
        self.log('test_f1', f1)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=config['save_name'],
    save_top_k=1,
    mode='min',
    save_last=True
)

if __name__ == '__main__':
    tr_path = val_path = test_path =  "/home/jie/Program/872/Dataset/SEED-VIG-Subset/"
    device = get_device()
    isIntraSub = False
    model = channel_MLP()
    input = torch.randn(1, 17, 384)
    flops, params = profile(model, inputs=(input,))
    flops, params = clever_format([flops, params], "%.3f")
    print("\031[42m" + f"FLOPs: {flops}, Parameters: {params}" + "\031[42m")
    
    for test_sub in range(config['subjects_num']):
        print(f"Testing subject {test_sub}")
        tr_set = prep_dataloader(tr_path, 'train', config['batch_size'], test_sub, isIntraSub, njobs=6)
        val_set = prep_dataloader(val_path, 'val', config['batch_size'], test_sub, isIntraSub, njobs=6)
        test_set = prep_dataloader(test_path, 'test', config['batch_size'], test_sub, isIntraSub, njobs=1)
        model =  channel_MLP().to(device)
        logger = TensorBoardLogger(config['log_path1'])#, config['log_path2'])
        trainer = Trainer(val_check_interval=1.0, max_epochs=config['n_epochs'], devices=[0], accelerator='gpu',
                        logger=logger,
                        callbacks=[
                            #EarlyStopping(monitor='val_loss', mode='min', check_on_train_epoch_end=True, patience=10, min_delta=1e-4),
                            checkpoint_callback
                        ]
                        )
        
        trainer.fit(model, train_dataloaders=tr_set, val_dataloaders=val_set)
        #trainer.save_checkpoint('FastAlertNet_final.ckpt')

        test_results = trainer.test(model, dataloaders=test_set)
        
        