In [3]:
from tqdm import tqdm
import numpy as np
import os, sys
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch_geometric.utils.convert import from_networkx
from torch import cuda
import networkx as nx

from utils import (
    data_utils, 
    eval_utils, 
    plotting_utils, 
    train_test_utils
)

from models import (
    GAT
)

# print(os.getcwd())


In [7]:
import importlib
data_utils = importlib.reload(data_utils)
eval_utils = importlib.reload(eval_utils)

In [5]:
if cuda.is_available():
    device = torch.device('cuda')
else:
    print("this is running in cpu!")
    device = torch.device('cpu')

In [None]:
with open("../gene_list.txt", 'r') as file:
    gene_list = [x.strip() for x in file.readlines()]

train_dataset, test_dataset = data_utils.load_variation_dataset(data_dir='../data/data',
                                                                gene_list=gene_list,
                                                                data_types=['seq-var-matrix'], 
                                                                phenotypes_path="../data/131338.parquet",
                                                                keep_genes_separate = False)

In [6]:
def convertPPIDataframeToNX(df):
    Graphtype = nx.Graph()
    G = nx.from_pandas_edgelist(df, source='#node1',target='node2',edge_attr=None, create_using=Graphtype)
    # print(G.is_directed())
    G = from_networkx(G)
    return G

df = pd.read_csv("ppi_networks/ppi2.tsv",delimiter='\t')
G = convertPPIDataframeToNX(df)
 

In [None]:
def train_epoch(model, train_loader, optimizer, loss_fn, graph, log_every=10):
    model.train()
    total_loss = 0
    
    all_labels, all_preds = [],[]
    for i, batch in enumerate(train_loader):
        # move batch dictionary to device
        data_utils.batch_dict_to_device(batch, device)
        labels, features = batch['labels'], batch['embeddings'] #changed to batch['embeddings'] from  batch['seq-var-matrix']
        
        # compute prediction and loss
        preds = model(features,graph.edge_index)
        loss = loss_fn(preds, labels)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # tracking
        total_loss += loss.item()
        all_labels.append(labels.cpu())
        all_preds.append(preds.flatten().detach().cpu())
        
        # logging
        if (i % log_every == 0):
            print(f"\tBatch {i} | BCE Loss: {loss.item():.4f}")
    
    metrics = eval_utils.get_metrics(torch.cat(all_labels), 
                                     torch.cat(all_preds))
    metrics['loss'] = total_loss
    
    return metrics
def test(model, test_loader, loss_fn,graph):
    model.eval()
    total_loss = 0
    all_labels, all_preds = [],[]
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_loader)):
            # move batch dictionary to device
            data_utils.batch_dict_to_device(batch, device)
            labels, features = batch['labels'], batch['embeddings'] #changed to batch['embeddings'] from  batch['seq-var-matrix']

            # compute prediction and loss
            preds = model(features,graph.edge_index)
            loss = loss_fn(preds, labels)

            # tracking
            total_loss += loss.item()
            
            all_labels.append(labels.cpu())
            all_preds.append(preds.flatten().detach().cpu())
    
    metrics = eval_utils.get_metrics(torch.cat(all_labels), 
                                     torch.cat(all_preds))
    metrics['loss'] = total_loss
    
    return metrics

def train(model, 
          train_dataset,
          test_dataset,
          graph,
          lr=1e-3, 
          n_epochs=10,
          batch_size=256):
    
    train_loader = DataLoader(
        dataset = train_dataset, 
        batch_size = batch_size,
        sampler = WeightedRandomSampler(train_dataset.weights('131338-0.0'), 
                                        num_samples = len(train_dataset)),
        num_workers=12
    )
    
    test_loader = DataLoader(
        dataset = test_dataset,
        batch_size = batch_size,
        num_workers=12
    )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    loss_fn = nn.BCEWithLogitsLoss()
    ptimizer = torch.optim.AdamW(model.parameters(), lr=hparams["train_params"]["lr"])
    
    track_metrics = {'train':{i:None for i in range(n_epochs)}, 
                     'test': {i:None for i in range(n_epochs)}}
    
    for epoch in range(n_epochs):
        print(f"Epoch #{epoch}:")
        train_metrics = train_epoch(model, 
                                    train_loader, 
                                    optimizer, 
                                    loss_fn,
                                    graph,
                                    log_every=10)
        print("Train metrics:")
        eval_utils.print_metrics(train_metrics)
        test_metrics = test(model, 
                            test_loader, 
                            loss_fn,
                           graph)
        print("Test metrics:")
        eval_utils.print_metrics(test_metrics)
        
        track_metrics['train'][epoch] = train_metrics
        track_metrics['test'][epoch] = test_metrics
    
    return track_metrics


In [13]:
# You can keep this in json and load it up
hparams = {
    "model_params":{
        "input_size":1280,
        "hid_dim":500,
        "alpha": 0.2,
        "num_heads":5,
        
    },
    "train_params":{
        "bsz":16,
        "num_epochs":100,
        "lr":1e-3
        
    }
    
}

hparams["model_params"]["ngenes"] = G.num_nodes

# 
# hparams["model_params"]["input_size"] = list(train_dataset.dataset.data['embeddings'].values())[0].shape[2]


In [19]:
model = GAT.GAT_Protein(**hparams["model_params"])
model

GAT_Protein(
  (input_layer): Linear(in_features=1280, out_features=500, bias=True)
  (GATLayer1): GATv2Conv(500, 500, heads=1)
  (output_layer1): Linear(in_features=11500, out_features=500, bias=True)
  (output_layer): Linear(in_features=500, out_features=1, bias=True)
  (leakyrelu): LeakyReLU(negative_slope=0.2)
)

In [None]:
model_metrics = train(model, 
                      train_dataset,
                      test_dataset, 
                      lr=hparams["train_params"]["lr"], 
                      n_epochs=hparams["train_params"]["num_epochs"],
                      batch_size=hparams["train_params"]["bsz"])

In [None]:
def plot_metrics(metrics, axes):
    metric_types = list(metrics['train'][0].keys())
    for mt,ax in zip(metric_types, axes.flatten()):
        ax.plot([x[mt] for x in model_metrics['train'].values()])
        ax.plot([x[mt] for x in model_metrics['test'].values()])
        ax.set_title(mt)
    axes[0,0].legend(['train', 'test'])
    
fig, axes = plt.subplots(2,3,figsize=(12,6), constrained_layout=True)

plot_metrics(model_metrics, axes)
# fig.savefig('../figures/lr_initial.png')