In [1]:
import plotly.graph_objects as go
import urllib, json
import os
os.getcwd()

'/lambda_stor/homes/ac.tfeng/git/DrugCell'

In [2]:
# import anndata
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.distributions import MultivariateNormal

import os
# os.environ["CUDA_VISIBLE_DEVICES"]="4"

import copy

import tqdm

from codes.utils.util import *
from codes.drugcell_NN import *

if torch.cuda.is_available():
  DEVICE = 'cuda'
else:
  DEVICE = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import plotly
import plotly.express as px
import plotly.io as pio

from captum.attr import LayerIntegratedGradients
from collections import defaultdict


In [4]:
!nvidia-smi

Mon Nov 27 10:33:03 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:1A:00.0 Off |                    0 |
| N/A   28C    P0              41W / 300W |      3MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  | 00000000:1B:00.0 Off |  

In [5]:
DEVICE = 'cuda:7'

# GO Term

# TCGA

### Data loading

In [6]:
training_file = "data/drugcell_train.txt"
testing_file = "data/drugcell_test.txt"
val_file = "data/drugcell_val.txt"
cell2id_file = "data/cell2ind.txt"
drug2id_file = "data/drug2ind.txt"
genotype_file = "data/cell2mutation.txt"
fingerprint_file = "data/drug2fingerprint.txt"
onto_file = "data/drugcell_ont.txt"
gene2id_file = "data/gene2ind.txt"

train_data, feature_dict, cell2id_mapping, drug2id_mapping = prepare_train_data(training_file, 
                                                                  testing_file, cell2id_file, 
                                                                  drug2id_file)

gene2id_mapping = load_mapping(gene2id_file)

# load cell/drug features
cell_features = np.genfromtxt(genotype_file, delimiter=',')
drug_features = np.genfromtxt(fingerprint_file, delimiter=',')

num_cells = len(cell2id_mapping)
num_drugs = len(drug2id_mapping)
num_genes = len(gene2id_mapping)
drug_dim = len(drug_features[0,:])

# load ontology
dG, root, term_size_map, \
    term_direct_gene_map = load_ontology(onto_file, 
                                         gene2id_mapping)

Total number of cell lines = 1225
Total number of drugs = 684
There are 3008 genes
There are 1 roots: GO:0008150
There are 2086 terms
There are 1 connected componenets


In [7]:
class RNASeqData(Dataset):
    
    def __init__(self, X, c=None, y=None, transform=None):
        self.X = X
        self.y = y
        self.c = c
        self.transform = transform
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, index):
        sample = self.X[index,:]
        
        if self.transform is not None:
            sample = self.transform(sample)
        
        if self.y is not None and self.c is None:
            return sample, self.y[index]
        elif self.y is not None and self.c is not None:
            return sample, self.y[index], self.c[index]
        elif self.y is None and self.c is not None:
            return sample, self.c[index]
        else:
            return sample

In [8]:
## expression and IHC data

rna_seq = pd.read_csv('data/tcga/train_tcga_expression_matrix_processed.tsv',sep='\t',index_col=0)
tcga_mad_genes = pd.read_csv('data/tcga/tcga_mad_genes.tsv', sep='\t')
tcga_sample_counts = pd.read_csv('data/tcga/tcga_sample_counts.tsv', sep='\t')
tcga_sample_identifiers = pd.read_csv('data/tcga/tcga_sample_identifiers.tsv', sep='\t',index_col=0)

# rna_seq = rna_seq.loc[:,tcga_mad_genes.gene_id[:3000].apply(str)]
rna_seq = rna_seq / rna_seq.std()
rna_seq = np.log(rna_seq + 1)

In [9]:
tcga_df = rna_seq
gene_id_dict = pd.read_csv('data/tcga/gene_dict.csv')

new_column = []
for col_name in tcga_df.columns:
    col_loc = (gene_id_dict.entrezgene_id == int(col_name))
    if np.sum(col_loc) == 0:
        tcga_df = tcga_df.drop(columns = col_name)
    else:
        new_column.append(gene_id_dict.hgnc_symbol[gene_id_dict.entrezgene_id == int(col_name)].iloc[0])
        

tcga_df.columns = new_column

In [10]:
tcga_df.columns = new_column

#### Create gene2ID for tcga

In [11]:
tcga_gene2id = {}
for idx, gene in enumerate(tcga_df.columns):
    tcga_gene2id[gene] = idx

#### Align gene id with tcga

In [12]:
gene_intersect_list = list(set(gene2id_mapping.keys()) & set(tcga_df.columns))
tcga_tensor = torch.zeros(tcga_df.shape[0], num_genes)

for gene in gene_intersect_list:
    idx = gene2id_mapping[gene]
    tcga_tensor[:,idx] = torch.tensor(tcga_df[gene])

In [13]:
cancer_2_idx = {}
idx_2_cancer = {}
cancer_type_idx = []
cancer_type = tcga_sample_identifiers.loc[tcga_df.index, 'cancer_type']

i = 0
for cancer in cancer_type:
    if cancer not in cancer_2_idx:
        cancer_2_idx[cancer] = i
        idx_2_cancer[i] = cancer
        cancer_type_idx.append(i)
        
        i += 1
    else:
        cancer_type_idx.append(cancer_2_idx[cancer])

y = torch.tensor(cancer_type_idx)

In [30]:
torch.manual_seed(0)

tcga_dataset = RNASeqData(X = tcga_tensor, y = y)
training_set, testing_set = random_split(tcga_dataset, [0.7, 0.3])

In [15]:
class dcell_vae(nn.Module):

    def __init__(self, term_size_map, term_direct_gene_map, dG, ngene, root, 
                 num_hiddens_genotype, num_hiddens_final, n_class, inter_loss_penalty = 0.2):

        super(dcell_vae, self).__init__()

        self.root = root
        self.num_hiddens_genotype = num_hiddens_genotype
        self.num_hiddens_final = num_hiddens_final
        self.n_class = n_class
        self.inter_loss_penalty = inter_loss_penalty
        self.dG = copy.deepcopy(dG)

        # dictionary from terms to genes directly annotated with the term
        self.term_direct_gene_map = term_direct_gene_map

        self.term_visit_count = {}
        self.init_term_visits(term_size_map)
        
        # calculate the number of values in a state (term): term_size_map is the number of all genes annotated with the term
        self.term_dim_map = {}
        self.cal_term_dim(term_size_map)

        # ngenes, gene_dim are the number of all genes
        self.gene_dim = ngene

        # add modules for neural networks to process genotypes
        self.contruct_direct_gene_layer()
        self.construct_NN_graph(self.dG)

        # add modules for final layer TODO: modify it into VAE
        final_input_size = num_hiddens_genotype # + num_hiddens_drug[-1]
        self.add_module('final_linear_layer', nn.Linear(final_input_size, num_hiddens_final * 2))
        self.add_module('final_batchnorm_layer', nn.BatchNorm1d(num_hiddens_final * 2))
        self.add_module('final_aux_linear_layer', nn.Linear(num_hiddens_final * 2, 1))
        self.add_module('final_linear_layer_output', nn.Linear(1, 1))
        
        self.decoder_affine = nn.Linear(num_hiddens_final, n_class)

    def init_term_visits(self, term_size_map):
        
        for term in term_size_map:
            self.term_visit_count[term] = 0
    
    # calculate the number of values in a state (term)
    def cal_term_dim(self, term_size_map):

        for term, term_size in term_size_map.items():
            num_output = self.num_hiddens_genotype

            # log the number of hidden variables per each term
            num_output = int(num_output)
#            print("term\t%s\tterm_size\t%d\tnum_hiddens\t%d" % (term, term_size, num_output))
            self.term_dim_map[term] = num_output


    # build a layer for forwarding gene that are directly annotated with the term
    def contruct_direct_gene_layer(self):

        for term, gene_set in self.term_direct_gene_map.items():
            if len(gene_set) == 0:
                print('There are no directed asscoiated genes for', term)
                sys.exit(1)

            # if there are some genes directly annotated with the term, add a layer taking in all genes and forwarding out only those genes
            self.add_module(term+'_direct_gene_layer', nn.Linear(self.gene_dim, len(gene_set)))

    # start from bottom (leaves), and start building a neural network using the given ontology
    # adding modules --- the modules are not connected yet
    def construct_NN_graph(self, dG):

        self.term_layer_list = []   # term_layer_list stores the built neural network
        self.term_neighbor_map = {}

        # term_neighbor_map records all children of each term
        for term in dG.nodes():
            self.term_neighbor_map[term] = []
            for child in dG.neighbors(term):
                self.term_neighbor_map[term].append(child)

        while True:
            leaves = [n for n in dG.nodes() if dG.out_degree(n) == 0]
            #leaves = [n for n,d in dG.out_degree().items() if d==0]
            #leaves = [n for n,d in dG.out_degree() if d==0]

            if len(leaves) == 0:
                break

            self.term_layer_list.append(leaves)

            for term in leaves:

                # input size will be #chilren + #genes directly annotated by the term
                input_size = 0

                for child in self.term_neighbor_map[term]:
                    input_size += self.term_dim_map[child]

                if term in self.term_direct_gene_map:
                    input_size += len(self.term_direct_gene_map[term])

                # term_hidden is the number of the hidden variables in each state
                term_hidden = self.term_dim_map[term]

                self.add_module(term+'_linear_layer', nn.Linear(input_size, term_hidden))
                self.add_module(term+'_batchnorm_layer', nn.BatchNorm1d(term_hidden))
                self.add_module(term+'_aux_linear_layer1', nn.Linear(term_hidden, self.n_class))
                self.add_module(term+'_aux_linear_layer2', nn.Linear(self.n_class, self.n_class))

            dG.remove_nodes_from(leaves)


    # definition of encoder
    def encoder(self, x):
        gene_input = x.narrow(1, 0, self.gene_dim)

        # define forward function for genotype dcell #############################################
        term_gene_out_map = {}

        for term, _ in self.term_direct_gene_map.items():
            term_gene_out_map[term] = self._modules[term + '_direct_gene_layer'](gene_input)

        term_NN_out_map = {}
        aux_out_map = {}

        for i, layer in enumerate(self.term_layer_list):

            for term in layer:

                child_input_list = []

                self.term_visit_count[term] += 1
                
                for child in self.term_neighbor_map[term]:
                    child_input_list.append(term_NN_out_map[child])

                if term in self.term_direct_gene_map:
                    child_input_list.append(term_gene_out_map[term])

                child_input = torch.cat(child_input_list,1)

                term_NN_out = self._modules[term+'_linear_layer'](child_input)

                Tanh_out = torch.tanh(term_NN_out)
                term_NN_out_map[term] = self._modules[term+'_batchnorm_layer'](Tanh_out)
                aux_layer1_out = torch.tanh(self._modules[term+'_aux_linear_layer1'](term_NN_out_map[term]))
                aux_out_map[term] = self._modules[term+'_aux_linear_layer2'](aux_layer1_out)

        # connect two neural networks at the top #################################################
        final_input = term_NN_out_map[self.root] # torch.cat((term_NN_out_map[self.root], drug_out), 1)

        out = self._modules['final_batchnorm_layer'](torch.tanh(self._modules['final_linear_layer'](final_input)))
        term_NN_out_map['final'] = out

        aux_layer_out = torch.tanh(self._modules['final_aux_linear_layer'](out))
        aux_out_map['final'] = self._modules['final_linear_layer_output'](aux_layer_out)

        return aux_out_map, term_NN_out_map
    
    def forward(self, x):
        
        aux_out_map, term_NN_out_map = self.encoder(x)
        
        mu = term_NN_out_map['final'][..., :self.num_hiddens_final]
        log_var = term_NN_out_map['final'][..., :self.num_hiddens_final]  # T X batch X z_dim
        std_dec = log_var.mul(0.5).exp_()
        # std_dec = 1
        
        latent = MultivariateNormal(loc = mu, 
                                    scale_tril=torch.diag_embed(std_dec))
        z = latent.rsample()
        
        recon_mean = self.decoder_affine(z)
        logits = F.softmax(recon_mean, -1)

        return logits #, mu, log_var, aux_out_map, term_NN_out_map
    
    def loss_log_vae(self, logits, y, mu, log_var, beta = 0.001):
        # y: true labels
        ori_y_shape = y.shape
        
        class_loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), 
                                     y.reshape(-1), reduction = 'none').div(np.log(2)).view(*ori_y_shape)
        
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), 
                              dim = -1)
        
        log_loss = class_loss + beta * KLD
        log_loss = torch.mean(torch.logsumexp(log_loss, 0))
        
        return log_loss
    
    def intermediate_loss(self, aux_out_map, y):
        
        inter_loss = 0
        for name, output in aux_out_map.items():
            if name == 'final':
                inter_loss += 0
            else: # change 0.2 to smaller one for big terms
                ori_y_shape = y.shape
        
                term_loss = F.cross_entropy(output.view(-1, output.shape[-1]), 
                                             y.reshape(-1), 
                                             reduction = 'none').div(np.log(2)).view(*ori_y_shape)
                inter_loss += term_loss

        return inter_loss

In [16]:
model = torch.load("model_200.pt", map_location=DEVICE)

In [101]:
def prepare_node_edge(model, data, target = 0):
    
    ## Generate node list
    layer_ig_dict = {}
    node_label = []
    node_col = []
    node_level = []
    node_x = []
    node_y = []
    
    # Loop through each layer
    for idx, leaves in enumerate(model.term_layer_list):
        term_list = []
        # add leaves in the layer to node list, with color, level and x axis for plot
        node_label.extend(leaves)
        node_col.extend([px.colors.qualitative.Plotly[idx]] * (len(leaves) + 1))
        node_level.extend([idx] * (len(leaves) + 1))
        node_x.extend([(idx )/len(model.term_layer_list)] * (len(leaves) + 1))
        
        # Prepare module list for LayerIG to explain the whole level at once
        for idx_leaf, term in enumerate(leaves):
            term_list.append(model._modules[term+'_linear_layer'])
            node_y.append((idx_leaf)/(len(leaves) + 1))
        
        layer_ig_dict[idx] = LayerIntegratedGradients(model, term_list)

        node_label.append('L' + str(idx) + '_other')
        node_y.append((idx_leaf+1)/(len(leaves) + 1))
    
    node_id_label = {i: l for i, l in enumerate(node_label)}
    node_label_id = {l: i for i, l in enumerate(node_label)}
    
    ## Generate edges
    
    line_col = np.array(['#eb675e', '#d6d6d6', '#3db1ff'])
    
    torch.manual_seed(0)
    np.random.seed(0)

    sankey_source = []
    sankey_target = []
    sankey_value = []
    sankey_color = []
    gene_attr_dict = {}

    for idx, leaves in enumerate(model.term_layer_list[1:]):
        # Explain the layer idx+1
        layer_attr_list = layer_ig_dict[idx+1].attribute(data.to(DEVICE), target=target, 
                                    attribute_to_layer_input=True)
        
        for idx_term, term in tqdm.tqdm(enumerate(leaves), desc="Level"+str(idx+1)):
            # For each leaf at the layer, get its explanation and list of children
            child_list_tmp = model.term_neighbor_map[term]
            term_attr = layer_attr_list[idx_term].detach().cpu()
            loc = 0
            
            # Loop through the list of children, for each child, find the location of the explanation
            # Here, the child is sorted, and the length of vector for the child is defined in model.term_dim_map[child_tmp]
            for child_tmp in child_list_tmp:
                attr_child = term_attr[:, loc:(loc + model.term_dim_map[child_tmp])]
                loc += model.term_dim_map[child_tmp]
                
                weight_child = attr_child.sum(dim=1).cpu().numpy()
                
                sankey_source.append(node_label_id[child_tmp])
                sankey_target.append(node_label_id[term])
                sankey_value.append(weight_child)
                sankey_color.append(line_col[(np.sign(weight_child) + 1).astype('int')])

            gene_attr_dict[term] = term_attr[:,loc:]

    sankey_source = np.array(sankey_source)
    sankey_target = np.array(sankey_target)
    sankey_value = np.array(sankey_value)
    sankey_color = np.array(sankey_color) 
    
    ## Simplify the plot
    
    # sankey_value_id = np.copy(sankey_value[:,sample_id])
    # sankey_source_id = np.copy(sankey_source)
    # sankey_target_id = np.copy(sankey_target)
    
    # level_argsort = []
    # for idx, leaves in enumerate(model.term_layer_list[:-1]):
    #     tmp_leaves_total_attr = []
    #     for term in leaves:
    #         tmp_leaves_total_attr.append(np.sum(np.abs(sankey_value_id[sankey_source_id == node_label_id[term]])))
    #     level_argsort.append(np.argsort(np.abs(tmp_leaves_total_attr)))  # index from smallest to largest

    # for idx, pruned_leaves in enumerate(level_argsort):
    #     for pruned_leaf_idx in pruned_leaves[:int((1-percent_keep[idx])*len(pruned_leaves))]:
    #         term = model.term_layer_list[idx][pruned_leaf_idx]
    #         sankey_source_id[sankey_source_id == node_label_id[term]] = node_label_id['L' + str(idx) + '_other']
    #         sankey_target_id[sankey_target_id == node_label_id[term]] = node_label_id['L' + str(idx) + '_other']
    
    # def list_duplicates(seq):
    #     tally = defaultdict(list)
    #     for i,item in enumerate(seq):
    #         tally[item].append(i)
    #     return ((key,locs) for key,locs in tally.items() 
    #                             if len(locs)>1)

    # loc_delete_list = []
    # for pair, loc in list_duplicates(zip(sankey_source_id, sankey_target_id)):
    #     total_attr = np.sum(sankey_value_id[loc])
    #     sankey_value_id[loc] = total_attr
    #     loc_delete_list.extend(loc[1:])

    # sankey_source_id = np.delete(sankey_source_id, loc_delete_list)
    # sankey_target_id = np.delete(sankey_target_id, loc_delete_list)
    # sankey_value_id = np.delete(sankey_value_id, loc_delete_list)

    # sankey_color = [line_col[int(term_sign + 1)] for term_sign in np.sign(sankey_value_id)]
    
    return sankey_value, sankey_source, sankey_target,\
        node_label_id, node_label, node_col, node_level

In [102]:
def sankey_trimming(sankey_value, sankey_source, sankey_target, 
                    node_label_id,
                    sample_id, 
                    percent_keep = [0.01, 0.02, 0.02, 0.02, 0.1]):
    ## Simplify the plot
        
    line_col = np.array(['#eb675e', '#d6d6d6', '#3db1ff'])    

    sankey_value_id = np.copy(sankey_value[:,sample_id])
    sankey_source_id = np.copy(sankey_source)
    sankey_target_id = np.copy(sankey_target)
    
    level_argsort = []
    for idx, leaves in enumerate(model.term_layer_list[:-1]):
        # For leaves in the layer, go through its edge values
        # The iteration starts with the lowest level (finest), and we look at edges start with the leaf
        # Then, the edges are ordered based on the absolute values of the edge values.
        # sort is from smallest to the largest
        tmp_leaves_total_attr = []
        for term in leaves:
            tmp_leaves_total_attr.append(np.sum(np.abs(sankey_value_id[sankey_source_id == node_label_id[term]])))
        level_argsort.append(np.argsort(np.abs(tmp_leaves_total_attr)))  # index from smallest to largest

    for idx, pruned_leaves in enumerate(level_argsort):
        # Here, we rename the edges with the smallest values to other, level_argsort includes the index of source 
        # nodes ordered based on the value of edges.
        # Both source and target will be renamed to other, then merged in the following steps.
        for pruned_leaf_idx in pruned_leaves[:int((1-percent_keep[idx])*len(pruned_leaves))]:
            term = model.term_layer_list[idx][pruned_leaf_idx]
            sankey_source_id[sankey_source_id == node_label_id[term]] = node_label_id['L' + str(idx) + '_other']
            sankey_target_id[sankey_target_id == node_label_id[term]] = node_label_id['L' + str(idx) + '_other']
    
    def list_duplicates(seq):
        # Use this function to list duplicate pairs of source-target
        tally = defaultdict(list)
        for i,item in enumerate(seq):
            tally[item].append(i)
        return ((key,locs) for key,locs in tally.items() 
                                if len(locs)>1)

    loc_delete_list = []
    for pair, loc in list_duplicates(zip(sankey_source_id, sankey_target_id)):
        # For each of the duplicated pair, assign the sum of edge value to all of them based on location, 
        # then remove all but the first duplicate pair
        total_attr = np.sum(sankey_value_id[loc])
        sankey_value_id[loc] = total_attr
        loc_delete_list.extend(loc[1:])

    sankey_source_id = np.delete(sankey_source_id, loc_delete_list)
    sankey_target_id = np.delete(sankey_target_id, loc_delete_list)
    sankey_value_id = np.delete(sankey_value_id, loc_delete_list)

    sankey_color = [line_col[int(term_sign + 1)] for term_sign in np.sign(sankey_value_id)]
    
    return sankey_value_id, sankey_source_id, sankey_target_id

In [116]:
def plot_sankey(sankey_value_id, sankey_source_id, sankey_target_id,
                node_label, node_col, node_level,
                filename:str = 'test',
                renderers_default = "vscode", positive_flow = True,
                font_size = 15):
    
    pio.renderers.default = renderers_default
    
    if positive_flow:
        selected_edge_list = sankey_value_id >0
        sankey_plot_value_id = np.copy(sankey_value_id[selected_edge_list])
    else:
        selected_edge_list = sankey_value_id <0
        sankey_plot_value_id = -np.copy(sankey_value_id[selected_edge_list])

    node_tmp_list = np.concatenate((sankey_source_id[selected_edge_list], 
                                    sankey_target_id[selected_edge_list]))
    sankey_plot_source_id = np.copy(sankey_source_id[selected_edge_list])
    sankey_plot_target_id = np.copy(sankey_target_id[selected_edge_list])
    


    level_counter = {idx:0 for idx in range(6)}
    sankey_node_label = []
    sankey_node_id = []
    sankey_node_origin_id = []
    sankey_node_col = []
    sankey_node_x = []
    sankey_node_y = []
    sankey_node_level = []

    id_current = 0
    for node_index in node_tmp_list:
        if node_index not in sankey_node_origin_id:
            sankey_node_label.append(node_label[node_index])
            sankey_node_id.append(id_current)
            sankey_node_origin_id.append(node_index)
            sankey_node_col.append(node_col[node_index])
            sankey_node_level.append(node_level[node_index])
            
            sankey_plot_source_id[sankey_source_id[selected_edge_list] == node_index] = id_current
            sankey_plot_target_id[sankey_target_id[selected_edge_list] == node_index] = id_current
            
            level_counter[node_level[node_index]] += 1
            sankey_node_x.append(node_level[node_index]/(max(node_level) + 1))
            sankey_node_y.append(level_counter[node_level[node_index]])
            
            id_current += 1

    sankey_node_x = np.array(sankey_node_x).astype('float')
    sankey_node_y = np.array(sankey_node_y).astype('float')
    sankey_node_level = np.array(sankey_node_level).astype('float')
    
    for level in range(6):
        sankey_node_y[sankey_node_level == level] = sankey_node_y[sankey_node_level == level]/np.max(sankey_node_y[sankey_node_level == level] )
    sankey_node_y[-1] = 0.5
    
    fig = go.Figure(data=[go.Sankey(
        arrangement = "snap",
        node = dict(
            pad = 15,
            thickness = 20,
            # line = dict(color = "black", width = 0.5),
            label = sankey_node_label, #node_label,
            color = sankey_node_col, #node_col,
            # groups = node_groups,
            x = sankey_node_x + 0.01,  #(np.array(node_x)+0.0001),
            y = sankey_node_y - 0.01 # [0.2] * len(sankey_node_x.tolist())   # ,  #(np.array(node_y)+0.0001)
            ),
            link = dict(
            source = sankey_plot_source_id, #, # np.array(sankey_source_id), # indices correspond to labels, eg A1, A2, A1, B1, ...
            target = sankey_plot_target_id, # sankey_target_id,
            value = sankey_plot_value_id,
            # color = sankey_color
            )
        )])

    fig.update_layout(font_size=font_size)
    fig.write_html(filename + ".html")
    fig.show()
    
    return fig

## Draw the plot

In [126]:
torch.manual_seed(0)

train_loader = DataLoader(training_set, batch_size=64, shuffle=True)
test_loader = DataLoader(testing_set, batch_size=256, shuffle=False)
data, labels = testing_set[0:]  # 256

In [127]:
labels.shape

torch.Size([2986])

In [128]:
labels_str = [idx_2_cancer[idx.item()] for idx in labels]
(list(zip(*np.unique(labels_str, return_counts=True))))

[('ACC', 23),
 ('BLCA', 115),
 ('BRCA', 332),
 ('CESC', 90),
 ('CHOL', 11),
 ('COAD', 127),
 ('DLBC', 19),
 ('ESCA', 50),
 ('GBM', 39),
 ('HNSC', 155),
 ('KICH', 28),
 ('KIRC', 148),
 ('KIRP', 86),
 ('LAML', 44),
 ('LGG', 148),
 ('LIHC', 101),
 ('LUAD', 171),
 ('LUSC', 140),
 ('MESO', 28),
 ('OV', 83),
 ('PAAD', 58),
 ('PCPG', 55),
 ('PRAD', 138),
 ('READ', 49),
 ('SARC', 72),
 ('SKCM', 139),
 ('STAD', 112),
 ('TGCT', 48),
 ('THCA', 152),
 ('THYM', 36),
 ('UCEC', 160),
 ('UCS', 14),
 ('UVM', 15)]

In [22]:
sankey_value, sankey_source, sankey_target,\
        node_label_id, node_label, node_col, node_level\
                = prepare_node_edge(model, data, target=labels.to(DEVICE))

Level1: 357it [00:00, 3888.29it/s]
Level2: 183it [00:00, 3721.94it/s]
Level3: 115it [00:00, 3057.85it/s]
Level4: 66it [00:00, 2991.04it/s]
Level5: 1it [00:00, 76.81it/s]


In [40]:
sankey_value_id, sankey_source_id, sankey_target_id \
    = sankey_trimming(sankey_value, sankey_source, sankey_target, 
                    node_label_id,
                    sample_id = 38, 
                    percent_keep = [0.02, 0.04, 0.04, 0.06, 0.1])

In [41]:
tmp = plot_sankey(sankey_value_id, sankey_source_id, sankey_target_id,
                node_label, node_col, node_level)

In [48]:
model.term_neighbor_map['GO:1902533']

['GO:0014068',
 'GO:0043123',
 'GO:0043410',
 'GO:0043950',
 'GO:0050850',
 'GO:0051897',
 'GO:0070304',
 'GO:1901224',
 'GO:2001244',
 'GO:0046579']

### At disease level

In [150]:
test_data, test_labels = testing_set[0:]
labels_str = [idx_2_cancer[idx.item()] for idx in test_labels]
data_ov, labels_ov = testing_set[np.nonzero(np.array(labels_str) == 'OV')[0].tolist()]

data_brca, labels_brca = testing_set[np.nonzero(np.array(labels_str) == 'BRCA')[0].tolist()]



torch.Size([83, 3008])

In [232]:
(inputdata, labels) = testing_set[:280]
inputdata = torch.cat((inputdata, data_ov))
labels = torch.cat((labels, labels_ov))

labels_str = [idx_2_cancer[idx.item()] for idx in labels]
list(zip(*np.unique(labels_str, return_counts=True)))

[('ACC', 2),
 ('BLCA', 16),
 ('BRCA', 27),
 ('CESC', 10),
 ('CHOL', 1),
 ('COAD', 18),
 ('DLBC', 1),
 ('ESCA', 2),
 ('GBM', 3),
 ('HNSC', 12),
 ('KICH', 1),
 ('KIRC', 15),
 ('KIRP', 7),
 ('LAML', 3),
 ('LGG', 16),
 ('LIHC', 4),
 ('LUAD', 20),
 ('LUSC', 7),
 ('MESO', 5),
 ('OV', 90),
 ('PAAD', 5),
 ('PCPG', 5),
 ('PRAD', 11),
 ('READ', 4),
 ('SARC', 12),
 ('SKCM', 11),
 ('STAD', 10),
 ('TGCT', 4),
 ('THCA', 11),
 ('THYM', 4),
 ('UCEC', 22),
 ('UVM', 4)]

In [233]:
torch.manual_seed(0)
logits = model(inputdata.to(DEVICE))

In [235]:
pred_res = torch.argmax(logits, 1).cpu() == labels
torch.sum(pred_res)/len(labels)

tensor(0.8843)

In [236]:
accu_per_type = {}

for cancer_type in range(33):
    accu_per_type[idx_2_cancer[cancer_type]] = (pred_res[labels == cancer_type]).float().mean().item()

In [237]:
accu_per_type

{'BRCA': 1.0,
 'LUAD': 0.75,
 'DLBC': 0.0,
 'UCEC': 0.8636363744735718,
 'SKCM': 1.0,
 'PRAD': 1.0,
 'HNSC': 1.0,
 'KIRP': 0.7142857313156128,
 'CESC': 0.0,
 'THCA': 1.0,
 'KIRC': 1.0,
 'STAD': 1.0,
 'COAD': 0.9444444179534912,
 'READ': 0.0,
 'LGG': 1.0,
 'MESO': 0.0,
 'LAML': 1.0,
 'BLCA': 0.9375,
 'OV': 0.9888888597488403,
 'LUSC': 1.0,
 'ACC': 0.0,
 'THYM': 1.0,
 'ESCA': 1.0,
 'PAAD': 0.800000011920929,
 'LIHC': 1.0,
 'SARC': 1.0,
 'GBM': 1.0,
 'TGCT': 1.0,
 'KICH': 0.0,
 'PCPG': 1.0,
 'UCS': nan,
 'UVM': 0.0,
 'CHOL': 0.0}

In [238]:
sankey_value, sankey_source, sankey_target, \
        node_label_id, node_label, node_col, node_level\
                = prepare_node_edge(model, inputdata, target=labels.to(DEVICE))


Multiple layers provided. Please ensure that each layer is**not** solely dependent on the outputs ofanother layer. Please refer to the documentation for moredetail.

Level1: 357it [00:00, 4890.51it/s]
Level2: 183it [00:00, 3067.79it/s]
Level3: 115it [00:00, 2604.65it/s]
Level4: 66it [00:00, 2048.64it/s]
Level5: 1it [00:00, 52.78it/s]


In [239]:
sankey_value[:, labels == 18].shape

(3167, 90)

In [240]:
node_label[2068]

'GO:0009725'

In [241]:
sankey_source[2984]

2068

In [242]:
sankey_target[2984]

2090

In [243]:
node_label[2090]

'GO:0008150'

In [244]:
sankey_value[2984,:].mean()

0.012110865949486897

In [245]:
cancer_2_idx['OV']

18

In [246]:
sankey_value_id_ov, sankey_source_id_ov, sankey_target_id_ov\
    = sankey_trimming(sankey_value[:, labels == cancer_2_idx['OV']].mean(axis=1, keepdims=True), 
                      sankey_source, sankey_target, 
                    node_label_id, 
                    sample_id = 0, 
                    percent_keep = [0.015, 0.03, 0.07, 0.1, 0.2])

In [247]:
ov_sankey = plot_sankey(sankey_value_id_ov, sankey_source_id_ov, sankey_target_id_ov,
                node_label, node_col, node_level, positive_flow=True)

In [250]:
sankey_value_id_brca, sankey_source_id_brca, sankey_target_id_brca\
    = sankey_trimming(sankey_value[:, labels == cancer_2_idx['BRCA']].mean(axis=1, keepdims=True), 
                      sankey_source, sankey_target, 
                    node_label_id, 
                    sample_id = 0, 
                    percent_keep = [0.03, 0.05, 0.06, 0.06, 0.15])

In [251]:
brca_sankey = plot_sankey(sankey_value_id_brca, sankey_source_id_brca, sankey_target_id_brca,
                node_label, node_col, node_level, positive_flow=True)

### For L0 level importance

In [252]:
for idx, leaves in enumerate(model.term_layer_list):
    term_list = []
    
    for idx_leaf, term in enumerate(leaves):
        term_list.append(model._modules[term+'_linear_layer'])
    
    layer_ig_dict[idx] = LayerIntegratedGradients(model, term_list)

layer_attr_list = layer_ig_dict[0].attribute(data.to(DEVICE), target=target, 
                                    attribute_to_layer_input=True)

[['GO:0007006',
  'GO:0008637',
  'GO:0006284',
  'GO:0006283',
  'GO:0019985',
  'GO:0000724',
  'GO:0006303',
  'GO:0044030',
  'GO:0010569',
  'GO:0045830',
  'GO:0045739',
  'GO:2000279',
  'GO:0032212',
  'GO:0051973',
  'GO:0000070',
  'GO:0000083',
  'GO:0000281',
  'GO:0000038',
  'GO:0019372',
  'GO:0006635',
  'GO:0043651',
  'GO:0042759',
  'GO:0046457',
  'GO:0006826',
  'GO:0097711',
  'GO:0090307',
  'GO:0006656',
  'GO:0009309',
  'GO:0033014',
  'GO:0042398',
  'GO:0042423',
  'GO:0043043',
  'GO:0046112',
  'GO:0072525',
  'GO:1901687',
  'GO:0046513',
  'GO:0006497',
  'GO:0009066',
  'GO:1901607',
  'GO:0009065',
  'GO:0006536',
  'GO:0006541',
  'GO:0006658',
  'GO:0006749',
  'GO:0006760',
  'GO:0043303',
  'GO:0051303',
  'GO:0047496',
  'GO:0007050',
  'GO:0045736',
  'GO:0031572',
  'GO:0043407',
  'GO:0043506',
  'GO:0045737',
  'GO:0000086',
  'GO:0034401',
  'GO:0043620',
  'GO:2000142',
  'GO:0000096',
  'GO:0044273',
  'GO:0050427',
  'GO:0035384',
  'GO:00