In [1]:
import socket
if socket.gethostname() == 'dlm':
  %env CUDA_DEVICE_ORDER=PCI_BUS_ID
  %env CUDA_VISIBLE_DEVICES=0

In [2]:
import os
import sys
import re
import collections
import functools
import itertools
import requests, zipfile, io
import pickle
import copy
import time

import pandas
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import sklearn
import sklearn.decomposition
import sklearn.metrics
import networkx

import torch
import torch.nn as nn

lib_path = 'I:/code'
if not os.path.exists(lib_path):
  lib_path = '/media/6T/.tianle/.lib'
if not os.path.exists(lib_path):
  lib_path = '/projects/academic/azhang/tianlema/lib'
if os.path.exists(lib_path) and lib_path not in sys.path:
  sys.path.append(lib_path)

from dl.models.transformer import MultiheadAttention, EncoderAttention, DecoderAttention, Transformer, StackedEncoder
from dl.models.dag import *
from dl.models.basic_models import *
from dl.models.factor_graph import BipartiteGraph1d, BipartiteGraph, EmbedCell, GeneNet
from dl.utils.visualization.visualization import *
from dl.utils.outlier import *
from dl.utils.train import *
from autoencoder.autoencoder import *
from vin.vin import *
from dl.utils.utils import get_overlap_samples, filter_clinical_dict, get_target_variable
from dl.utils.utils import get_shuffled_data, target_to_numpy, discrete_to_id, get_mi_acc
from dl.utils.utils import get_label_distribution, normalize_continuous_variable, tensor_to_num
from dl.utils.utils import adj_list_to_mat, adj_list_to_attention_mats, count_model_parameters

%load_ext autoreload
%autoreload 2


use_gpu = True
if use_gpu and torch.cuda.is_available():
  device = torch.device('cuda')
  print('Using GPU:)')
else:
  device = torch.device('cpu')
  print('Using CPU:(')
  
inf = float('Inf')

Using GPU:)


In [3]:
overwrite = False
save_filepath = 'htseq_cnt_go_clinical.pkl'
if overwrite or not os.path.exists(save_filepath):
  print(f'Write file {save_filepath}')
  with open(save_filepath, 'wb') as f:
      pickle.dump({'htseq_cnt_mat': htseq_cnt_mat, 'gene_ids': gene_ids, 
                   'aliquot_ids': aliquot_ids, 'go_gene_list': go_gene_list, 
                  'go_edges': go_edges, 'string_ppi': ppi, 
                  'percent_tumor_nuclei': percent_tumor_nuclei,
                  'sample_type': sample_type, 'diagnoses': diagnoses,
                  'clinical': clinical, 'cases': cases, 'survival_plot': survival_plot}, 
                 f)

with open(save_filepath, 'rb') as f:
  data = pickle.load(f)

go_gene_list = data['go_gene_list']
go_edges = data['go_edges']
gene_ids = data['gene_ids']
htseq_cnt_mat = data['htseq_cnt_mat']
ppi = data['string_ppi']

mat = np.log2(htseq_cnt_mat+1)

In [4]:
num_gene = 5000
min_num_gene_per_go = 5

mean = mat.mean(axis=1)
std = mat.std(axis=1)
idx = np.argsort(-std)[:num_gene]
gene_ids = gene_ids[idx]
mat = mat[idx]
print(mat.shape, gene_ids.shape)

(5000, 11417) (5000,)


In [5]:
common_geneset = set(functools.reduce(lambda x,y: x+y, go_gene_list.values()))
while common_geneset != set(gene_ids):
  common_geneset = common_geneset.intersection(gene_ids)
  # remove genes that are not in common_geneset
  go_gene_list = {k: sorted(common_geneset.intersection(v)) for k, v in go_gene_list.items()}
  # remove GO terms that have less than min_num_gene_per_go genes
  go_gene_list = {k: v for k, v in go_gene_list.items() if len(v)>=min_num_gene_per_go}
  go_ids = set(go_gene_list)
  # some genes may no longer appear in go_gene_list, remove them
  common_geneset_new = set(functools.reduce(lambda x,y: x+y, go_gene_list.values()))
  idx_gene = np.array([(i, g) for i, g in enumerate(gene_ids) if g in common_geneset_new])
  gene_ids = idx_gene[:, 1]
  mat = mat[idx_gene[:, 0].astype('int')]
  print(gene_ids.shape, mat.shape, len(common_geneset))

(4994,) (4994, 11417) 5000
(4994,) (4994, 11417) 4994


In [6]:
num_steps = 4 # len(chain_graph_go)
name_to_id_gene = {n:i for i, n in enumerate(gene_ids)}
num_gene = len(gene_ids)

num_go = len(go_gene_list)
go_ids = set(go_gene_list)
go_edges = go_edges[[s[0] in go_ids and s[1] in go_ids for s in go_edges]]
name_to_id_go, chain_graph_go = get_topological_order(go_edges[:,[1,0]])
print([(i, len(v)) for i, v in enumerate(chain_graph_go)])
# there are some isolated GO terms; include them as well in the model
for go in go_gene_list:
  if go not in name_to_id_go:
    name_to_id_go[go] = len(name_to_id_go)
# prepare dag
dag = collections.defaultdict(list)
for s in go_edges:
  left = name_to_id_go[s[0]]
  right = name_to_id_go[s[1]]
  dag[right].append(left)
dag = {k: sorted(set(v)) for k, v in dag.items()}
# prepare bigraph
id_to_name_go = {i: n for n, i in name_to_id_go.items()}
num_leaf_go = min(dag)
bigraph = []
for i in range(num_leaf_go):
  bigraph.append(sorted([name_to_id_gene[v] for v in go_gene_list[id_to_name_go[i]]]))
print(num_go, num_gene, num_steps)

[(0, 1194), (1, 467), (2, 214), (3, 100), (4, 59), (5, 28), (6, 10), (7, 4), (8, 3), (9, 372)]
2696 4994 4


In [7]:
attention_mats_filepath = f'attention_mats-{num_gene}-{num_go}-{num_steps}-GeneNet.pkl'
overwrite = True
if overwrite or not os.path.exists(attention_mats_filepath):
#   gene_gene_adj_mat, _ = adj_list_to_mat(ppi, name_to_id=name_to_id_gene, bipartite=False, 
#                                        add_self_loop=True, symmetric=True, return_list=False)
  gene_gene_adj_mat = np.zeros((num_gene, num_gene))
  for s in ppi:
    if s[0] in name_to_id_gene and s[1] in name_to_id_gene:
      left = name_to_id_gene[s[0]]
      right = name_to_id_gene[s[1]]
      gene_gene_adj_mat[left, right] = float(s[2])/1000
  gene_gene_adj_mat[range(num_gene), range(num_gene)] = 1
  
  go_go_adj_mat, _ = adj_list_to_mat(go_edges, name_to_id=name_to_id_go, bipartite=False, 
                   add_self_loop=False, symmetric=False, return_list=False)
  
  gene_go_adj_mat = np.zeros((num_gene, num_go))
  for k, v in go_gene_list.items():
    go_id = name_to_id_go[k]
    for g in v:
      gene_id = name_to_id_gene[g]
      gene_go_adj_mat[gene_id, go_id] = 1
      
  attention_mats = {}
  start_time = time.time()
  attention_mats['gene1->gene0'], id_to_name_gene = adj_list_to_attention_mats(
    adj_list=None, num_steps=num_steps, name_to_id=name_to_id_gene, bipartite=False, 
    add_self_loop=True, symmetric=True, target_to_source=None, use_transition_matrix=True, 
    Ms=gene_gene_adj_mat, softmax_normalization=False, min_value=-100, device=device)

  attention_mats['pathway1->pathway0'], id_to_name_go = adj_list_to_attention_mats(
    adj_list=None, num_steps=num_steps, name_to_id=name_to_id_go, bipartite=False, 
    add_self_loop=False, symmetric=False, target_to_source=None, use_transition_matrix=True, 
    Ms=go_go_adj_mat, softmax_normalization=False, min_value=-100, device=device)

  mats, _ = adj_list_to_attention_mats(
    adj_list=None, num_steps=num_steps*2, name_to_id=[name_to_id_gene, name_to_id_go], 
    bipartite=True, add_self_loop=False, symmetric=False, target_to_source=None, 
    use_transition_matrix=True, Ms=gene_go_adj_mat, softmax_normalization=False, min_value=-100, 
    device=device)
  # this is very tricky: 
  # the even positions are all gene->pathway in mats[0], while odd ones gene->gene
  attention_mats['gene0->pathway1'] = [m for i, m in enumerate(mats[0]) if i%2==0]
  attention_mats['pathway0->gene1'] = [m for i, m in enumerate(mats[1]) if i%2==0]

  end_time = time.time()
  print(f'Time spent on generating attention_mats {end_time - start_time} s')

#   start_time = time.time()
#   with open(attention_mats_filepath, 'wb') as f:
#     if device == torch.device('cuda'):
#       for k, v in attention_mats.items():
#         for i in range(len(v)):
#           v[i] = v[i].detach().cpu().numpy()
#     pickle.dump(attention_mats, f)
#   end_time = time.time()
#   print(f'Time spent on writing attention_mats {end_time - start_time} s')

# start_time = time.time()
# with open('attention_mats-GeneNet.pkl', 'rb') as f:
#   attention_mats = pickle.load(f)
#   for k, v in attention_mats.items():
#     for i in range(len(v)):
#       v[i] = torch.tensor(v[i]).float().to(device)
# end_time = time.time()
# print(f'Time spent on loading attention_mats {end_time - start_time} s')

Graph with 4994 nodes and 147228 edges
Graph with 2696 nodes and 2815 edges
Bipartite graph: in_features=4994, out_features=2696
Time spent on generating attention_mats 3.1285510063171387 s


In [8]:
# use_dag_layer = False
# dag_in_channel_list = [1,1,1]
# dag_kwargs = {'residual':True, 'duplicate_dag':True}
# batch_size = 500
# model = GeneNet(num_genes=num_gene, num_pathways=num_go, attention_mats=None, dense=True,
#                 use_dag_layer=use_dag_layer, dag=dag, dag_in_channel_list=dag_in_channel_list, 
#                 dag_kwargs=dag_kwargs,
#                 nonlinearity=nn.ReLU(), use_layer_norm=True).to(device)

# x = torch.randn(batch_size, num_gene).to(device)

# start_time = time.time()
# y = model(x, attention_mats=attention_mats, max_num_layers=num_steps, min_num_layers=num_steps, 
#           return_layers='all')
# end_time = time.time()
# print(f'Time spent on forward pass {end_time - start_time} s')
# print(y[0].shape, y[1].shape, y[2].shape, y[3].shape)

# start_time = time.time()
# loss = y[0].sum() + y[1].sum() + y[2].sum() + y[3].sum()
# loss.backward()
# end_time = time.time()
# print(f'Time spent on backward pass {end_time - start_time} s')

In [9]:
# use_dag_layer = True
# dag_in_channel_list = [1]
# dag_kwargs = {'residual':True, 'duplicate_dag':True}
# batch_size = 500
# model = GeneNet(num_genes=num_gene, num_pathways=num_go, attention_mats=None, dense=True,
#                 use_dag_layer=use_dag_layer, dag=dag, dag_in_channel_list=dag_in_channel_list, 
#                 dag_kwargs=dag_kwargs,
#                 nonlinearity=nn.ReLU(), use_layer_norm=True).to(device)

# x = torch.randn(batch_size, num_gene).to(device)

# start_time = time.time()
# y = model(x, attention_mats=attention_mats, max_num_layers=num_steps, min_num_layers=num_steps, 
#           return_layers='all')
# end_time = time.time()
# print(f'Time spent on forward pass {end_time - start_time} s')
# print(y[0].shape, y[1].shape, y[2].shape, y[3].shape)

# start_time = time.time()
# loss = y[0].sum() + y[1].sum() + y[2].sum() + y[3].sum()
# loss.backward()
# end_time = time.time()
# print(f'Time spent on backward pass {end_time - start_time} s')

In [10]:
embedding_dim = 5
in_channels_list = [5]
key_dim = 5
value_dim = 5
fc_dim = 5
dim_per_cls = 1
num_heads = 1
num_attention = 1
num_go = max(dag)+1
print(num_go)
graph_encoder = [m[:num_go][:,:num_go] for m in attention_mats['pathway1->pathway0']]
graph_weight_encoder = 0.5
knn = 50

residual = True 
duplicate_dag = True 
gibbs_sampling = True 
duplicated_attention = True 
feature_max_norm = 1
use_layer_norm = True 
bias = True 
nonlinearity = nn.ReLU()

batch_size = 15
x = torch.randn(batch_size, num_gene).to(device)
num_features = x.size(1)
num_cls = 2

model = DAGEncoder(num_features=num_features, embedding_dim=embedding_dim, 
         in_channels_list=in_channels_list, 
         bigraph=bigraph, dag=dag, key_dim=key_dim, value_dim=value_dim, fc_dim=fc_dim, 
         num_cls=num_cls, dim_per_cls=dim_per_cls, feature_max_norm=feature_max_norm, 
         use_layer_norm=use_layer_norm, bias=bias, 
         nonlinearity=nonlinearity, residual=residual, duplicate_dag=duplicate_dag, 
         gibbs_sampling=gibbs_sampling, num_heads=num_heads, num_attention=num_attention,
         knn=knn, duplicated_attention=duplicated_attention,
         graph_encoder=None, graph_weight_encoder=graph_weight_encoder, 
         graph_decoder=None, graph_weight_decoder=0.5, use_encoders=True).to(device)

start_time = time.time()
y = model(x, graph_encoder=graph_encoder)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = y.sum()
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')

2451
