In [1]:
import timm
import os
import h5py
import numpy as np
import torch
from sklearn.model_selection import KFold
from torch.utils.data import Dataset,DataLoader
from timm.scheduler.cosine_lr import CosineLRScheduler
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_LRP import FGM
from baselines.ViT.ViT_explanation_generator import LRP
import warnings
from early_stopping import EarlyStopping
device=torch.device('cuda:2')
warnings.filterwarnings('ignore')

In [2]:
class MyDataset(Dataset):
    def __init__(self, vit_data,emotion,gambling,language):
        self.emotion=[os.path.join(emotion,str(item)+'.mat') for item in vit_data]
        self.gambling=[os.path.join(gambling,str(item)+'.mat') for item in vit_data]
        self.language=[os.path.join(language,str(item)+'.mat') for item in vit_data]
        self.data=self.emotion+self.gambling+self.language
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        image=self.data[idx]
        image=h5py.File(image,'r')
        image = np.array(image['fc_matrix'])
        image=torch.from_numpy(image).float()
        if self.data[idx] in self.emotion:
            label=torch.tensor(0)
        elif self.data[idx] in self.gambling:
            label=torch.tensor(1)
        elif self.data[idx] in self.language:
            label=torch.tensor(2)
        return image,label

In [3]:
for random_state_num in range(1998,2004):
    data = np.arange(1,1008)
    kf=KFold(n_splits=5,shuffle=True,random_state=random_state_num)
    percent_value = [0.2]
    for percent in range(1):
        print('start percent divide:{} {}'.format(random_state_num,percent_value[percent]))
        for k,(train_index, test_index) in list(enumerate(kf.split(data))):
            save_path = '/media/D/zephyr/vit_155_155/other_baseline/path_save/3class_rondom_{}_{}_{}_new.pth'.format(random_state_num,int(percent_value[percent]*100),k)
            early_stopping = EarlyStopping(save_path)
            emotion="/media/D/zephyr/functional_connectivity/HCP/s1200/percent_{}/EMOTION".format(percent_value[percent])
            gambling="/media/D/zephyr/functional_connectivity/HCP/s1200/percent_{}/GAMBLING".format(percent_value[percent])
            language="/media/D/zephyr/functional_connectivity/HCP/s1200/percent_{}/LANGUAGE".format(percent_value[percent])
            print('=' * 50)
            train_data=data[train_index]
            test_data=data[test_index]
            train_dataset=MyDataset(train_data, emotion,gambling,language)
            test_dataset=MyDataset(test_data, emotion,gambling,language)
            train_dataloader=DataLoader(train_dataset,batch_size =32, shuffle=True)
            test_dataloader=DataLoader(test_dataset,batch_size =1, shuffle=False)
            model = vit_LRP(pretrained=False,num_classes=3,in_chans=1)
            model.patch_embed=torch.nn.Linear(155,model.pos_embed.shape[2])
            model.pos_embed=torch.nn.Parameter(torch.zeros(1, 155+1, model.pos_embed.shape[2]))
            fgm = FGM(model,epsilon=0.001,emb_name='patch_embed')
            optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
            lr_schedule=CosineLRScheduler(optimizer=optimizer,t_initial=10,lr_min=1e-5,warmup_t=5)
            loss_fn= torch.nn.CrossEntropyLoss()
            epochs=60
            loss_fn=loss_fn.to(device)
            model=model.to(device)
            for epoch in range(epochs):
                model.train()
                train_loss,test_acc,test_loss=.0,.0,.0
                for image,label in train_dataloader:
                    image=image.to(device)
                    label=label.to(device)
                    pred=model(image)
                    loss=loss_fn(pred,label)
                    loss.backward(retain_graph=True)
                    fgm.attack() # 在embedding上添加对抗扰动
                    loss_adv = loss_fn(pred, label)
                    loss_adv.backward() # 反向传播，并在正常的grad基础上，累加对抗训练的梯度
                    fgm.restore() # 恢复embedding参数
                    optimizer.step()
                    optimizer.zero_grad()
                    train_loss+=loss.item()
                lr_schedule.step(epoch)
                with torch.no_grad():
                    model.eval()
                    for image,label in test_dataloader:
                        image=image.to(device)
                        label=label.to(device)
                        pred=model(image)
                        loss=loss_fn(pred,label)
                        acc = (pred.argmax(dim=1) == label).float().mean()
                        test_acc += acc.item()
                        test_loss += loss.item()
                print('Epoch: {:2d}  Train Loss: {:.4f}  Test Loss: {:.4f}  Test Acc: {:.4f}'.format(epoch,train_loss/len(train_dataloader),test_loss/len(test_dataloader),test_acc/len(test_dataloader)))
                early_stopping(test_loss, model)
        #达到早停止条件时，early_stop会被置为True
                if early_stopping.early_stop:
                    print("Early stopping")
                    break #跳出迭代，结束训练     
            # torch.save(model.state_dict(), '/media/D/zephyr/vit_155_155/other_map/path_save/rondom_1998_{}_{}'.format(percent_value[percent],k))

start percent divide:1998 0.2
Epoch:  0  Train Loss: 1.1036  Test Loss: 1.0986  Test Acc: 0.3482
Epoch:  1  Train Loss: 1.1037  Test Loss: 1.0986  Test Acc: 0.3482
Epoch:  2  Train Loss: 1.1148  Test Loss: 0.8783  Test Acc: 0.5726
Epoch:  3  Train Loss: 0.6238  Test Loss: 0.3275  Test Acc: 0.8696
Epoch:  4  Train Loss: 0.2396  Test Loss: 0.2761  Test Acc: 0.9076
Epoch:  5  Train Loss: 0.1288  Test Loss: 0.1865  Test Acc: 0.9175
Epoch:  6  Train Loss: 0.0651  Test Loss: 0.1376  Test Acc: 0.9521
Epoch:  7  Train Loss: 0.0371  Test Loss: 0.1676  Test Acc: 0.9521
EarlyStopping counter: 1 out of 7
Epoch:  8  Train Loss: 0.0119  Test Loss: 0.2094  Test Acc: 0.9488
EarlyStopping counter: 2 out of 7
Epoch:  9  Train Loss: 0.0021  Test Loss: 0.1898  Test Acc: 0.9620
EarlyStopping counter: 3 out of 7
Epoch: 10  Train Loss: 0.0006  Test Loss: 0.1963  Test Acc: 0.9620
EarlyStopping counter: 4 out of 7
Epoch: 11  Train Loss: 0.0005  Test Loss: 0.2018  Test Acc: 0.9620
EarlyStopping counter: 5 out o