In [1]:
import torch, shap, glob, os
import numpy as np, pandas as pd, torchvision.transforms as transforms
import random

from Bio import SeqIO
from tqdm.notebook import tqdm
from torch.nn import Module
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
from torchvision.models import  resnet18, alexnet

np.random.seed(2020)
random.seed(2020)
torch.manual_seed(2020)

<torch._C.Generator at 0x7febcc074130>

In [2]:
save_weight_path ='./models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.30[NATCG]/'
# save_weight_path = './models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.24[NACGTRYKMSWBDHV]/'

weights_name = "weights_Multiclass_Covid19(Non-kmer3)[NACGT].2022.03.30.pt"
# weights_name = "weights_Multiclass_Covid19(Non-kmer3)[NACGTRYKMSWBDHV].2022.03.24.pt"

path2weights = os.path.join(save_weight_path,weights_name)

In [3]:
# npy_path = './np_image_totalunit/multiclass_totalunit/'
npy_path = './np_image_totalunit/multiclass_nactg/'

In [4]:
npy_data_list = [os.path.join(npy_path,'image_npy',i ) for i in sorted(os.listdir(os.path.join(npy_path,'image_npy')))]
label_ = np.load(os.path.join(npy_path,'label.npy'))

In [5]:
X_train, X_test, y_train, y_test = train_test_split(npy_data_list, label_,stratify = label_, test_size=0.25, random_state=42)

print(len(X_train), len(y_train)) 

print(len(X_test), len(y_test))

1053 1053
351 351


In [6]:
class TransferDataset(Dataset):
    def __init__(self, s_path, labels, transform):
        self.transform = transform
        self.s_path = s_path
        self.labels = labels
        
    def __len__(self):
        return len(self.s_path)

    def __getitem__(self, idx):
        singel_image_ = np.load(self.s_path[idx]).astype(np.float32)
        seed = np.random.randint(1e9)       
        random.seed(seed)
        np.random.seed(seed)
        singel_image_ = self.transform(singel_image_)
        label = int(self.labels[idx])
        # print(label)

        return singel_image_, label

In [7]:
transformer = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean, std),
            ])     

train_ds = TransferDataset(s_path= X_train, labels= y_train, transform= transformer)
test_ds = TransferDataset(s_path= X_test, labels= y_test, transform= transformer)
print(len(train_ds), len(test_ds))

1053 351


In [8]:
# eval
models = alexnet(pretrained=False, num_classes=max(label_)+1)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(path2weights, map_location=torch.device('cpu'))
# checkpoint = torch.load('./models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.24[NACGTRYKMSWBDHV]/weights_Multiclass_Covid19(Non-kmer3)[NACGTRYKMSWBDHV].2022.03.24.pt', map_location=torch.device('cpu'))
models.load_state_dict(checkpoint['model_state_dict'])
models.to(device)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [9]:
print(len(train_ds), len(test_ds))

1053 351


In [10]:
train_dl = DataLoader(train_ds, batch_size= 1053, 
                        shuffle=False)
test_dl = DataLoader(test_ds, batch_size= 351, 
                        shuffle=False)  

In [11]:
# {'Alpha': 0, 'B.1.1.318-like': 1, 'Beta': 2, 'Delta': 3, 'Eta': 4, 'Gamma': 5, 'Iota': 6, 'Lambda': 7, 'Mu': 8, 'None': 9}
models.eval()
with torch.no_grad():
    batch = next(iter(test_dl))
    images, label = batch
    pred = models(images.to(device))
    Y_val = label

In [12]:
# since shuffle=True, this is a random sample of test data
test_batch = next(iter(test_dl))
t_images, t_label = test_batch

batch_background = next(iter(train_dl))
b_images, b_label = batch_background

print(b_images.shape, b_label.shape)
print(t_images.shape, t_label.shape)

torch.Size([1053, 3, 100, 100]) torch.Size([1053])
torch.Size([351, 3, 100, 100]) torch.Size([351])


In [13]:
e = shap.DeepExplainer(models, b_images.to(device))
# e = shap.DeepExplainer(models, b_images.cpu())

In [14]:
seq_list = [[idx, image, label] for idx, (image, label) in enumerate(zip(X_test, y_test))]
seq_index_list = [int(i[1][-8:-4]) for i in seq_list]
# print(len(t_images), len(seq_index_list))

In [15]:
seq_list[1]

[1, './np_image_totalunit/multiclass_nactg/image_npy/0180.npy', 3]

In [16]:
t_images.shape, len(seq_list)

(torch.Size([351, 3, 100, 100]), 351)

In [17]:
from tqdm.notebook import tqdm
# # save npy
save_path = './shap_npy/multiclass_nactg_2022.03.30-2'
if not os.path.exists(save_path):
    os.mkdir(save_path)
for idx, (img_, seq_n) in enumerate(tqdm(zip(t_images[250::],seq_list[250::]))):
    sav_name = seq_n[1][-8:-4]
    sv = e.shap_values(torch.unsqueeze(img_, axis=0))
    np.save(f"{save_path}/{sav_name}.npy", sv)
#     # break


0it [00:00, ?it/s]

Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


In [18]:
seq_list[0]

[0, './np_image_totalunit/multiclass_nactg/image_npy/0194.npy', 3]

In [19]:
nas_path = "./dataset_1401/"
lineage_label = pd.read_csv('./dataset_1401/1404_lineage_report and metadata 20220316.csv')[['scorpio_call_y','diff','region']]
lineage_label = np.array(lineage_label.fillna("None"))
label_s = []
name_ = []
new_lineage_label = []
for idx, rna in enumerate(SeqIO.parse('./dataset_1401/1404.sequences.aln.fasta',"fasta")):
    # print(lineage_label[idx][0].split(' ')[0])
    label_s.append([lineage_label[idx][0].split(' ')[0], lineage_label[idx][2]])
    name_.append(lineage_label[idx][0])
    new_lineage_label.append(str(rna.seq).replace('-','N'))

In [84]:
label_class = ['Alpha', 'B.1.1.318-like', 'Beta', 'Delta', 'Eta', 'Gamma', 'Iota', 'Lambda', 'Mu', 'None']
location_map =np.load('./deepinsight_location_npy/coords_[NACGT]-multiclass=1404.npy')
square_map = np.load('./deepinsight_location_npy/feature_density_matrix_[NACGT]-multiclass=1404.npy')
total_sv_image_class_dict = {'Alpha': [], 'B.1.1.318-like':[], 'Beta':[], 'Delta':[], 'Eta':[], 'Gamma':[], 'Iota':[], 'Lambda':[], 'Mu':[], 'None':[]}
label_class = ['Alpha', 'B.1.1.318-like', 'Beta', 'Delta', 'Eta', 'Gamma', 'Iota', 'Lambda', 'Mu', 'None']

for idx1, (exp_image, sv_npy) in enumerate(tqdm(zip(t_images[250::], seq_list[250::]))):
    if label_class[Y_val[idx1]]==label_class[np.argmax(pred.cpu().numpy(), axis=1)[idx1]]: # if ground truth == predict result
        sav_name = sv_npy[1][-8:-4]

        load_ = np.load(os.path.join(save_path,f"{sav_name}.npy"))
        # print(load_.shape)
        shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in load_]

        # Extract Seq Image feature
        image = np.squeeze(shap_numpy[np.argmax(pred.cpu().numpy(), axis=1)[idx1]])
        # image_sum = image[:,:,0] #RGB each value was different
        image_sum = image[:,:,0]+image[:,:,1]+image[:,:,2]
        single_seq = new_lineage_label[int(sav_name)] #get original sequence
        important_location = {}
        # print(np.sum(image[:,:,0]), np.sum(image[:,:,1]), np.sum(image[:,:,2]))
        for  x_id, x_value in enumerate(image_sum):
            for y_id, y_value in enumerate(x_value):
                if y_value!=0:
                    important_location[x_id,y_id] = round(y_value,10)
                    # print(f"[{x_id},{y_id}] = {round(y_value,4)}")
        # print(location_map.shape)
        # print(len(important_location))
        # print(location_map[0:5])
        # print(important_location[10,29])
        # print(np.where((location_map[:,0] == 10) & (location_map[:,1] == 29)))
        # print("Sequence",sav_name)
        Sample_filter_important_value_list = []
        count = 0
        for idx2 in important_location:
            if count<3:
                
                print("t-sne [X,Y]: ",idx2)
                print("shap [X,Y] values: ", important_location[idx2])
                temp = np.where((location_map[:,0] == idx2[0]) & (location_map[:,1] == idx2[1]))[0]
                print("shap map = t-sne[X,Y]: ", temp)
                print("Seq Index RNA: ", [single_seq[i] for i in temp])
                count+=1
            else:
                break
        break
        # for idx2, (seq_rna, location_xy) in enumerate(zip(single_seq, location_map)):
        #     if (location_xy[0], location_xy[1]) in (list(important_location.keys())):
        #         # print("Seq Index: ",idx,"Acid: ", seq_rna, "Mat location [X, Y]: ",location_xy ,"Value: ", important_location[location_xy[0], location_xy[1]])
        #         Sample_filter_important_value_list.append([idx2, seq_rna, location_xy, important_location[location_xy[0], location_xy[1]]]) #save single sequnce index, epch [X, Y] position point, point weight
        # Sample_filter_important_value_list.sort(key = lambda s: s[3], reverse = True)
        # total_sv_image_class_dict[label_class[np.argmax(pred.cpu().numpy(), axis=1)[idx1]]].append([Sample_filter_important_value_list, sv_npy[3][1]]) #predict sequcne +  local country name

0it [00:00, ?it/s]

t-sne [X,Y]:  (10, 29)
shap [X,Y] values:  5.101e-07
shap map = t-sne[X,Y]:  [14853 14856 14859 14862 14865 14868 14873 14880 14884 14889 14892 14893]
Seq Index RNA:  ['G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G']
t-sne [X,Y]:  (10, 30)
shap [X,Y] values:  2.38e-08
shap map = t-sne[X,Y]:  [14676 14677 14691 14700 14706 14709 14711 14717 14718 14719 14729 14730
 14733 14734 14737 14742 14745 14766 14771 14772 14775 14776 15877 15893
 15897 15905 15906 15907 15909 15912 15918 15920 15936 15946 15954 15955
 15956 15957 15960 15961 15964 15969 15972 15975 15981 15990 15993 15994
 16004 16012 16013 16017 16019 16026 16032 16035 16064 16065 16067 16071
 16074 16077 16088 16105 16109 16116 16119 16121 16128 16129 16136 16140]
Seq Index RNA:  ['G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 

In [None]:
class_dict = {}
for classes in total_sv_image_class_dict:
    class_dict[classes] = {}
    if len(total_sv_image_class_dict[classes])!=0:
        index_location_dict = {}
        for single_seq in total_sv_image_class_dict[classes]:
            for rna_position in single_seq[0]:
                if (rna_position[0], rna_position[1]) not  in index_location_dict.keys():
                    index_location_dict[rna_position[0], rna_position[1]] = rna_position[-1]
                else:
                    index_location_dict[rna_position[0], rna_position[1]] += rna_position[-1]
        print(len(index_location_dict))
        class_dict[classes] = index_location_dict
        # #     break
        # # break
    

In [None]:
for i in class_dict.keys():
    class_dict[i] = {k: v for k, v in sorted(class_dict[i].items(), key=lambda item: item[1], reverse=True)}

class_dict_minus_sign ={}
for i in class_dict.keys():
    class_dict_minus_sign[i] = {k: v for k, v in sorted(class_dict[i].items(), key=lambda item: item[1], reverse=False)}

In [None]:
for i in class_dict:
    print(i, list(class_dict[i].keys())[0:20])


In [None]:
import matplotlib.pyplot as plt
title_label = 'N'
count = 0
label_plus = []
values_plus = []

label_minus_sign = []
values_minus_sign = []
for i in class_dict[title_label]:
        if count!=30:
                label_plus.append(str(i))
                values_plus.append(class_dict[title_label][i])
                count+=1
        else:
                break
for i in class_dict_minus_sign[title_label]:
        if count!=60:
                label_minus_sign.append(str(i))
                values_minus_sign.append(class_dict_minus_sign[title_label][i])
                count+=1
        else:
                break
plt.figure(figsize=(8, 10))
plt.barh(label_minus_sign,
        [i/1 for i in values_minus_sign])
plt.barh(label_plus[::-1],
        [i/1 for i in values_plus[::-1]])
x_ticks=np.arange(int(min(values_minus_sign))//1,int(max(values_plus))//1,2) 
plt.xticks(x_ticks)     
# plt.xticks(rotation='vertical')
plt.xlabel("value")

plt.ylabel("RNA Index")
plt.title(title_label)
plt.tight_layout()

plt.grid()
plt.show()

In [None]:
import matplotlib.pyplot as plt
title_label = 'Y'
count = 0
label_plus = []
values_plus = []

label_minus_sign = []
values_minus_sign = []
for i in class_dict[title_label]:
        if count!=30:
                label_plus.append(str(i))
                values_plus.append(class_dict[title_label][i])
                count+=1
        else:
                break
for i in class_dict_minus_sign[title_label]:
        if count!=60:
                label_minus_sign.append(str(i))
                values_minus_sign.append(class_dict_minus_sign[title_label][i])
                count+=1
        else:
                break
plt.figure(figsize=(8, 10))
plt.barh(label_minus_sign,
        [i/1 for i in values_minus_sign])
plt.barh(label_plus[::-1],
        [i/1 for i in values_plus[::-1]])
x_ticks=np.arange(int(min(values_minus_sign))//1,int(max(values_plus))//1,2) 
plt.xticks(x_ticks)     
# plt.xticks(rotation='vertical')
plt.xlabel("value")

plt.ylabel("RNA Index")
plt.title(title_label)
plt.tight_layout()

plt.grid()
plt.show()