# 加载数据集


In [1]:
from scipy.io import loadmat
import numpy as np

file_path = 'dataset/OCD_90_200_fMRI.mat'

labels=['OCD', 'NC']

label_mapping = {
    "OCD": 0,
    "NC": 1,
}

# 使用loadmat函数读取.mat文件
data = loadmat(file_path)

## data格式
data={  
    'OCD':array(62, 90, 200),   
    'NC':array(20, 90, 200),  
}  
转换为  
data:(82, 90, 200)  
label:(82,)

尚未归一化  
后续：对于 不同样本间的 同一脑区的 同一时间的数据 进行归一化

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler


ocd = data['OCD']
nc = data['NC']


all_labels = np.concatenate((np.zeros(ocd.shape[0]), np.ones(nc.shape[0])))

data_combined = np.vstack((ocd, nc))

data_combined.shape

## 分为训练和验证

In [3]:
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(data_combined, all_labels, test_size=0.2, random_state=0)

展示一个样本

In [None]:
d=data_combined[0][0]

plt.scatter(range(len(d)), d)
plt.xlabel('Index')

# Mamba

## 导入基本库

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from einops import rearrange
from tqdm import tqdm

import math
import os
import urllib.request
from zipfile import ZipFile

from transformers import AutoTokenizer

torch.autograd.set_detect_anomaly(True)

## 设置标志和超参数

In [6]:
# Configuration flags and hyperparameters
USE_MAMBA = 1
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 定义超参数和初始化  
这里的超参数，如模型维度(d_model)、状态大小、序列长度和批大小。

In [7]:
d_model = 200
state_size = 128  # Example state size
seq_len = 90  # Example sequence length
batch_size = 64  # Example batch size
last_batch_size = 81  # only for the very last batch of the dataset
current_batch_size = batch_size
different_batch_size = False
h_new = None
temp_buffer = None

## S6模块
S6模块是Mamba架构中的一个复杂组件，负责通过一系列线性变换和离散化过程处理输入序列。它在捕获序列的时间动态方面起着关键作用，这是序列建模任务(如语言建模)的一个关键方面。这里包括张量运算和自定义离散化方法来处理序列数据的复杂需求。

这个S6的模块，可以处理离散化过程和正向传播。


In [8]:
class S6(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(S6, self).__init__()

        self.fc1 = nn.Linear(d_model, d_model, device=device)
        self.fc2 = nn.Linear(d_model, state_size, device=device)
        self.fc3 = nn.Linear(d_model, state_size, device=device)

        self.seq_len = seq_len
        self.d_model = d_model
        self.state_size = state_size


        self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
        nn.init.xavier_uniform_(self.A)

        self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
        self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)

        self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
        self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)

        # h  [batch_size, seq_len, d_model, state_size]
        self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)

 
    def discretization(self):

        self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)

        self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))


        return self.dA, self.dB

    def forward(self, x):
    # Algorithm 2  MAMBA paper
        self.B = self.fc2(x)
        self.C = self.fc3(x)
        self.delta = F.softplus(self.fc1(x))

        self.discretization()

        if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:  
            global current_batch_size
            current_batch_size = x.shape[0]

            if self.h.shape[0] != current_batch_size:
                different_batch_size = True
                h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB

            else:
                different_batch_size = False
                h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB

                # y  [batch_size, seq_len, d_model]
                self.y = torch.einsum('bln,bldn->bld', self.C, h_new)

                global temp_buffer
                temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()

                return self.y

        else:  
            # h [batch_size, seq_len, d_model, state_size]
            h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
            y = torch.zeros_like(x)

            h =  torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB

            # y  [batch_size, seq_len, d_model]
            y = torch.einsum('bln,bldn->bld', self.C, h)

            return y

## MambaBlock类
MambaBlock类是一个定制的神经网络模块，被设计为Mamba模型的关键构建块。它封装了几个层和操作来处理输入数据。

包括线性投影、卷积、激活函数、自定义S6模块和残差连接。该块是Mamba模型的基本组件，负责通过一系列转换处理输入序列，以捕获数据中的相关模式和特征。这些不同层和操作的组合允许MambaBlock有效地处理复杂的序列建模任务。

In [9]:
class MambaBlock(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(MambaBlock, self).__init__()

        self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
        self.out_proj = nn.Linear(2*d_model, d_model, device=device)

        # For residual skip connection
        self.D = nn.Linear(d_model, 2*d_model, device=device)

        # Set _no_weight_decay attribute on bias
        self.out_proj.bias._no_weight_decay = True

        # Initialize bias to a small constant value
        nn.init.constant_(self.out_proj.bias, 1.0)

        self.S6 = S6(seq_len, 2*d_model, state_size, device)

        # Add 1D convolution with kernel size 3
        self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)

        # Add linear layer for conv output
        self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)

        # rmsnorm
        self.norm = RMSNorm(d_model, device=device)

    def forward(self, x):
        """
                x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model])
                x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model])
                x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model])
                """
        # Refer to Figure 3 in the MAMBA paper

        x = self.norm(x)

        x_proj = self.inp_proj(x)

        # Add 1D convolution with kernel size 3
        x_conv = self.conv(x_proj)

        x_conv_act = F.silu(x_conv)

        # Add linear layer for conv output
        x_conv_out = self.conv_linear(x_conv_act)

        x_ssm = self.S6(x_conv_out)
        x_act = F.silu(x_ssm)  # Swish activation can be implemented as x * sigmoid(x)

        # residual skip connection with nonlinearity introduced by multiplication
        x_residual = F.silu(self.D(x))

        x_combined = x_act * x_residual

        x_out = self.out_proj(x_combined)

        return x_out

## Mamba模型

In [10]:
class Mamba(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(Mamba, self).__init__()
        self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)

    def forward(self, x):
        x = self.mamba_block1(x)
        x = self.mamba_block2(x)
        x = self.mamba_block3(x)
        return x

## 分类器

In [11]:
class MambaClassifier(nn.Module):
    def __init__(self, seq_len, d_model, state_size, num_classes, device):
        super(MambaClassifier, self).__init__()
        self.mamba = Mamba(seq_len, d_model, state_size, device)
        self.fc = nn.Linear(d_model, num_classes)  # 分类层

    def forward(self, x):
        x = self.mamba(x)
        x = x[:, -1, :]  # 取最后一时刻的输出
        x = self.fc(x)
        return x


## RMSNorm  
RMSNorm是一个自定义规范化层，这一层用于规范神经网络的激活，这可以帮助稳定和加快训练。

In [12]:
class RMSNorm(nn.Module):
    def __init__(self,
        d_model: int,
        eps: float = 1e-5,
        device: str ='cuda'):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model, device=device))


    def forward(self, x):
        
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

## 初始化模型

In [13]:
num_classes = 2
model = MambaClassifier(seq_len=seq_len, d_model=d_model, state_size=state_size, num_classes=num_classes, device=device).to(device)

## 加载数据集至DataLoader

In [14]:
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)  # [num_samples, seq_len, d_model]
y_train_tensor = torch.tensor(y_train, dtype=torch.long)      # [num_samples]

X_val_tensor = torch.tensor(X_val, dtype=torch.float32)      # [num_samples, seq_len, d_model]
y_val_tensor = torch.tensor(y_val, dtype=torch.long)          # [num_samples]

# 将数据移动到设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from torch.utils.data import Dataset, DataLoader


class OCDDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


train_dataset = OCDDataset(X_train_tensor, y_train_tensor)
val_dataset = OCDDataset(X_val_tensor, y_val_tensor)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

## 训练和验证

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-4)

history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': []
}

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    for batch in train_loader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()  # 清空梯度

        # 前向推理
        outputs = model(inputs)

        # 计算损失
        loss = criterion(outputs, labels)
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        total_loss += loss.item()

        # 计算准确率
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    # 计算训练损失和准确率
    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = correct_predictions / total_samples
    history['train_loss'].append(avg_train_loss)
    history['train_acc'].append(train_accuracy)

    # 验证阶段
    model.eval()
    total_val_loss = 0
    correct_val_predictions = 0
    total_val_samples = 0
    
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            val_outputs = model(inputs)  # 在验证集上进行前向推理
            val_loss = criterion(val_outputs, labels)
            total_val_loss += val_loss.item()

            # 计算验证准确率
            _, predicted = torch.max(val_outputs, 1)
            correct_val_predictions += (predicted == labels).sum().item()
            total_val_samples += labels.size(0)

    # 计算验证损失和准确率
    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = correct_val_predictions / total_val_samples
    history['val_loss'].append(avg_val_loss)
    history['val_acc'].append(val_accuracy)

    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, '
          f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}')


## 绘图

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(4, 4))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss over Epochs')
plt.legend()
plt.show()

plt.figure(figsize=(4, 4))
plt.plot(history['train_acc'], label='Train Accuracy')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy over Epochs')
plt.legend()
plt.show()

In [None]:
from torchsummary import summary

summary(model, input_size=(90, 200))