In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
from tqdm import tqdm
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from thop import profile, clever_format
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

config = {
    'subjects_num': 12,
    'n_epochs': 50, 
    'batch_size': 64,
    'save_name': 'logs/EEG-Transformer-{epoch:02d}-{val_acc:.2f}',
    'log_path1': 'logs/EEG-Transformer_logs',
    'num_class': 2
}

isIntraSub = False

def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

class EEG_IntraSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        sub_list = [i for i in range(config['subjects_num'])]
        data = []
        label = []
        
        for i in sub_list:
            data_sub = np.load(path + f'sub_{i}_eeg.npy')
            label_sub = np.load(path + f'sub_{i}_labels.npy')
            data.extend(data_sub)
            label.extend(label_sub)
            
        data = np.array(data)
        label = np.array(label).flatten()
        
        shuffle_idx = np.random.permutation(len(data))
        data = data[shuffle_idx]
        label = label[shuffle_idx]
    
        if mode == 'train':
            data = data[:int(len(data)*0.8)]
            label = label[:int(len(label)*0.8)]
       
        elif mode == 'val':
            data = data[int(len(data)*0.8):int(len(data)*0.9)]
            label = label[int(len(label)*0.8):int(len(label)*0.9)]
        
        elif mode == 'test':
            data = data[int(len(data)*0.9):]
            label = label[int(len(label)*0.9):]
        
        self.data = torch.FloatTensor(data).unsqueeze(1)
        self.label = torch.LongTensor(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.label[index]
        
class EEG_InterSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        self.test_sub = test_sub
        
        if mode == 'train' or mode == 'val':
            train_sub = [i for i in range(config['subjects_num'])]
            train_sub.remove(test_sub)
            data = []
            label = []
            for i in train_sub:
                data_sub = np.load(path + f'sub_{i}_eeg.npy')
                label_sub = np.load(path + f'sub_{i}_labels.npy')
                data.extend(data_sub)
                label.extend(label_sub)
                
            data = np.array(data)
            label = np.array(label).flatten()

            shuffle_idx = np.random.permutation(len(data))
            data = data[shuffle_idx]
            label = label[shuffle_idx]
    
            if mode == 'train':
                data = data[:int(len(data)*0.8)]
                label = label[:int(len(label)*0.8)]
                
            elif mode == 'val':
                data = data[int(len(data)*0.8):]
                label = label[int(len(label)*0.8):]
                   
        elif mode == 'test':
            data = np.load(path + f'sub_{test_sub}_eeg.npy')
            label = np.load(path + f'sub_{test_sub}_labels.npy')

        self.data = torch.FloatTensor(data).unsqueeze(1)
        self.label = torch.LongTensor(label)      

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

def prep_dataloader(path, mode, batch_size, test_sub, isIntraSub = False, njobs=1):
    if isIntraSub:
        print("IntraSub")
        dataset = EEG_IntraSub_Dataset(path, mode, test_sub)
    else:
        print("InterSub")
        dataset = EEG_InterSub_Dataset(path, mode, test_sub)
        
    dataloader = DataLoader(dataset, batch_size, shuffle=(mode == 'train'), drop_last=False, num_workers=njobs,
                            pin_memory=True)
    return dataloader


In [None]:

class PatchEmbedding(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(1, 2, (1, 51), (1, 1)),
            nn.BatchNorm2d(2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(2, emb_size, (16, 5), stride=(1, 5)),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b = x.shape[0]
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.qkv = nn.ModuleList([nn.Linear(emb_size, emb_size) for _ in range(3)])
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        queries, keys, values = [rearrange(layer(x), "b n (h d) -> b h n d", h=self.num_heads) 
                               for layer in self.qkv]
        
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            energy.masked_fill_(~mask, float('-inf'))
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        
        out = torch.einsum('bhal, bhlv -> bhav', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.projection(out)

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return x + self.fn(x, **kwargs)

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size)
        )

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self, emb_size, num_heads=5, drop_p=0.5, forward_expansion=4):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(emb_size, forward_expansion, drop_p),
                nn.Dropout(drop_p)
            ))
        )

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes):
        super().__init__()
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )

    def forward(self, x):
        return self.clshead(x)

class ChannelAttention(nn.Module):
    def __init__(self, sequence_num=384, inter=30):
        super().__init__()
        self.extract_sequence = sequence_num // inter
        
        self.qk = nn.ModuleList([
            nn.Sequential(
                nn.Linear(17, 17),
                nn.LayerNorm(17),  
                nn.Dropout(0.3)
            ) for _ in range(2)
        ])
        
        self.projection = nn.Sequential(
            nn.Linear(17, 17),     
            nn.LayerNorm(17),     
            nn.Dropout(0.3)
        )
        
        self.dropout = nn.Dropout(0)
        self.pooling = nn.AvgPool2d((1, inter), (1, inter))
        
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
                    
    def forward(self, x):
        temp = rearrange(x, 'b o c s->b o s c')
        query, key = [rearrange(layer(temp), 'b o s c -> b o c s') 
                     for layer in self.qk]
        
        channel_query = self.pooling(query)
        channel_key = self.pooling(key)
        
        scaling = self.extract_sequence ** 0.5
        attention = torch.einsum('b o c s, b o m s -> b o c m', 
                               channel_query, channel_key) / scaling
        
        scores = self.dropout(F.softmax(attention, dim=-1))
        out = torch.einsum('b o c s, b o c m -> b o c s', x, scores)
        
        out = rearrange(out, 'b o c s -> b o s c')
        out = self.projection(out)
        return rearrange(out, 'b o s c -> b o c s')

class ViT(pl.LightningModule):
    def __init__(self, emb_size=60, depth=3, n_classes=2, **kwargs):
        super(ViT, self).__init__()      
        self.channel_attention = ResidualAdd(nn.Sequential(
            nn.LayerNorm(384),
            ChannelAttention(),
            nn.Dropout(0.5)
        ))
        self.patch_embedding = PatchEmbedding(emb_size)
        self.transformer = TransformerEncoder(depth, emb_size) 
        self.classifier = ClassificationHead(emb_size, n_classes)
        
        self._init_weights()
        
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.channel_attention(x)
        x = self.patch_embedding(x)
        x = self.transformer(x)
        x = self.classifier(x)
        
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer 
      
    def training_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('training_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        loss = {'loss': loss}
        return loss

    def validation_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('val_loss', loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)

    def test_step(self, batch):
        x, y = batch
        preds = self(x)
        
        y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
        acc = accuracy_score(y.cpu(), y_pre.cpu())
        pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
        recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
        f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

        self.log('test_acc', acc)
        self.log('test_pre', pre)
        self.log('test_recall', recall)
        self.log('test_f1', f1)
        
           
        return {'test_acc': acc, 'test_pre': pre, 'test_recall': recall, 'test_f1': f1} 


def predict(model, dataloader):
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            preds = model(x)
            y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
            acc = accuracy_score(y.cpu(), y_pre.cpu())
            pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
            recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
            f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

    return acc, pre, recall, f1
       
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=config['save_name'],
    save_top_k=1,
    mode='min',
    save_last=True
)

if __name__ == '__main__':
    tr_path = val_path = test_path =  "/home/jie/Program/872/Dataset/SEED-VIG-Subset/"
    device = get_device()
    isIntraSub=True
    
    
    
    AC,PR,RE,F1 = 0,0,0,0
    for test_sub in range(config['subjects_num']):
        tr_set = prep_dataloader(tr_path, 'train', config['batch_size'], test_sub, isIntraSub, njobs=6)
        val_set = prep_dataloader(val_path, 'val', config['batch_size'], test_sub, isIntraSub, njobs=6)
        test_set = prep_dataloader(test_path, 'test', config['batch_size'], test_sub, isIntraSub, njobs=1)
        model =  ViT().to(device)
        logger = TensorBoardLogger(config['log_path1'])#, config['log_path2'])
        trainer = Trainer(val_check_interval=1.0, max_epochs=config['n_epochs'], devices=[0], accelerator='gpu',
                        logger=logger,
                        callbacks=[
                            #EarlyStopping(monitor='val_loss', mode='min', check_on_train_epoch_end=True, patience=10, min_delta=1e-4),
                            checkpoint_callback
                        ]
                        )
        
        trainer.fit(model, train_dataloaders=tr_set, val_dataloaders=val_set)
        #trainer.save_checkpoint('FastAlertNet_final.ckpt')
        test_results = trainer.test(model, dataloaders=test_set)
       
        AC += test_results[0]['test_acc']
        PR += test_results[0]['test_pre']
        RE += test_results[0]['test_recall']
        F1 += test_results[0]['test_f1']
        
    AC /= config['subjects_num']
    PR /= config['subjects_num'] 
    RE /= config['subjects_num']
    F1 /= config['subjects_num']
    print(f"&{AC*100:.2f}",f"&{PR*100:.2f}",f"&{RE*100:.2f}",f"&{F1*100:.2f}")
        