This notebook performs pretraining using the ClinVar, GPD and Annovar annotations for 324 FoundationOne genes.

In [1]:
import sys
sys.path.append("../src/")

In [3]:
import numpy as np
import pandas as pd

import datetime
import logging
import os
import time
import torch

from torch import nn
from torch.nn import functional as F

from functools import cached_property

from torch.nn import Linear, ReLU, Sequential

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score

from datasets_drug_filtered import (
    AggCategoricalAnnotatedCellLineDatasetFilteredByDrug,
    AggCategoricalAnnotatedTcgaDatasetFilteredByDrug,
)
# from metric import NdcgMetric
from utils import get_kld_loss, get_zinb_loss, get_zinorm_loss

# from testbed import EvaluationTestbed
from seaborn import scatterplot

from sklearn.metrics import pairwise_distances

In [4]:
sample_id = 0 # replace with 

## Model Definition - Gene expression

In [5]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [6]:
input_dim_vae = 324 * 6 * 4
k_list = [128, 64] # original
# k_list = [1024, 128]
actf_list = ["tanh", "tanh"]
is_real = True
eps = 1e-10
ridge_lambda = 0.05
is_mean = True
weight_decay = 1e-4

# The below modules are expected to be available in the scope where this module is instialized
from ffnzinb import ffnzinb
from vae import vae

ffb_zinb_model1 = ffnzinb(input_dim_vae).cuda(device)
vae_model1 = vae(input_dim_vae, k_list, actf_list, is_real).cuda(device)

ffb_zinb_model2 = ffnzinb(input_dim_vae).cuda(device)
vae_model2 = vae(input_dim_vae, k_list, actf_list, is_real).cuda(device)

#
zinb_layers_mu: 
OrderedDict([('mu', Linear(in_features=7776, out_features=7776, bias=True))])
#
zinb_layers_theta: 
OrderedDict([('theta', Linear(in_features=7776, out_features=7776, bias=True))])
#
zinb_layers_pi: 
OrderedDict([('pi', Linear(in_features=7776, out_features=7776, bias=True)), ('pi-actf', Sigmoid())])
#
U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=128, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=128, out_features=64, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
sigma_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=32, out_features=64, bias=True)
  (-act-0): Tanh()
  (dec-0): Linear(in_features=64, out_features=128, bias=True)
  (act-0): Tanh()
  (dec-1): Linear(in_features=128, out_features=7776, bias=True)
  (act-1): Tanh()
)
#
zinb_layers_mu: 
OrderedDict([('mu', Linear(in_features=7776, out_features=7776,

In [7]:
from utils import get_kld_loss, get_zinb_loss, get_zinorm_loss

# alignment loss
def coral(source, target):

    d = source.size(1)  # dim vector

    source_c = compute_covariance(source)
    target_c = compute_covariance(target)

    loss = torch.sum(torch.mul((source_c - target_c), (source_c - target_c)))

    #     loss = loss / (4 * d * d)
    return loss


def compute_covariance(input_data):
    """
    Compute Covariance matrix of the input data
    """
    n = input_data.size(0)  # batch_size

    # Check if using gpu or cpu
    if input_data.is_cuda:
        device = torch.device("cuda:1")
    else:
        device = torch.device("cpu")

    id_row = torch.ones(n).resize(1, n).to(device=device)
    sum_column = torch.mm(id_row, input_data)
    mean_column = torch.div(sum_column, n)
    term_mul_2 = torch.mm(mean_column.t(), mean_column)
    d_t_d = torch.mm(input_data.t(), input_data)
    c = torch.add(d_t_d, (-1 * term_mul_2)) * 1 / (n - 1)

    return c


def get_cell_line_tcga(vae_model, zinb_model):

    with torch.no_grad():

        test_dataset = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(
            is_train=False, filter_for="tcga", sample_id = sample_id
        )
        test_depmap_ids, drug_names, _ = list(
            test_dataset[: len(test_dataset)].values()
        )
        uniq_test_depmap_ids = np.unique(np.array(test_depmap_ids))

        test_model_in = torch.tensor(
            test_dataset.clinvar_gpd_annovar_annotated.loc[uniq_test_depmap_ids].to_numpy(),
            device=device,
            dtype=torch.float,
        )

        x_enc, X_test_mu, logvar, X_recons = vae_model(test_model_in)
        X_mu, X_theta, X_pi = zinb_model(X_recons)

        error = F.mse_loss(
            X_mu.detach(), test_model_in.detach(), reduction="sum"
        ).item()

    return error

In [8]:
num_iterations = 50
learning_rate = 1e-3
convg_thres = 1e-4


criterion = nn.MSELoss(reduction="mean")

params_list = []
params_list += list(vae_model1.parameters())
params_list += list(ffb_zinb_model1.parameters())
params_list += list(vae_model2.parameters())
params_list += list(ffb_zinb_model2.parameters())

optimizer = torch.optim.Adam(params_list, lr=learning_rate, weight_decay=weight_decay)


train_cell_line_dataset = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(
    is_train=True, filter_for="tcga", sample_id = sample_id
)
train_depmap_ids, drug_names, _ = list(
    train_cell_line_dataset[: len(train_cell_line_dataset)].values()
)
uniq_train_depmap_ids = np.unique(np.array(train_depmap_ids))

train_pdx_dataset = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id = sample_id)

train_pdx_ids, drug_names, _ = list(
    train_pdx_dataset[: len(train_pdx_dataset)].values()
)
uniq_train_pdx_ids = np.unique(np.array(train_pdx_ids))

train_losses = []
val_errors = []
best_error = None
for epoch in range(num_iterations):

    optimizer.zero_grad()
    model1_in = torch.tensor(
        train_cell_line_dataset.clinvar_gpd_annovar_annotated.loc[uniq_train_depmap_ids].to_numpy(),
        device=device,
        dtype=torch.float,
    )
    model2_in = torch.tensor(
        train_pdx_dataset.clinvar_gpd_annovar_annotated.loc[uniq_train_pdx_ids].to_numpy(),
        device=device,
        dtype=torch.float,
    )

    # ===================forward=====================
    x_enc1, X_train_mu1, logvar1, X_recons1 = vae_model1(model1_in)
    X_mu1, X_theta1, X_pi1 = ffb_zinb_model1(X_recons1)

    x_enc2, X_train_mu2, logvar2, X_recons2 = vae_model2(model2_in)
    X_mu2, X_theta2, X_pi2 = ffb_zinb_model2(X_recons2)

    coral_loss = coral(X_train_mu1, X_train_mu2)

    loss_zinb1 = get_zinb_loss(
        model1_in,
        X_mu1,
        X_theta1,
        X_pi1,
        is_mean=True,
        eps=eps,
        ridge_lambda=ridge_lambda,
    )
    loss_vae1 = get_kld_loss(X_train_mu1, logvar1, is_mean=True)

    loss_zinb2 = get_zinb_loss(
        model2_in,
        X_mu2,
        X_theta2,
        X_pi2,
        is_mean=True,
        eps=eps,
        ridge_lambda=ridge_lambda,
    )
    loss_vae2 = get_kld_loss(X_train_mu2, logvar2, is_mean=True)

    loss_epoch = loss_zinb1 + loss_vae1 + coral_loss + loss_zinb2 + loss_vae2

    loss_epoch.backward()

    optimizer.step()
    train_losses.append(loss_epoch.item())
    if epoch % 10 == 0:
        curr_error = get_cell_line_tcga(vae_model1, ffb_zinb_model1)
        val_errors.append(curr_error)
        if (best_error is None) or (curr_error < best_error):
            best_error = curr_error
            
            if epoch > 20:
                torch.save(
                    vae_model1.state_dict(),
                    f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
                )
                torch.save(
                    ffb_zinb_model1.state_dict(),
                    f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_zinb_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
                )

                torch.save(
                    vae_model2.state_dict(),
                    f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
                )
                torch.save(
                    ffb_zinb_model2.state_dict(),
                    f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_zinb_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
                )

    print(
        "epoch [{}/{}], loss:{:.6f}".format(
            epoch + 1, num_iterations, loss_epoch.item()
        )
    )

    if (len(train_losses) > 2) and abs(
        train_losses[-1] - train_losses[-2]
    ) < convg_thres:
        print("Training converged, exiting")
        break



epoch [1/50], loss:1.277349




epoch [2/50], loss:0.940689




epoch [3/50], loss:0.656264
epoch [4/50], loss:0.507137




epoch [5/50], loss:0.425920
epoch [6/50], loss:0.378338




epoch [7/50], loss:0.337705
epoch [8/50], loss:0.309702




epoch [9/50], loss:0.288941
epoch [10/50], loss:0.271330




epoch [11/50], loss:0.250265




epoch [12/50], loss:0.231707
epoch [13/50], loss:0.217306




epoch [14/50], loss:0.208442
epoch [15/50], loss:0.203055




epoch [16/50], loss:0.199637
epoch [17/50], loss:0.194367




epoch [18/50], loss:0.188004
epoch [19/50], loss:0.178567




epoch [20/50], loss:0.167325
epoch [21/50], loss:0.154403




epoch [22/50], loss:0.142167




epoch [23/50], loss:0.129109
epoch [24/50], loss:0.117697




epoch [25/50], loss:0.108433
epoch [26/50], loss:0.101021




epoch [27/50], loss:0.095384
epoch [28/50], loss:0.090952




epoch [29/50], loss:0.087775
epoch [30/50], loss:0.085474




epoch [31/50], loss:0.084376
epoch [32/50], loss:0.083110




epoch [33/50], loss:0.082417
epoch [34/50], loss:0.082516
Training converged, exiting




In [9]:
vae_model1.load_state_dict(
    torch.load(
        f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
        map_location="cuda:1",
    )
)
ffb_zinb_model1.load_state_dict(
    torch.load(
        f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_zinb_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
        map_location="cuda:1",
    )
)
vae_model2.load_state_dict(
    torch.load(
        f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
        map_location="cuda:1",
    )
)
ffb_zinb_model2.load_state_dict(
    torch.load(
        f"../data/model_checkpoints/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_zinb_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
        map_location="cuda:1",
    )
)

<All keys matched successfully>

In [10]:
vae_model1.eval()
ffb_zinb_model1.eval()
vae_model2.eval()
ffb_zinb_model2.eval()

ffnzinb(
  (zinb_layers_mu): Sequential(
    (mu): Linear(in_features=7776, out_features=7776, bias=True)
  )
  (zinb_layers_theta): Sequential(
    (theta): Linear(in_features=7776, out_features=7776, bias=True)
  )
  (zinb_layers_pi): Sequential(
    (pi): Linear(in_features=7776, out_features=7776, bias=True)
    (pi-actf): Sigmoid()
  )
)