In [None]:
import torch, shap, glob, os
import numpy as np, pandas as pd, torchvision.transforms as transforms
import random

from Bio import SeqIO
from tqdm.notebook import tqdm
from torch.nn import Module
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
from torchvision.models import  resnet18, alexnet

np.random.seed(2020)
random.seed(2020)
torch.manual_seed(2020)

In [None]:
npy_path = './np_image_totalunit/tsne-binary-perplexity=5-pixel=400[onehot]/'
# npy_path = './np_image_totalunit/multiclass_totalunit/'

save_weight_path ='./models/weights_res18_tsne-binary-perplexity=5-pixel=400[onehot]/'
# save_weight_path = './models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.24[NACGTRYKMSWBDHV]/'

weights_name = "weights_binaryclass_Covid19(Non-kmer3)[NACGT].2022.05.09-onehot.pt"
# weights_name = "weights_Multiclass_Covid19(Non-kmer3)[NACGTRYKMSWBDHV].2022.03.24.pt"

path2weights = os.path.join(save_weight_path,weights_name)

In [None]:
npy_data_list = [os.path.join(npy_path,'image_npy',i ) for i in sorted(os.listdir(os.path.join(npy_path,'image_npy')))]
label_ = np.load(os.path.join(npy_path,'label.npy'))

In [None]:
nas_path = "./dataset_1401/"
lineage_label = pd.read_csv('./dataset_1401/1404_lineage_report and metadata 20220316.csv')[['scorpio_call_y','diff']]
lineage_label = np.array(lineage_label.fillna("None"))
label_s = []
name_ = []
new_lineage_label = []
for idx, rna in enumerate(SeqIO.parse('./dataset_1401/1404.sequences.aln.fasta',"fasta")):
    if "B.1.617.2" == lineage_label[idx][0]:
    # print(lineage_label[idx][0].split(' ')[0])
        label_s.append(lineage_label[idx][1].split(' ')[0])
        name_.append(lineage_label[idx][0])
        new_lineage_label.append(str(rna.seq))

In [None]:
class_,_ ,_,_= np.unique(label_s,return_counts=True,return_index=True,return_inverse=True)
print(class_)

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(npy_data_list, label_,stratify = label_, test_size=0.25, random_state=42)

print(len(X_train), len(y_train)) 

print(len(X_test), len(y_test))

In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import glob
from PIL import Image
import torch
import numpy as np
import random
np.random.seed(2020)
random.seed(2020)
torch.manual_seed(2020)

class TransferDataset(Dataset):
    def __init__(self, s_path, labels, transform):
        self.transform = transform
        self.s_path = s_path
        self.labels = labels
        
    def __len__(self):
        return len(self.s_path)

    def __getitem__(self, idx):
        singel_image_ = np.load(self.s_path[idx]).astype(np.float32)
        seed = np.random.randint(1e9)       
        random.seed(seed)
        np.random.seed(seed)
        singel_image_ = self.transform(singel_image_)
        label = float(self.labels[idx])
        # print(label)

        return singel_image_, label

In [None]:
transformer = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean, std),
            ])     

train_ds = TransferDataset(s_path= X_train, labels= y_train, transform= transformer)
test_ds = TransferDataset(s_path= X_test, labels= y_test, transform= transformer)
print(len(train_ds), len(test_ds))

In [None]:
imgs, label = train_ds[10]
batch_size = 32
train_dl = DataLoader(train_ds, batch_size= batch_size, 
                        shuffle=True)
test_dl = DataLoader(test_ds, batch_size= 2*batch_size, 
                        shuffle=False)  

In [None]:
# eval
def reload_model():
    models = resnet18(pretrained=False, num_classes=1)
    checkpoint = torch.load(path2weights, map_location=torch.device('cpu'))
    # checkpoint = torch.load('./models/weights_Multiclass_Covid19(Non-kmer3)_IndexRemark.2022.03.24[NACGTRYKMSWBDHV]/weights_Multiclass_Covid19(Non-kmer3)[NACGTRYKMSWBDHV].2022.03.24.pt', map_location=torch.device('cpu'))
    models.load_state_dict(checkpoint['model_state_dict'])
    return models

In [None]:
test_dl = DataLoader(test_ds, batch_size= 1, 
                        shuffle=False) 
images, label =next(iter(test_dl))
print(images.shape, label)

In [None]:
import matplotlib.pyplot as plt
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchcam.methods import SmoothGradCAMpp, LayerCAM, GradCAM
from torchcam.utils import overlay_mask
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class_dict = {0:'N', 1:'Y'}
def loader_cam(image_, lable_, idx): #single classes loader
    model = reload_model().to(device).eval()
    cam_extractor = LayerCAM(model, ["layer4"])
    classes__ = class_dict[int(lable_)]
    out = model(image_.to(device))
    # print(torch.sigmoid(out))
    cams = cam_extractor(out.squeeze(0).argmax().item(), out)
    
    cam_extractor.clear_hooks()
    # for i in cams:
    #     print(i.shape)
    # Resize it
    ths = 0.5
    resized_cams = [resize(to_pil_image(cam), img.shape[-2:]) for cam in cams]
    segmaps = [to_pil_image((resize(cam.unsqueeze(0), img.shape[-2:]).squeeze(0) >= ths).to(dtype=torch.float32)) for cam in cams]

    # Calc cam weight
    for name, cam, seg in zip(cam_extractor.target_names, resized_cams, segmaps):
        capture_image = np.where(np.array(seg), np.array(images[0][0]), np.array(images[0][0])*0)
    
    # Plot it
    if False:
        for name, cam, seg in zip(cam_extractor.target_names, resized_cams, segmaps):
            _, axes = plt.subplots(1, 5, figsize=(25, 7))
            axes[0].imshow(cam); axes[0].axis('off'); axes[0].set_title(f'heatmap')
            axes[1].imshow(seg); axes[1].axis('off'); axes[1].set_title(f' mask > {ths} [ths]')
            axes[2].imshow(images[0][0]); axes[2].axis('off'); axes[2].set_title(f'seq image - Ground Truth: {classes__}')
            axes[3].imshow(images[0][0], cmap='bone')
            axes[3].imshow(seg, alpha=0.5, cmap='bone'); axes[3].axis('off'); axes[3].set_title(f'MIX image - Pred Results: {class_dict[torch.sigmoid(out).item()>0.5]}')

            capture_image = np.where(np.array(seg), np.array(images[0][0]), np.array(images[0][0])*0)
            axes[4].imshow(capture_image); axes[4].axis('off'); axes[4].set_title(f'Captrue Area')
            plt.savefig(os.path.join(save_fig_path, project_name, f'{idx}.png'))
            plt.show()
    return capture_image, images[0][0].cpu().numpy()
    
save_fig_path = './GramCam_FIG/'
project_name = 'weights_res18_tsne-binary-perplexity=5-pixel=400[onehot]'
if not os.path.exists(os.path.join(save_fig_path)):
    os.mkdir(os.path.join(save_fig_path))
if not os.path.exists(os.path.join(save_fig_path, project_name)):
    os.mkdir(os.path.join(save_fig_path, project_name))

sum_image = np.zeros((440,440))
tsne_image = np.zeros((440,440))
for idx, (img, lab) in enumerate(test_dl):
    if int(lab)==1:
        images = img
        label = lab
        temp, temp2 = loader_cam(images, label, idx)
        tsne_image = tsne_image + temp2
        sum_image = sum_image + temp
        # tsne_image = temp2
        # sum_image = temp
        # break

In [None]:
import pickle as pk, os
save_model_path = './deepinsight_location_npy/'
save_name = 'tsne-binary-perplexity=50-pixel=400[onehot].pkl'
# save_name = 'tsne-binary-perplexity=50-pixel=400.pkl'
it = pk.load(open(os.path.join(save_model_path,save_name),'rb'))

In [None]:
_, axes = plt.subplots(1, 3, figsize=(30, 15))

axes[0].imshow(tsne_image, cmap='bone'); axes[0].axis('off'); axes[0].set_title(f'lineage diff = [Y]')
axes[1].imshow(tsne_image, cmap='bone')
axes[1].imshow(sum_image, alpha=0.7, cmap='hot'); axes[1].axis('on'); axes[1].set_title(f'MIX')
axes[2].imshow(sum_image, cmap='hot'); axes[2].axis('on'); axes[2].set_title(f'test diff= [Y] image stack')

plt.show()

In [None]:
feature_dict = {}
for x, first_array in enumerate(sum_image, ):
    for y, second_array in enumerate(first_array):
        if second_array>0.0:
            feature_dict[x,y] = second_array
            # print((x,y), round(second_array,8))

In [None]:
RNA_SEQ = {0:'-', 1:'N', 2:'A', 3:'C', 4:'G', 5:'T'}
total_feature_stack = []
for seq_index, xy in enumerate(it.coords()):
    if tuple(xy) in list(feature_dict.keys()):
        total_feature_stack.append([feature_dict[tuple(xy)], int(seq_index/6)+1, RNA_SEQ[seq_index%6]])
        # print('index', seq_index, 'rna: ', RNA_SEQ[seq_index%6], 'weight', feature_dict[tuple(xy)])

In [None]:
# len(it.coords())
# ml_pos_stack = [1048, 13482, 15952, 17236, 21846, 21987, 22792, 23593, 23896, 24928, 25352, 26107]
# ml_gene_stack = [['G', 'T'], ['G'], ['A'], ['A'], ['C', 'T'], ['G', 'N', 'A'], ['C', 'T'], ['C','G'], ['C'], ['G'], ['G'], ['C', 'G']]


In [None]:
# ml_save_feature_stack = []
# for pos, genes in zip(ml_pos_stack, ml_gene_stack):
#     for ge_ in genes:
#         ml_save_feature_stack.append([pos, ge_])
# # (pd.DataFrame(ml_save_feature_stack, columns =['Position', 'Gene'])).to_csv('../Gene-Translation/demo/ml_feature_list.csv', index=False)

In [None]:
# pos2xy = []
# NUM_SEQ = {'-':0, 'N':1, 'A':2, 'C':3, 'G':4, 'T':5}
# for i in ml_save_feature_stack:
#     print((i[0]-1)*6, ((i[0]-1)*6)-(5-NUM_SEQ[i[1]]), i[1])
#     # print(it.coords()[((i[0]-1)*6)-(5-NUM_SEQ[i[1]])])
#     pos2xy.append(it.coords()[((i[0]-1)*6)-(5-NUM_SEQ[i[1]])])

In [None]:
# ml_image = np.zeros((440,440))
# for draw_ in pos2xy:
#     ml_image[draw_[0]][draw_[1]] = 0.5

# for draw_ in ml_image:
#     for draw_2 in draw_:
#         if (draw_2 >0):
#             print(draw_, draw_2)

In [None]:

# _, axes = plt.subplots(1, 2, figsize=(20, 15))

# axes[0].imshow(tsne_image, cmap='bone')
# axes[0].imshow(ml_image, alpha=0.7, cmap='hot'); axes[0].axis('on'); axes[0].set_title(f'Machine Learning')
# axes[1].imshow(tsne_image, cmap='bone')
# axes[1].imshow(sum_image, alpha=0.7, cmap='hot'); axes[1].axis('on'); axes[1].set_title(f'Deep Learning')

# plt.show()

In [None]:
total_feature_stack = sorted(total_feature_stack, reverse=True)
# total_feature_stack = sorted(total_feature_stack, key = lambda seq_index : seq_index[1])

In [None]:
save_feature_stack = [[i[1],i[2]] for i in total_feature_stack]

In [None]:
(pd.DataFrame(save_feature_stack, columns =['Position', 'Gene'])).to_csv('../Gene-Translation/demo/feature_list_ths0.5.csv', index=False)
# (pd.DataFrame(save_feature_stack, columns =['Position', 'Gene'])).to_csv('../Gene-Translation/demo/feature_list_ths0.5.csv', sep='\t', index=False)