In [1]:
import math
import torch
import torch.nn as nn
from einops import rearrange
from easydict import EasyDict
from typing import List, Tuple


In [2]:

def get_sincos_pos_embed(dim: int, seq_len: int, cls_token: bool = False):
    if cls_token:
        pe = torch.zeros(seq_len + 1, dim)
        position = torch.arange(0, seq_len + 1, dtype=torch.float).unsqueeze(1)
    else:
        pe = torch.zeros(seq_len, dim)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)

    return pe



In [3]:

def calculate_output_size(input_size, layers):
    output_size = input_size
    for layer in layers:
        kernel_size, stride = layer[2], layer[3]
        output_size = math.floor((output_size - kernel_size) / stride) + 1
    return output_size


In [3]:
# import math

# def calculate_output_size(input_size, layers):
#     output_size = input_size
#     for layer in layers:
#         kernel_size, stride, padding = layer[2], layer[3], layer[4]
#         output_size = math.floor((output_size + 2 * padding - kernel_size) / stride) + 1
#     return output_size


In [4]:

class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


In [5]:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_gradients = None
        self.attention_map = None
        self.value_gradients = None

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def get_attn_gradients(self):
        return self.attn_gradients

    def save_attention_map(self, attention_map):
        self.attention_map = attention_map

    def get_attention_map(self):
        return self.attention_map

    def save_value_gradients(self, value_gradients):
        self.value_gradients = value_gradients

    def get_value_gradients(self):
        return self.value_gradients

    def forward(self, x, register_hook=False):
        b, n, _, h = *x.shape, self.num_heads

        qkv = self.qkv(x)
        q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)

        dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
        attn = dots.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = torch.einsum("bhij,bhjd->bhid", attn, v)

        self.save_attention_map(attn)
        if register_hook:
            v.register_hook(self.save_value_gradients)
            attn.register_hook(self.save_attn_gradients)

        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


In [6]:

class TransposeLast(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.transpose(-2, -1)


In [7]:

class PatchEmbed(nn.Module):
    def __init__(
        self,
        layers: List[Tuple[int, int, int, int]],
        bias: bool = False,
    ):
        super().__init__()

        self.blocks = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv1d(
                        in_dim,
                        out_dim,
                        kernel_size=kernel,
                        stride=stride,
#                         Changes
#                         padding= padding,
                        bias=bias,
                    ),
                    TransposeLast(),
                    nn.LayerNorm(out_dim),
                    TransposeLast(),
                    nn.GELU(),
                )
                for (in_dim, out_dim, kernel, stride) in layers
            ]
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)

        return x


In [8]:


class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()

        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x, register_hook=False):
        x = x + self.attn(self.norm1(x), register_hook=register_hook)
        x = x + self.mlp(self.norm2(x))
        return x


In [9]:
class TemporalSpatialEncoder(nn.Module):
    def __init__(self, embed_dim: int, nhead: int, dropout_rate: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim

        self.temporal_block = Block(
            dim=self.embed_dim,
            num_heads=nhead,
            mlp_ratio=1.0,
            qkv_bias=False,
            drop=dropout_rate,
            attn_drop=dropout_rate,
        )

        self.spatial_block = Block(
            dim=self.embed_dim,
            num_heads=nhead,
            mlp_ratio=1.0,
            qkv_bias=False,
            drop=dropout_rate,
            attn_drop=dropout_rate,
        )

        
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, register_hook=False):
        B, C, D, T = x.shape

        # Temporal Block
        x = x.reshape(B * C, D, T)  # BC x D x T
        x = x.transpose(1, 2)  # BC x T x D
        x = self.temporal_block(x, register_hook=register_hook)
        x = x.reshape(B, C, T, D)  # B x C x T x D
        x = x.transpose(1, 2)  # B x T x C x D

        # Spatial Block
        x = x.reshape(B * T, C, D)  # BT x C x D
        x = self.spatial_block(x, register_hook=register_hook)
        x = x.reshape(B, T, C, D)  # B x T x C x D
        x = x.permute(0, 2, 3, 1)  # B x C x D x T

        return x

In [10]:
class Embedding(nn.Module):
    def __init__(
        self,
        embed_dim,
        nhead,
        spatial_len,
        input_size,
        cnn_layers,
        dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.seq_len = calculate_output_size(input_size, cnn_layers)
        self.spatial_len = spatial_len

        self.patch_embed = PatchEmbed(cnn_layers)

        self.temporal_block = Block(
            dim=self.embed_dim,
            num_heads=nhead,
            mlp_ratio=1.0,
            qkv_bias=False,
            drop=dropout_rate,
            attn_drop=dropout_rate,
        )

        self.spatial_block = Block(
            dim=self.embed_dim,
            num_heads=nhead,
            mlp_ratio=1.0,
            qkv_bias=False,
            drop=dropout_rate,
            attn_drop=dropout_rate,
        )

        self.temporal_pos_embed = nn.Parameter(
            torch.zeros(1, self.seq_len + 1, embed_dim), requires_grad=False
        )
        self.spatial_pos_embed = nn.Parameter(
            torch.zeros(1, self.spatial_len + 1, self.embed_dim),
            requires_grad=False,
        )

        self.temporal_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.spatial_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))

        self.initialize_weights()
        
    
    def initialize_weights(self):
        temporal_pos_embed = get_sincos_pos_embed(
            dim=self.embed_dim, seq_len=self.seq_len, cls_token=True
        )
        self.temporal_pos_embed.data.copy_(temporal_pos_embed)

        spatial_pos_embed = get_sincos_pos_embed(
            dim=self.embed_dim,
            seq_len=self.spatial_len,
            cls_token=True,
        )
        self.spatial_pos_embed.data.copy_(spatial_pos_embed)

        torch.nn.init.normal_(self.temporal_token, std=0.02)
        torch.nn.init.normal_(self.spatial_token, std=0.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, register_hook=False):
        B, C, D, T = x.shape

        # Path embedding
        x = x.reshape(B * C, D, T)  # BC x D x T
        x = self.patch_embed(x)

        # Temporal position embedding & block
        x = x.transpose(1, 2)  # BC x T x D
        x = x + self.temporal_pos_embed[:, 1:, :]
        token = self.temporal_token + self.temporal_pos_embed[:, :1, :]
        token = token.expand(B * C, -1, -1)
        x = torch.cat((token, x), dim=1)
        x = self.temporal_block(x, register_hook=register_hook)
        x = x.reshape(B, C, -1, self.embed_dim)  # B x C x T x D
        x = x.transpose(1, 2)  # B x T x C x D

        # Spatial position embedding & block
        B, T, C, D = x.shape
        x = x.reshape(B * T, C, D)  # BT x C x D
        x = x + self.spatial_pos_embed[:, 1:, :]
        token = self.spatial_token + self.spatial_pos_embed[:, :1, :]
        token = token.expand(B * T, -1, -1)
        x = torch.cat((token, x), dim=1)
        x = self.spatial_block(x, register_hook=register_hook)
        x = x.reshape(B, -1, C + 1, self.embed_dim)  # B x T x C x D
        x = x.permute(0, 2, 3, 1)  # B x C x D x T

        return x


In [11]:
class Embedding(nn.Module):
    def __init__(
        self,
        embed_dim,
        nhead,
        spatial_len,
        input_size,
        cnn_layers,
        dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.seq_len = calculate_output_size(input_size, cnn_layers)
        self.spatial_len = spatial_len

        self.patch_embed = PatchEmbed(cnn_layers)

        self.temporal_block = Block(
            dim=self.embed_dim,
            num_heads=nhead,
            mlp_ratio=1.0,
            qkv_bias=False,
            drop=dropout_rate,
            attn_drop=dropout_rate,
        )

        self.spatial_block = Block(
            dim=self.embed_dim,
            num_heads=nhead,
            mlp_ratio=1.0,
            qkv_bias=False,
            drop=dropout_rate,
            attn_drop=dropout_rate,
        )

        self.temporal_pos_embed = nn.Parameter(
            torch.zeros(1, self.seq_len + 1, embed_dim), requires_grad=False
        )
        self.spatial_pos_embed = nn.Parameter(
            torch.zeros(1, self.spatial_len + 1, embed_dim),
            requires_grad=False,
        )

        self.temporal_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.spatial_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))

        self.initialize_weights()
        
    
    def initialize_weights(self):
        temporal_pos_embed = get_sincos_pos_embed(
            dim=self.embed_dim, seq_len=self.seq_len, cls_token=True
        )
        self.temporal_pos_embed.data.copy_(temporal_pos_embed)

        spatial_pos_embed = get_sincos_pos_embed(
            dim=self.embed_dim,
            seq_len=self.spatial_len,
            cls_token=True,
        )
        self.spatial_pos_embed.data.copy_(spatial_pos_embed)

        torch.nn.init.normal_(self.temporal_token, std=0.02)
        torch.nn.init.normal_(self.spatial_token, std=0.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, register_hook=False):
        B, C, D, T = x.shape

        # Path embedding
        x = x.reshape(B * C, D, T)  # BC x D x T
        x = self.patch_embed(x)

        # Temporal position embedding & block
        x = x.transpose(1, 2)  # BC x T x D
        x = x + self.temporal_pos_embed[:, :T, :]
        token = self.temporal_token.expand(B * C, -1, -1)
        x = torch.cat((token, x), dim=1)
        x = self.temporal_block(x, register_hook=register_hook)
        x = x.reshape(B, C, -1, self.embed_dim)  # B x C x T x D
        x = x.transpose(1, 2)  # B x T x C x D

        # Spatial position embedding & block
        B, T, C, D = x.shape
        x = x.reshape(B * T, C, D)  # BT x C x D
        x = x + self.spatial_pos_embed[:, :D, :]
        token = self.spatial_token.expand(B * T, -1, -1)
        x = torch.cat((token, x), dim=1)
        x = self.spatial_block(x, register_hook=register_hook)
        x = x.reshape(B, -1, C + 1, self.embed_dim)  # B x T x C x D
        x = x.permute(0, 2, 3, 1)  # B x C x D x T

        return x


In [11]:
class ClassifierHead(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_classes,
        num_channels,
        seq_len,
        use_token,
        dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.use_token = use_token

        if self.use_token:
            self.mlp_head = nn.Sequential(
                nn.Flatten(),
                nn.Linear((num_channels + 1 + seq_len) * embed_dim, num_classes),
            )
        else:
            self.mlp_head = nn.Sequential(
                nn.Conv2d(
                    embed_dim,
                    embed_dim,
                    kernel_size=(num_channels, 1),
                    bias=False,
                ),
                nn.ELU(),
                nn.BatchNorm2d(embed_dim),
                nn.Dropout2d(dropout_rate),
                nn.Flatten(),
                nn.Linear(seq_len * embed_dim, num_classes),
            )

    def forward(self, x):
        if not self.use_token:
            B, C, D, T = x.shape
            x = x.transpose(1, 2)  # B x D x C x T
        else:
            B, C, D = x.shape
            x = x.reshape(B * C, self.embed_dim)

        x = self.mlp_head(x)
        return x


In [18]:
class DFformer(nn.Module):
    def __init__(
        self,
        embed_dim=128,
        nhead=8,
        inter_information_length=22,
        origin_ival=(1, 64, 3, 1),
        cnn_layers=[(1, 64, 3, 1), (64, 128, 3, 1)],
        nlayer=4,
        num_classes=2,
        use_token=True,
        apply_cls_head=True,
#         db_name="BCIC2a",
        dropout_rate=0.1,
    ):
        super().__init__()

        self.embed_dim = embed_dim
        self.use_token = use_token
        self.apply_cls_head = apply_cls_head

        self.embedding = Embedding(
            self.embed_dim,
            nhead,
            inter_information_length,
            origin_ival[-1],
            cnn_layers,
            dropout_rate,
        )

        self.blocks = nn.ModuleList(
            [
                TemporalSpatialEncoder(self.embed_dim, nhead, dropout_rate)
                for _ in range(nlayer)
            ]
        )

        if self.apply_cls_head:
            self.classifier_head = ClassifierHead(
                self.embed_dim,
                num_classes,
                inter_information_length,
                self.embedding.seq_len,
#                 db_name,
                use_token,
                dropout_rate,
            )

    def forward(self, x, register_hook=False):
        x = self.embedding(x, register_hook=register_hook)

        for block in self.blocks:
            x = x + block(x, register_hook=register_hook)  # B x C x D x T

        if self.apply_cls_head:
            x = self.classifier_head(x[:, 1:, :, 1:])

        return x

In [15]:
class DFformer(nn.Module):
    def __init__(
        self,
        embed_dim=128,
        nhead=8,
        inter_information_length=22,
        origin_ival=(1, 64, 3, 1),
        cnn_layers=[(1, 64, 3, 1), (64, 128, 3, 1)],
        nlayer=4,
        num_classes=2,
        use_token=True,
        apply_cls_head=True,
        dropout_rate=0.1,
    ):
        super().__init__()

        self.embed_dim = embed_dim
        self.use_token = use_token
        self.apply_cls_head = apply_cls_head

        self.embedding = Embedding(
            self.embed_dim,
            nhead,
            inter_information_length,
            origin_ival[-1],
            cnn_layers,
            dropout_rate,
        )

        self.blocks = nn.ModuleList(
            [
                TemporalSpatialEncoder(self.embed_dim, nhead, dropout_rate)
                for _ in range(nlayer)
            ]
        )

        if self.apply_cls_head:
            self.classifier_head = ClassifierHead(
                self.embed_dim,
                num_classes,
                inter_information_length,
                self.embedding.seq_len,
                use_token,
                dropout_rate,
            )

    def forward(self, x, register_hook=False):
        x = self.embedding(x, register_hook=register_hook)

        for block in self.blocks:
            x = x + block(x, register_hook=register_hook)

        if self.apply_cls_head:
            x = self.classifier_head(x[:, 1:, :, 1:])

        return x

In [19]:
model = DFformer()
print(model)

RuntimeError: Trying to create tensor with negative dimension -2: [1, -2, 128]

In [13]:
import os

def load_BCI2a_data_directory(data_dir, training, all_trials=True):
    """ Loading and Dividing of the data set based on the subject-specific 
    (subject-dependent) approach.
    In this approach, we used the same training and testing data as the original
    competition, i.e., 288 x 9 trials in session 1 for training, 
    and 288 x 9 trials in session 2 for testing.  
   
    Parameters
    ----------
    data_dir: string
        Directory containing dataset files
    training: bool
        If True, load training data
        If False, load testing data
    all_trials: bool, optional
        If True, load all trials
        If False, ignore trials with artifacts 

    Returns
    -------
    data: numpy array
        Loaded data
    labels: numpy array
        Loaded labels
    """

    # Define MI-trials parameters
    n_channels = 22
    n_tests = 6 * 48
    window_Length = 7 * 250

    # Define MI trial window
    fs = 250          # sampling rate
    t1 = int(1.5 * fs)  # start time_point
    t2 = int(6 * fs)    # end time_point

    class_return = []
    data_return = []

    for filename in os.listdir(data_dir):
        if filename.endswith(".mat"):
            filepath = os.path.join(data_dir, filename)
            a = sio.loadmat(filepath)
            a_data = a['data']
            for ii in range(0, a_data.size):
                a_data1 = a_data[0, ii]
                a_data2 = [a_data1[0, 0]]
                a_data3 = a_data2[0]
                a_X = a_data3[0]
                a_trial = a_data3[1]
                a_y = a_data3[2]
                a_artifacts = a_data3[5]

                for trial in range(0, a_trial.size):
                    if a_artifacts[trial] != 0 and not all_trials:
                        continue
                    data = np.transpose(a_X[int(a_trial[trial]):(int(a_trial[trial]) + window_Length), :22])
                    data_return.append(data[:, t1:t2])  # Adjusted indexing here
                    class_return.append(int(a_y[trial]) - 1)

    data_return = np.array(data_return)
    class_return = np.array(class_return)

    return data_return, class_return


In [None]:
# import numpy as np
# import scipy.io as sio

# def load_BCI2a_data_directory(data_dir, training, all_trials=True):
#     n_channels = 22
#     window_length = 7 * 250  # 7 seconds * 250 samples/second
#     t1 = int(1.5 * 250)  # Start time point
#     t2 = int(6 * 250)    # End time point

#     data_return = []
#     class_return = []

#     for filename in os.listdir(data_dir):
#         if filename.endswith(".mat"):
#             filepath = os.path.join(data_dir, filename)
#             mat_data = sio.loadmat(filepath)
#             data = mat_data['data'][0, 0][0]
#             trial_info = mat_data['data'][0, 0][1]
#             labels = mat_data['data'][0, 0][2]
#             artifacts = mat_data['data'][0, 0][5]

#             for i in range(len(trial_info)):
#                 if artifacts[i] != 0 and not all_trials:
#                     continue
#                 trial_data = np.transpose(data[int(trial_info[i]):int(trial_info[i]) + window_length, :n_channels])
#                 trial_data = trial_data[:, t1:t2]  # Adjusted indexing
#                 data_return.append(trial_data)
#                 class_return.append(int(labels[i]) - 1)

#     data_return = np.array(data_return, dtype=np.float32)
#     class_return = np.array(class_return, dtype=np.int64)

#     return data_return, class_return


In [14]:
data_dir = 'BCICIV_2a_mat'

In [15]:
import scipy.io as sio
import numpy as np
data , label = load_BCI2a_data_directory(data_dir, True, True)

In [17]:
data

array([[[  2.63671875,  -4.00390625, -11.9140625 , ...,  -0.09765625,
          -1.953125  ,  -6.8359375 ],
        [  2.734375  ,  -3.515625  ,  -7.32421875, ...,   1.46484375,
           0.68359375,  -4.54101562],
        [  4.54101562,  -4.58984375, -11.9140625 , ...,   3.125     ,
           1.80664062,  -5.22460938],
        ...,
        [  1.85546875,  -8.05664062, -10.83984375, ...,  -0.1953125 ,
           1.41601562,  -2.05078125],
        [  0.53710938,  -8.15429688, -10.9375    , ...,  -0.390625  ,
           0.73242188,  -3.75976562],
        [  1.51367188,  -9.1796875 , -10.64453125, ...,  -0.04882812,
           2.49023438,  -0.92773438]],

       [[ -0.29296875,   0.29296875,  -0.83007812, ...,  -9.08203125,
          -9.66796875,  -8.15429688],
        [  0.390625  ,   1.85546875,  -1.31835938, ...,  -4.54101562,
          -7.91015625,  -7.08007812],
        [ -2.34375   ,  -2.24609375,  -3.85742188, ...,  -9.22851562,
         -10.44921875,  -9.27734375],
        ...,


In [19]:
data.shape

(4896, 22, 1125)

In [20]:
label.shape

(4896,)

In [17]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [21]:
from torch.utils.data import Dataset, DataLoader  # Import Dataset
import torch.nn as nn
import torch.optim as optim
class BCI2aDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_data = torch.from_numpy(self.data[idx]).float()  # Convert to float tensor
        sample_label = torch.tensor(self.labels[idx])  # Convert to tensor

        # Optionally, you can add data preprocessing or augmentation here

        return sample_data, sample_label

# Load your data
# # Assuming you have your data and labels loaded as NumPy arrays
# data = np.random.rand(4896, 22, 1125)  # Replace with your actual data
# labels = np.random.randint(0, 2, size=4896)  # Replace with your actual labels

# Create the dataset
dataset = BCI2aDataset(data, label)

# Create the dataloader
batch_size = 1
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define your loss function
criterion = nn.CrossEntropyLoss()

cnn_layers = [(22, 64, 3, 1, 1), (64, 128, 3, 1, 1)] 

# Initialize the DFformer model
model = DFformer(
    embed_dim=128,
    nhead=8,
    inter_information_length=22,
    origin_ival=(1, 22, 1125, 1),
#     cnn_layers=[(1, 64, 3, 1), (64, 128, 3, 1)],
    cnn_layers=cnn_layers,
    nlayer=4,
    num_classes=2,
    use_token=True,
    apply_cls_head=True,
    dropout_rate=0.1,
).to(device)
# Define your optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)  # You can adjust the learning rate as needed
num_epochs=10
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch_data, batch_labels in dataloader:
        # Add batch dimension to batch_data
        batch_data = batch_data.unsqueeze(1).to(device) # (batch_size, 1, 22, 1125)

        # Forward pass through the DFformer model
        output = model(batch_data).to(device)

        # Compute loss
        loss = criterion(output, batch_labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch {epoch + 1} loss: {epoch_loss:.4f}")

print("Training finished.")

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 6.44 GiB is allocated by PyTorch, and 368.08 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)