In [None]:
import torch, shap, glob
import numpy as np, pandas as pd, torchvision.transforms as transforms
import random

from feedback import *
from tqdm.notebook import tqdm
from torchvision.models import alexnet
from torch.nn import Module
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image

np.random.seed(2020)
random.seed(2020)
torch.manual_seed(2020)

In [None]:
save_weight_path ='./models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.30[NATCG]/'
# save_weight_path = './models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.24[NACGTRYKMSWBDHV]/'

weights_name = "weights_Multiclass_Covid19(Non-kmer3)[NACGT].2022.03.30.pt"
# weights_name = "weights_Multiclass_Covid19(Non-kmer3)[NACGTRYKMSWBDHV].2022.03.24.pt"

path2weights = os.path.join(save_weight_path,weights_name)

In [None]:
# npy_path = './np_image_totalunit/multiclass_totalunit/'
npy_path = './np_image_totalunit/multiclass_nactg/'

In [None]:
npy_data_list = [os.path.join(npy_path,'image_npy',i ) for i in sorted(os.listdir(os.path.join(npy_path,'image_npy')))]
label_ = np.load(os.path.join(npy_path,'label.npy'))

In [None]:
nas_path = "./dataset_1401/"
lineage_label = pd.read_csv('./dataset_1401/1404_lineage_report and metadata 20220316.csv')[['scorpio_call_y','diff','region']]
lineage_label = np.array(lineage_label.fillna("None"))
label_s = []
name_ = []
new_lineage_label = []
for idx, rna in enumerate(SeqIO.parse('./dataset_1401/1404.sequences.aln.fasta',"fasta")):
    # print(lineage_label[idx][0].split(' ')[0])
    label_s.append([lineage_label[idx][0].split(' ')[0], lineage_label[idx][2]])
    name_.append(lineage_label[idx][0])
    new_lineage_label.append(str(rna.seq).replace('-','N'))

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(npy_data_list, label_,stratify = label_, test_size=0.25, random_state=42)
_, label_country, _, _ = train_test_split(label_s, label_,stratify = label_, test_size=0.25, random_state=42)
print(len(X_train), len(y_train)) 

print(len(X_test), len(y_test))

In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import glob
from PIL import Image
import torch
import numpy as np
import random
np.random.seed(2020)
random.seed(2020)
torch.manual_seed(2020)

class TransferDataset(Dataset):
    def __init__(self, s_path, labels, transform):
        self.transform = transform
        self.s_path = s_path
        self.labels = labels
        
    def __len__(self):
        return len(self.s_path)

    def __getitem__(self, idx):
        singel_image_ = np.load(self.s_path[idx]).astype(np.float32)
        seed = np.random.randint(1e9)       
        random.seed(seed)
        np.random.seed(seed)
        singel_image_ = self.transform(singel_image_)
        label = int(self.labels[idx])
        # print(label)

        return singel_image_, label

In [None]:
transformer = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean, std),
            ])     

train_ds = TransferDataset(s_path= X_train, labels= y_train, transform= transformer)
test_ds = TransferDataset(s_path= X_test, labels= y_test, transform= transformer)
print(len(train_ds), len(test_ds))

In [None]:
imgs, label = train_ds[10]
batch_size = 32
train_dl = DataLoader(train_ds, batch_size= batch_size, 
                        shuffle=True)
test_dl = DataLoader(test_ds, batch_size= 2*batch_size, 
                        shuffle=False)  

In [None]:
# eval
models = alexnet(pretrained=False, num_classes=max(label_)+1)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(path2weights, map_location=torch.device('cpu'))
# checkpoint = torch.load('./models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.24[NACGTRYKMSWBDHV]/weights_Multiclass_Covid19(Non-kmer3)[NACGTRYKMSWBDHV].2022.03.24.pt', map_location=torch.device('cpu'))
models.load_state_dict(checkpoint['model_state_dict'])
models.to(device)

In [None]:
train_dl = DataLoader(train_ds, batch_size= 1053, 
                        shuffle=False)
test_dl = DataLoader(test_ds, batch_size= 351, 
                        shuffle=False)  

In [None]:
# {'Alpha': 0, 'B.1.1.318-like': 1, 'Beta': 2, 'Delta': 3, 'Eta': 4, 'Gamma': 5, 'Iota': 6, 'Lambda': 7, 'Mu': 8, 'None': 9}
models.eval()
with torch.no_grad():
    batch = next(iter(test_dl))
    images, label = batch
    pred = models(images.to(device))
    Y_val = label

In [None]:
# since shuffle=True, this is a random sample of test data
test_batch = next(iter(test_dl))
t_images, t_label = test_batch

batch_background = next(iter(train_dl))
b_images, b_label = batch_background

print(b_images.shape, b_label.shape)
print(t_images.shape, t_label.shape)

In [None]:
e = shap.DeepExplainer(models, b_images.to(device))

In [None]:
seq_list = [[idx, image, label, country] for idx, (image, label, country) in enumerate(zip(X_test, y_test, label_country))]
seq_index_list = [int(i[1][-8:-4]) for i in seq_list]
# print(len(t_images), len(seq_index_list))

In [None]:
seq_list[1]

In [None]:
from tqdm.notebook import tqdm
# # save npy
save_path = './shap_npy/multiclass_nactg_2022.03.30'
# for idx, (img_, seq_n) in enumerate(tqdm(zip(t_images[150::],seq_list[150::]))):
#     sav_name = seq_n[1][-8:-4]
#     sv = e.shap_values(torch.unsqueeze(img_, axis=0))
#     np.save(f"{save_path}/{sav_name}.npy", sv)
#     # break


In [None]:
label_class = ['Alpha', 'B.1.1.318-like', 'Beta', 'Delta', 'Eta', 'Gamma', 'Iota', 'Lambda', 'Mu', 'None']
location_map =np.load('./deepinsight_location_npy/coords_[NACGT]-multiclass=1404.npy')
square_map = np.load('./deepinsight_location_npy/feature_density_matrix_[NACGT]-multiclass=1404.npy')
total_sv_image_class_dict = {'Alpha': [], 'B.1.1.318-like':[], 'Beta':[], 'Delta':[], 'Eta':[], 'Gamma':[], 'Iota':[], 'Lambda':[], 'Mu':[], 'None':[]}

for idx1, (exp_image, sv_npy) in enumerate(tqdm(zip(t_images, seq_list))):
    if label_class[Y_val[idx1]]==label_class[np.argmax(pred.cpu().numpy(), axis=1)[idx1]]: # if ground truth == predict result
        sav_name = sv_npy[1][-8:-4]

        load_ = np.load(os.path.join(save_path,f"{sav_name}.npy"))
        # print(load_.shape)
        shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in load_]

        # Extract Seq Image feature
        image = np.squeeze(shap_numpy[np.argmax(pred.cpu().numpy(), axis=1)[idx1]])
        image_sum = image[:,:,0]

        single_seq = new_lineage_label[int(sav_name)] #get original sequence
        important_location = {}
        for  x_id, x_value in enumerate(image_sum):
            for y_id, y_value in enumerate(x_value):
                if y_value!=0:
                    important_location[x_id,y_id] = round(y_value,10)
                    # print(f"[{x_id},{y_id}] = {round(y_value,4)}")
        Sample_filter_important_value_list = []
        for idx2, (seq_rna, location_xy) in enumerate(zip(single_seq, location_map)):
            if (location_xy[0], location_xy[1]) in (list(important_location.keys())):
                # print("Seq Index: ",idx,"Acid: ", seq_rna, "Mat location [X, Y]: ",location_xy ,"Value: ", important_location[location_xy[0], location_xy[1]])
                Sample_filter_important_value_list.append([idx2, seq_rna, location_xy, important_location[location_xy[0], location_xy[1]]]) #save single sequnce index, epch [X, Y] position point, point weight
        Sample_filter_important_value_list.sort(key = lambda s: s[3], reverse = True)
        total_sv_image_class_dict[label_class[np.argmax(pred.cpu().numpy(), axis=1)[idx1]]].append([Sample_filter_important_value_list, sv_npy[3][1]]) #predict sequcne +  local country name

In [None]:
class_dict = {}
for classes in total_sv_image_class_dict:
    class_dict[classes] = {}
    if len(total_sv_image_class_dict[classes])!=0:
        index_location_dict = {}
        for single_seq in total_sv_image_class_dict[classes]:
            for rna_position in single_seq[0]:
                if (rna_position[0], rna_position[1]) not  in index_location_dict.keys():
                    index_location_dict[rna_position[0], rna_position[1]] = rna_position[-1]
                else:
                    index_location_dict[rna_position[0], rna_position[1]] += rna_position[-1]
        print(len(index_location_dict))
        class_dict[classes] = index_location_dict
        # #     break
        # # break
    

In [None]:
for i in class_dict.keys():
    class_dict[i] = {k: v for k, v in sorted(class_dict[i].items(), key=lambda item: item[1], reverse=True)}

class_dict_minus_sign ={}
for i in class_dict.keys():
    class_dict_minus_sign[i] = {k: v for k, v in sorted(class_dict[i].items(), key=lambda item: item[1], reverse=False)}

In [None]:
for i in class_dict:
    print(i, list(class_dict[i].keys())[0:10])


In [None]:

lineage_index = {'Alpha': 0, 'B.1.1.318-like':1, 'Beta':2, 'Delta':3, 'Eta':4, 'Gamma':5, 'Iota':6, 'Lambda':7, 'Mu':8, 'None':9}
error_class_list = list(np.zeros(10).astype(np.int8))

In [None]:
error_seq_array = []
for class_ in total_sv_image_class_dict:
    if len(total_sv_image_class_dict[class_])>0:
        for class_seq in total_sv_image_class_dict[class_]:
            for single_se in class_seq[0]:
                if single_se[0] ==7:
                    error_seq_array.append([7, class_, class_seq[1], single_se[1]])
                if single_se[0] ==17:
                    error_seq_array.append([17, class_, class_seq[1], single_se[1]])
                if single_se[0] ==217:
                    error_seq_array.append([217, class_, class_seq[1], single_se[1]])

In [None]:
(pd.DataFrame(error_seq_array, columns=['Index Location', 'Lineage', 'Country', 'RNA'])).to_csv('./seven_error.csv')