In [3]:
# basics + plotting
import os, sys
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 250
plt.rcParams["font.family"] = "sans serif"

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler

# custom
PROJECT_PATH = '/'.join(os.getcwd().split('/')[:-1])
sys.path.insert(1, PROJECT_PATH)

from utils import (
    data_utils, 
    eval_utils, 
    plotting_utils, 
    train_test_utils
)

from models import (
    gat
)

In [8]:
import importlib
data_utils = importlib.reload(data_utils)
eval_utils = importlib.reload(eval_utils)

In [6]:
if cuda.is_available():
    print('using cuda')
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

using cuda


In [9]:
data = data_utils.load_variation_dataset("../data/data/", 
                                         "../gene_list.txt", 
                                         ["embeddings"], 
                                         "../data/phenotypes_hcm_only.parquet",
                                         predict = ["hcm"], low_memory=True,
                                         embeddings_file='esm2s_embeddings')

Fetching FHL1 data ... Done
Fetching ACTC1 data ... Done
Fetching ACTN2 data ... Done
Fetching CSRP3 data ... Done
Fetching MYBPC3 data ... Done
Fetching MYH6 data ... Done
Fetching MYH7 data ... Done
Fetching MYL2 data ... Done
Fetching MYL3 data ... Done
Fetching MYOZ2 data ... Done
Fetching LDB3 data ... Done
Fetching TCAP data ... Done
Fetching TNNC1 data ... Done
Fetching TNNI3 data ... Done
Fetching TNNT2 data ... Done
Fetching TPM1 data ... Done
Fetching TRIM63 data ... Done
Fetching PLN data ... Done
Fetching JPH2 data ... Done
Fetching FLNC data ... Done
Fetching ALPK3 data ... Done
Fetching LMNA data ... Done
Fetching NEXN data ... Done
Fetching VCL data ... Done
Fetching MYOM2 data ... Done
Fetching CASQ2 data ... Done
Fetching CAV3 data ... Done
Fetching MYLK2 data ... Done
Fetching CRYAB data ... Done
Combining tables ... Done
Integrating with phenotypes data ...Done


In [6]:
def convertPPIDataframeToNX(df):
    Graphtype = nx.Graph()
    G = nx.from_pandas_edgelist(df, source='#node1',target='node2',edge_attr=None, create_using=Graphtype)
    # print(G.is_directed())
    G = from_networkx(G)
    return G

df = pd.read_csv("ppi_networks/ppi2.tsv",delimiter='\t')
G = convertPPIDataframeToNX(df)
 

In [10]:
def train_epoch(model, 
                train_loader, 
                optimizer, 
                loss_fn, 
                graph, 
                log_every=10):
    model.train()
    total_loss = 0
    
    all_labels, all_preds = [],[]
    for i, batch in enumerate(train_loader):
        # move batch dictionary to device
        data_utils.batch_dict_to_device(batch, device)
        labels, features = batch['labels'], batch['embeddings']
        
        # compute prediction and loss
        preds = model(features,graph.edge_index)
        loss = loss_fn(preds, labels)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # tracking
        total_loss += loss.item()
        all_labels.append(labels.cpu())
        all_preds.append(preds.flatten().detach().cpu())
        
        # logging
        if (i % log_every == 0):
            print(f"\tBatch {i} | BCE Loss: {loss.item():.4f}")
    
    metrics = eval_utils.get_metrics(torch.cat(all_labels), 
                                     torch.cat(all_preds))
    metrics['loss'] = total_loss
    
    return metrics

def test(model, 
         test_loader, 
         loss_fn,
         graph):
    model.eval()
    total_loss = 0
    all_labels, all_preds = [],[]
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            # move batch dictionary to device
            data_utils.batch_dict_to_device(batch, device)
            labels, features = batch['labels'], batch['embeddings']

            # compute prediction and loss
            preds = model(features,graph.edge_index)
            loss = loss_fn(preds, labels)

            # tracking
            total_loss += loss.item()
            
            all_labels.append(labels.cpu())
            all_preds.append(preds.flatten().detach().cpu())
    
    metrics = eval_utils.get_metrics(torch.cat(all_labels), 
                                     torch.cat(all_preds))
    metrics['loss'] = total_loss
    
    return metrics

def train(model, 
          train_dataset,
          test_dataset,
          graph,
          lr=1e-3, 
          n_epochs=10,
          batch_size=32):
    
    train_loader = DataLoader(
        dataset = train_dataset, 
        batch_size = batch_size,
        sampler = WeightedRandomSampler(train_dataset.weights('hcm',
                                                              flatten_factor=1), 
                                        num_samples = len(train_dataset)),
        num_workers=2
    )
    
    test_loader = DataLoader(
        dataset = test_dataset,
        batch_size = batch_size,
        num_workers=2
    )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    loss_fn = nn.BCEWithLogitsLoss()
    ptimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    track_metrics = {'train':{i:None for i in range(n_epochs)}, 
                     'test': {i:None for i in range(n_epochs)}}
    
    for epoch in range(n_epochs):
        print(f"Epoch #{epoch}:")
        train_metrics = train_epoch(model, 
                                    train_loader, 
                                    optimizer, 
                                    loss_fn,
                                    graph,
                                    log_every=10)
        print("Train metrics:")
        eval_utils.print_metrics(train_metrics)
        test_metrics = test(model, 
                            test_loader, 
                            loss_fn,
                            graph)
        print("Test metrics:")
        eval_utils.print_metrics(test_metrics)
        
        track_metrics['train'][epoch] = train_metrics
        track_metrics['test'][epoch] = test_metrics
    
    return track_metrics


In [15]:
# You can keep this in json and load it up
hparams = {
    "model_params":{
        "input_size":320,
        "hid_dim":100,
        "alpha": 0.2,
        "num_heads":5,
    },
    "train_params":{
        "bsz":16,
        "num_epochs":100,
        "lr":1e-3
    }
}

hparams["model_params"]["ngenes"] = data.n_genes

# 
# hparams["model_params"]["input_size"] = list(train_dataset.dataset.data['embeddings'].values())[0].shape[2]


In [16]:
model = gat.GAT_Protein(**hparams["model_params"])
model

GAT_Protein(
  (input_layer): Linear(in_features=320, out_features=500, bias=True)
  (GATLayer1): GATv2Conv(500, 500, heads=1)
  (output_layer1): Linear(in_features=14500, out_features=500, bias=True)
  (output_layer): Linear(in_features=500, out_features=1, bias=True)
  (leakyrelu): LeakyReLU(negative_slope=0.2)
)

In [None]:
model_metrics = train(model, 
                      train_dataset,
                      test_dataset, 
                      lr=hparams["train_params"]["lr"], 
                      n_epochs=hparams["train_params"]["num_epochs"],
                      batch_size=hparams["train_params"]["bsz"])

In [None]:
def plot_metrics(metrics, axes):
    metric_types = list(metrics['train'][0].keys())
    for mt,ax in zip(metric_types, axes.flatten()):
        ax.plot([x[mt] for x in model_metrics['train'].values()])
        ax.plot([x[mt] for x in model_metrics['test'].values()])
        ax.set_title(mt)
    axes[0,0].legend(['train', 'test'])
    
fig, axes = plt.subplots(2,3,figsize=(12,6), constrained_layout=True)

plot_metrics(model_metrics, axes)
# fig.savefig('../figures/lr_initial.png')