In [None]:
import socket
if socket.gethostname() == 'dlm':
  %env CUDA_DEVICE_ORDER=PCI_BUS_ID
  %env CUDA_VISIBLE_DEVICES=0

In [None]:
import os
import sys
import re
import collections
import functools
import itertools
import requests, zipfile, io
import pickle
import copy
import time

import pandas
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import sklearn
import sklearn.decomposition
import sklearn.metrics
import networkx

import torch
import torch.nn as nn

lib_path = 'I:/code'
if not os.path.exists(lib_path):
  lib_path = '/media/6T/.tianle/.lib'
if not os.path.exists(lib_path):
  lib_path = '/projects/academic/azhang/tianlema/lib'
if os.path.exists(lib_path) and lib_path not in sys.path:
  sys.path.append(lib_path)

from dl.models.transformer import MultiheadAttention, EncoderAttention, DecoderAttention, Transformer, StackedEncoder
from dl.models.dag import *
from dl.models.basic_models import *
from dl.models.factor_graph import BipartiteGraph1d, BipartiteGraph, EmbedCell, GeneNet, PathNet, GraphConvolution1d
from dl.utils.visualization.visualization import *
from dl.utils.outlier import *
from dl.utils.train import *
from autoencoder.autoencoder import *
from vin.vin import *
from dl.utils.utils import get_overlap_samples, filter_clinical_dict, get_target_variable
from dl.utils.utils import get_shuffled_data, target_to_numpy, discrete_to_id, get_mi_acc
from dl.utils.utils import get_label_distribution, normalize_continuous_variable, tensor_to_num
from dl.utils.utils import adj_list_to_mat, adj_list_to_attention_mats, count_model_parameters
from dl.utils.utils import get_split

%load_ext autoreload
%autoreload 2


use_gpu = True
if use_gpu and torch.cuda.is_available():
  device = torch.device('cuda')
  print('Using GPU:)')
else:
  device = torch.device('cpu')
  print('Using CPU:(')
  
inf = float('Inf')

In [None]:
use_jia_data = False
sel_proj_id = 'BRCA'
target_name = 'tumor_stage'
split_portion = [1, 1, 8]
seed = 0
randomize_labels = False
init_num_gene = 5000
min_num_gene_per_go = 5
lr = 1e-3
weight_decay = 1e-4
num_epochs = 50
reduce_every = 100
batch_size = None
print_every = 1
eval_every = 1
return_best_val = True
result_folder = 'results'

In [None]:
split_portion_str = np.array(split_portion)
split_portion_str = split_portion_str * 100 / split_portion_str.sum()
split_portion_str = '-'.join(map(lambda s: str(int(s)), split_portion_str))
res_filename_prefix = f'{sel_proj_id}_{split_portion_str}_seed{seed}_{target_name}'
res_filename = f'{res_filename_prefix}.pkl'
if not os.path.exists(result_folder):
  os.makedirs(result_folder)
if use_jia_data:
  sel_proj_id = 'JIA'
else:
  if isinstance(sel_proj_id, str):
    if not sel_proj_id.startswith('TCGA-'):
      sel_proj_id = 'TCGA-' + sel_proj_id
      sel_proj_list = [sel_proj_id]
  elif isinstance(sel_proj_id, (list, tuple)): # currently no use
    sel_proj_id = [proj_id if proj_id.startswith('TCGA-') else 'TCGA-' + proj_id 
                   for proj_id in sel_proj_id]
    sel_proj_list = sel_proj_id

data_folder = 'F:/TCGA/GDC13.0'
if not os.path.exists(data_folder):
  data_folder = '/media/6T/.Trash-1014/GDC13.0'
if not os.path.exists(data_folder):
  data_folder = '/projects/academic/azhang/tianlema/TCGA/GDC13.0'  
data_filepath = f'{data_folder}/htseq_cnt_go_clinical.pkl'
# if not os.path.exists(data_filepath):
#   print(f'Write file {data_filepath}')
#   with open(save_filepath, 'wb') as f:
#       pickle.dump({'htseq_cnt_mat': htseq_cnt_mat, 'gene_ids': gene_ids, 
#                    'aliquot_ids': aliquot_ids, 'go_gene_list': go_gene_list, 
#                   'go_edges': go_edges, 'string_ppi': ppi, 
#                   'percent_tumor_nuclei': percent_tumor_nuclei,
#                   'sample_type': sample_type, 'diagnoses': diagnoses,
#                   'clinical': clinical, 'cases': cases, 'survival_plot': survival_plot}, 
#                  f)
with open(data_filepath, 'rb') as f:
  data = pickle.load(f)

go_gene_list = data['go_gene_list']
go_edges = data['go_edges']
gene_ids = data['gene_ids']
htseq_cnt_mat = data['htseq_cnt_mat']
ppi = data['string_ppi']
clinical = data['clinical']
aliquot_ids = data['aliquot_ids']

In [None]:
if use_jia_data:
  print('Use JIA data')
  with open(f'{data_folder}/JIA.pkl', 'rb') as f:
    jia = pickle.load(f)
  gene_cnt = jia['gene_cnt']
  sample_info = jia['sample_info']
  gene_ids = np.array(gene_cnt.index)
  mat = np.log2(gene_cnt.values+1)
  type_to_id = {'ADT': 1, 'CRM':2, 'HC':0}
  y_true = np.array([type_to_id[s] for s in sample_info['disease.status']])
  sel_idx = y_true<2
  y_true = y_true[sel_idx]
  mat = mat[:, sel_idx]
  xs = [mat]
  ys = [y_true]
else:
  print('Use TCGA data')
  mat = np.log2(htseq_cnt_mat+1)
  # most are from are TCGA projects, only a few are from TARGET
  # there are technical duplicates; for simplicity, treat them as different samples
  submitter_ids = [s[:12] if s.startswith('TCGA') else s[:16] for s in aliquot_ids]
  submitter_set = set(submitter_ids)
  sample_submitter_ids = [s[:16] if s.startswith('TCGA') else s[:20] for s in aliquot_ids]
  sample_submitter_set = set(sample_submitter_ids)

  sample_tumor_stage = {}
  diagnoses = data['diagnoses']
  # print(list(diagnoses.columns).index('submitter_id'), 
  #       list(diagnoses.columns).index('tumor_stage'))
  for s in diagnoses.values:
    if s[9].startswith('TCGA') and s[9].endswith('diagnosis'):
      sample_tumor_stage[s[9][:-10]] = s[11]
  sample_tumor_stage = {k: v for k, v in sample_tumor_stage.items() if k in submitter_set}
  sample_tumor_stage = {k: v for k, v in sample_tumor_stage.items() if v.startswith('stage')}

  sample_type = data['sample_type']
  common_sample_submitter_ids = set(sample_type.submitter_id).intersection(sample_submitter_ids)
  sample_type = {k: v for k, v in zip(sample_type.submitter_id, sample_type.sample_type) 
                 if k in common_sample_submitter_ids}

  # currently not used
  percent_tumor_nuclei = data['percent_tumor_nuclei']
  percent_tumor_nuclei = {k: v for k, v in zip(percent_tumor_nuclei.submitter_id, 
                                          percent_tumor_nuclei.percent_tumor_nuclei)
                          if k in sample_submitter_set}

  submitter_proj_dict = {s: p for s, p in zip(clinical.submitter_id, clinical.project_id)}
  proj_case_cnt = collections.Counter([submitter_proj_dict[s] for s in submitter_ids])
  # sorted(zip(proj_case_cnt.values(), proj_case_cnt.keys()))
  proj_case_loc_list = collections.defaultdict(list)
  for i, s in enumerate(submitter_ids):
    proj_id = submitter_proj_dict[s]
    proj_case_loc_list[proj_id].append(i)
#   candidate_proj_list = []
#   for k, v in proj_case_loc_list.items():
#     if k.startswith('TCGA'):
#       tmp = collections.Counter(s[13:15] for s in aliquot_ids[v])
#       if tmp['01'] >= 100 and tmp['11'] >= 50:
#         print(k, tmp['01'], tmp['11'])
#         candidate_proj_list.append(k[5:])
#   print(candidate_proj_list)
    
  if target_name == 'tumor_stage':
    tumor_stage_cnt = {}
    min_num_per_cls = 100
    min_num_per_type = 200
    for k, v in proj_case_loc_list.items():
      if k.startswith('TCGA') and len(v)>0:
        sample_tumor_stage_cnt = collections.Counter(sample_tumor_stage[s]
                                              for s in np.array(submitter_ids)[v] 
                                              if s in sample_tumor_stage)
        sample_tumor_stage_cnt = {k: v for k, v in sample_tumor_stage_cnt.items() 
                                  if v>=min_num_per_cls}
        if (len(sample_tumor_stage_cnt)>1 
            and sum(sample_tumor_stage_cnt.values())>=min_num_per_type):
          tumor_stage_cnt[k] = sample_tumor_stage_cnt
#           print(k)
#           print(sample_tumor_stage_cnt)
#           print()
    # only consider TCGA projects
    assert isinstance(sel_proj_id, str)
    sel_cls = set(tumor_stage_cnt[sel_proj_id])
    sel_aliquot_loc = []
    y_target = []
    for i, s in enumerate(aliquot_ids):
      if (submitter_proj_dict[submitter_ids[i]] == sel_proj_id 
          and submitter_ids[i] in sample_tumor_stage
          and sample_tumor_stage[submitter_ids[i]] in sel_cls
          and s[13:15] == '01'):
        sel_aliquot_loc.append(i)
        y_target.append(sample_tumor_stage[submitter_ids[i]])
    sel_aliquot_loc = np.array(sel_aliquot_loc)
    y_target = np.array(y_target)
    y_true, target_name_to_id = discrete_to_id(y_target)
    ys = [y_true]
    mat = mat[:, sel_aliquot_loc]
    xs = [mat]
  elif target_name == 'sample_type':
    xs = []
    ys = []
    for proj_id in sel_proj_list:
      sel_aliquot_loc = proj_case_loc_list[proj_id]
      x = mat[:, sel_aliquot_loc]
      sample_type_01_loc = [i for i, s in enumerate(aliquot_ids[sel_aliquot_loc]) if s[13:15]=='01']
      sample_type_11_loc = [i for i, s in enumerate(aliquot_ids[sel_aliquot_loc]) if s[13:15]=='11']

      x = np.concatenate([x[:, sample_type_01_loc], x[:, sample_type_11_loc]], axis=1)
      xs.append(x)
      # sample_type_01: primary solid tumor label -> 1
      # sample_type_11: solid normal label -> 0
      y = [1]*len(sample_type_01_loc) + [0]*len(sample_type_11_loc)
      ys.append(y)
    mat = np.concatenate(xs, axis=1)
    y_true = np.concatenate(ys)
  elif target_name == 'pfi' or target_name == 'overall_survival':
    with open(f'{data_folder}/clinical_pancan.pkl', 'rb') as f:
      clinical_pancan = pickle.load(f)
    # Use Pan-Cancer clinical dataset published in Cell
    pfi = {}
    overall_survival = {}
    for sid, p, o in zip(clinical_pancan.bcr_patient_barcode, clinical_pancan.PFI, 
                         clinical_pancan.OS):
      if sid in submitter_set:
        if isinstance(p, float) and p==p:
          pfi[sid] = int(p)
        if isinstance(o, float) and o==o:
          overall_survival[sid] = int(o)
    clinical_pancan = {'pfi': pfi, 'overall_survival': overall_survival}
#     tmp = collections.defaultdict(list)
#     for k, v in clinical_pancan[target_name].items():
#       tmp[submitter_proj_dict[k]].append(v)
#     candidate_proj_list = []
#     for k, v in tmp.items():
#       if collections.Counter(v)[0]>=100 and collections.Counter(v)[1]>=100:
#         print(k, collections.Counter(v)[0], collections.Counter(v)[1])
#         candidate_proj_list.append(k[5:])
#     print(candidate_proj_list)
    sel_aliquot_loc = []
    y_target = []
    for i, s in enumerate(aliquot_ids):
      if (submitter_proj_dict[submitter_ids[i]] in sel_proj_list
          and submitter_ids[i] in clinical_pancan[target_name]
          and s[13:15] == '01'):
        # there are technical duplicates; for simplicity, treat them as different samples
        sel_aliquot_loc.append(i)
        y_target.append(clinical_pancan[target_name][submitter_ids[i]])
    sel_aliquot_loc = np.array(sel_aliquot_loc)
    y_target = np.array(y_target)
    y_true, target_name_to_id = discrete_to_id(y_target)
    ys = [y_true]
    mat = mat[:, sel_aliquot_loc]
    xs = [mat]

num_cls = len(np.unique(y_true))
print(sel_proj_id, target_name, mat.shape, y_true.shape)

In [None]:
mean = mat.mean(axis=1)
std = mat.std(axis=1)
idx = np.argsort(-std)[:init_num_gene]
gene_ids = gene_ids[idx]
mat = mat[idx]
print(mat.shape, gene_ids.shape)

In [None]:
common_geneset = set(functools.reduce(lambda x,y: x+y, go_gene_list.values()))
while common_geneset != set(gene_ids):
  common_geneset = common_geneset.intersection(gene_ids)
  # remove genes that are not in common_geneset
  go_gene_list = {k: sorted(common_geneset.intersection(v)) for k, v in go_gene_list.items()}
  # remove GO terms that have less than min_num_gene_per_go genes
  go_gene_list = {k: v for k, v in go_gene_list.items() if len(v)>=min_num_gene_per_go}
  go_ids = set(go_gene_list)
  # some genes may no longer appear in go_gene_list, remove them
  common_geneset_new = set(functools.reduce(lambda x,y: x+y, go_gene_list.values()))
  idx_gene = np.array([(i, g) for i, g in enumerate(gene_ids) if g in common_geneset_new])
  gene_ids = idx_gene[:, 1]
  mat = mat[idx_gene[:, 0].astype('int')]
  print(gene_ids.shape, mat.shape, len(common_geneset))

if len(xs) > 1:
  cnt = 0
  for i, x in enumerate(xs):
    xs[i] = mat[:, cnt:cnt+x.shape[1]]
    cnt = cnt + x.shape[1]

In [None]:
num_steps = 4 # len(chain_graph_go)
name_to_id_gene = {n:i for i, n in enumerate(gene_ids)}
num_gene = len(gene_ids)

num_go = len(go_gene_list)
go_ids = set(go_gene_list)
go_edges = go_edges[[s[0] in go_ids and s[1] in go_ids for s in go_edges]]
name_to_id_go, chain_graph_go = get_topological_order(go_edges[:,[1,0]])
print('chain_graph_go', [(i, len(v)) for i, v in enumerate(chain_graph_go)])
# there are some isolated GO terms; include them as well in the model
for go in go_gene_list:
  if go not in name_to_id_go:
    name_to_id_go[go] = len(name_to_id_go)
# prepare dag
dag = collections.defaultdict(list)
for s in go_edges:
  left = name_to_id_go[s[0]]
  right = name_to_id_go[s[1]]
  dag[right].append(left)
dag = {k: sorted(set(v)) for k, v in dag.items()}
# prepare bigraph
id_to_name_go = {i: n for n, i in name_to_id_go.items()}
num_leaf_go = min(dag)
bigraph = []
for i in range(num_leaf_go):
  bigraph.append(sorted([name_to_id_gene[v] for v in go_gene_list[id_to_name_go[i]]]))
print(f'num_go={num_go}, num_gene={num_gene}, num_steps={num_steps}')

In [None]:
attention_mats_filepath = f'attention_mats-{num_gene}-{num_go}-{num_steps}-GeneNet.pkl'
overwrite = True
if overwrite or not os.path.exists(attention_mats_filepath):
#   gene_gene_adj_mat, _ = adj_list_to_mat(ppi, name_to_id=name_to_id_gene, bipartite=False, 
#                                        add_self_loop=True, symmetric=True, return_list=False)
  gene_gene_adj_mat = np.zeros((num_gene, num_gene))
  for s in ppi:
    if s[0] in name_to_id_gene and s[1] in name_to_id_gene:
      left = name_to_id_gene[s[0]]
      right = name_to_id_gene[s[1]]
      gene_gene_adj_mat[left, right] = float(s[2])/1000
  gene_gene_adj_mat[range(num_gene), range(num_gene)] = 1
  
  go_go_adj_mat, _ = adj_list_to_mat(go_edges, name_to_id=name_to_id_go, bipartite=False, 
                   add_self_loop=False, symmetric=False, return_list=False)
  
  gene_go_adj_mat = np.zeros((num_gene, num_go))
  for k, v in go_gene_list.items():
    go_id = name_to_id_go[k]
    for g in v:
      gene_id = name_to_id_gene[g]
      gene_go_adj_mat[gene_id, go_id] = 1
      
  attention_mats = {}
  start_time = time.time()
  attention_mats['gene1->gene0'], id_to_name_gene = adj_list_to_attention_mats(
    adj_list=None, num_steps=num_steps, name_to_id=name_to_id_gene, bipartite=False, 
    add_self_loop=True, symmetric=True, target_to_source=None, use_transition_matrix=True, 
    Ms=gene_gene_adj_mat, softmax_normalization=False, min_value=-100, device=device)

  attention_mats['pathway1->pathway0'], id_to_name_go = adj_list_to_attention_mats(
    adj_list=None, num_steps=num_steps, name_to_id=name_to_id_go, bipartite=False, 
    add_self_loop=False, symmetric=False, target_to_source=None, use_transition_matrix=True, 
    Ms=go_go_adj_mat, softmax_normalization=False, min_value=-100, device=device)

  mats, _ = adj_list_to_attention_mats(
    adj_list=None, num_steps=num_steps*2, name_to_id=[name_to_id_gene, name_to_id_go], 
    bipartite=True, add_self_loop=False, symmetric=False, target_to_source=None, 
    use_transition_matrix=True, Ms=gene_go_adj_mat, softmax_normalization=False, min_value=-100, 
    device=device)
  # this is very tricky: 
  # the even positions are all gene->pathway in mats[0], while odd ones gene->gene
  attention_mats['gene0->pathway1'] = [m for i, m in enumerate(mats[0]) if i%2==0]
  attention_mats['pathway0->gene1'] = [m for i, m in enumerate(mats[1]) if i%2==0]

  end_time = time.time()
  print(f'Time spent on generating attention_mats {end_time - start_time} s')

#   start_time = time.time()
#   with open(attention_mats_filepath, 'wb') as f:
#     if device == torch.device('cuda'):
#       for k, v in attention_mats.items():
#         for i in range(len(v)):
#           v[i] = v[i].detach().cpu().numpy()
#     pickle.dump(attention_mats, f)
#   end_time = time.time()
#   print(f'Time spent on writing attention_mats {end_time - start_time} s')

# start_time = time.time()
# with open('attention_mats-GeneNet.pkl', 'rb') as f:
#   attention_mats = pickle.load(f)
#   for k, v in attention_mats.items():
#     for i in range(len(v)):
#       v[i] = torch.tensor(v[i]).float().to(device)
# end_time = time.time()
# print(f'Time spent on loading attention_mats {end_time - start_time} s')

In [None]:
def permute_rows(xs, seed=None, randperm=None):
  if seed is not None:
    torch.random.manual_seed(seed)
  if isinstance(xs, (list, tuple)):
    if randperm is None:
      randperm = torch.randperm(len(xs[0]))
    for i, x in enumerate(xs):
      assert len(x)==len(xs[0])
      xs[i] = x[randperm]
  else:
    if randperm is None:
      randperm = torch.randperm(len(xs))
    xs = xs[randperm]
  return xs

if len(xs) == 1:
  x_all = torch.tensor(mat.T).float().to(device)
  y_all = torch.tensor(y_true).long().to(device)
  # this permutation is needed because later we sequentially split them into several buckets
  x_all, y_all = permute_rows([x_all, y_all], seed=seed)
  if randomize_labels:
    # y_all = permute_rows(y_all)
    y_all = torch.randint(2, y_all.size()).long().to(device)

  cls_loc = collections.defaultdict(list)
  for i, e in enumerate(y_all):
    cls_loc[e.item()].append(i)

  num_split = len(split_portion)
  preset_split_size = None
  split_loc = collections.defaultdict(list)
  for c, v in cls_loc.items():
    # split_portion and split_size have been defined before
    split_size = get_split(len(v), split_portion, split_size=preset_split_size)
    cnt = 0
    for s in split_size:
      split_loc[c].append(v[cnt:cnt+s])
      cnt = cnt + s
  xs = []
  ys = []
  for i in range(num_split):
    x_split = []
    y_split = []
    for c, v in sorted(split_loc.items()):
      x_split.append(x_all[v[i]])
      y_split.append(y_all[v[i]])
    xs.append(torch.cat(x_split, dim=0))
    ys.append(torch.cat(y_split, dim=0))
else:
  xs = [torch.tensor(x.T).float().to(device) for x in xs]
  ys = [torch.tensor(y).long().to(device) for y in ys]

x_train, x_val, x_test = xs[:3]
y_train, y_val, y_test = ys[:3]
  
# permute rows to remove the location dependency for data points of the same classes
x_train, y_train = permute_rows([x_train, y_train], seed=seed)
x_val, y_val = permute_rows([x_val, y_val], seed=seed)
x_test, y_test = permute_rows([x_test, y_test], seed=seed)
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape)

plot_scatter(y_=x_train, colors=['r' if i==1 else 'g' for i in y_train], title='train')
plot_scatter(y_=x_val, colors=['r' if i==1 else 'g' for i in y_val], title='val')
plot_scatter(y_=x_test, colors=['r' if i==1 else 'g' for i in y_test], title='test')

model_names = []
split_names = ['train', 'val', 'test']
metric_names = ['acc', 'precision', 'recall', 'f1_score', 'adjusted_mutual_info', 'auc', 
                'average_precision']
metric_all = []
confusion_mat_all = []
loss_his_all = []
acc_his_all = []

loss_fn_cls = nn.CrossEntropyLoss()
loss_fn_reg = nn.MSELoss()
if num_cls == 1:
  loss_fn = loss_fn_reg
else:
  loss_fn = loss_fn_cls

In [None]:
# GeneNet
model_names.append('GeneNet')
print(f'{model_names[-1]} Model')

use_dag_layer = False
dag_in_channel_list = [1,1]
dag_kwargs = {'residual':True, 'duplicate_dag':True}

forward_kwargs = {'attention_mats': attention_mats, 'max_num_layers': num_steps, 
          'min_num_layers': num_steps, 'return_layers': 'cls_score'}

model = GeneNet(num_genes=num_gene, num_pathways=num_go, attention_mats=None, dense=True,
                use_dag_layer=use_dag_layer, dag=dag, dag_in_channel_list=dag_in_channel_list, 
                dag_kwargs=dag_kwargs, nonlinearity=nn.ReLU(), use_layer_norm=True,
               num_cls=num_cls).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

print('Before training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()
best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)
print('After training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# PathNet: use GO gene set and GO hierarchy
model_names.append('PathNet')
print(f'{model_names[-1]} Model')

attention_mats_pathnet = {'pathway0->gene': attention_mats['pathway0->gene1'],
                 'pathway1->pathway0': attention_mats['pathway1->pathway0'],
                 'gene->pathway1': attention_mats['gene0->pathway1']}

forward_kwargs = {'attention_mats': attention_mats_pathnet, 'max_num_layers': num_steps, 
          'min_num_layers': num_steps, 'return_layers': 'cls_score'}

model = PathNet(num_genes=num_gene, num_pathways=num_go, attention_mats=None, dense=True,
                use_dag_layer=use_dag_layer, dag=dag, dag_in_channel_list=dag_in_channel_list, 
                dag_kwargs=dag_kwargs, nonlinearity=nn.ReLU(), use_layer_norm=True,
               num_cls=num_cls).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

print('Before training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()
best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)
print('After training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# BipartiteGraph 
model_names.append('BipartiteGraph')
print(f'{model_names[-1]} Model')

# mats had been calculated when we first calculate attention_mats
forward_kwargs = {'attention_mats': mats, 'max_num_layers': num_steps, 
          'min_num_layers': num_steps, 'return_layers': 'cls_score'}

model = BipartiteGraph1d(in_features=num_gene, out_features=num_go, 
                         use_layer_norm=True, num_cls=num_cls).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

print('Before training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()
best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)
print('After training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()
go_weight_bipartite_graph = model.classifier.weight.detach().cpu().numpy()

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# MLP
model_names.append('MLP')
print(f'{model_names[-1]}')

in_dim = x_train.shape[1]
hidden_dim = [100]
dense = False
residual = False

model = DenseLinear(in_dim, hidden_dim+[num_cls], dense=dense, residual=residual).to(device)
multi_heads = False

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

train_result = train_single_loss(model,  
                    x_train, y_train if num_cls>1 else y_train.unsqueeze(-1).float(), 
                    x_val, y_val if num_cls>1 else y_val.unsqueeze(-1).float(), 
                    x_test, y_test if num_cls>1 else y_test.unsqueeze(-1).float(), 
                    loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, 
    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val)
if train_result is not None:
  best_model, best_val_acc, best_epoch = train_result

metric = get_result(model, best_model, best_val_acc, best_epoch, 
                    x_train, y_train, x_val, y_val, x_test, y_test, 
                    batch_size=batch_size, multi_heads=False, average='weighted', 
        show_results_in_notebook=True, loss_idx=0, acc_idx=0, forward_kwargs={}, 
        predict_func=None, pred_kwargs=None, plot_loss=True, 
        plot_acc=True if num_cls>1 else False, plot_scatter=True if num_cls>1 else False, 
        loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
        acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# PrototypicalNet, requires_grad_prototype=False
model_names.append('PrototypicalNet0')
print(f'{model_names[-1]}')

in_dim = x_train.shape[1]
hidden_dim = [100]
dense = False
residual = False

model = PrototypicalNet(in_dim, hidden_dim, num_cls, dense=dense, 
                        residual=residual, requires_grad_prototype=False).to(device)
multi_heads = False

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

# must provide forward_kwargs_train for PrototypicalNet; 
# all other things are the same as DenseLinear
train_result = train_single_loss(model,  
                    x_train, y_train if num_cls>1 else y_train.unsqueeze(-1).float(), 
                    x_val, y_val if num_cls>1 else y_val.unsqueeze(-1).float(), 
                    x_test, y_test if num_cls>1 else y_test.unsqueeze(-1).float(), 
                    loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, 
    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train={'y': y_train, 'train': True})
if train_result is not None:
  best_model, best_val_acc, best_epoch = train_result

metric = get_result(model, best_model, best_val_acc, best_epoch, 
                    x_train, y_train, x_val, y_val, x_test, y_test, 
                    batch_size=batch_size, multi_heads=False, average='weighted', 
        show_results_in_notebook=True, loss_idx=0, acc_idx=0, forward_kwargs={}, 
        predict_func=None, pred_kwargs=None, plot_loss=True, 
        plot_acc=True if num_cls>1 else False, plot_scatter=True if num_cls>1 else False, 
        loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
        acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# PrototypicalNet, requires_grad_prototype=True
model_names.append('PrototypicalNet1')
print(f'{model_names[-1]}')

in_dim = x_train.shape[1]
hidden_dim = [100]
dense = False
residual = False

model = PrototypicalNet(in_dim, hidden_dim, num_cls, dense=dense, 
                        residual=residual, requires_grad_prototype=True).to(device)
multi_heads = False

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

# must provide forward_kwargs_train for PrototypicalNet; 
# all other things are the same as DenseLinear
train_result = train_single_loss(model,  
                    x_train, y_train if num_cls>1 else y_train.unsqueeze(-1).float(), 
                    x_val, y_val if num_cls>1 else y_val.unsqueeze(-1).float(), 
                    x_test, y_test if num_cls>1 else y_test.unsqueeze(-1).float(), 
                    loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, 
    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train={'y': y_train, 'train': True})
if train_result is not None:
  best_model, best_val_acc, best_epoch = train_result

metric = get_result(model, best_model, best_val_acc, best_epoch, 
                    x_train, y_train, x_val, y_val, x_test, y_test, 
                    batch_size=batch_size, multi_heads=False, average='weighted', 
        show_results_in_notebook=True, loss_idx=0, acc_idx=0, forward_kwargs={}, 
        predict_func=None, pred_kwargs=None, plot_loss=True, 
        plot_acc=True if num_cls>1 else False, plot_scatter=True if num_cls>1 else False, 
        loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
        acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# # DAGEncoder
# model_names.append('DAG')
# print(f'{model_names[-1]} Model')

# embedding_dim = 5
# in_channels_list = [5]
# key_dim = 5
# value_dim = 5
# fc_dim = 5
# dim_per_cls = 1
# num_heads = 1
# num_attention = 1
# use_encoders = False
# # num_go_dag = max(dag)+1
# # print(num_go_dag)
# graph_encoder = None # [m[:num_go_dag][:,:num_go_dag] for m in attention_mats['pathway1->pathway0']]
# graph_weight_encoder = 0.5
# knn = 50

# residual = True 
# duplicate_dag = True 
# gibbs_sampling = True 
# duplicated_attention = True 
# feature_max_norm = 1
# use_layer_norm = True 
# bias = True 
# nonlinearity = nn.ReLU()

# forward_kwargs = {'return_attention': False, 'graph_encoder': None, 'graph_decoder': None, 
#     'encoder_stochastic_depth': False}

# model = DAGEncoder(num_features=num_gene, embedding_dim=embedding_dim, 
#          in_channels_list=in_channels_list, 
#          bigraph=bigraph, dag=dag, key_dim=key_dim, value_dim=value_dim, fc_dim=fc_dim, 
#          num_cls=num_cls, dim_per_cls=dim_per_cls, feature_max_norm=feature_max_norm, 
#          use_layer_norm=use_layer_norm, bias=bias, 
#          nonlinearity=nonlinearity, residual=residual, duplicate_dag=duplicate_dag, 
#          gibbs_sampling=gibbs_sampling, num_heads=num_heads, num_attention=num_attention,
#          knn=knn, duplicated_attention=duplicated_attention,
#          graph_encoder=None, graph_weight_encoder=graph_weight_encoder, 
#          graph_decoder=None, graph_weight_decoder=0.5, use_encoders=use_encoders).to(device)

# start_time = time.time()
# y_pred = model(x_train, **forward_kwargs)
# end_time = time.time()
# print(f'Time spent on forward pass {end_time - start_time} s')

# start_time = time.time()
# loss = loss_fn(y_pred, y_train)
# loss.backward()
# end_time = time.time()
# print(f'Time spent on backward pass {end_time - start_time} s')

# loss_train_his = []
# loss_val_his = []
# loss_test_his = []
# acc_train_his = []
# acc_val_his = []
# acc_test_his = []
# best_model = model
# best_val_acc = 0
# best_epoch = 0

# best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
#     x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
#     amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
#     eval_every=eval_every, print_every=print_every, verbose=False, 
#     loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
#     acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
#     return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
#     forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)

# metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
#             x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
#             show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
#             forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
#             plot_loss=True, plot_acc=True, plot_scatter=False, 
#             loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
#             acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

# loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
# acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
# metric_all.append([v[0] for v in metric])
# confusion_mat_all.append([v[1] for v in metric])

In [None]:
# GraphConvolutionNetwork
model_names.append('GCN')
print(f'{model_names[-1]} Model')

use_string_ppi = True
if use_string_ppi:
  attention_mat_gcn = attention_mats['gene1->gene0'][0]
else:
  attention_mat_gcn = attention_mats['gene1->gene0'][1]
attention_mat_gcn = (attention_mat_gcn + attention_mat_gcn.t())/2
  
duplicate_layers=False, 
dense=False 
residual=True
use_bias=True
use_layer_norm=False 
nonlinearity=nn.ReLU()
classifier_bias=True

forward_kwargs = {'attention_mats':attention_mat_gcn, 'max_num_layers':num_steps, 
                  'min_num_layers':num_steps, 'return_layers': 'last-layer'}

model = GraphConvolution1d(num_features=num_gene, num_layers=num_steps, 
          duplicate_layers=duplicate_layers, dense=dense, residual=residual, use_bias=use_bias, 
          use_layer_norm=use_layer_norm, nonlinearity=nonlinearity, num_cls=num_cls, 
          classifier_bias=classifier_bias).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')


loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

# sklearn classifiers

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier

model_names_sklearn = ['kNN', 'Naive Bayes', 'SVM', 'Decision Tree', 
                             'Random Forest', 'AdaBoost']
model_names = list(model_names) + model_names_sklearn

classifiers = [KNeighborsClassifier(5), 
               GaussianNB(), 
               sklearn.svm.SVC(kernel="linear", C=0.025),
               DecisionTreeClassifier(max_depth=5),
               RandomForestClassifier(max_depth=5, n_estimators=10),
               AdaBoostClassifier()
              ]

# assert train_portion > 0 and val_portion > 0 and test_portion > 0 # Assume there are 3 splits
for name, classifier in zip(model_names_sklearn, classifiers):
  print(name)
  classifier.fit(x_train, y_train)
  metric = []
  for x_, y_ in zip([x_train, x_val, x_test], [y_train, y_val, y_test]):
    if name == 'SVM':
      y_score = classifier.decision_function(x_) # sklearn.svm.SVC does not have predict_proba
    else:
      y_score = classifier.predict_proba(x_)
    metric.append(eval_classification(y_true=y_, y_pred=y_score, 
                                      average='weighted', verbose=True))
  metric_all.append([v[0] for v in metric])
  confusion_mat_all.append([v[1] for v in metric])
  # loss and accuracy history are discarded for these classifiers
  loss_his_all.append([])
  acc_his_all.append([])

# Randomize the order of genes

In [None]:
random_idx_feature = torch.randperm(x_train.size(1))
x_train = x_train[:, random_idx_feature]
x_val = x_val[:, random_idx_feature]
x_test = x_test[:, random_idx_feature]

In [None]:
# GeneNet
model_names.append('RandGeneNet')
print(f'{model_names[-1]} Model')

use_dag_layer = False
dag_in_channel_list = [1,1]
dag_kwargs = {'residual':True, 'duplicate_dag':True}

forward_kwargs = {'attention_mats': attention_mats, 'max_num_layers': num_steps, 
          'min_num_layers': num_steps, 'return_layers': 'cls_score'}

model = GeneNet(num_genes=num_gene, num_pathways=num_go, attention_mats=None, dense=True,
                use_dag_layer=use_dag_layer, dag=dag, dag_in_channel_list=dag_in_channel_list, 
                dag_kwargs=dag_kwargs, nonlinearity=nn.ReLU(), use_layer_norm=True,
               num_cls=num_cls).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

print('Before training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()
best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)
print('After training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# PathNet: use GO gene set and GO hierarchy
model_names.append('RandPathNet')
print(f'{model_names[-1]} Model')

attention_mats_pathnet = {'pathway0->gene': attention_mats['pathway0->gene1'],
                 'pathway1->pathway0': attention_mats['pathway1->pathway0'],
                 'gene->pathway1': attention_mats['gene0->pathway1']}

forward_kwargs = {'attention_mats': attention_mats_pathnet, 'max_num_layers': num_steps, 
          'min_num_layers': num_steps, 'return_layers': 'cls_score'}

model = PathNet(num_genes=num_gene, num_pathways=num_go, attention_mats=None, dense=True,
                use_dag_layer=use_dag_layer, dag=dag, dag_in_channel_list=dag_in_channel_list, 
                dag_kwargs=dag_kwargs, nonlinearity=nn.ReLU(), use_layer_norm=True,
               num_cls=num_cls).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

print('Before training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()
best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)
print('After training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# BipartiteGraph 
model_names.append('RandBipartiteGraph')
print(f'{model_names[-1]} Model')

# mats had been calculated when we first calculate attention_mats
forward_kwargs = {'attention_mats': mats, 'max_num_layers': num_steps, 
          'min_num_layers': num_steps, 'return_layers': 'cls_score'}

model = BipartiteGraph1d(in_features=num_gene, out_features=num_go, 
                         use_layer_norm=True, num_cls=num_cls).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

print('Before training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()
best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)
print('After training: last layer weight distribution')
plt.plot(sorted(model.classifier.weight[0].detach().cpu().numpy()))
plt.show()

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# MLP
model_names.append('RandMLP')
print(f'{model_names[-1]}')

in_dim = x_train.shape[1]
hidden_dim = [100]
dense = False
residual = False

model = DenseLinear(in_dim, hidden_dim+[num_cls], dense=dense, residual=residual).to(device)
multi_heads = False

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

train_result = train_single_loss(model,  
                    x_train, y_train if num_cls>1 else y_train.unsqueeze(-1).float(), 
                    x_val, y_val if num_cls>1 else y_val.unsqueeze(-1).float(), 
                    x_test, y_test if num_cls>1 else y_test.unsqueeze(-1).float(), 
                    loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, 
    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val)
if train_result is not None:
  best_model, best_val_acc, best_epoch = train_result

metric = get_result(model, best_model, best_val_acc, best_epoch, 
                    x_train, y_train, x_val, y_val, x_test, y_test, 
                    batch_size=batch_size, multi_heads=False, average='weighted', 
        show_results_in_notebook=True, loss_idx=0, acc_idx=0, forward_kwargs={}, 
        predict_func=None, pred_kwargs=None, plot_loss=True, 
        plot_acc=True if num_cls>1 else False, plot_scatter=True if num_cls>1 else False, 
        loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
        acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# # DAGEncoder
# model_names.append('RandDAG')
# print(f'{model_names[-1]} Model')

# embedding_dim = 5
# in_channels_list = [5]
# key_dim = 5
# value_dim = 5
# fc_dim = 5
# dim_per_cls = 1
# num_heads = 1
# num_attention = 1
# num_go = max(dag)+1
# print(num_go)
# use_encoders = False
# graph_encoder = None # [m[:num_go][:,:num_go] for m in attention_mats['pathway1->pathway0']]
# graph_weight_encoder = 0.5
# knn = 50

# residual = True 
# duplicate_dag = True 
# gibbs_sampling = True 
# duplicated_attention = True 
# feature_max_norm = 1
# use_layer_norm = True 
# bias = True 
# nonlinearity = nn.ReLU()

# forward_kwargs = {'return_attention': False, 'graph_encoder': None, 'graph_decoder': None, 
#     'encoder_stochastic_depth': False}

# model = DAGEncoder(num_features=num_gene, embedding_dim=embedding_dim, 
#          in_channels_list=in_channels_list, 
#          bigraph=bigraph, dag=dag, key_dim=key_dim, value_dim=value_dim, fc_dim=fc_dim, 
#          num_cls=num_cls, dim_per_cls=dim_per_cls, feature_max_norm=feature_max_norm, 
#          use_layer_norm=use_layer_norm, bias=bias, 
#          nonlinearity=nonlinearity, residual=residual, duplicate_dag=duplicate_dag, 
#          gibbs_sampling=gibbs_sampling, num_heads=num_heads, num_attention=num_attention,
#          knn=knn, duplicated_attention=duplicated_attention,
#          graph_encoder=None, graph_weight_encoder=graph_weight_encoder, 
#          graph_decoder=None, graph_weight_decoder=0.5, use_encoders=use_encoders).to(device)

# start_time = time.time()
# y_pred = model(x_train, **forward_kwargs)
# end_time = time.time()
# print(f'Time spent on forward pass {end_time - start_time} s')

# start_time = time.time()
# loss = loss_fn(y_pred, y_train)
# loss.backward()
# end_time = time.time()
# print(f'Time spent on backward pass {end_time - start_time} s')

# loss_train_his = []
# loss_val_his = []
# loss_test_his = []
# acc_train_his = []
# acc_val_his = []
# acc_test_his = []
# best_model = model
# best_val_acc = 0
# best_epoch = 0

# best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
#     x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
#     amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
#     eval_every=eval_every, print_every=print_every, verbose=False, 
#     loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
#     acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
#     return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
#     forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)

# metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
#             x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
#             show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
#             forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
#             plot_loss=True, plot_acc=True, plot_scatter=False, 
#             loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
#             acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

# loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
# acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
# metric_all.append([v[0] for v in metric])
# confusion_mat_all.append([v[1] for v in metric])

In [None]:
# GraphConvolutionNetwork
model_names.append('RandGCN')
print(f'{model_names[-1]} Model')

use_string_ppi = True
if use_string_ppi:
  attention_mat_gcn = attention_mats['gene1->gene0'][0]
else:
  attention_mat_gcn = attention_mats['gene1->gene0'][1]
attention_mat_gcn = (attention_mat_gcn + attention_mat_gcn.t())/2
  
duplicate_layers=False, 
dense=False 
residual=True
use_bias=True
use_layer_norm=False 
nonlinearity=nn.ReLU()
classifier_bias=True

forward_kwargs = {'attention_mats':attention_mat_gcn, 'max_num_layers':num_steps, 
                  'min_num_layers':num_steps, 'return_layers': 'last-layer'}

model = GraphConvolution1d(num_features=num_gene, num_layers=num_steps, 
          duplicate_layers=duplicate_layers, dense=dense, residual=residual, use_bias=use_bias, 
          use_layer_norm=use_layer_norm, nonlinearity=nonlinearity, num_cls=num_cls, 
          classifier_bias=classifier_bias).to(device)

start_time = time.time()
y_pred = model(x_train, **forward_kwargs)
end_time = time.time()
print(f'Time spent on forward pass {end_time - start_time} s')

start_time = time.time()
loss = loss_fn(y_pred, y_train)
loss.backward()
end_time = time.time()
print(f'Time spent on backward pass {end_time - start_time} s')


loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, reduce_every=reduce_every, 
    eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=return_best_val, forward_kwargs_train=forward_kwargs, 
    forward_kwargs_val=forward_kwargs, forward_kwargs_test=forward_kwargs)

metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
            x_test, y_test, batch_size=batch_size, multi_heads=False, average='weighted', 
            show_results_in_notebook=True, loss_idx=0, acc_idx=0, 
            forward_kwargs=forward_kwargs, predict_func=None, pred_kwargs=None, 
            plot_loss=True, plot_acc=True, plot_scatter=False, 
            loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his,
            acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his)

loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

In [None]:
# result comparison
model_names_all = np.array(model_names)
res_all = np.array(metric_all)

metric_idx = 4
split_idx = 2
# i = 5
# subset = range(6*i, 6*i+6)
# subset = range(i, len(model_names_all), 6)
subset = range(len(model_names_all))
if res_all.ndim==4:
  mean = res_all.mean(axis=0)[subset]
  std = res_all.std(axis=0)[subset]
else:
  mean = res_all[subset]
  std = np.zeros_like(mean)
model_names = model_names_all[subset]
sorted_idx = np.argsort(-mean, axis=0)
mean = mean[sorted_idx[:, split_idx, metric_idx], split_idx, metric_idx]
std = std[sorted_idx[:, split_idx, metric_idx], split_idx, metric_idx]
names = model_names[sorted_idx[:, split_idx, metric_idx]]
res = [(i+1, n, m, sd) for i, (n, m, sd) in enumerate(zip(names, mean, std))]
print('{:^4} {:^50} {:^5} \t {:^5}'.format('Rank', 'Name', 'Mean', 'Std'))
for s in res:
  print(f'{s[0]:^4} {s[1]:^50} {s[2]:^.3f} \t {s[3]:^.3f}')

In [None]:
with open(f'{result_folder}/{res_filename}', 'wb') as f:
  print(f'Write result to file {result_folder}/{res_filename}')
  pickle.dump({'loss_his_all': loss_his_all,
               'acc_his_all': acc_his_all,
               'metric_all': metric_all,
               'confusion_mat_all': confusion_mat_all,
               'model_names': model_names,
               'split_names': split_names,
               'metric_names': metric_names,
               'name_to_id_gene': name_to_id_gene, 
               'name_to_id_go': name_to_id_go, 
               'dag': dag, 
               'bigraph': bigraph, 
               'go_weight_bipartite_graph': go_weight_bipartite_graph
              }, f)

In [None]:
# use_dag_layer = True
# dag_in_channel_list = [1]
# dag_kwargs = {'residual':True, 'duplicate_dag':True}
# batch_size = 500
# model = GeneNet(num_genes=num_gene, num_pathways=num_go, attention_mats=None, dense=True,
#                 use_dag_layer=use_dag_layer, dag=dag, dag_in_channel_list=dag_in_channel_list, 
#                 dag_kwargs=dag_kwargs,
#                 nonlinearity=nn.ReLU(), use_layer_norm=True).to(device)

# x = torch.randn(batch_size, num_gene).to(device)

# start_time = time.time()
# y = model(x, attention_mats=attention_mats, max_num_layers=num_steps, min_num_layers=num_steps, 
#           return_layers='all')
# end_time = time.time()
# print(f'Time spent on forward pass {end_time - start_time} s')
# print(y[0].shape, y[1].shape, y[2].shape, y[3].shape)

# start_time = time.time()
# loss = y[0].sum() + y[1].sum() + y[2].sum() + y[3].sum()
# loss.backward()
# end_time = time.time()
# print(f'Time spent on backward pass {end_time - start_time} s')