In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, "../MultiOmicsGraphEmbedding/")

import logging
logger = logging.getLogger("wandb")
logger.setLevel(logging.ERROR)

# Necessary imports
import pickle, os, time, random, datetime, itertools, warnings
from argparse import Namespace

import networkx as nx
import numpy as np
import pandas as pd

from sklearn.metrics import classification_report, roc_auc_score, average_precision_score

from umap import UMAP
from openomics import MultiOmics
from openomics.database import GeneOntology
from moge.network.heterogeneous import HeterogeneousNetwork
from moge.visualization.data import heatmap, plot_training_history, heatmap_compare, clf_report, clf_report_compare
from moge.visualization.network import graph_viz, graph_viz3d

from moge.evaluation.utils import largest_indices
from moge.evaluation.embedding import distances_correlation
from moge.visualization.evaluation import plot_roc_curve_multiclass, plot_roc_curve, plot_pr_curve_multiclass
from moge.generator import SubgraphGenerator, SubgraphDataset

import plotly.graph_objects as go

import wandb
from pytorch_lightning.loggers import WandbLogger

warnings.filterwarnings('ignore')
np.set_printoptions(precision=3, suppress=True)
pd.set_option('display.max_rows', 500)

In [2]:
import torch
from torch.utils import data
from torch_geometric import datasets

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import loggers

from torchsample.callbacks import ReduceLROnPlateau
from torchsample.regularizers import L1Regularizer, L2Regularizer

torch.cuda.is_available()

True

In [3]:
# from model.classification.transformer import Transformer, StarEncoderLayer

# Import 

In [4]:
with open('../MultiOmicsGraphEmbedding/moge/data/gtex_string_network.pickle', 'rb') as file:
    network = pickle.load(file)

network.multiomics.Protein.annotation_expressions = network.multiomics.Protein.expressions.T
network.annotations.info()

<class 'pandas.core.frame.DataFrame'>
Index: 18815 entries, A1BG to ZZEF1
Data columns (total 8 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   gene_name             18815 non-null  object
 1   protein_size          17965 non-null  object
 2   protein_id            17965 non-null  object
 3   annotation            17965 non-null  object
 4   Transcript sequence   17965 non-null  object
 5   go_id                 18815 non-null  object
 6   disease_associations  8755 non-null   object
 7   omic                  18815 non-null  object
dtypes: object(8)
memory usage: 1.3+ MB


In [5]:
network.annotations["Transcript sequence"].map(lambda x: len(x) if isinstance(x, str) else None).describe()

count    17965.000000
mean       582.315001
std        608.918757
min         24.000000
25%        278.000000
50%        437.000000
75%        699.000000
max      35991.000000
Name: Transcript sequence, dtype: float64

# Data Parameters

In [6]:
# INPUT PARAMETERS
variables = []
# variables = ['chromosome_name', 'transcript_start', 'transcript_end']
targets = ['go_id']

network.process_feature_tranformer(filter_label=targets[0], min_count=100, verbose=False)
classes = network.feature_transformer[targets[0]].classes_
n_classes = len(classes)

test_frac = 0.05
max_length = 1000
input_shape = (None, )
batch_size = 2000
n_steps = int(400000/batch_size)

directed = False

seed = random.randint(0,1000)
n_classes

2522

In [7]:
split_idx = 0
dataset_train = network.get_train_generator(
    SubgraphGenerator, split_idx=split_idx, variables=variables, targets=targets,
    traversal="bfs", batch_size=batch_size, agg_mode=None,
    method="GAT", adj_output="coo",
    sampling="cycle", n_steps=n_steps, directed=directed,
    maxlen=max_length, padding='post', truncating='post', variable_length=False,
    seed=seed, verbose=True)

dataset_test = network.get_test_generator(
    SubgraphGenerator, split_idx=split_idx, variables=variables, targets=targets,
    traversal='all', batch_size=batch_size, agg_mode=None,
    method="GAT", adj_output="coo",
    sampling="log", n_steps=1, directed=directed,
    maxlen=max_length, padding='post', truncating='post', variable_length=False,
    seed=seed, verbose=True)

dataset_train.tokenizer.word_index == dataset_test.tokenizer.word_index
vocab = dataset_train.tokenizer.word_index

node_list 17067 {'directed': 17067, 'undirected': 17067}
node_list 898 {'directed': 898, 'undirected': 898}


In [8]:
# class_weights = network.feature_transformer["go_id"].transform(network.annotations["go_id"].str.split("|")).sum(0)
# class_weights = 1/class_weights

In [9]:
# X_train, y, idx_train = dataset_train.__getitem__()
# print({k: v.shape if not isinstance(v, list) else (len(v), len(v[0])) for k, v in X_train.items()}, 
#       {"y_train": y.shape}, idx_train.shape)

# X, y, idx_train = dataset_test.__getitem__()
# print({k: v.shape if not isinstance(v, list) else (len(v), len(v[0])) for k, v in X.items()}, 
#       {"y_train": y.shape}, idx_train.shape)

# Prepare Data Generators

In [10]:
# def padding_tensor(sequences):
#     num = len(sequences)
#     max_len = max([s.size(-1) for s in sequences])
#     out_dims = (num, 2, max_len)
#     out_tensor = sequences[0].data.new(*out_dims).fill_(0)
# #     mask = sequences[0].data.new(*out_dims).fill_(0)
#     for i, tensor in enumerate(sequences):
#         length = tensor.size(-1)
#         out_tensor[i, :, :length] = tensor
# #         mask[i, :length] = 1
#     return out_tensor

# def collate_fn(batch):
#     input_seqs_all, subnetwork_all, y_all, idx_all = [], [], [] ,[]
#     for X, y, idx in batch:
#         input_seqs_all.append(torch.tensor(X["input_seqs"]))
#         subnetwork_all.append(torch.tensor(X["subnetwork"]))
#         y_all.append(torch.tensor(y))
#         idx_all.append(torch.tensor(idx))
        
#     X_all = {"input_seqs": torch.cat(input_seqs_all), "subnetwork": padding_tensor(subnetwork_all)}
#     return X_all, torch.cat(y_all), torch.cat(idx_all)


params = {
    'batch_size': None,
    'shuffle': False,
    'num_workers': 10,
#     'collate_fn': collate_fn,
}

dataloader_train = data.DataLoader(dataset_train, **params)
dataloader_test = data.DataLoader(dataset_test, **params)

In [11]:
# Test model
X_train, y, idx_train = next(iter(dataloader_train))
print({k: v.shape if not isinstance(v, list) else (len(v), len(v[0])) for k, v in X_train.items()}, 
      {"y_train": y.shape}, idx_train.shape)

X, y, idx_train = next(iter(dataloader_test))
print({k: v.shape if not isinstance(v, list) else (len(v), len(v[0])) for k, v in X.items()}, 
      {"y_train": y.shape}, idx_train.shape)

{'input_seqs': torch.Size([2000, 1000]), 'subnetwork': torch.Size([2, 34516])} {'y_train': torch.Size([2000, 2522])} torch.Size([2000])
{'input_seqs': torch.Size([898, 1000]), 'subnetwork': torch.Size([2, 2164])} {'y_train': torch.Size([898, 2522])} torch.Size([898])


# Build Model

In [12]:
from moge.module.trainer import LightningModel
from moge.module.classifier import EncoderEmbedderClassifier

# CUDA for PyTorch
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
max_epochs = 100

hparams = {
    "encoder": "ConvLSTM",
    "encoding_dim": 128,
    "vocab_size": len(vocab),
    "word_embedding_size": 18,
    "max_length": max_length,
    
#     "num_hidden_layers": 1,
#     "num_hidden_groups": 1,
#     "hidden_dropout_prob": 0.16,
#     "attention_probs_dropout_prob": 0.1356,
#     "num_attention_heads": 8,
#     "intermediate_size": 1024,

    "nb_conv1_filters": 154,
    "nb_conv1_kernel_size": 9,
    "nb_conv1_dropout": 0.4838,
    "nb_conv1_batchnorm": False,
    
    "nb_conv2_filters": 43,
    "nb_conv2_kernel_size": 6,
    "nb_conv2_batchnorm": True,

    "nb_max_pool_size": 19,

    "nb_lstm_bidirectional": False,
    "nb_lstm_units": 162,
    "nb_lstm_hidden_dropout": 0.1615,
    "nb_lstm_layernorm": True,

    "embedder": "GAT",
    "embedding_dim": 256,
    "nb_attn_heads": 8,
    "nb_attn_dropout": 0.32,

    "classifier": "Dense",
    "nb_cls_dense_size": 1536,
    "nb_cls_dropout": 0.3245,
    "n_classes": n_classes,

    "nb_weight_decay": 0.03,
    "lr": 1e-3,
    
    "loss_type": "SIGMOID_FOCAL_CROSS_ENTROPY",
    
    "optimizer": "adam",
#     "class_weights": class_weights,
}

eec = EncoderEmbedderClassifier(Namespace(**hparams))
model = LightningModel(eec)

INFO:transformers.file_utils:PyTorch version 1.5.0 available.
INFO:transformers.file_utils:TensorFlow version 2.2.0 available.


INFO: Output of `_classifier` is logits


In [13]:
# y_hat = eec.forward(X_train)
# loss = eec.loss(y_hat, y, idx_train)

In [14]:
regularizers = [L1Regularizer(scale=1e-3, module_filter='conv*'),
                L2Regularizer(scale=1e-5, module_filter='fc*')]

# wandb_logger = WandbLogger(project="multiplex-rna-embedding")
# wandb_logger.log_hyperparams(hparams)

In [15]:
trainer = Trainer(
    gpus=1,
#     distributed_backend='ddp',
    min_epochs=20,
    max_epochs=max_epochs,
#     early_stop_callback=EarlyStopping(monitor='val_loss', patience=3),
#     regularizers=regularizers,
#     logger=wandb_logger,
    weights_summary='top',
    amp_level='O1', precision=16
)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:Using 16bit precision.


# Run Model

In [None]:
trainer.fit(model, train_dataloader=dataloader_train, val_dataloaders=dataloader_test)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


INFO:lightning:
  | Name   | Type                      | Params
-------------------------------------------------
0 | _model | EncoderEmbedderClassifier | 4 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

{'val_precision': 0.03791998211523064, 'val_recall': 0.4929057625208655, 'val_top_k': 0.029660927957309645, 'val_loss': 0.12888985872268677}


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

{'val_precision': 0.0, 'val_recall': 0.0, 'val_top_k': 0.30427948881629313, 'val_loss': 0.12596221268177032}


In [None]:
model

In [None]:
# wandb_logger.experiment._summary

# Evaluate

In [None]:
X_train, y_train, idx_train = dataset_train.load_data(dropna=True)
print({k: v.shape for k, v in X_train.items()}, {"y_train": y_train.shape})
X_test, y_test, idx_test = dataset_test.load_data(dropna=True)
print({k: v.shape for k, v in X_test.items()}, {"y_test": y_test.shape})

In [None]:
emb_test = eec.get_embeddings(X_test, cuda=True)
y_test_pred = eec.predict(emb_test, cuda=True)

emb_train = eec.cpu().get_embeddings(X_train, cuda=False)
y_train_pred = eec.predict(emb_train, cuda=False)

In [None]:
y_train_pred = pd.DataFrame(y_train_pred, index=y_train.index, columns=y_train.columns)
y_test_pred = pd.DataFrame(y_test_pred, index=y_test.index, columns=y_test.columns)

emb_all = np.concatenate([emb_train, emb_test], axis=0)
y_all = np.concatenate([y_train, y_test], axis=0)
idx_all = pd.concat([pd.Series(["train"] * idx_train.shape[0], index=y_train.index), 
                     pd.Series(["test"] * idx_test.shape[0], index=y_test.index)])
y_all.shape

# Classification report

In [None]:
(y_test.sum(axis=1).mean(), (y_test_pred>0.5).sum(axis=1).mean()), (y_train.sum(axis=1).mean(), (y_train_pred>0.5).sum(axis=1).mean())

In [None]:
clf_report_compare(y_train, y_train_pred, 
                   y_test, y_test_pred, 
                   classes=network.feature_transformer[targets[0]].classes_, 
                   threshold=0.5)

In [None]:
top_classes = y_test.sum(0).sort_values(ascending=False)[:1].index
plot_roc_curve_multiclass(y_test, y_test_pred, classes=None, #sample_weight=idx_test.astype(int).values,
                          plot_classes=False,
                          width=400, height=400)

In [None]:
# plot_pr_curve_multiclass(y_test, (y_test_pred>0.5).astype(int), classes=top_classes, #sample_weight=idx_test.astype(int).values, 
#                          plot_classes=False,
#                           width=400, height=400)

In [None]:
idx = largest_indices(y_test_pred.sum(axis=1), 50)
cols = y_test.columns[largest_indices(y_test.iloc[idx].sum(0), 100)].sort_values()
heatmap_compare(y_test.iloc[idx][cols], y_test_pred.iloc[idx][cols], 
                title=f"Predictions on {n_classes} GO Terms")

In [None]:
idx = largest_indices(y_train_pred.sum(axis=1), 100)
cols = y_train.columns[largest_indices(y_train.iloc[idx].sum(0), 100)].sort_values()
heatmap_compare(y_train.iloc[idx][cols], y_train_pred.iloc[idx][cols], 
                title=f"Predictions on {n_classes} GO Terms")

# Viz Graph

In [None]:
pos = UMAP(n_components=3).fit_transform(emb_all)
pos = {idx_all.index[i]:pair for i, pair in enumerate(pos)}

In [None]:
nodelist = idx_all.index

In [None]:
gene_ontology = GeneOntology()
gene_ontology.filter_network("biological_process")
go_id_colors = gene_ontology.get_node_color("~/Bioinformatics_ExternalData/GeneOntology/go_colors_biological.csv")

In [None]:
labels_color = network.get_labels_color("go_id", go_id_colors, label_filter=set(gene_ontology.node_list))

In [None]:
graph_viz3d(network.G_u, nodelist=nodelist, 
            pos=pos,
            node_text=nodelist,# + ", " +labels_terms.loc[nodelist],
            node_symbol=idx_all.loc[nodelist],
            node_color=labels_color.loc[nodelist],
#             edge_label="database",
#             iterations=100,
            max_edges=5000, showlegend=False, )

# Clustering Results

In [None]:
from moge.model.static_graph_embedding import ImportedGraphEmbedding
from moge.evaluation.clustering import evaluate_clustering

In [None]:
ige = ImportedGraphEmbedding(hparams["encoding_dim"], "MultiplexEmbedding")
ige.node_list = idx_all.index.to_list()
ige._X = emb_all

In [None]:
nodelist = network.annotations["go_id"][network.annotations["go_id"].notnull()].index & pd.Index(ige.node_list) & y_train.index
evaluate_clustering(ige, network.annotations, nodelist, node_label="go_id", 
                    metrics=['homogeneity', 'completeness', 'nmi', 'ami'], max_clusters=1000)

# Distance correlation analysis


In [None]:
# Embeddings vs sequences
print(distances_correlation(emb_train, network.annotations[["Transcript sequence"]], index=y_train.index, n_nodes=200))
print(distances_correlation(emb_test, network.annotations[["Transcript sequence"]], index=y_test.index, n_nodes=200))

In [None]:
# Encodings vs sequences
print(distances_correlation(train_protein_encodings, network.annotations[["Transcript sequence"]], index=y_train.index, n_nodes=100))
print(distances_correlation(test_protein_encodings, network.annotations[["Transcript sequence"]], index=y_test.index, n_nodes=100))

In [None]:
# Embedding vs Adj
adj = nx.adjacency_matrix(network.G_u, nodelist=y_train.index).toarray()
print(distances_correlation(emb_train, adj, index=y_train.index, n_nodes=200, verbose=False))
adj = nx.adjacency_matrix(network.G_u, nodelist=y_test.index).toarray()
print(distances_correlation(emb_test, adj, index=y_test.index, n_nodes=200, verbose=False))

In [None]:
# Encoding vs Adj
adj = nx.adjacency_matrix(network.G_u, nodelist=y_train.index).toarray()
print(distances_correlation(train_protein_encodings, adj, index=y_train.index, n_nodes=200, verbose=False))
adj = nx.adjacency_matrix(network.G_u, nodelist=y_test.index).toarray()
print(distances_correlation(test_protein_encodings, adj, index=y_test.index, n_nodes=200, verbose=False))