In [None]:
import timm
import torch
from timm.models.vision_transformer import VisionTransformer
from torch.utils.data import Dataset,DataLoader
import random
import scipy.io as scio
from timm.scheduler.cosine_lr import CosineLRScheduler
import warnings
import numpy as np
warnings.filterwarnings('ignore')


In [None]:
class MyDataset(Dataset):
    def __init__(self,data,topk=10):
        self.data=data
        self.topk=topk
        self.label={
            'T-T':2,
            'P-T':1,
            'P-P':0
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path_1,label=self.data[idx]
        image=scio.loadmat(path_1)['fc'] #load FC
        image=torch.from_numpy(image).float()

        path_2=path_1.replace('fc','sc')
        mask=scio.loadmat(path_2)['sc']  #load SC
        mask=mask*1.0
        mask=torch.from_numpy(mask)
        mask=mask.sum(dim=1)
        _,index=torch.topk(mask,10,largest=False)   
        mask=torch.tensor([i in index for i in range(64)])
        
        label=torch.tensor(self.label[label])
        return image,mask,label

In [None]:
with open('your subject list.txt','r') as f:  #load subject list
    data=f.readlines()
data=[item.replace('\n','').split('\t') for item in data]
random.seed(2022)
random.shuffle(data)

In [None]:
class MaskViT(VisionTransformer):
    def __init__(self,layer=None):
        super().__init__(patch_size=16, embed_dim=384, depth=12, num_heads=6,num_classes=3)
        if layer is None:
            self.layer=11
        else:
            self.layer=layer

    def forward_with_mask(self,x,mask):

        x = self.patch_embed(x)
        #（B,N,C) 
        B,num_patches=x.shape[0],x.shape[1]
        mask_size=mask.sum()//x.shape[0]   
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # cls_tokens impl from Phil Wang, thanks
        # (B,1,C)   
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)
            cls_token_nums=1
        x = self.pos_drop(x + self.pos_embed) 
        res_x = torch.masked_select(x[:,cls_token_nums:,:],mask.reshape(x.shape[0],x.shape[1]-cls_token_nums,1))
        # (B*mask_size,C)  
        res_x = res_x.reshape(x.shape[0],mask.sum()//x.shape[0],-1)
        # (B,mask_size,C)  
        res_x = torch.cat((cls_token,res_x),dim=1)
        #(B,mask_size+1,C)  
        for i in range(self.layer):
            res_x=self.blocks[i](res_x)


        index=torch.range(0,x.shape[0]*(x.shape[1]-cls_token_nums)-1,dtype=torch.int64,device=mask.device)
        index=index.reshape(x.shape[0],-1)
        index=torch.masked_select(index,mask.reshape(x.shape[0],-1))

        cls_token = res_x[:,:cls_token_nums,:]
        res_x = res_x[:,cls_token_nums:,:]
        # (b,patch_nums+1)->(n,patch_nums)
        res_x = res_x.reshape(B*mask_size,-1)
        x = x[:,cls_token_nums:,:]
        x = x.reshape(x.shape[0]*num_patches,-1)
        x = x.index_add_(0,index,res_x)
        x = x.reshape(B,num_patches,-1)
        x = torch.cat((cls_token,x),dim=1)
        for i in range(self.layer,len(self.blocks)):
            x = self.blocks[i](x)
        x = self.norm(x)
        return self.head(x[:,0])


In [None]:
from sklearn.metrics import roc_auc_score,f1_score,accuracy_score
from sklearn.metrics import confusion_matrix
import numpy as np

acc,f1,auc=0.,0.,0.
k_fold=5
acc_list=[]
result_index_matrix=torch.zeros(5,3,3)

for kk in range(k_fold):
    k_fold_len = int(len(data)//k_fold)
    test_data = data[k_fold_len*kk:k_fold_len*(kk+1)]
    train_data = data[:k_fold_len*kk] + data[k_fold_len*(kk+1):]
    train_dataset=MyDataset(train_data,16)
    test_dataset=MyDataset(test_data,16)
    train_dataloader=DataLoader(train_dataset,batch_size=32,shuffle=True)
    test_dataloader=DataLoader(test_dataset,1)

    mymodel=MaskViT(6)
    ViTmodel=timm.create_model('vit_small_patch16_224',pretrained=True,num_classes=3)
    pretrain_weight=ViTmodel.state_dict()
    mymodel.load_state_dict(pretrain_weight)
    mymodel.patch_embed=torch.nn.Linear(64,384)
    mymodel.pos_embed=torch.nn.Parameter(torch.zeros(1, 64+1, 384))

    optimizer=torch.optim.Adam(mymodel.parameters(),lr=5e-4)
    lr_schedule=CosineLRScheduler(optimizer=optimizer,t_initial=10,lr_min=5e-6,warmup_t=5)
    loss_fn= torch.nn.CrossEntropyLoss()

    epochs=100
    device=torch.device('cuda:1')
    loss_fn=loss_fn.to(device)
    mymodel=mymodel.to(device)

    result_acc,result_f1,result_auc=0.,0.,0.
    acc_training_line,auc_training_line=[],[]
    index_matrix=torch.zeros(3,3)
    
    for epoch in range(epochs):
        mymodel.train()
        train_true,train_pred_prob,test_true,test_pred_prob=[],[],[],[]
        train_pred,test_pred=[],[]
        for image,mask,label in train_dataloader:
            image=image.to(device)
            mask=mask.to(device)
            label=label.to(device)
            pred=mymodel.forward_with_mask(image,mask)
            loss=loss_fn(pred,label)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(mymodel.parameters(), 1e-1)
            optimizer.step()
            train_true.extend(label.tolist())
            train_pred_prob.extend(pred.tolist())
            train_pred.extend(pred.argmax(dim=1).tolist())
        lr_schedule.step(epoch)
        with torch.no_grad():
            mymodel.eval()
            for image,mask,label in test_dataloader:
                image=image.to(device)
                label=label.to(device)
                mask=mask.to(device)
                pred=mymodel.forward_with_mask(image,mask)
                loss=loss_fn(pred,label)
                test_pred.extend(pred.argmax(dim=1).tolist())
                test_true.extend(label.tolist())
                test_pred_prob.extend(pred.tolist())
        train_acc=accuracy_score(train_true,train_pred)
        train_f1=f1_score(train_true,train_pred,average='micro')
        test_acc=accuracy_score(test_true,test_pred)
        test_f1=f1_score(test_true,test_pred,average='micro')
        train_true=np.eye(pred.shape[1])[train_true]
        tmp_test_true=test_true
        test_true=np.eye(pred.shape[1])[test_true]
        train_auc=roc_auc_score(train_true,train_pred_prob)
        test_auc=roc_auc_score(test_true,test_pred_prob)
        c_matrix=confusion_matrix(tmp_test_true,test_pred)
        for idm in range(3):
            index_matrix[0][idm]=c_matrix[idm][idm]/np.sum(c_matrix[:,idm]) #acc
            index_matrix[1][idm]=c_matrix[idm][idm]/np.sum(c_matrix[idm,:]) #sensativity, recall
            index_matrix[2][idm]=(np.sum(c_matrix)-np.sum(c_matrix[:,idm])-np.sum(c_matrix[idm,:])+c_matrix[idm][idm])/(np.sum(c_matrix)-np.sum(c_matrix[idm,:])) #spe
        acc_training_line.append(train_acc)
        auc_training_line.append(train_auc)
        if test_acc >= result_acc and (test_f1+test_auc) > (result_f1+result_auc):
            result_acc,result_f1,result_auc=test_acc,test_f1,test_auc
            result_index_matrix[kk]=index_matrix
    print('{:.0f}_fold_result: [Acc:{:.4f}  F1:{:.4f}  AUC:{:.4f}]'.format(kk+1,result_acc,result_f1,result_auc))
    print(index_matrix)
    acc_list.append(result_acc)

print('Result: [mean:{:.4f}  std:{:.4f}]'.format(np.mean(acc_list),np.std(acc_list)))
print(result_index_matrix.sum(dim=0)/5)
