In [17]:
# basics + plotting
import os, sys
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 250
plt.rcParams["font.family"] = "sans serif"

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler

# custom
PROJECT_PATH = '/'.join(os.getcwd().split('/')[:-1])
sys.path.insert(1, PROJECT_PATH)

from utils import (
    data_utils, 
    eval_utils, 
    plotting_utils, 
    train_test_utils
)

from models import (
    esm_transfer_learning,
    gnn
)

import importlib
data_utils = importlib.reload(data_utils)
eval_utils = importlib.reload(eval_utils)
train_test_utils = importlib.reload(train_test_utils)
esm_transfer_learning = importlib.reload(esm_transfer_learning)
gnn = importlib.reload(gnn)

In [3]:
if torch.cuda.is_available():
    print('using cuda')
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

using cuda


In [4]:
data = data_utils.load_variation_dataset("../data/data/", 
                                         "../gene_list.txt", 
                                         ["embeddings", "indicators"], 
                                         "../data/phenotypes_hcm_only.parquet",
                                         predict=["hcm"], 
                                         low_memory=True,
                                         embeddings_file='esm2s_embeddings.npy',
                                         ppi_graph_path='../ppi_networks/string_interactions_short.tsv')

Fetching FHL1 data ... Done
Fetching ACTC1 data ... Done
Fetching ACTN2 data ... Done
Fetching CSRP3 data ... Done
Fetching MYBPC3 data ... Done
Fetching MYH6 data ... Done
Fetching MYH7 data ... Done
Fetching MYL2 data ... Done
Fetching MYL3 data ... Done
Fetching MYOZ2 data ... Done
Fetching LDB3 data ... Done
Fetching TCAP data ... Done
Fetching TNNC1 data ... Done
Fetching TNNI3 data ... Done
Fetching TNNT2 data ... Done
Fetching TPM1 data ... Done
Fetching TRIM63 data ... Done
Fetching PLN data ... Done
Fetching JPH2 data ... Done
Fetching FLNC data ... Done
Fetching ALPK3 data ... Done
Fetching LMNA data ... Done
Fetching NEXN data ... Done
Fetching VCL data ... Done
Fetching MYOM2 data ... Done
Fetching CASQ2 data ... Done
Fetching CAV3 data ... Done
Fetching MYLK2 data ... Done
Fetching CRYAB data ... Done
Combining tables ... Done
Integrating with phenotypes data ...Done
Processing PPI graph ...Done


In [5]:
train_dataset, test_dataset = data.train_test_split(balance_on=['hcm','ethnicity'])

In [49]:
def train_epoch(model, 
                train_loader, 
                optimizer, 
                loss_fn, 
                log_every=10):
    model.train()
    total_loss = 0
    
    all_labels, all_preds = [],[]
    for i, batch in enumerate(train_loader):
        # move batch dictionary to device
        data_utils.batch_dict_to_device(batch, device)
        labels = batch['labels']
        
        # compute prediction and loss
        preds = model(batch)
        loss = loss_fn(preds, labels)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # tracking
        total_loss += loss.item()
        all_labels.append(labels.cpu())
        all_preds.append(preds.flatten().detach().cpu())
        
        # logging
        if (i % log_every == 0):
            print(f"\tBatch {i} | BCE Loss: {loss.item():.4f}")
    
    metrics = eval_utils.get_metrics(torch.cat(all_labels), 
                                     torch.cat(all_preds))
    metrics['loss'] = total_loss
    
    return metrics

def test(model, 
         test_loader, 
         loss_fn):
    model.eval()
    total_loss = 0
    all_labels, all_preds = [],[]
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            # move batch dictionary to device
            data_utils.batch_dict_to_device(batch, device)
            labels = batch['labels']

            # compute prediction and loss
            preds = model(batch)
            loss = loss_fn(preds, labels)

            # tracking
            total_loss += loss.item()
            
            all_labels.append(labels.cpu())
            all_preds.append(preds.flatten().detach().cpu())
            
            # logging
            if (i % 10 == 0):
                print(f"\tBatch {i} | BCE Loss: {loss.item():.4f}")
    
    metrics = eval_utils.get_metrics(torch.cat(all_labels), 
                                     torch.cat(all_preds))
    metrics['loss'] = total_loss
    
    return metrics

def train(model, 
          train_dataset,
          test_dataset,
          lr=1e-3, 
          n_epochs=16,
          batch_size=32):
    
    train_loader = DataLoader(
        dataset = train_dataset, 
        batch_size = batch_size,
        sampler = WeightedRandomSampler(train_dataset.weights('hcm',
                                                              flatten_factor=1), 
                                        num_samples = len(train_dataset)),
        num_workers=18
    )
    
    test_loader = DataLoader(
        dataset = test_dataset,
        batch_size = batch_size,
        num_workers=18
    )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    track_metrics = {'train':{i:None for i in range(n_epochs)}, 
                     'test': {i:None for i in range(n_epochs)}}
    
    for epoch in range(n_epochs):
        print(f"Epoch #{epoch}:")
        train_metrics = train_epoch(model, 
                                    train_loader, 
                                    optimizer, 
                                    loss_fn,
                                    log_every=10)
        print("Train metrics:")
        eval_utils.print_metrics(train_metrics)
        test_metrics = test(model, 
                            test_loader, 
                            loss_fn)
        print("Test metrics:")
        eval_utils.print_metrics(test_metrics)
        
        track_metrics['train'][epoch] = train_metrics
        track_metrics['test'][epoch] = test_metrics
    
    return track_metrics


In [57]:
gat_params = {
    'in_dim': 0,
    'embed_dim': 256,
    'n_heads': 4,
    'n_nodes': 0,
    'mlp_hidden_dims': [128],
    'mlp_actn': 'gelu'
}

gcn_params = {
    'in_dim': 0,
    'embed_dim': 256,
    'n_nodes': 0,
    'mlp_hidden_dims': [256],
    'mlp_actn': 'leakyrelu'
}

mlp_params = {
    'in_dim': 0,
    'hidden_dims': [],
    'out_dim':0,
    'actn': 'None'
}

model = esm_transfer_learning.ESMTransferLearner(
    esm_model_name="esm1v_t33_650M_UR90S_1",
    agg_emb_method="weighted_average", 
    predict_method="gat",
    n_genes=29,
    predictor_params=gat_params,
    add_residue_features=data.get_feature_dimensions(['indicators']),
    edge_index=data.edge_index.to('cuda')
)

In [58]:
model = nn.DataParallel(model)

In [61]:
sum([m.numel() for m in model.parameters()])

11032891

In [59]:
model

DataParallel(
  (module): ESMTransferLearner(
    (agg_embeddings): ProteinEmbeddingsGenerator(
      (weight_logits): ModuleList(
        (0): Linear(in_features=1281, out_features=1, bias=True)
        (1): Linear(in_features=1281, out_features=1, bias=True)
        (2): Linear(in_features=1281, out_features=1, bias=True)
        (3): Linear(in_features=1281, out_features=1, bias=True)
        (4): Linear(in_features=1281, out_features=1, bias=True)
        (5): Linear(in_features=1281, out_features=1, bias=True)
        (6): Linear(in_features=1281, out_features=1, bias=True)
        (7): Linear(in_features=1281, out_features=1, bias=True)
        (8): Linear(in_features=1281, out_features=1, bias=True)
        (9): Linear(in_features=1281, out_features=1, bias=True)
        (10): Linear(in_features=1281, out_features=1, bias=True)
        (11): Linear(in_features=1281, out_features=1, bias=True)
        (12): Linear(in_features=1281, out_features=1, bias=True)
        (13): Linear(

In [53]:
model_metrics = train(model, 
                      train_dataset,
                      test_dataset, 
                      lr=1e-4, 
                      n_epochs=3,
                      batch_size=128)

Epoch #0:
	Batch 0 | BCE Loss: 0.6923
	Batch 10 | BCE Loss: 0.6927
	Batch 20 | BCE Loss: 0.6948
	Batch 30 | BCE Loss: 0.6929
	Batch 40 | BCE Loss: 0.6901
	Batch 50 | BCE Loss: 0.6911
	Batch 60 | BCE Loss: 0.6987
	Batch 70 | BCE Loss: 0.6945
	Batch 80 | BCE Loss: 0.7019
	Batch 90 | BCE Loss: 0.6956
	Batch 100 | BCE Loss: 0.6931
	Batch 110 | BCE Loss: 0.6949
	Batch 120 | BCE Loss: 0.6928


KeyboardInterrupt: 

In [None]:
def plot_metrics(metrics, axes):
    metric_types = list(metrics['train'][0].keys())
    for mt,ax in zip(metric_types, axes.flatten()):
        ax.plot([x[mt] for x in model_metrics['train'].values()])
        ax.plot([x[mt] for x in model_metrics['test'].values()])
        ax.set_title(mt)
    axes[0,0].legend(['train', 'test'])
    
fig, axes = plt.subplots(2,3,figsize=(12,6), constrained_layout=True)

plot_metrics(model_metrics, axes)
# fig.savefig('../figures/lr_initial.png')