In [1]:
from http.client import ImproperConnectionState
import torch
from sklearn.model_selection import RepeatedKFold
from sklearn.metrics import accuracy_score, cohen_kappa_score, precision_score, recall_score, f1_score
import numpy as np
from model import fNIRS_T, fNIRS_PreT
from dataloader import Dataset, Load_Dataset_A, Load_Dataset_B, Load_Dataset_C
import os
import argparse
from model_init import init_model_MK2
parser = argparse.ArgumentParser(description="传入您的device_id, dataset_id(0->A, 1->B, 2->C), models_id(0->T, 1->PreT, 2->Ours)")
parser.add_argument('--device_id', type=str, default='0', help='传入您的device_id')
# parser.add_argument('--dataset_id', type=int, default='1', help='传入您的dataset_id(0->A, 1->B, 2->C)')
parser.add_argument('--models_id', type=int, default='16', help='传入您的models_id(0->T, 1->PreT, 2->Ours)')
args = parser.parse_known_args()[0]

def average_log(save_path):
    # 存储每个epoch的总和，用于求平均
    train_loss_all = []
    train_acc_all = []
    test_loss_all = []
    test_acc_all = []

    sub_folders = [f.path for f in os.scandir(save_path) if f.is_dir()]  # 获取所有子文件夹
    num_subs = len(sub_folders)  # 记录子文件夹数量

    # 遍历所有子文件夹，读取train_loss_history.txt, train_acc_history.txt等文件
    for sub_folder in sub_folders:
        with open(os.path.join(sub_folder, 'train_loss_history.txt'), 'r') as f:
            train_loss_history = eval(f.read())  # 假设文件内容是一个列表
        
        with open(os.path.join(sub_folder, 'train_acc_history.txt'), 'r') as f:
            train_acc_history = eval(f.read())
        
        with open(os.path.join(sub_folder, 'test_loss_history.txt'), 'r') as f:
            test_loss_history = eval(f.read())
        
        with open(os.path.join(sub_folder, 'test_acc_history.txt'), 'r') as f:
            test_acc_history = eval(f.read())

        # 将当前子文件夹的数据累加到对应的总和列表中
        if len(train_loss_all) == 0:  # 第一次初始化每个列表
            train_loss_all = np.array(train_loss_history)
            train_acc_all = np.array(train_acc_history)
            test_loss_all = np.array(test_loss_history)
            test_acc_all = np.array(test_acc_history)
        else:
            train_loss_all += np.array(train_loss_history)
            train_acc_all += np.array(train_acc_history)
            test_loss_all += np.array(test_loss_history)
            test_acc_all += np.array(test_acc_history)

        # 计算平均值，并保留三位小数
        train_loss_avg = [round(loss, 3) for loss in (train_loss_all / num_subs).tolist()]
        train_acc_avg = [round(acc, 3) for acc in (train_acc_all / num_subs).tolist()]
        test_loss_avg = [round(loss, 3) for loss in (test_loss_all / num_subs).tolist()]
        test_acc_avg = [round(acc, 3) for acc in (test_acc_all / num_subs).tolist()]

    # 将最终结果保存到文件
    with open(os.path.join(save_path, 'final_train_loss_history.txt'), 'w') as f:
        f.write(str(train_loss_avg))
    
    with open(os.path.join(save_path, 'final_train_acc_history.txt'), 'w') as f:
        f.write(str(train_acc_avg))
    
    with open(os.path.join(save_path, 'final_test_loss_history.txt'), 'w') as f:
        f.write(str(test_loss_avg))
    
    with open(os.path.join(save_path, 'final_test_acc_history.txt'), 'w') as f:
        f.write(str(test_acc_avg))


class LabelSmoothing(torch.nn.Module):
    """NLL loss with label smoothing."""
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss 
        return loss.mean()





  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Training epochs
EPOCH = 1

# Device setting

device_id = args.device_id

# bs
batch_size = 128

# Select dataset
dataset = ['C']

models = ['fNIRS-T', 'fNIRS-PreT',  'CT-Net', 'fNIRS_TTT_LM', 'fNIRS_TTT_M', 'fNIRS_TTT_LL', 'fNIRS_TTT_L']
models_id = 2
print(models[models_id])

for index, ds in enumerate(dataset):
    dataset_id = index

    print(dataset[dataset_id])


    if dataset[dataset_id] == 'A':
        flooding_level = [0, 0, 0]
        if models[models_id] == 'fNIRS-T' or 1:
            feature, label = Load_Dataset_A("data/A", model='fNIRS-T')
        elif models[models_id] == 'fNIRS-PreT':
            feature, label = Load_Dataset_A("data/A", model='fNIRS-PreT')
    elif dataset[dataset_id] == 'B':
        if models[models_id] == 'fNIRS-T' or 1:
            flooding_level = [0.45, 0.40, 0.35]
        else:
            flooding_level = [0.40, 0.38, 0.35]
        feature, label = Load_Dataset_B("data/B")
    elif dataset[dataset_id] == 'C':
        flooding_level = [0.45, 0.40, 0.35]
        feature, label = Load_Dataset_C("data/C")
    
    _, _, channels, sampling_points = feature.shape

    feature = feature.reshape((label.shape[0], -1))
    # 5 × 5-fold-CV
    rkf = RepeatedKFold(n_splits=5, n_repeats=5, random_state=42)
    n_runs = 0

    result_acc = []
    result_pre = []
    result_rec = []
    result_f1  = []
    result_kap = []
    break 


CT-Net
C
1  OK
2  OK
3  OK
4  OK
5  OK
6  OK
7  OK
8  OK
9  OK
10  OK
11  OK
12  OK
13  OK
14  OK
15  OK
16  OK
17  OK
18  OK
19  OK
20  OK
21  OK
22  OK
23  OK
24  OK
25  OK
26  OK
27  OK
28  OK
29  OK
30  OK
feature  (2250, 2, 20, 256)
label  (2250,)


In [2]:
import os
gpus = [0]
import numpy as np
import pandas as pd
import random
import datetime
import time
import ttt
from pandas import ExcelWriter
# from torchsummary import summary
import torch
from torch.backends import cudnn
# from utils import calMetrics
# from utils import calculatePerClass
# from utils import numberClassChannel
import math
import warnings
warnings.filterwarnings("ignore")
cudnn.benchmark = False
cudnn.deterministic = True



import torch
from torch import nn
from torch import Tensor
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat
import torch.nn.functional as F

# from utils import numberClassChannel
# from utils import load_data_evaluate
import numpy as np
import pandas as pd
from torch.autograd import Variable


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from sympy import false, true


class _ScaleModule(nn.Module):
    def __init__(self, dims, init_scale=1.0, init_bias=0):
        super(_ScaleModule, self).__init__()
        self.dims = dims
        self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
        self.bias = None
    
    def forward(self, x):
        return torch.mul(self.weight, x)

class WTConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
        super(WTConv2d, self).__init__()
        # super().__init__()
        # assert in_channels == out_channels

        self.in_channels = in_channels
        self.wt_levels = wt_levels
        self.stride = stride
        self.dilation = 1

        self.wt_filter, self.iwt_filter = wavelet.create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
        self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
        self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)

        self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels, bias=bias)
        self.base_scale = _ScaleModule([1,in_channels,1,1])

        self.wavelet_convs = nn.ModuleList(
            [nn.Conv2d(in_channels*4, in_channels*4, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels*4, bias=False) for _ in range(self.wt_levels)]
        )
        self.wavelet_scale = nn.ModuleList(
            [_ScaleModule([1,in_channels*4,1,1], init_scale=0.1) for _ in range(self.wt_levels)]
        )

        if self.stride > 1:
            self.do_stride = nn.AvgPool2d(kernel_size=1, stride=stride)
        else:
            self.do_stride = None

    def forward(self, x):

        x_ll_in_levels = []
        x_h_in_levels = []
        shapes_in_levels = []

        curr_x_ll = x

        for i in range(self.wt_levels):
            curr_shape = curr_x_ll.shape
            shapes_in_levels.append(curr_shape)
            if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
                curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
                curr_x_ll = F.pad(curr_x_ll, curr_pads)

            curr_x = wavelet.wavelet_transform(curr_x_ll, self.wt_filter)
            curr_x_ll = curr_x[:,:,0,:,:]
            
            shape_x = curr_x.shape
            curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
            curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
            curr_x_tag = curr_x_tag.reshape(shape_x)

            x_ll_in_levels.append(curr_x_tag[:,:,0,:,:])
            x_h_in_levels.append(curr_x_tag[:,:,1:4,:,:])

        next_x_ll = 0

        for i in range(self.wt_levels-1, -1, -1):
            curr_x_ll = x_ll_in_levels.pop()
            curr_x_h = x_h_in_levels.pop()
            curr_shape = shapes_in_levels.pop()

            curr_x_ll = curr_x_ll + next_x_ll

            curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
            next_x_ll = wavelet.inverse_wavelet_transform(curr_x, self.iwt_filter)

            next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]

        x_tag = next_x_ll
        assert len(x_ll_in_levels) == 0
        
        x = self.base_scale(self.base_conv(x))
        x = x + x_tag
        
        if self.do_stride is not None:
            x = self.do_stride(x)

        return x

class Spatial_layer(nn.Module):#spatial attention layer
    def __init__(self):
        super(Spatial_layer, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        identity = x
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)

        out = self.sigmoid(x)*identity
        return out
    
class Channel_layer(nn.Module):
    """Constructs a channel layer.
    Args:k_size: Adaptive selection of kernel size
    """
    def __init__(self, k_size=3):
        super(Channel_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)
        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        # Multi-scale information fusion
        y = self.sigmoid(y)

        out = x * y.expand_as(x)
        return out
    
class Temporal_layer(nn.Module):
    """Constructs a Temporal layer.
    Args:k_size: Adaptive selection of kernel size
    """
    def __init__(self, num_T=16):
        super(Temporal_layer, self).__init__()

        self.sa_layer = Spatial_layer()
        self.ch_layer = Channel_layer()

        self.conv = nn.Conv2d(2*num_T, 1*num_T, kernel_size=3, padding=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y_s = self.sa_layer(x)
        y_c = self.ch_layer(x)
        y_t = torch.cat([y_s, y_c], dim=1)
        y_t = self.conv(y_t)  

        out = self.sigmoid(y_t)
        return out



class PatchEmbeddingCNN(nn.Module):
    def __init__(self, f1=40, kernel_size=16, D=2, pooling_size1=2, pooling_size2=2, dropout_rate=0.3, number_channel=52, emb_size=40, sampling_points=100):
        super().__init__()
        f2 = D*f1
        self.cnn_module = nn.Sequential(
            # temporal conv kernel size 64=0.25fs
            nn.Conv2d(2, f1, (1, kernel_size), (1, 1), padding='same', bias=False), # [batch, 22, 1000] 
            nn.BatchNorm2d(f1),
            # channel depth-wise conv
            # nn.Conv2d(f1, f2, (number_channel, 1), (1, 1), groups=f1, padding='valid', bias=False), # 
            # # Temporal_layer(num_T=f2),
            # nn.BatchNorm2d(f2),
            # nn.ELU(),
            # # average pooling 1
            # # nn.AvgPool2d((1, pooling_size1)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            # nn.Dropout(dropout_rate),
            # # spatial conv
            # nn.Conv2d(f2, f2, (1, 32), padding='same', bias=False), 
            # # Temporal_layer(num_T=f2),
            # nn.BatchNorm2d(f2),
            # nn.ELU(),
            # # average pooling 2 to adjust the length of feature into transformer encoder
            # # nn.AvgPool2d((1, pooling_size2)),
            # nn.Dropout(dropout_rate),  
                    
        )
        # self.cnn_module = nn.Conv2d(2, f1, (1, kernel_size), (1, 1), padding='same', bias=False) # [batch, 22, 1000] 
        self.Conv1 = nn.Conv2d(f1, f1, (1,1), padding='same')
        self.Conv2 = nn.Conv2d(f1, f1, (1, (sampling_points//4)),  padding='same')
        self.Conv3 = nn.Conv2d(f1, f1, (1, (sampling_points//8)),  padding='same')
        self.Conv4 = nn.Conv2d(f1, f1, (1, (sampling_points//16)), padding='same')
        self.InputMaxPooling = nn.MaxPool2d((1,1))
        self.Conv1_2 = nn.Conv2d(f1, f1, (1,1), padding='same')
        self.bottleneck = nn.Sequential( 
            nn.Conv2d(4 * f1, f1, (1, 1)),  # 输入通道=4f1 (4个特征图)
            nn.BatchNorm2d(f1)
        )
        self.Conv1_3 = nn.Conv2d(f1, f1, (1,1), padding='same')
        
        self.BN = nn.BatchNorm2d(f1)


        self.pre_att = Temporal_layer(num_T=f2)
        self.projection = nn.Sequential(
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        
        
    def forward(self, x: Tensor) -> Tensor:
        # x = x.unsqueeze(1)
        b, _, _, _ = x.shape # b, 2, 52, 140
        if x.shape[1]!=2:
            print(x.shape)
            assert 'ArithmeticError'
        x = self.cnn_module(x)
        # print("x after cnn: ", x.shape) #(128, 40, 1, 2)
        c1 = self.Conv1(x)
        xmp = self.InputMaxPooling(x)
        c2 = self.Conv2(c1)
        c3 = self.Conv3(c1)
        c4 = self.Conv4(c1)
        c12 = self.Conv1_2(xmp)
        xc = torch.cat([c2,c3,c4,c12], dim=1)
        xc = self.bottleneck(xc)
        x13 = self.Conv1_3(xc)
        y = self.BN(x13)

        y = self.projection(y) # (128, 2, 40)
        return y
    


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    


# PointWise FFN
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class ClassificationHead(nn.Sequential):
    def __init__(self, flatten_number, n_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(flatten_number, n_classes)
        )

    def forward(self, x):
        out = self.fc(x)
        
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn, emb_size, drop_p):
        super().__init__()
        self.fn = fn
        self.drop = nn.Dropout(drop_p)
        self.layernorm = nn.LayerNorm(emb_size)

    def forward(self, x, **kwargs):
        x_input = x
        res = self.fn(x, **kwargs)
        
        out = self.layernorm(self.drop(res)+x_input)
        return out


class BranchEEGNetTransformer_wavelet(nn.Sequential):
    def __init__(self, heads=4, 
                 depth=6, 
                 emb_size=40, 
                 number_channel=22,
                 f1 = 20,
                 kernel_size = 64,
                 D = 2,
                 pooling_size1 = 8,
                 pooling_size2 = 8,
                 dropout_rate = 0.3,
                 **kwargs):
        super().__init__(
            PatchEmbeddingCNN_wavelet(f1=f1, 
                                 kernel_size=kernel_size,
                                 D=D, 
                                 pooling_size1=pooling_size1, 
                                 pooling_size2=pooling_size2, 
                                 dropout_rate=dropout_rate,
                                 number_channel=number_channel,
                                 emb_size=emb_size),
#             TransformerEncoder(heads, depth, emb_size),
        )


class BranchEEGNetTransformer(nn.Sequential):
    def __init__(self, heads=4, 
                 depth=6, 
                 emb_size=40, 
                 number_channel=22,
                 f1 = 20,
                 kernel_size = 64,
                 D = 2,
                 pooling_size1 = 8,
                 pooling_size2 = 8,
                 dropout_rate = 0.3,
                 **kwargs):
        super().__init__(
            PatchEmbeddingCNN(f1=f1, 
                                 kernel_size=kernel_size,
                                 D=D, 
                                 pooling_size1=pooling_size1, 
                                 pooling_size2=pooling_size2, 
                                 dropout_rate=dropout_rate,
                                 number_channel=number_channel,
                                 emb_size=emb_size),
#             TransformerEncoder(heads, depth, emb_size),
        )
    

class PositioinalEncoding(nn.Module):
    def __init__(self, embedding, length=100, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.encoding = nn.Parameter(torch.randn(1, length, embedding))
    def forward(self, x): # x-> [batch, embedding, length]
        x = x + self.encoding[:, :x.shape[1], :].cuda()
        # x = x + self.encoding[:, :x.shape[1], :].cuda()
        return self.dropout(x)        




class MultiheadAttention(nn.Module):
    """
    from: https://github.com/yaohungt/Multimodal-Transformer
    Multi-headed attention.
    See "Attention Is All You Need" for more details.
    """

    def __init__(self, embed_dim, num_heads, attn_dropout=0.,
                 bias=True, add_bias_kv=False, add_zero_attn=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.attn_dropout = attn_dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.in_proj_weight = nn.Parameter(torch.Tensor(3 * embed_dim, embed_dim))
        self.register_parameter('in_proj_bias', None)
        if bias:
            self.in_proj_bias = nn.Parameter(torch.Tensor(3 * embed_dim))
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
            self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.in_proj_weight)
        nn.init.xavier_uniform_(self.out_proj.weight)
        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.)
            nn.init.constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)

    def forward(self, query, key, value, attn_mask=None):
        """Input shape: Time x Batch x Channel
        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Timesteps can be masked by supplying a T x T mask in the
        `attn_mask` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
        kv_same = key.data_ptr() == value.data_ptr()

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        assert key.size() == value.size()

        aved_state = None

        if qkv_same:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif kv_same:
            # encoder-decoder attention
            q = self.in_proj_q(query)

            if key is None:
                assert value is None
                k = v = None
            else:
                k, v = self.in_proj_kv(key)
        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q = q * self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        src_len = k.size(1)

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
        
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            try:
                attn_weights += attn_mask.unsqueeze(0)
            except:
                print(attn_weights.shape)
                print(attn_mask.unsqueeze(0).shape)
                assert False
                
        attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
        # attn_weights = F.relu(attn_weights)
        # attn_weights = attn_weights / torch.max(attn_weights)
        attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training)

        attn = torch.bmm(attn_weights, v)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]

        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        # average attention weights over heads
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
        attn_weights = attn_weights.sum(dim=1) / self.num_heads
        return attn, attn_weights

    def in_proj_qkv(self, query):
        return self._in_proj(query).chunk(3, dim=-1)

    def in_proj_kv(self, key):
        return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)

    def in_proj_q(self, query, **kwargs):
        return self._in_proj(query, end=self.embed_dim, **kwargs)

    def in_proj_k(self, key):
        return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)

    def in_proj_v(self, value):
        return self._in_proj(value, start=2 * self.embed_dim)

    def _in_proj(self, input, start=0, end=None, **kwargs):
        weight = kwargs.get('weight', self.in_proj_weight)
        bias = kwargs.get('bias', self.in_proj_bias)
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)


class TransformerEncoderLayer(nn.Module):
    """Encoder layer block.
    In the original paper each operation (multi-head attention or FFN) is
    postprocessed with: `dropout -> add residual -> layernorm`. In the
    tensor2tensor code they suggest that learning is more robust when
    preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.encoder_normalize_before* to ``True``.
    Args:
        embed_dim: Embedding dimension
    """

    def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1,
                 attn_mask=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            attn_dropout=attn_dropout
        )
        self.attn_mask = attn_mask

        self.relu_dropout = relu_dropout
        self.res_dropout = res_dropout
        self.normalize_before = True

        self.fc1 = nn.Linear(self.embed_dim, 4*self.embed_dim)   # The "Add & Norm" part in the paper
        self.fc2 = nn.Linear(4*self.embed_dim, self.embed_dim)
        self.layer_norms = nn.ModuleList([nn.LayerNorm(self.embed_dim) for _ in range(2)])

    def forward(self, x, x_k=None, x_v=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
            x_k (Tensor): same as x
            x_v (Tensor): same as x
        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(0, x, before=True)
        mask = None
        if x_k is None and x_v is None:
            x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask)
        else:
            x_k = self.maybe_layer_norm(0, x_k, before=True)
            x_v = self.maybe_layer_norm(0, x_v, before=True) 
            x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask)
        x = F.dropout(x, p=self.res_dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(0, x, after=True)

        residual = x
        x = self.maybe_layer_norm(1, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.res_dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(1, x, after=True)
        return x

    def maybe_layer_norm(self, i, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return self.layer_norms[i](x)
        else:
            return x


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=4,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                MultiHeadAttention(emb_size, num_heads, drop_p),
                ), emb_size, drop_p),
            ResidualAdd(nn.Sequential(
                FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                ), emb_size, drop_p)
            
            )    
        
        
class TransformerEncoder(nn.Sequential):
    def __init__(self, heads, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size, heads) for _ in range(depth)])


class FeatureExtractor:
    def __init__(self):
        self.features = None

    def save_features(self, features):
        self.features = features.detach().cpu().numpy()  # 转换为numpy方便后续t-SNE

    def get_features(self):
        return self.features


class fNIRS_TTT_LM(nn.Module):
    def __init__(self, n_class, sampling_point, dim, depth, heads, mlp_dim, pool='cls', 
                 dim_head=64, dropout=0., emb_dropout=0.1,intermediate_size=4, dataset="A",
                 mini_batch_size=16, device='cpu', batch_size=128):
        super().__init__()
        self.device = device
        self.bs = batch_size
        kernel_time = 30
        match dataset:
            case "A":
                patch_num_head = 2
                input_ch = 52
                patch_channel_size = 5
                inner_channels = 8

            case "B":
                patch_num_head = 4
                input_ch = 36
                patch_channel_size = 5
                inner_channels = 8

            case "C":
                patch_num_head = 8
                input_ch = 20
                patch_channel_size = 2
                inner_channels = 8


        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=inner_channels, kernel_size=(patch_channel_size, kernel_time), padding="same"),
            # nn.Conv2d(in_channels=2, out_channels=4, kernel_size=(2, 30), padding="same"),
            Rearrange('b c h w  -> b w (c h)'), #bs, feature_dim, ch, time -> bs, time, feature_dim*ch
            nn.Linear(inner_channels*input_ch, dim),
            nn.LayerNorm(dim)
            )

        self.to_channel_embedding = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=inner_channels, kernel_size=(1, kernel_time), padding="same"),
            Rearrange('b c h w  -> b w (c h)'), #bs, feature_dim, ch, time -> bs, time, feature_dim*ch
            nn.Linear(inner_channels*input_ch, dim),
            nn.LayerNorm(dim)
            )

        self.pos_embedding_patch = nn.Parameter(torch.randn(1, sampling_point, dim))
        self.cls_token_patch = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout_patch = nn.Dropout(emb_dropout)
        self.layernormalization = nn.LayerNorm(dim)
        self.cross_transformer = TransformerEncoderLayer(dim, heads)
        self.feature_extractor = FeatureExtractor()  # 初始化特征提取器

        config_p = ttt.TTTConfig(
                                hidden_size=sampling_point,           # 隐藏层大小
                                intermediate_size=sampling_point*intermediate_size,    # MLP中间层的大小，可以设置为hidden_size的倍数
                                num_hidden_layers=1,      # 隐藏层的数量
                                num_attention_heads=patch_num_head,    # 注意力头的数量
                                rms_norm_eps=1e-6,        # RMS归一化epsilon值
                                mini_batch_size=mini_batch_size )
        self.ttt_PreNorm_patch = nn.LayerNorm(dim)
        self.tttMLP = ttt.TTTMLP(config_p, layer_idx=0).to(device)
        self.patch_cache = ttt.TTTCache_MK2(self.tttMLP, batch_size, self.tttMLP.config.mini_batch_size).to(device)
        
        self.pos_embedding_channel = nn.Parameter(torch.randn(1, sampling_point, dim))
        self.cls_token_channel = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout_channel = nn.Dropout(emb_dropout)


        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, n_class))
        
    def fulfill(source, aim):
        tensor1 = aim
        tensor2 = source
        padded_tensor1 = torch.cat([tensor2, torch.zeros((tensor1.shape[0] - tensor2.shape[0]), 
                                                        tensor2.shape[1], 
                                                        tensor2.shape[2],
                                                        tensor2.shape[3]).to(tensor1.device)], dim=0)
        return(padded_tensor1)

    def forward(self, img, mask=None):
        n_samples = img.shape[0]
        if n_samples != self.bs:
            img = torch.cat([img, torch.zeros((self.bs - n_samples, 
                                               img.shape[1],
                                               img.shape[2],
                                               img.shape[3])).to(self.device)], dim=0)
        x = self.to_patch_embedding(img)
        x2 = self.to_channel_embedding(img.squeeze())

        # pos embedding
        b, n, _ = x.shape
        x += self.pos_embedding_patch[:, :(n + 1)]
        x = self.dropout_patch(x)
        b, n, _ = x2.shape
        x2 += self.pos_embedding_channel[:, :(n + 1)]
        x2 = self.dropout_channel(x2)
        
        # #cross attn
        # cr_s1 = x
        # cr_s2 = x2
        # x = self.cross_transformer(cr_s1, cr_s2, cr_s2)
        # concat
        x = torch.cat([x,x2],dim=2)

        #ttt
        x = self.ttt_PreNorm_patch(x)
        x = x.transpose(1, 2)
        patch_position_ids = torch.arange(x.shape[1]).unsqueeze(0).repeat(x.shape[0],1).to(self.device)
        x = self.tttMLP(x, position_ids=patch_position_ids ,cache_params=self.patch_cache)
        x = x.transpose(1, 2)
        
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        x_cross = self.to_latent(x)

        # 保存特征
        self.feature_extractor.save_features(x_cross)

        def get_saved_features(self):
            return self.feature_extractor.get_features()
        
        return self.mlp_head(x_cross)[:n_samples]


class CTNet(nn.Module):
    def __init__(self, heads=4, 
                 emb_size=16,
                 depth=6, 
                 eeg1_f1 = 40,
                 eeg1_kernel_size = 64,
                 eeg1_D = 2,
                 eeg1_pooling_size1 = 8,
                 eeg1_pooling_size2 = 8,
                 eeg1_dropout_rate = 0.1,
                 eeg1_number_channel = 22,
                 flatten_eeg1 = 240,
                 n_class = 2,
                 **kwargs):
        super().__init__()
        self.number_class, self.number_channel = n_class, eeg1_number_channel
        self.emb_size = emb_size
        self.flatten_eeg1 = flatten_eeg1
        self.flatten = nn.Flatten()
        # print('self.number_channel', self.number_channel)
        self.cnn = BranchEEGNetTransformer(heads, depth, emb_size, number_channel=self.number_channel,
                                              f1 = eeg1_f1,
                                              kernel_size = eeg1_kernel_size,
                                              D = eeg1_D,
                                              pooling_size1 = eeg1_pooling_size1,
                                              pooling_size2 = eeg1_pooling_size2,
                                              dropout_rate = eeg1_dropout_rate,
                                              )
        # cnn部分
        # self.Conv1 = nn.Conv2d(2, eeg1_f1, (1,1), padding='same')
        # self.Conv2 = nn.Conv2d(eeg1_f1, eeg1_f1, (1, (sampling_points//4)),  padding='same')
        # self.Conv3 = nn.Conv2d(eeg1_f1, eeg1_f1, (1, (sampling_points//8)),  padding='same')
        # self.Conv4 = nn.Conv2d(eeg1_f1, eeg1_f1, (1, (sampling_points//16)), padding='same')
        # self.InputMaxPooling = nn.MaxPool2d((1,1))
        # self.Conv1_2 = nn.Conv2d(2, eeg1_f1, (1,1), padding='same')
        # self.bottleneck = nn.Sequential( 
        #     nn.Conv2d(4 * eeg1_f1, eeg1_f1, (1, 1)),  # 输入通道=4f1 (4个特征图)
        #     nn.BatchNorm2d(eeg1_f1)
        # )
        # self.Conv1_3 = nn.Conv2d(eeg1_f1, eeg1_f1, (1,1), padding='same')
        # self.BN = nn.BatchNorm2d(eeg1_f1)

        self.position = PositioinalEncoding(emb_size, dropout=0.1)
        self.trans = TransformerEncoder(heads, depth, emb_size)
        self.flatten = nn.Flatten()
        self.classification = ClassificationHead(self.flatten_eeg1 , self.number_class) # FLATTEN_EEGNet + FLATTEN_cnn_module

    def forward(self, x):
        # print(x.shape) (128, 2, 52, 140)
        cnn = self.cnn(x)
        # cnn2 = self.cnn_wavelet(x)
        #print(cnn.shape)
        #print(cnn2.shape)
        # add label 
        # cnn = cnn * math.sqrt(self.emb_size)
        # cnn = self.position(cnn) (128, 140, 40)
        
        trans = self.trans(cnn) #(128, 140, 40)

        # cnn_fusion = self.cross_attention(cnn, cnn2, cnn2)
        # cnn_fusion = self.flatten(cnn_fusion)
        
        # features = cnn2 + trans + cnn
        # features = trans + cnn_fusion
        features = trans + cnn

        # features = cnn
        # features = self.cross_attention(query=cnn, key=trans, value=trans)
        # print(features.shape)
        out = self.classification(self.flatten(features))
        return out
    


In [None]:
for train_index, test_index in rkf.split(feature):
    n_runs += 1
    print('======================================\n', n_runs)

    X_train = feature[train_index]
    y_train = label[train_index]
    X_test = feature[test_index]
    y_test = label[test_index]

    X_train = X_train.reshape((X_train.shape[0], 2, channels, -1))
    X_test = X_test.reshape((X_test.shape[0], 2, channels, -1))

    train_set = Dataset(X_train, y_train, transform=True)
    test_set = Dataset(X_test, y_test, transform=True)
    ########### fix seed ###########
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    ################################
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True)

    # sample = train_set[0]
    # in_shape = sample.shape
    # -------------------------------------------------------------------------------------------------------------------- #
    device = torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu')

    match dataset[dataset_id]:
        case "A":
            n_class = 2
            sampling_point = 140
            h = 52
            ct_f = int(5600)
        case "B":
            n_class = 2
            sampling_point = 200
            h = 36
            ct_f = int(8000)
        case "C":
            n_class = 3
            sampling_point = 256
            h = 20
            ct_f = int(sampling_point*h*2)

    # net = CTNet(heads=4, emb_size=40,depth=6, eeg1_f1=40, 
    #                     eeg1_D=2,eeg1_kernel_size=32, eeg1_pooling_size1=2, eeg1_pooling_size2=2,
    #                     eeg1_dropout_rate=0.1,eeg1_number_channel=h,flatten_eeg1=ct_f,n_class=n_class).to(device) # A 80 B 120 C 160

    net = fNIRS_TTT_LM(n_class=n_class, sampling_point=sampling_point,
                             dim=64, depth=6, heads=8, mlp_dim=64, device=device,
                             batch_size=batch_size, dataset=dataset[dataset_id]).to(device)
    criterion = LabelSmoothing(0.1)
    optimizer = torch.optim.AdamW(net.parameters())
    lrStep = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
    # -------------------------------------------------------------------------------------------------------------------- #
    test_max_acc = 0


    train_loss_history = []
    train_acc_history = []
    test_loss_history = []
    test_acc_history = []
    break

for epoch in range(10):
    net.train()
    train_running_acc = 0
    total = 0
    loss_steps = []
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels.long())

        # Piecewise decay flooding. b is flooding level, b = 0 means no flooding
        if epoch < 30:
            b = flooding_level[0]
        elif epoch < 50:
            b = flooding_level[1]
        else:
            b = flooding_level[2]

        # flooding
        loss = (loss - b).abs() + b

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_steps.append(loss.item())
        total += labels.shape[0]
        pred = outputs.argmax(dim=1, keepdim=True)
        train_running_acc += pred.eq(labels.view_as(pred)).sum().item()

    train_running_loss = float(np.mean(loss_steps))
    train_running_acc = 100 * train_running_acc / total
    # 将训练损失和准确率保存到对应列表
    train_loss_history.append(train_running_loss)
    train_acc_history.append(train_running_acc)
    print('[%d, %d] Train loss: %0.4f' % (n_runs, epoch, train_running_loss))
    print('[%d, %d] Train acc: %0.3f%%' % (n_runs, epoch, train_running_acc))

    # -------------------------------------------------------------------------------------------------------------------- #
    net.eval()
    test_running_acc = 0
    total = 0
    loss_steps = []
    y_label = y_pred = None
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels.long())

            loss_steps.append(loss.item())
            total += labels.shape[0]
            pred = outputs.argmax(dim=1, keepdim=True)
            test_running_acc += pred.eq(labels.view_as(pred)).sum().item()

        test_running_acc = 100 * test_running_acc / total
        test_running_loss = float(np.mean(loss_steps))
        test_loss_history.append(test_running_loss)
        test_acc_history.append(test_running_acc)
        print('     [%d, %d] Test loss: %0.4f' % (n_runs, epoch, test_running_loss))
        print('     [%d, %d] Test acc: %0.3f%%' % (n_runs, epoch, test_running_acc))



 5
torch.Size([1800, 2, 20, 256])
torch.Size([1800])
torch.Size([450, 2, 20, 256])
torch.Size([450])


OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 GiB. GPU 0 has a total capacty of 79.32 GiB of which 10.14 GiB is free. Process 111527 has 3.60 GiB memory in use. Process 12272 has 61.14 GiB memory in use. Process 130920 has 4.37 GiB memory in use. Of the allocated memory 2.68 GiB is allocated by PyTorch, and 108.90 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF