# Preparation

## Requirment

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split, Subset, ConcatDataset
import torchaudio
from torchaudio import transforms
import math, copy, time, random, os
import numpy as np
import scipy.io as sio
import scipy.sparse as sp
import pandas as pd
import itertools
import matplotlib.pyplot as plt
from einops import rearrange
from tqdm import tqdm
import shutil
import re
from scipy.io import loadmat
from sklearn.model_selection import KFold

## Config

In [2]:
class Config():
    root = '/home/test/Desktop/python/EEG_data/AAD_dataset/AAD_DTU/Processed/Dataset'
    root_feature = '/home/test/Desktop/python/2023/dr/Auditory_Attention_Detection/Features/DTU_1s'
    file_name = 'S1_Dataset_1s.npz'
    feature_name = 'S1_Features.npz'

    current_fold = 5
    tim_len = 1
    chan_num = 64
    band_num = 5
    inc = 32
    eeg_len = 128
    audio_len = 16000
    mode = 'cross_trails'
    seed = 3407
    batch_size = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weight_decay = 1e-4
    lr = 1e-4
    patience = 2
    factor = 0.8
    epochs = 30
    dropout = 0.1

## Utils

In [3]:
def print_size(net, keyword=None):
    if net is not None and isinstance(net, torch.nn.Module):
        module_parameters = filter(lambda p: p.requires_grad, net.parameters())
        params = sum([np.prod(p.size()) for p in module_parameters])
        
        print("{} Parameters: {:.6f}M".format(
            net.__class__.__name__, params / 1e6), flush=True, end="; ")
        
        if keyword is not None:
            keyword_parameters = [p for name, p in net.named_parameters() if p.requires_grad and keyword in name]
            params = sum([np.prod(p.size()) for p in keyword_parameters])
            print("{} Parameters: {:.6f}M".format(
                keyword, params / 1e6), flush=True, end="; ")
        
        print(" ")

In [4]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

In [5]:
def pre_evaluate(dataloader, model):
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for eeg, eeg_feature, wavA, wavB, event in dataloader:
            eeg = eeg.to(Config.device)
            eeg_feature = eeg_feature.to(Config.device)
            wavA = wavA.to(Config.device)
            wavB = wavB.to(Config.device)
            event = event.to(Config.device)
            event = event.squeeze()
            pred = model(eeg, eeg_feature, wavA, wavB)
            _, predicted = torch.max(pred, 1)
            correct += (predicted == event).sum().item()
    accuracy = correct / len(dataloader.dataset)
    return accuracy

# Model

## Temporal block

In [6]:
class down_sample(nn.Module):
    def __init__(self, inc, kernel_size, stride, padding):
        super(down_sample, self).__init__()
        self.conv = nn.Conv2d(in_channels = inc, out_channels = inc, kernel_size = (1, kernel_size), stride = (1, stride), padding = (0, padding), bias = False)
        self.bn = nn.BatchNorm2d(inc) 
        self.elu = nn.ELU(inplace = False)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain=1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        output = self.elu(self.bn(self.conv(x)))
        return output

In [7]:
class Residual_Block(nn.Module): 
    def __init__(self, inc, outc, groups = 1):
        super(Residual_Block, self).__init__()
        if inc is not outc:
            self.conv_expand = nn.Conv2d(in_channels = inc, out_channels = outc, kernel_size = 1, 
                                       stride = 1, padding = 0, groups = groups, bias = False)
        else:
            self.conv_expand = None
          
        self.conv1 = nn.Conv2d(in_channels = inc, out_channels = outc, kernel_size = (1, 3), 
                               stride = 1, padding = (0, 1), groups = groups, bias = False)
        self.bn1 = nn.BatchNorm2d(outc)
        self.conv2 = nn.Conv2d(in_channels = outc, out_channels = outc, kernel_size = (1, 3), 
                               stride = 1, padding = (0, 1), groups = groups, bias = False)
        self.bn2 = nn.BatchNorm2d(outc)
        self.elu = nn.ELU(inplace = False)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain = 1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x): 
        if self.conv_expand is not None:
            identity_data = self.conv_expand(x)
        else:
            identity_data = x
        output = self.bn1(self.conv1(x))
        output = self.bn2(self.conv2(output))
        output = torch.add(output,identity_data)
        return output 

In [8]:
class input_layer(nn.Module):
    def __init__(self, outc, groups):
        super(input_layer, self).__init__()
        self.conv_input = nn.Conv2d(in_channels = 1, out_channels = outc, kernel_size = (1, 3), 
                                    stride = 1, padding = (0, 1), groups = groups, bias = False)
        self.bn_input = nn.BatchNorm2d(outc) 
        self.elu = nn.ELU(inplace = False)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain = 1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        output = self.bn_input(self.conv_input(x))
        return output

In [9]:
def embedding_network(input_block, Residual_Block, num_of_layer, outc, groups = 1):
    layers = []
    layers.append(input_block(outc,groups=groups))
    for i in range(0, num_of_layer):
        layers.append(Residual_Block(inc = int(math.pow(2, i)*outc), outc = int(math.pow(2, i+1)*outc),
                                     groups = groups))
    return nn.Sequential(*layers) 

In [10]:
class Multi_Scale_Temporal_Block(nn.Module):
    def __init__(self, outc, num_of_layer = 1):
        super().__init__() 
        self.num_of_layer = num_of_layer
        self.embedding = embedding_network(input_layer, Residual_Block, num_of_layer = num_of_layer, outc = outc, groups=1)    

        self.downsampled1 = down_sample(outc*int(math.pow(2, num_of_layer))+1, 4, 4, 0)
        self.downsampled2 = down_sample(outc*int(math.pow(2, num_of_layer))+1, 8, 8, 0)
        self.downsampled3 = down_sample(outc*int(math.pow(2, num_of_layer))+1, 16, 16, 0)
        self.downsampled4 = down_sample(outc*int(math.pow(2, num_of_layer))+1, 32, 32, 0)
        self.downsampled5 = down_sample(outc*int(math.pow(2, num_of_layer))+1, 32, 32, 0)

    def forward(self, x):

        embedding_x = self.embedding(x)
        cat_x = torch.cat((embedding_x, x), 1)

        downsample1 = self.downsampled1(cat_x)
        downsample2 = self.downsampled2(cat_x)
        downsample3 = self.downsampled3(cat_x)
        downsample4 = self.downsampled4(cat_x)
        downsample5 = self.downsampled5(cat_x)

        temporal_fe = torch.concat((downsample1,downsample2,downsample3,downsample4,downsample5),3)

        return temporal_fe

In [11]:
class Temporal_Block(nn.Module):
    def __init__(self):
        super().__init__() 
        self.mstblock1 = Multi_Scale_Temporal_Block(outc=2)
        self.mstblock2 = Multi_Scale_Temporal_Block(outc=2)
        self.mstblock3 = Multi_Scale_Temporal_Block(outc=2)
        self.mstblock4 = Multi_Scale_Temporal_Block(outc=2)
        self.mstblock5 = Multi_Scale_Temporal_Block(outc=2)

        self.fc = nn.Linear(640,256)

    def forward(self,x):

        t_fe1 = self.mstblock1(x[:,0,:,:].unsqueeze(1))
        t_fe2 = self.mstblock2(x[:,1,:,:].unsqueeze(1))
        t_fe3 = self.mstblock3(x[:,2,:,:].unsqueeze(1))
        t_fe4 = self.mstblock4(x[:,3,:,:].unsqueeze(1))
        t_fe5 = self.mstblock5(x[:,4,:,:].unsqueeze(1))
        t_fe = torch.cat((t_fe1,t_fe2,t_fe3,t_fe4,t_fe5),1)

        return t_fe

## GCN

In [12]:
class Electrodes:
    def __init__(self):
        self.positions_3d = np.array([
            [-0.029338731, 0.09029533, -0.003315452],
            [-0.055805583, 0.076809795, -0.003315452],
            [-0.038593441, 0.082763901, 0.026185549],
            [-0.027261703, 0.067475084, 0.061064823],
            [-0.051775707, 0.063937674, 0.0475],
            [-0.06925438, 0.060201914, 0.024587809],
            [-0.076809795, 0.055805583, -0.003315452],
            [-0.09029533, 0.029338731, -0.003315452],
            [-0.084349336, 0.032378676, 0.029356614],
            [-0.064255824, 0.034165428, 0.061064823],
            [-0.035597403, 0.035597403, 0.080564569],
            [-0.037119457, 0, 0.087447961],
            [-0.068337281, 0, 0.065992545],
            [-0.088690141, 0, 0.034044955],
            [-0.094942129, 0, -0.003315452],
            [-0.09029533, -0.029338731, -0.003315452],
            [-0.084349336, -0.032378676, 0.029356614],
            [-0.064255824, -0.034165428, 0.061064823],
            [-0.035597403, -0.035597403, 0.080564569],
            [-0.027261703, -0.067475084, 0.061064823],
            [-0.051775707, -0.063937674, 0.0475],
            [-0.06925438, -0.060201914, 0.024587809],
            [-0.076809795, -0.055805583, -0.003315452],
            [-0.069655748, -0.050607863, -0.040148735],
            [-0.055805583, -0.076809795, -0.003315452],
            [-0.038593441, -0.082763901, 0.026185549],
            [-0.029338731, -0.09029533, -0.003315452],
            [5.27206E-18, -0.08609924, -0.040148735],
            [5.81353E-18, -0.094942129, -0.003315452],
            [5.4307E-18, -0.088690141, 0.034044955],
            [4.18445E-18, -0.068337281, 0.065992545],
            [2.27291E-18, -0.037119457, 0.087447961],
            [5.81353E-18, 0.094942129, -0.003315452],
            [0.029338731, 0.09029533, -0.003315452],
            [0.055805583, 0.076809795, -0.003315452],
            [0.038593441, 0.082763901, 0.026185549],
            [5.4307E-18, 0.088690141, 0.034044955],
            [4.18445E-18, 0.068337281, 0.065992545],
            [0.027261703, 0.067475084, 0.061064823],
            [0.051775707, 0.063937674, 0.0475],
            [0.06925438, 0.060201914, 0.024587809],
            [0.076809795, 0.055805583, -0.003315452],
            [0.09029533, 0.029338731, -0.003315452],
            [0.084349336, 0.032378676, 0.029356614],
            [0.064255824, 0.034165428, 0.061064823],
            [0.035597403, 0.035597403, 0.080564569],
            [2.27291E-18, 0.037119457, 0.087447961],
            [0, 0, 0.095],
            [0.037119457, 0, 0.087447961],
            [0.068337281, 0, 0.065992545],
            [0.088690141, 0, 0.034044955],
            [0.094942129, 0, -0.003315452],
            [0.09029533, -0.029338731, -0.003315452],
            [0.084349336, -0.032378676, 0.029356614],
            [0.064255824, -0.034165428, 0.061064823],
            [0.035597403, -0.035597403, 0.080564569],
            [0.027261703, -0.067475084, 0.061064823],
            [0.051775707, -0.063937674, 0.0475],
            [0.06925438, -0.060201914, 0.024587809],
            [0.076809795, -0.055805583, -0.003315452],
            [0.069655748, -0.050607863, -0.040148735],
            [0.055805583, -0.076809795, -0.003315452],
            [0.038593441, -0.082763901, 0.026185549],
            [0.029338731, -0.09029533, -0.003315452]])
        self.positions_3d = np.int_(self.positions_3d * 1000)
        self.channel_names = np.array([
            'Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7',
            'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz',
            'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz',
            'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'])
        self.channel_to_index = {name: idx for idx, name in enumerate(self.channel_names)}
        self.edge_importance = nn.Parameter(
            torch.eye(64, device='cuda') * 0.1 + torch.randn(64, 64, device='cuda') * 0.01,
            requires_grad=True
        )


    def get_adjacency_matrix(self, calibration_constant=6, active_threshold=0.1):
        distance_3d_matrix = np.linalg.norm(self.positions_3d[:, np.newaxis] - self.positions_3d, axis=-1)
        with np.errstate(divide='ignore', invalid='ignore'):
            distance_3d_matrix = np.where(distance_3d_matrix != 0, calibration_constant / distance_3d_matrix, 0)

        local_conn_mask = distance_3d_matrix > active_threshold
        local_connections = distance_3d_matrix * local_conn_mask

        np.fill_diagonal(local_connections, 0)
        min_conn = local_connections.min()
        max_conn = local_connections.max()
        if max_conn > min_conn:
            adj_matrix = (local_connections - min_conn) / (max_conn - min_conn)
        else:
            adj_matrix = np.zeros_like(local_connections)
        np.fill_diagonal(adj_matrix, 1)


        return adj_matrix
    def get_adj(self):
        base_adj_matrix = self.get_adjacency_matrix()
        A = torch.tensor(base_adj_matrix, dtype=torch.float32, device=self.edge_importance.device) + self.edge_importance
        A = (A + A.t()) / 2.0
        I = torch.eye(A.size(-1), device=A.device)
        A = A + I
        rowsum = A.sum(dim=-1, keepdim=True)
        d_inv_sqrt = rowsum.pow(-0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0 
        A = d_inv_sqrt * A * d_inv_sqrt.transpose(-1, -2)

        return A

In [13]:
class GATENet(nn.Module):
    def __init__(self, inc, reduction_ratio=128):
        super(GATENet, self).__init__()
        self.fc = nn.Sequential(nn.Linear(inc, inc // reduction_ratio, bias=False),
                                nn.ELU(inplace=False),
                                nn.Linear(inc // reduction_ratio, inc, bias=False),
                                nn.Tanh(),
                                nn.ReLU(inplace=False))

    def forward(self, x):
        y = self.fc(x)
        return y


class resGCN(nn.Module):
    def __init__(self, inc, outc,band_num):
        super(resGCN, self).__init__()
        self.GConv1 = nn.Conv2d(in_channels=inc,
                                out_channels=outc,
                                kernel_size=(1, 3),
                                stride=(1, 1),
                                padding=(0, 0),
                                groups=band_num,
                                bias=False)
        self.bn1 = nn.BatchNorm2d(outc)
        self.GConv2 = nn.Conv2d(in_channels=outc,
                                out_channels=outc,
                                kernel_size=(1, 1),
                                stride=(1, 1),
                                padding=(0, 1),
                                groups=band_num,
                                bias=False)
        self.bn2 = nn.BatchNorm2d(outc)
        self.ELU = nn.ELU(inplace=False)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain=1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, x_p, L):
        x = self.bn2(self.GConv2(self.ELU(self.bn1(self.GConv1(x)))))
        y = torch.einsum('bijk,kp->bijp', (x, L))
        y = self.ELU(torch.add(y, x_p))
        return y



class HGCN(nn.Module):
    def __init__(self, dim, chan_num, band_num):
        super(HGCN, self).__init__()
        self.chan_num = chan_num
        self.dim = dim
        self.electrodes = Electrodes() 
        self.resGCN = resGCN(inc=dim * band_num, outc=dim * band_num, band_num=band_num)
        self.ELU = nn.ELU(inplace=False)
        self.initialize()

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain=1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, A_ds):  
        A_ds_tensor = torch.tensor(A_ds, dtype=torch.float32).to(x.device)  
        edge_importance = self.electrodes.edge_importance
        A = A_ds_tensor + edge_importance
        L = torch.einsum('ik,kp->ip', 
                         (A, torch.diag(torch.reciprocal(torch.sum(A, dim=-1)))))
        x = x.permute(0, 1, 3, 2)
        G = self.resGCN(x, x, L).contiguous()
        return G.squeeze(2).transpose(2, 1)

In [14]:
class SGCN(nn.Module):
    def __init__(self, dim, device=Config.device):
        super().__init__()
        self.electrodes = Electrodes() 
        self.GATENet = GATENet(Config.chan_num * Config.chan_num, reduction_ratio=Config.chan_num)
        self.gcn = HGCN(dim=dim, chan_num=Config.chan_num, band_num=Config.band_num)

    def forward(self, x):
        A_ds = self.electrodes.get_adjacency_matrix()  
        feat = self.gcn(x, A_ds)

        return feat

## EEGEncoder

In [15]:
class EEGEncoder(nn.Module):
    def __init__(self):
        super().__init__() 
        self.t_block = Temporal_Block()
        self.tgcn = SGCN(dim=5)
        self.dgcn = SGCN(dim=1)
        self.pgcn = SGCN(dim=1)

        self.chanattn = nn.Conv2d(in_channels=int(Config.chan_num*Config.tim_len),out_channels=Config.inc,kernel_size=1)

        self.proj1 = nn.Linear(Config.inc*25,int(Config.chan_num*Config.tim_len))
        self.proj2 = nn.Linear(5,int(Config.chan_num*Config.tim_len*0.5))
        self.proj3 = nn.Linear(5,int(Config.chan_num*Config.tim_len*0.5))
        self.dpout = nn.Dropout(p=0.5)

    def forward(self, filtered, prefeat):
        filtered = filtered.unsqueeze(1).expand(-1, 5, -1, -1)
        temporal_fe = self.t_block(filtered)
        t_fe = self.tgcn(temporal_fe)
        t_fe = self.chanattn(t_fe).permute(0,3,2,1).contiguous()
        t_fe = self.proj1(t_fe.view(prefeat.shape[0],prefeat.shape[1],-1))
        d_fe = self.dgcn(prefeat[:,:,:5].permute(0,2,1).unsqueeze(3))
        p_fe = self.pgcn(prefeat[:,:,5:].permute(0,2,1).unsqueeze(3))
        p_fe = self.proj2(p_fe)
        d_fe = self.proj3(d_fe)
        fe = torch.cat((t_fe,d_fe,p_fe),dim=2)
        return fe

## AudioEncoder

In [16]:
class AudioEncoder(nn.Module):
    def __init__(self, in_channels=1, out_channels=Config.chan_num, input_length=int(Config.audio_len*Config.tim_len), target_length=64):
        super(AudioEncoder, self).__init__()
        self.downsample_factor = input_length // target_length  
        self.conv1 = nn.Conv1d(in_channels, 16, kernel_size=3, stride=5, padding=1)
        self.bn1 = nn.BatchNorm1d(16)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=5, padding=1)
        self.bn2 = nn.BatchNorm1d(32)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv3 = nn.Conv1d(32, out_channels, kernel_size=3, stride=5, padding=1)
        self.bn3 = nn.BatchNorm1d(out_channels)
        self.relu3 = nn.ReLU(inplace=True)
        
    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        
        return x 

## Fusion

In [17]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        output = torch.matmul(attn, v) 
        
        return output, attn

class ConvCrossAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, in_ch,
                 kernel_size, dilation, dropout=0.1):
        super(ConvCrossAttention, self).__init__()
        self.n_head = n_head
        self.in_ch = in_ch
        head_dim = in_ch // n_head
        self.d_k = head_dim
        self.d_v = head_dim

        self.q_proj = nn.Linear(in_ch, n_head * head_dim)
        self.k_proj = nn.Linear(in_ch, n_head * head_dim)
        self.v_proj = nn.Linear(in_ch, n_head * head_dim)

        self.out_proj = nn.Linear(n_head * head_dim, in_ch)

        self.attention = ScaledDotProductAttention(temperature=head_dim ** 0.5)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.GroupNorm(1, in_ch, eps=1e-8)

    def forward(self, q, k, v, mask=None):
        residual = v  
        q = q.transpose(1, 2) 
        k = k.transpose(1, 2) 
        v = v.transpose(1, 2)  

        B, len_q, _ = q.size()
        len_k = k.size(1)
        len_v = v.size(1)

        n_head, d_k, d_v = self.n_head, self.d_k, self.d_v

        q = self.q_proj(q).view(B, len_q, n_head, d_k) 
        k = self.k_proj(k).view(B, len_k, n_head, d_k)  
        v = self.v_proj(v).view(B, len_v, n_head, d_v)  

        q = q.transpose(1, 2) 
        k = k.transpose(1, 2)  
        v = v.transpose(1, 2) 

        if mask is not None:
            mask = mask.unsqueeze(1)  
        out, attn = self.attention(q, k, v, mask=mask)  
        out = out.transpose(1, 2).contiguous().view(B, len_q, n_head * d_v) 
        out = self.out_proj(out)  
        out = out.transpose(1, 2) 
        out = self.dropout(out)
        out = out + residual
        out = self.layer_norm(out)

        return out


class MultiLayerCrossAttention(nn.Module):
    def __init__(self, input_size=int(Config.eeg_len * Config.tim_len),
                 layer=4, in_ch=Config.chan_num, kernel_size=3, dilation=1):
        super(MultiLayerCrossAttention, self).__init__()
        self.layer = layer
        self.in_ch = in_ch
        self.spike_encoder = nn.ModuleList()
        self.LayernormList_spike = nn.ModuleList()
        self.projection = nn.Conv1d(in_ch * 2, in_ch, kernel_size, padding='same')
        self.layernorm_out = nn.GroupNorm(1, in_ch, eps=1e-8)
        for _ in range(layer):
            self.LayernormList_spike.append(nn.GroupNorm(1, in_ch, eps=1e-8))
        for _ in range(layer):
            self.spike_encoder.append(
                ConvCrossAttention(
                    n_head=1,
                    d_model=input_size,  
                    d_k=input_size,
                    d_v=input_size,
                    in_ch=in_ch,
                    kernel_size=kernel_size,
                    dilation=dilation,
                )
            )

    def forward(self, spike, audio):
        out_spike = spike                          
        skip_spike = torch.zeros_like(spike)       
        residual_spike = spike

        for i in range(self.layer):
            out_spike = self.spike_encoder[i](out_spike, audio, audio)
            out_spike = out_spike + residual_spike
            out_spike = self.LayernormList_spike[i](out_spike)

            residual_spike = out_spike
            skip_spike = skip_spike + out_spike

        out = torch.cat((skip_spike, spike), dim=1)  
        out = self.projection(out)                   
        out = self.layernorm_out(out)
        return out

## Classifier

In [18]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.input_stack = nn.Sequential(
            nn.Linear(int(Config.eeg_len*Config.tim_len*2), 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )
        self.output_stack = nn.Sequential(
            nn.Linear(4096, 256),
            nn.ReLU(),
            nn.Linear(256, 2),
        )
        
    def forward(self, x):
        x_forward = self.input_stack(x)
        x_forward = x_forward.view(x_forward.shape[0], -1)
        output = self.output_stack(x_forward)
        return output

## GCANet

In [19]:
class GCANet(nn.Module):
    def __init__(self):
        super(GCANet, self).__init__()
        self.eeg_encoder = EEGEncoder()
        self.audio_encoder = AudioEncoder()
        self.fusion = MultiLayerCrossAttention()
        self.classifier = Classifier()
        

    def forward(self, eeg, eeg_feature, waveA, waveB):
        eeg_feature = self.eeg_encoder(eeg, eeg_feature)
        waveA_feature = self.audio_encoder(waveA)
        waveB_feature = self.audio_encoder(waveB)
        aligned_waveA_feature = self.fusion(eeg_feature, waveA_feature)
        aligned_waveB_feature = self.fusion(eeg_feature, waveB_feature)
        combined_feature = torch.cat((aligned_waveA_feature, aligned_waveB_feature), dim=-1)

        classifier = self.classifier(combined_feature)
        return classifier

# Dataloader

## Seed

In [20]:
def set_seed(seed):
    torch.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)  
    random.seed(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.deterministic = False 
    torch.backends.cudnn.benchmark = False  
    
def collate_fn(item):
    eeg, eeg_feature, wavA, wavB, event = zip(*item)
    return torch.stack(eeg), torch.stack(eeg_feature), torch.stack(wavA), torch.stack(wavB), torch.stack(event)

## Cross trails

In [21]:
def Dataset_cross_trails(root, file_name, root_feature, feature_name, batch_size, fold_num=None):
    set_seed(seed=Config.seed)
    TotalDataset = EEGDataset(root, file_name, root_feature, feature_name)
    total_len = len(TotalDataset)
    kf = KFold(n_splits=5, shuffle=False)
    fold_dataloaders = []

    for fold_idx, (train_indices, val_indices) in enumerate(kf.split(range(total_len))):
        if fold_num is not None and (fold_idx != (fold_num-1)): 
            continue

        Train_dataset = Subset(TotalDataset, train_indices)
        Valid_dataset = Subset(TotalDataset, val_indices)
        
        kwargs = {"batch_size": batch_size, "num_workers": 4, "pin_memory": False, "drop_last": False}
        
        Train_dataloader = DataLoader(Train_dataset, collate_fn=collate_fn, shuffle=True, **kwargs)
        Valid_dataloader = DataLoader(Valid_dataset, collate_fn=collate_fn, shuffle=False, **kwargs)
        
        if fold_num is not None:
            return (Train_dataloader, Valid_dataloader) 
        else:
            fold_dataloaders.append((Train_dataloader, Valid_dataloader))
    
    return fold_dataloaders if fold_num is None else None

# Pretrain

## Train epoch

In [22]:
def preTrain_epoch(model, train_loader, optimizer, lr_scheduler, step):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    criterion = nn.CrossEntropyLoss()
    for eeg, eeg_feature, wavA, wavB, event in tqdm_object:
        eeg = eeg.to(Config.device)
        eeg_feature = eeg_feature.to(Config.device)
        wavA = wavA.to(Config.device)
        wavB = wavB.to(Config.device)
        event = event.to(Config.device).squeeze(-1)
        output = model(eeg, eeg_feature, wavA, wavB)
        loss= criterion(output, event)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        count = eeg.size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter

def preVal_epoch(model, valid_loader):
    loss_meter = AvgMeter()
    criterion = nn.CrossEntropyLoss()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for eeg, eeg_feature, wavA, wavB, event in tqdm_object:
        eeg = eeg.to(Config.device)
        eeg_feature = eeg_feature.to(Config.device)
        wavA = wavA.to(Config.device)
        wavB = wavB.to(Config.device)
        event = event.to(Config.device).squeeze(-1)
        output = model(eeg, eeg_feature, wavA, wavB)
        loss= criterion(output, event)

        count = eeg.size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter

## train_and_save_model

In [23]:
def train_and_save_model(config):
    
    train_dataloader, valid_dataloader = Dataset_cross_trails(config.root, config.file_name, config.root_feature, config.feature_name, 
                                                                  config.batch_size, config.current_fold)
        
    save_dir = f"/DTU_{Config.mode}_{Config.tim_len}s/fold_{config.current_fold}"

    sub_dir_name = config.file_name.split('_')[0]
    final_save_dir = os.path.join(save_dir, sub_dir_name)
    if not os.path.exists(final_save_dir):
        os.makedirs(final_save_dir)

    model = GCANet().to(config.device)

    params = [{"params": model.parameters(),"lr": config.lr, "weight_decay": config.weight_decay}]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=config.patience, factor=config.factor
    )
    step = "epoch"
    best_loss = float('inf')
    acc_max = 0
    
    for epoch in range(config.epochs):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = preTrain_epoch(model, train_dataloader, optimizer, lr_scheduler, step)
        model.eval()
        with torch.no_grad():
            train_acc = pre_evaluate(train_dataloader, model)
            print(f"train accuracy: {train_acc:.6f}")
            valid_loss = preVal_epoch(model, valid_dataloader)
            val_acc = pre_evaluate(valid_dataloader, model)
            print(f"valid accuracy: {val_acc:.6f}")
            if val_acc > acc_max:
                acc_max = val_acc
                model_filename = f"{epoch + 1}_best.pt"
                save_path = os.path.join(final_save_dir, model_filename)
                torch.save(model.state_dict(), save_path)
                print(f"Saved Best Model as {save_path}!")

        lr_scheduler.step(valid_loss.avg)
    print_size(model)
    print(f"Final accuracy max for {config.file_name}: {acc_max:.3f}")

    best_model_files = [f for f in os.listdir(final_save_dir) if f.endswith('_best.pt')]
    best_epoch = max(int(f.split('_')[0]) for f in best_model_files)
    best_model_filename = f"{best_epoch}_best.pt"
    best_model_path = os.path.join(final_save_dir, best_model_filename)
    new_filename = f"{Config.mode}_{sub_dir_name.lower()}_best.pt"
    new_file_path = os.path.join(final_save_dir, new_filename)
    shutil.copy(best_model_path, new_file_path)
    print(f"Copied and renamed {best_model_filename} to {new_filename}")
    folder_name = f"{config.file_name.split('.')[0]}_fold{config.current_fold}"

    return folder_name, acc_max

## run all files

In [24]:
def run_all_files_in_directory(config):
    files = [f for f in os.listdir(config.root) if os.path.isfile(os.path.join(config.root, f))]
    files.sort(key=lambda x: int(x.split('_')[0][1:]))
    feature_files = [f for f in os.listdir(config.root_feature) if os.path.isfile(os.path.join(config.root_feature, f))]
    feature_files.sort(key=lambda x: int(x.split('_')[0][1:]))
    
    if len(files) != len(feature_files):
        raise ValueError("The number of EEG files and feature files does not match!")
        
    results = {}

    for eeg_file, feature_file in zip(files, feature_files):
        if Config.mode == 'cross_subject':
            match = re.search(r's(\d+)', eeg_file)
            subject_number = int(match.group(1))
            config.LOO = subject_number

        config.file_name = eeg_file  
        config.feature_name = feature_file 
        print(f"Processing EEG file: {config.file_name}, Feature file: {config.feature_name}")
        subject_acc = []
        
        for fold_num in range(1, 6):
            print(f"Fold {fold_num}/5")
            config.current_fold = fold_num 

            folder_name, fold_acc = train_and_save_model(config)
            subject_acc.append(fold_acc)
            
            print(f"Fold {fold_num} Accuracy: {fold_acc:.3f}")
        
        avg_acc = np.mean(subject_acc)
        std_acc = np.std(subject_acc)
        
        sub_dir_name = config.file_name.split('_')[0]
        
        subj_id = f"{sub_dir_name}"
        results[subj_id] = {
            'avg_acc': avg_acc,
            'std_acc': std_acc,
            'fold_accs': subject_acc
        }
        
        print(f"Subject {subj_id} Final: {avg_acc:.3f} ± {std_acc:.3f}")
    
    print("\nFinal Results:")
    for subj, metrics in results.items():
        print(f"{subj}:")
        print(f"  Fold Accuracies: {[f'{x:.3f}' for x in metrics['fold_accs']]}")
        print(f"  Average Accuracy: {metrics['avg_acc']:.3f}")
        print(f"  Standard Deviation: {metrics['std_acc']:.3f}")
    
    return results

In [None]:
results = run_all_files_in_directory(Config)