In [1]:
import timm
import os
import h5py
import numpy as np
import torch
from sklearn.model_selection import KFold
from torch.utils.data import Dataset,DataLoader
from baselines.ViT.ViT_LRP_copy import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_explanation_generator import LRP
import warnings
warnings.filterwarnings('ignore')
import os
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [2]:
class MyDataset(Dataset):
    def __init__(self, vit_data,emotion,gambling,language,motor,relational,social,wm):
        self.emotion=[os.path.join(emotion,str(item)+'.mat') for item in vit_data]
        self.gambling=[os.path.join(gambling,str(item)+'.mat') for item in vit_data]
        self.language=[os.path.join(language,str(item)+'.mat') for item in vit_data]
        self.motor=[os.path.join(motor,str(item)+'.mat') for item in vit_data]
        self.relational=[os.path.join(relational,str(item)+'.mat') for item in vit_data]
        self.social=[os.path.join(social,str(item)+'.mat') for item in vit_data]
        self.wm=[os.path.join(wm,str(item)+'.mat') for item in vit_data]
        self.data=self.emotion+self.gambling+self.language+self.motor+self.relational+self.social+self.wm
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        image=self.data[idx]
        image=h5py.File(image,'r')
        image = np.array(image['fc_matrix'])
        image=torch.from_numpy(image).float()
        if self.data[idx] in self.emotion:
            label=torch.tensor(0)
        elif self.data[idx] in self.gambling:
            label=torch.tensor(1)
        elif self.data[idx] in self.language:
            label=torch.tensor(2)
        elif self.data[idx] in self.motor:
            label=torch.tensor(3)
        elif self.data[idx] in self.relational:
            label=torch.tensor(4)
        elif self.data[idx] in self.social:
            label=torch.tensor(5)
        else:
            label=torch.tensor(6)
        return image,label

In [3]:
# class MyDataset(Dataset):
#     def __init__(self, vit_data,rest,task):
#         self.rest=[os.path.join(rest,str(item)+'.mat') for item in vit_data]
#         self.task=[os.path.join(task,str(item)+'.mat') for item in vit_data]
#         self.data=self.rest+self.task
#     def __len__(self):
#         return len(self.data)
#     def __getitem__(self, idx):
#         image=self.data[idx]
#         image=h5py.File(image,'r')
#         image = np.array(image['fc_matrix'])
#         image=torch.from_numpy(image).float()
#         if self.data[idx] in self.rest:
#             label=torch.tensor(0)
#         else:
#             label=torch.tensor(1)
#         return image,label

In [3]:
def generate_visualization(original_image, class_index=None):
        transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).to(device), method="transformer_attribution_conn", index=class_index).detach().cpu()
        # transformer_attribution = ((transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min()))
        transformer_attribution_edge = transformer_attribution.reshape(220,220)[1:,1:]
        transformer_attribution_node = transformer_attribution.reshape(220,220)[0,1:]
        diag = torch.diag(transformer_attribution_edge)
        a_diag = torch.diag_embed(diag)
        transformer_attribution_edge = transformer_attribution_edge - a_diag
        # transformer_attribution_edge = ((transformer_attribution_edge - transformer_attribution_edge.min()) / (transformer_attribution_edge.max() - transformer_attribution_edge.min()))
        # transformer_attribution_node = ((transformer_attribution_node - transformer_attribution_node.min()) / (transformer_attribution_node.max() - transformer_attribution_node.min()))
        return transformer_attribution_edge,transformer_attribution_node

In [5]:
model = vit_LRP(pretrained=False,num_classes=7,in_chans=1)
model.patch_embed=torch.nn.Linear(219,model.pos_embed.shape[2])
model.pos_embed=torch.nn.Parameter(torch.zeros(1, 219+1, model.pos_embed.shape[2]))
CLS2IDX =  ['emotion','gambling','language','motor','relational','social','wm']
for random_state_num in range(1998,2004):
    data = np.arange(1,1008)
    kf=KFold(n_splits=5,shuffle=True,random_state=random_state_num)
    percent_value = [15,20,25,30]
    for percent in range(4):
        print('start percent divide:{} {}'.format(random_state_num,percent_value[percent]))
        for k,(train_index, test_index) in list(enumerate(kf.split(data))):
            emotion="/media/D/zephyr/functional_connectivity/HCP/Schaefer/Schaefer_7net/{}/EMOTION".format(percent_value[percent])
            gambling="/media/D/zephyr/functional_connectivity/HCP/Schaefer/Schaefer_7net/{}/GAMBLING".format(percent_value[percent])
            language="/media/D/zephyr/functional_connectivity/HCP/Schaefer/Schaefer_7net/{}/LANGUAGE".format(percent_value[percent])
            motor="/media/D/zephyr/functional_connectivity/HCP/Schaefer/Schaefer_7net/{}/MOTOR".format(percent_value[percent])
            relational="/media/D/zephyr/functional_connectivity/HCP/Schaefer/Schaefer_7net/{}/RELATIONAL".format(percent_value[percent])
            social="/media/D/zephyr/functional_connectivity/HCP/Schaefer/Schaefer_7net/{}/SOCIAL".format(percent_value[percent])
            wm="/media/D/zephyr/functional_connectivity/HCP/Schaefer/Schaefer_7net/{}/WM".format(percent_value[percent])
            print('=' * 50)
            dataset=MyDataset(data[test_index],emotion,gambling,language,motor,relational,social,wm)
            model.load_state_dict(torch.load('/media/D/zephyr/vit_155_155/other_map/aparc_path/Schaefer_rondom_{}_{}_{}_new.pth'.
                                            format(random_state_num,percent_value[percent],k),map_location=torch.device('cpu')))
            model=model.to(device)
            model.eval()
            attribution_generator = LRP(model)
            test_data=data[test_index]
            num_train = len(data[test_index])
            print(num_train)
            for cls in np.arange(7):
                # output = torch.zeros((155,155))#fix
                for ima_num in np.arange(num_train*cls,num_train*(cls+1)):  #num_train*cls,num_train*(cls+1)
                    result_1,result_2 = generate_visualization(original_image = dataset[ima_num][0],class_index=dataset[ima_num][1].numpy())
                    # output = output + result
                    np.save('/media/D/zephyr/vit_155_155/other_map/aparc_vis/{}/{}/Schaefer_{}_edge_{}_7class_original'
                            .format(percent_value[percent],CLS2IDX[cls],test_data[ima_num%num_train],random_state_num),result_1)
                    np.save('/media/D/zephyr/vit_155_155/other_map/aparc_vis/{}/{}/Schaefer_{}_node_{}_7class_original'
                            .format(percent_value[percent],CLS2IDX[cls],test_data[ima_num%num_train],random_state_num),result_2)
        # for cls in np.arange(0,8):
        #     output = torch.zeros((223,223))
        #     for ima_num in np.arange(652*cls,652*(cls+1)):
        #         result = generate_visualization(original_image = train_dataset[ima_num][0],class_index=train_dataset[ima_num][1].numpy())
        #         output = output + result
        #     np.savetxt('edge_random_30_{}_{}.txt'.format(k,CLS2IDX[cls]),output,'%.4f')




    

start percent divide:1998 15
202
202
201
201
201
start percent divide:1998 20
202
202
201
201
201
start percent divide:1998 25
202
202
201
201
201
start percent divide:1998 30
202
202
201
201
201
start percent divide:1999 15
202
202
201
201
201
start percent divide:1999 20
202
202
201
201
201
start percent divide:1999 25
202
202
201
201
201
start percent divide:1999 30
202
202
201
201
201
start percent divide:2000 15
202
202
201
201
201
start percent divide:2000 20
202
202
201
201
201
start percent divide:2000 25
202
202
201
201
201
start percent divide:2000 30
202
202
201
201
201
start percent divide:2001 15
202
202
201
201
201
start percent divide:2001 20
202
202
201
201
201
start percent divide:2001 25
202
202
201
201
201
start percent divide:2001 30
202
202
201
201
201
start percent divide:2002 15
202
202
201
201
201
start percent divide:2002 20
202
202
201
201
201
start percent divide:2002 25
202
202
201
201
201
start percent divide:2002 30
202
202
201
201
201
start percent divide