In [None]:
import os
import random
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset,TensorDataset,SubsetRandomSampler
from sklearn.metrics import classification_report
from torchvision import transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from sklearn.metrics import precision_recall_fscore_support
import seaborn as sns
from sklearn.metrics import precision_recall_curve,f1_score,roc_curve,roc_auc_score,auc,accuracy_score,average_precision_score,precision_score,recall_score
from sklearn.preprocessing import label_binarize
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings("ignore")

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        # torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def make_img(t_img):
    img = pd.read_pickle(t_img)
    img_l = []
    for i in range(len(img)):
        img_l.append(img.values[i][0])
    
    return np.array(img_l)


In [None]:
def plot_classification_report(y_tru, y_prd, mode, learning_rate, batch_size,epochs, figsize=(7, 7), ax=None):

    plt.figure(figsize=figsize)

    xticks = ['precision', 'recall', 'f1-score', 'support']
    yticks = ["Control", "Moderate", "Alzheimer's" ] 
    yticks += ['avg']

    rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
    avg = np.mean(rep, axis=0)
    avg[-1] = np.sum(rep[:, -1])
    rep = np.insert(rep, rep.shape[0], avg, axis=0)

    sns.heatmap(rep,
                annot=True, 
                cbar=False, 
                xticklabels=xticks, 
                yticklabels=yticks,
                ax=ax, cmap = "Blues")
    
    plt.savefig('report_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'_' + str(epochs)+'.png')

In [None]:
def calc_confusion_matrix(result, test_label,mode, learning_rate, batch_size, epochs):
    result = F.one_hot(result,num_classes=4)
    # print(result)

    test_label = F.one_hot(test_label,num_classes=4)
    # print(test_label)

    true_label= np.argmax(test_label, axis =1)

    predicted_label= np.argmax(result, axis =1)
    
    n_classes = 4
    precision = dict()
    recall = dict()
    thres = dict()
    for i in range(n_classes):
        precision[i], recall[i], thres[i] = precision_recall_curve(test_label[:, i],
                                                            result[:, i])


    print ("Classification Report :") 
    print (classification_report(true_label, predicted_label))
    cr = classification_report(true_label, predicted_label, output_dict=True)
    return cr, precision, recall, thres

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=100, kernel_size=3, stride=1)
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout1 = nn.Dropout(p=0.5)
        self.conv2 = nn.Conv2d(in_channels=100, out_channels=50, kernel_size=3, stride=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout2 = nn.Dropout(p=0.3)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(50 * 16 * 16, 4)  # Assuming input shape (72, 72, 1) after convolutions and pooling

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool1(x)
        x = self.dropout1(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool2(x)
        x = self.dropout2(x)
        x = self.flatten(x)
        x = self.fc(x)
        out=torch.softmax(x, dim=1)
        return out

In [None]:

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


# 多尺度卷积块
class MultiScaleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MultiScaleConv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2)
        self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=7, padding=3)
        self.bn = nn.BatchNorm2d(out_channels * 3)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv3(x)
        x3 = self.conv5(x)
        x = torch.cat((x1, x2, x3), dim=1)
        x = self.bn(x)
        return x
class new_AttentionMultiScaleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(new_AttentionMultiScaleCNN, self).__init__()
        self.multi_scale_conv1 = MultiScaleConv(3, 64)
        self.se1 = SELayer(64 * 3)

        self.multi_scale_conv2 = MultiScaleConv(64 * 3, 100)
        self.se2 = SELayer(100 * 3)

        self.multi_scale_conv3 = MultiScaleConv(100 * 3, 50)
        self.se3 = SELayer(50 * 3)

        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(50 * 3 * 9 * 9, num_classes) 

    def forward(self, x):
        x = self.relu(self.multi_scale_conv1(x))
        x = self.se1(x)
        x = self.maxpool(x)

        x = self.relu(self.multi_scale_conv2(x))
        x = self.se2(x)
        x = self.maxpool(x)

        x = self.relu(self.multi_scale_conv3(x))
        x = self.se3(x)
        x = self.maxpool(x)
        # print(x.shape)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return x


In [None]:
class FC(nn.Module):
    def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
        super(FC, self).__init__()
        self.dropout_r = dropout_r
        self.use_relu = use_relu

        self.linear = nn.Linear(in_size, out_size)

        if use_relu:
            self.relu = nn.ReLU(inplace=True)

        if dropout_r > 0:
            self.dropout = nn.Dropout(dropout_r)

    def forward(self, x):
        x = self.linear(x)

        if self.use_relu:
            x = self.relu(x)

        if self.dropout_r > 0:
            x = self.dropout(x)

        return x
class MLP(nn.Module):
    def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
        super(MLP, self).__init__()

        self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
        self.linear = nn.Linear(mid_size, out_size)

    def forward(self, x):
        return self.linear(self.fc(x))
class FFN(nn.Module):
    def __init__(self,):
        super(FFN, self).__init__()

        self.mlp = MLP(
            in_size=50,
            mid_size=100,
            out_size=50,
            dropout_r=0.5,
            use_relu=True
        )

    def forward(self, x):
        return self.mlp(x)

In [None]:
def train(mode,batch_size, epochs, learning_rate, seed):
    train_img = make_img("../../processed_data/PPMI/MRI_nobrain/X_train_img.pkl")
    val_img = make_img("../../processed_data/PPMI/overlap/X_val_img.pkl")
    test_img = make_img("../../processed_data/PPMI/overlap/X_test_img.pkl")
    print(train_img.shape,test_img.shape)

    y_train = pd.read_csv("../../processed_data/PPMI/MRI_nobrain/y_train.csv").drop("PATNO Visit", axis=1).values.astype("int").flatten()
    y_val = pd.read_csv("../../processed_data/PPMI/overlap/y_val.csv").drop("PATNO Visit", axis=1).values.astype("int").flatten()
    y_test = pd.read_csv("../../processed_data/PPMI/overlap/y_test.csv").drop("PATNO Visit", axis=1).values.astype("int").flatten()

    train_img_tensor = torch.tensor(train_img, dtype=torch.float).to(device)
    val_img_tensor = torch.tensor(val_img, dtype=torch.float).to(device)
    test_img_tensor = torch.tensor(test_img, dtype=torch.float).to(device)


    train_label_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    val_label_tensor = torch.tensor(y_val, dtype=torch.long).to(device)
    test_label_tensor = torch.tensor(y_test, dtype=torch.long).to(device)
    # print(val_img_tensor.shape,val_label_tensor.shape)
    
    # seeds = random.sample(range(1, 200),1)
    # train_img_tensor, val_img_tensor, train_label_tensor, val_label_tensor = train_test_split(X, y, test_size=0.2, random_state=seeds[0])

    train_dataset = TensorDataset(train_img_tensor, train_label_tensor)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataset = TensorDataset(val_img_tensor, val_label_tensor)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_dataset = TensorDataset(test_img_tensor, test_label_tensor)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    class_weights = compute_class_weight('balanced', classes=torch.unique(torch.tensor(y_train)).numpy(), y=y_train)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    

    model=new_AttentionMultiScaleCNN(num_classes=4).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    best_val_acc=0
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        predicted_label=[]
        true_label=[]
        for img, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(img)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            predicted_label.extend(predicted.tolist())
            true_label.extend(labels.tolist())
        train_loss = running_loss / len(train_loader)
        # train_acc = 100.0 * correct / total
        train_acc = f1_score(true_label,predicted_label,average='macro')
        # 在验证集上评估模型
        correct = 0
        total = 0
        predicted_label=[]
        true_label=[]
        for img, labels in val_loader:
            outputs = model(img)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            predicted_label.extend(predicted.tolist())
            true_label.extend(labels.tolist())
        # val_acc = 100.0 * correct / total
        val_acc = f1_score(true_label,predicted_label,average='macro')
        print(f"Epoch {epoch+1}, train Loss: {train_loss},train acc:{train_acc},val acc:{val_acc},best acc:{best_val_acc}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model
            torch.save(best_model.state_dict(), f'../../models/PPMI/MRI/best_{seed}.pth')
            print(f'Epoch {epoch+1} get best modal')
    
    best_model.eval()
    # all_probs=[]
    predicted_label=[]
    with torch.no_grad():
        outputs = best_model(test_img_tensor)
        # probs = nn.functional.softmax(outputs, dim=1)
        # all_probs.extend(probs.cpu().detach().numpy())
        _, predicted = outputs.max(1)
        predicted_label.extend(predicted.tolist())
    # test_acc = 100.0 * correct / total
    test_acc = f1_score(y_test,predicted_label,average='macro')
    # 计算AUC
    num_classes=4
    true_label_binarized = label_binarize(y_test, classes=list(range(num_classes)))
    # predicted_label_binarized = label_binarize(predicted_label, classes=list(range(num_classes)))
    auc_score = roc_auc_score(true_label_binarized, outputs.cpu().detach().numpy(), average='macro', multi_class='ovr')
    
    print(f'test Acc: {test_acc:.4f},test auc: {auc_score:.4f}')
    print(predicted_label)
    cr, precision, recall, thresholds = calc_confusion_matrix(torch.tensor(predicted_label), test_label_tensor.cpu(), mode, learning_rate, batch_size, epochs)
    
    return cr , batch_size, learning_rate, epochs, seed,auc_score

In [None]:
accurancy=[]
precision=[]
recall=[]
f1=[]
auc_score_list=[]
seeds = random.sample(range(1, 200), 5)
for s in seeds:
    set_random_seed(s)
    print('seeds:',s)
    cr, bs_, lr_, e_ , seed,auc_score= train('new_AttentionMultiScaleCNN', 32, 100, 0.00001, s)
    accurancy.append(cr['accuracy'])
    precision.append(cr["macro avg"]["precision"])
    recall.append(cr["macro avg"]["recall"])
    f1.append(cr["macro avg"]["f1-score"])
    auc_score_list.append(auc_score)
    print ('-'*55)
print("Mean accuracy is: ",sum(accurancy)/len(accurancy))
print("precision:",sum(precision)/len(precision))
print("recall:",sum(recall)/len(recall))
print("f1:",sum(f1)/len(f1))
print("auc_score:",sum(auc_score_list)/len(auc_score_list))
print("Std accuracy: " + str(np.array(accurancy).std()))
print("Std precision: " + str(np.array(precision).std()))
print("Std recall: " + str(np.array(recall).std()))
print("Std f1: " + str(np.array(f1).std()))
print("Std auc_score: " + str(np.array(auc_score_list).std()))

In [None]:
# standard_scaler = StandardScaler()
test_img = make_img("../../processed_data/PPMI/overlap/X_test_img.pkl")
# test_img = make_img("../../processed_data/PPMI/overlap/X_test_img.pkl")
test_label = pd.read_csv("../../processed_data/PPMI/overlap/y_test.csv").drop("PATNO Visit", axis=1).values.astype("int").flatten()
test_img_tensor = torch.tensor(test_img, dtype=torch.float).to(device)
# test_img_tensor = torch.tensor(test_img, dtype=torch.float).to(device)
modal_list=[22,45,152,174,176]
print('f1_score                 acc                 precision           recall              auc             aupr')
accurancy=[]
precision=[]
recall=[]
f1=[]
auc_score_list=[]
aupr_score_list=[]
for i in modal_list:
    img_modal=new_AttentionMultiScaleCNN(num_classes=4).to(device)
    img_modal.load_state_dict(torch.load(f'../../models/PPMI/MRI/best_{i}.pth'))
    img_modal.eval()
    predicted_label_0=[]
    with torch.no_grad():
        outputs_0 = img_modal(test_img_tensor)
        # probs_0 = nn.functional.softmax(outputs_0, dim=1)
        # all_probs_0.extend(probs_0.cpu().detach().numpy())
        _, predicted_0 = outputs_0.max(1)
        predicted_label_0.extend(predicted_0.tolist())

    num_classes=4
    true_label_binarized = label_binarize(test_label, classes=list(range(num_classes)))
    test_acc_0=accuracy_score(test_label,predicted_label_0)
    test_f1_0 = f1_score(test_label,predicted_label_0,average='macro')
    test_precision_0=precision_score(test_label,predicted_label_0,average='macro')
    recall_0=recall_score(test_label,predicted_label_0,average='macro')
    auc_score_0 = roc_auc_score(true_label_binarized, outputs_0.cpu().detach().numpy(), average='macro', multi_class='ovr')
    aupr_score_0 = average_precision_score(true_label_binarized, outputs_0.cpu().detach().numpy(), average='macro')
    print(test_f1_0,test_acc_0,test_precision_0,recall_0,auc_score_0,aupr_score_0)
    accurancy.append(test_acc_0)
    precision.append(test_precision_0)
    recall.append(recall_0)
    f1.append(test_f1_0)
    auc_score_list.append(auc_score_0)
    aupr_score_list.append(aupr_score_0)
print('avg')
print(sum(f1)/len(modal_list),sum(accurancy)/len(modal_list),sum(precision)/len(modal_list),sum(recall)/len(modal_list),sum(auc_score_list)/len(modal_list),sum(aupr_score_list)/len(modal_list))