In [1]:
import numpy as np
import pandas as pd
import os
import networkx as nx


import torch
from torch_geometric.data import Data


train_df = pd.read_pickle('processed/train_df.pkl')
val_df = pd.read_pickle('processed/val_df.pkl')
test_df = pd.read_pickle('processed/test_df.pkl')

max_resid = 168
aas = 'ACDEFGHIKLMNPQRSTVWY'
aa_dict = {aa: np.array([1 if i == j else 0 for i in range(len(aas))]) for j, aa in enumerate(aas)}

wt_seq = "MTEYKLVVVGAGGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTRQGVDDAFYTLVREIRKHKEK"
wt_x = np.array([aa_dict[aa] for aa in wt_seq[:max_resid]])
wt_x.shape


# Function to convert edge matrix to PyG Data object
def create_pyg_data(df):
    data_list = []
    for index, row in df.iterrows():
        edge_matrix = row['edge_matrix']
        edge_index = torch.tensor(np.nonzero(edge_matrix), dtype=torch.long)
        edge_attr = torch.tensor(edge_matrix[edge_matrix != 0], dtype=torch.float)
        label = torch.tensor((row['inactive_dist'],  row['active_dist']), dtype=torch.float)
        
        x = wt_x.copy()
        mut_resid = row['resid']
        mut_aa = row['variant'].split('_')[0][-1]
        x[mut_resid-1] = aa_dict[mut_aa]
        x = torch.tensor(x, dtype=torch.float)
        
        data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x, y=label)
        data_list.append(data)
    return data_list

# Create PyG dataset
train_list = create_pyg_data(train_df)
val_list = create_pyg_data(val_df)
test_list = create_pyg_data(test_df)

  from .autonotebook import tqdm as notebook_tqdm
  edge_index = torch.tensor(np.nonzero(edge_matrix), dtype=torch.long)


In [2]:
from torch_geometric.loader import DataLoader, DenseDataLoader


train_loader = DataLoader(train_list, batch_size=128, shuffle=True)
val_loader = DataLoader(val_list, batch_size=128, shuffle=True)
test_loader = DataLoader(test_list, batch_size=128, shuffle=False)

len(train_list), len(val_list), len(test_list)

(806, 101, 101)

In [3]:
from tqdm import tqdm

import torch
from model import DenseDiffPool, DenseGNN
from torch_geometric.utils import to_dense_batch, to_dense_adj

import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(model, loader, optimizer, loss_func):

    total_loss, total_lp, total_er = 0, 0, 0
    model.train()
    for data in loader:
        data = data.to(device)

        x, m = to_dense_batch(data.x, data.batch)
        a = to_dense_adj(data.edge_index, data.batch, data.edge_attr)
        

        optimizer.zero_grad()

        if isinstance(model, DenseDiffPool):
            output, lp_loss, er_loss = model(x, a, m)
        else:
            output = model(x, a, m)
            lp_loss, er_loss = torch.tensor(0), torch.tensor(0)

        loss = loss_func(output, data.y.view(-1, 2))

        combined_loss = 50 * loss +  lp_loss +  er_loss
        combined_loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_lp += lp_loss.item()
        total_er += er_loss.item()

    return total_loss / len(loader), total_lp / len(loader), total_er / len(loader)

def test(model, loader):

    mse = 0
    model.eval()
    for data in loader:
        data = data.to(device)

        x, m = to_dense_batch(data.x, data.batch)
        a = to_dense_adj(data.edge_index, data.batch, data.edge_attr)

        with torch.no_grad():
            if isinstance(model, DenseDiffPool):
                out, _, _ = model(x, a, m)
            else:
                out = model(x, a, m)
            mse += F.mse_loss(out, data.y.view(-1, 2)).item()

    

    return mse / len(loader)
     


In [4]:
epochs = 500
learning_rate = 1e-4
weight_decay = 1e-2
skip = True

max_resid = 168
num_layers = 3
in_channels = 20
hidden_channels = 64
out_channels = 2

node_ratio = 0.25

heads = 5
dropout = 0.2

In [5]:
df_results = pd.DataFrame(columns=['model', 'conv_type', 'test_mse'])
df_train = pd.DataFrame(columns=['model', 'conv_type', 'epoch', 'train_loss', 'train_lp', 'train_er', 'val_loss'])
def run_results(model_name, conv_type, n_runs=5):


    for run in range(n_runs):

        if model_name == 'DenseGNN':
            model = DenseGNN(conv_type=conv_type, 
                            num_layers=num_layers, in_channels=in_channels, hidden_channels=hidden_channels, out_channels=out_channels, 
                            skip=skip, pred=True,
                            heads=heads, dropout=dropout)
        elif model_name == 'DenseDiffPool':
            model = DenseDiffPool(max_nodes=max_resid, node_ratio=node_ratio, 
                            conv_type=conv_type, num_layers=num_layers, in_channels=in_channels, hidden_channels=hidden_channels, out_channels=out_channels, skip=skip,
                            heads=heads, dropout=dropout,
                            diff_skip=True)
            
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        loss_func = torch.nn.MSELoss()

        model = model.to(device)

        for epoch in tqdm(range(1, epochs)):
            loss, lp_loss, er_loss = train(model, train_loader, optimizer, loss_func)
            val_mse = test(model, val_loader)

            df_train.loc[len(df_train)] = [model_name, conv_type, epoch, loss, lp_loss, er_loss, val_mse]

        test_mse = test(model, test_loader)
        df_results.loc[len(df_results)] = [model_name, conv_type, test_mse]


In [None]:
run_results('DenseDiffPool', 'MLP')
run_results('DenseDiffPool', 'GCN')
run_results('DenseDiffPool', 'SAGE')
run_results('DenseDiffPool', 'GAT')
run_results('DenseDiffPool', 'GCN-GAT')

df_results