In [1]:
import timm
import os
import h5py
import numpy as np
import torch
from sklearn.model_selection import KFold
from torch.utils.data import Dataset,DataLoader
from baselines.ViT.ViT_LRP_copy import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_explanation_generator import LRP
import warnings
warnings.filterwarnings('ignore')
import os
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [2]:
class MyDataset(Dataset):
    def __init__(self, vit_data,emotion,gambling):
        self.emotion=[os.path.join(emotion,str(item)+'.mat') for item in vit_data]
        self.gambling=[os.path.join(gambling,str(item)+'.mat') for item in vit_data]
        self.data=self.emotion+self.gambling
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        image=self.data[idx]
        image=h5py.File(image,'r')
        image = np.array(image['fc_matrix'])
        image=torch.from_numpy(image).float()
        if self.data[idx] in self.emotion:
            label=torch.tensor(0)
        elif self.data[idx] in self.gambling:
            label=torch.tensor(1)
        return image,label

In [3]:
def generate_visualization(original_image, class_index=None):
        transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).to(device), method="transformer_attribution_conn", index=class_index).detach().cpu()
        # transformer_attribution = ((transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min()))
        transformer_attribution_edge = transformer_attribution.reshape(156,156)[1:,1:]
        transformer_attribution_node = transformer_attribution.reshape(156,156)[0,1:]
        diag = torch.diag(transformer_attribution_edge)
        a_diag = torch.diag_embed(diag)
        transformer_attribution_edge = transformer_attribution_edge - a_diag
        # transformer_attribution_edge = ((transformer_attribution_edge - transformer_attribution_edge.min()) / (transformer_attribution_edge.max() - transformer_attribution_edge.min()))
        # transformer_attribution_node = ((transformer_attribution_node - transformer_attribution_node.min()) / (transformer_attribution_node.max() - transformer_attribution_node.min()))
        return transformer_attribution_edge,transformer_attribution_node

In [5]:
model = vit_LRP(pretrained=False,num_classes=2,in_chans=1)
model.patch_embed=torch.nn.Linear(155,model.pos_embed.shape[2])
model.pos_embed=torch.nn.Parameter(torch.zeros(1, 155+1, model.pos_embed.shape[2]))
CLS2IDX =  ['emotion','gambling','language','motor','relational','social','wm']
data = np.arange(1,1008)
for random_state_num in range(1998,2004):
    kf=KFold(n_splits=5,shuffle=True,random_state=random_state_num)
    percent_value = [0.15,0.2,0.25,0.3]
    for percent in range(1,2):
        print('start percent divide:{} {}'.format(random_state_num,percent_value[percent]))
        for k,(train_index, test_index) in list(enumerate(kf.split(data))):
            emotion="/media/D/zephyr/functional_connectivity/HCP/s1200/percent_{}/EMOTION".format(percent_value[percent])
            gambling="/media/D/zephyr/functional_connectivity/HCP/s1200/percent_{}/GAMBLING".format(percent_value[percent])
            print('=' * 50)
            dataset = MyDataset(data[test_index],emotion,gambling)
            model.load_state_dict(torch.load('/media/D/zephyr/vit_155_155/other_baseline/path_save/2class_rondom_{}_{}_{}_new.pth'.
                                            format(random_state_num,int(percent_value[percent]*100),k),map_location=torch.device('cpu')))
            model=model.to(device)
            model.eval()
            attribution_generator = LRP(model)
            test_data=data[test_index]
            num_train = len(data[test_index])
            print(num_train)
            for cls in np.arange(2):
                # output = torch.zeros((155,155))#fix
                for ima_num in np.arange(num_train*cls,num_train*(cls+1)):  #num_train*cls,num_train*(cls+1)
                    result_1,result_2 = generate_visualization(original_image = dataset[ima_num][0],class_index=dataset[ima_num][1].numpy())
                    # output = output + result
                    np.save('/media/D/zephyr/vit_155_155/other_baseline/vis/{}/{}_{}_edge_{}_2class_original'
                            .format(int(percent_value[percent]*100),CLS2IDX[cls],test_data[ima_num%num_train],random_state_num),result_1)
                    np.save('/media/D/zephyr/vit_155_155/other_baseline/vis/{}/{}_{}_node_{}_2class_original'
                            .format(int(percent_value[percent]*100),CLS2IDX[cls],test_data[ima_num%num_train],random_state_num),result_2)
        # for cls in np.arange(0,8):
        #     output = torch.zeros((223,223))
        #     for ima_num in np.arange(652*cls,652*(cls+1)):
        #         result = generate_visualization(original_image = train_dataset[ima_num][0],class_index=train_dataset[ima_num][1].numpy())
        #         output = output + result
        #     np.savetxt('edge_random_30_{}_{}.txt'.format(k,CLS2IDX[cls]),output,'%.4f')




    

start percent divide:1998 0.2
202
202
