In [None]:
import timm
import os
import numpy as np
import random
import torch
import scipy.io as scio
from torch.utils.data import Dataset,DataLoader
from timm.scheduler.cosine_lr import CosineLRScheduler
import warnings
warnings.filterwarnings('ignore')

In [None]:
class MyDataset(Dataset):
    def __init__(self, data,topk=16):
        self.data=data
        self.topk=topk
        self.label={
            'T-T':2,
            'P-T':1,
            'P-P':0
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path_1,label=self.data[idx]
        image=scio.loadmat(path_1)['fc']
        image=torch.from_numpy(image).float()
        label=torch.tensor(self.label[label])
        return image,label
        

In [None]:
with open('fc_list.txt','r') as f:
    data=f.readlines()
data=[item.replace('\n','').split('\t') for item in data]
random.seed(2022)
random.shuffle(data)

In [None]:
from sklearn.metrics import roc_auc_score,f1_score,accuracy_score,confusion_matrix
import numpy as np

acc,f1,auc=0.,0.,0.
k_fold=5
acc_list=[]
result_index_matrix=torch.zeros(k_fold,3,3)

for kk in range(k_fold):
    k_fold_len = int(len(data)//k_fold)
    test_data = data[k_fold_len*kk:k_fold_len*(kk+1)]
    train_data = data[:k_fold_len*kk] + data[k_fold_len*(kk+1):]
    train_dataset=MyDataset(train_data,16)
    test_dataset=MyDataset(test_data,16)
    train_dataloader=DataLoader(train_dataset,batch_size=32,shuffle=True)
    test_dataloader=DataLoader(test_dataset,1)
    print('{:.0f}_fold = {:.0f}'.format(k_fold,kk+1))

    mymodel=timm.create_model('vit_small_patch16_224',pretrained=True,num_classes=3)
    mymodel.patch_embed=torch.nn.Linear(64,mymodel.pos_embed.shape[2])
    mymodel.pos_embed=torch.nn.Parameter(torch.zeros(1, 64+1, mymodel.pos_embed.shape[2]))

    optimizer=torch.optim.Adam(mymodel.parameters(),lr=1e-3)
    lr_schedule=CosineLRScheduler(optimizer=optimizer,t_initial=10,lr_min=1e-5,warmup_t=5)
    loss_fn= torch.nn.CrossEntropyLoss()

    epochs=80
    device=torch.device('cuda:0')
    loss_fn=loss_fn.to(device)
    mymodel=mymodel.to(device)

    result_acc,result_f1,result_auc=0.,0.,0.
    acc_training_line,auc_training_line=[],[]
    index_matrix=torch.zeros(3,3)

    for epoch in range(epochs):
        mymodel.train()
        train_true,train_pred_prob,test_true,test_pred_prob=[],[],[],[]
        train_pred,test_pred=[],[]
        for image,label in train_dataloader:
            image=image.to(device)
            label=label.to(device)
            pred=mymodel(image)
            loss=loss_fn(pred,label)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(mymodel.parameters(), 1e-1)
            optimizer.step()
            train_true.extend(label.tolist())
            train_pred_prob.extend(pred.tolist())
            train_pred.extend(pred.argmax(dim=1).tolist())
        lr_schedule.step(epoch)
        with torch.no_grad():
            mymodel.eval()
            for image,label in test_dataloader:
                image=image.to(device)
                label=label.to(device)
                pred=mymodel(image)
                loss=loss_fn(pred,label)
                test_pred.extend(pred.argmax(dim=1).tolist())
                test_true.extend(label.tolist())
                test_pred_prob.extend(pred.tolist())
        train_acc=accuracy_score(train_true,train_pred)
        train_f1=f1_score(train_true,train_pred,average='micro')
        test_acc=accuracy_score(test_true,test_pred)
        test_f1=f1_score(test_true,test_pred,average='micro')
        train_true=np.eye(pred.shape[1])[train_true]
        tmp_test_true=test_true
        test_true=np.eye(pred.shape[1])[test_true]
        train_auc=roc_auc_score(train_true,train_pred_prob)
        test_auc=roc_auc_score(test_true,test_pred_prob)
        acc_training_line.append(train_acc)
        auc_training_line.append(train_auc)
        c_matrix=confusion_matrix(tmp_test_true,test_pred)
        for idm in range(3):
            index_matrix[0][idm]=c_matrix[idm][idm]/np.sum(c_matrix[:,idm]) 
            index_matrix[1][idm]=c_matrix[idm][idm]/np.sum(c_matrix[idm,:]) 
            index_matrix[2][idm]=(np.sum(c_matrix)-np.sum(c_matrix[:,idm])-np.sum(c_matrix[idm,:])+c_matrix[idm][idm])/(np.sum(c_matrix)-np.sum(c_matrix[idm,:])) #spe
        if test_acc >= result_acc and (test_f1+test_auc) > (result_f1+result_auc):
            result_acc,result_f1,result_auc=test_acc,test_f1,test_auc
            result_index_matrix[kk]=index_matrix
    print('{:.0f}_fold_result: [Acc:{:.4f}  F1:{:.4f}  AUC:{:.4f}]'.format(kk+1,result_acc,result_f1,result_auc))
    acc=result_acc+acc
    f1=result_f1+f1
    auc=+result_auc+auc
    acc_list.append(result_acc)
    print(result_index_matrix[kk])
print('Round_Result_{:.0f}: [mean:{:.4f}  std:{:.4f}]'.format(round+1,np.mean(acc_list),np.std(acc_list)))

