In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import pickle
import copy

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch_geometric
from torchsummary import summary
import torch_scatter

from rdkit import Chem
from rdkit.Chem import Descriptors, rdmolops

from sklearn.model_selection import train_test_split

# Start Here

In [2]:
class LigandDataset():
    def __init__(self, atom_list, natom_list, edge_index_list, edge_feats_list, y_list, denticities_list):
        self.atom_list = atom_list
        self.natom_list = natom_list
        self.edge_index_list = edge_index_list
        self.edge_feats_list = edge_feats_list
        self.y_list = y_list
        self.denticities_list = denticities_list
        
    def __len__(self):
        return len(self.atom_list)
        
    def __getitem__(self, idx):
        return Data(x=torch.Tensor(self.atom_list[idx]),
                    natoms=torch.Tensor([self.natom_list[idx]]),
                    edge_index=torch.Tensor(np.array(self.edge_index_list[idx])),
                    edge_attr=torch.Tensor(self.edge_feats_list[idx]),
                    y=torch.Tensor(self.y_list[idx]).unsqueeze(1).to(torch.long),
                    denticity=torch.Tensor([self.denticities_list[idx]]),
                    # y=torch.nn.functional.one_hot(torch.Tensor(self.y_list[idx]).to(torch.long), num_classes=2) # one-hot
                   )

train_data = torch.load('data/train_dataset.pt')
test_data = torch.load('data/test_dataset.pt')
val_data = torch.load('data/val_dataset.pt')

In [3]:
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
val_loader = DataLoader(val_data, batch_size=100, shuffle=True)
test_loader = DataLoader(test_data, batch_size=100, shuffle=False)

In [4]:
def compute_batch_loss(preds: torch.Tensor, labels: torch.Tensor, inds: torch.Tensor):
    """ 
    Computes a cross-entropy loss for each atom.
    Then, computes the mean of that loss for each ligand, and then across all ligands in the batch
    Parameters
    ----------
    preds : torch.Tensor (N,1)
        Atom-wise predicted logits for not-being or being a coordinating atom
    labels : torch.Tensor (N,1)
        Atom-wise labels for whether it isn't or is a coordinating atom
    inds : torch.Tensor (batch_size+1)
        The indices defining the ligands within each batch. Uses the batch.ptr generated by the torch_geometric dataloader.
    Return
    ------
    torch.Tensor (1,)
        Mean batch loss
    """
    # this is (N,2) for some reason
    loss_per_node = torch.nn.functional.binary_cross_entropy(preds, labels, reduction='none',
                                                             weight=torch.Tensor([1]).to(0))
    # Compute the mean cross-entropy across each individual graph, then the mean across the entire batch
    # graph_sizes = torch.diff(inds)
    # segment_ids = torch.repeat_interleave(torch.arange(len(graph_sizes), device=preds.device), graph_sizes)
    # graph_losses = torch_scatter.scatter_mean(loss_per_node, segment_ids, dim=0)
    # return graph_losses.mean()
    
    # Experimenting with averaging negative and positive losses
    # Note: this does not average over graphs
    # neg_loss = loss_per_node[labels[:,0].nonzero()]
    # pos_loss = loss_per_node[labels[:,1].nonzero()]
    # return neg_loss.mean() + pos_loss.mean()
    
    ## Average negative and positive losses separately per graph
    graph_sizes = torch.diff(inds)
    # Get how many ones/zeros are in each individual graph
    num_ones_per_graph = torch.Tensor([len(labels[inds[i-1]:inds[i]].nonzero()) for i in range(1,len(inds))],
                                     ).to(torch.long)
    num_zeros_per_graph = torch.Tensor([len(torch.where(labels[inds[i-1]:inds[i]]==0)[0]) for i in range(1,len(inds))],
                                     ).to(torch.long)
    ones_seg_ids = torch.repeat_interleave(torch.arange(len(num_ones_per_graph)), num_ones_per_graph).to(preds.device)
    zeros_seg_ids = torch.repeat_interleave(torch.arange(len(num_zeros_per_graph)), num_zeros_per_graph).to(preds.device)
    # compute mean loss for each pos/neg for each graph
    pos_loss = torch_scatter.scatter_mean(loss_per_node[labels.flatten().nonzero().flatten()], ones_seg_ids, dim=0)
    neg_loss = torch_scatter.scatter_mean(loss_per_node[torch.where(labels==0)[0]], zeros_seg_ids, dim=0)
    combined_loss_per_graph = pos_loss + neg_loss # element-wise for each graph

    # pred_num_one = torch_scatter.scatter_add(loss_per_node[labels.flatten().nonzero().flatten()], ones_seg_ids, dim=0)
    # pred_num_zero = torch_scatter.scatter_add(loss_per_node[torch.where(labels==0)[0]], zeros_seg_ids, dim=0)
    
    return (combined_loss_per_graph.mean())
            # + 0.5*torch.mean(torch.square(pred_num_one/denticities - 1))
            # + 0.01*torch.mean(torch.square(pred_num_zero/(natoms-denticities))))


In [13]:
gat = torch_geometric.nn.GAT(-1, 20, num_layers=2, out_channels=1, dropout=0.5)
optimizer = torch.optim.Adam(gat.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_epochs = 50
gat.to(device)

train_epoch_losses = []
val_epoch_losses = []

# what is the loss if we predict 0s for everything
pred_0_loss_val = 0
pred_opp_loss_val = 0
with torch.no_grad():
    for i, batch in enumerate(val_loader):
        batch.to(device)
        out_probs = torch.zeros(batch.y.shape, dtype=torch.float64).to(device)
        loss = compute_batch_loss(out_probs, batch.y.to(torch.float64), batch.ptr)
        pred_0_loss_val += loss.item()
pred_0_loss_val = pred_0_loss_val / (i+1)

# What is the loss if we predict the opposite for everything
with torch.no_grad():
    for i, batch in enumerate(val_loader):
        batch.to(device)
        out_probs = 1-batch.y.to(torch.float64).to(device)
        loss = compute_batch_loss(out_probs, batch.y.to(torch.float64), batch.ptr)
        pred_opp_loss_val += loss.item()
pred_opp_loss_val = pred_opp_loss_val / (i+1)
print(f'Val Loss for all 0 predictions: {pred_0_loss_val:.3}')
print(f'Val Loss for opposite predictions: {pred_opp_loss_val:.3}')

best_loss = 10000
# Training Loop
for epoch in range(num_epochs):
    epoch_train_loss = 0
    gat.train()
    for i, batch in enumerate(train_loader):
        batch.to(device)

        out_logits = gat(x=batch.x, edge_index=batch.edge_index.to(torch.int64), edge_attr=batch.edge_attr)
        out_probs = torch.nn.functional.sigmoid(out_logits)
        loss = compute_batch_loss(out_probs, batch.y.to(torch.float32), batch.ptr)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()
    scheduler.step()

    epoch_train_loss = epoch_train_loss / (i+1)
    train_epoch_losses.append(epoch_train_loss)
    
    # Validation
    epoch_val_loss = 0
    gat.eval()
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            batch.to(0)
            out_logits = gat(x=batch.x, edge_index=batch.edge_index.to(torch.int64), edge_attr=batch.edge_attr)
            out_probs = torch.nn.functional.sigmoid(out_logits)
            loss = compute_batch_loss(out_probs, batch.y.to(torch.float32), batch.ptr)
            epoch_val_loss += loss.item()
    epoch_val_loss = epoch_val_loss / (i+1)
    val_epoch_losses.append(epoch_val_loss)
    
    print(f'Epoch: {epoch+1} | Avg Train Loss: {epoch_train_loss:.3} | Avg Val Loss: {epoch_val_loss:.3}')

    # Early stopping
    if epoch_val_loss < best_loss:
        best_loss = epoch_val_loss
        best_model_weights = copy.deepcopy(gat.state_dict())  # Deep copy here      
        patience = 5  # Reset patience counter
    else:
        patience -= 1
        if patience == 0:
            break

    # Load the best model weights
    gat.load_state_dict(best_model_weights)

Val Loss for all 0 predictions: 1e+02
Val Loss for opposite predictions: 2e+02
Epoch: 1 | Avg Train Loss: 1.21 | Avg Val Loss: 0.933
Epoch: 2 | Avg Train Loss: 1.05 | Avg Val Loss: 0.776
Epoch: 3 | Avg Train Loss: 1.02 | Avg Val Loss: 0.723
Epoch: 4 | Avg Train Loss: 0.992 | Avg Val Loss: 0.699
Epoch: 5 | Avg Train Loss: 0.977 | Avg Val Loss: 0.681
Epoch: 6 | Avg Train Loss: 0.971 | Avg Val Loss: 0.669
Epoch: 7 | Avg Train Loss: 0.959 | Avg Val Loss: 0.648
Epoch: 8 | Avg Train Loss: 0.953 | Avg Val Loss: 0.65
Epoch: 9 | Avg Train Loss: 0.951 | Avg Val Loss: 0.652
Epoch: 10 | Avg Train Loss: 0.953 | Avg Val Loss: 0.655
Epoch: 11 | Avg Train Loss: 0.955 | Avg Val Loss: 0.654
Epoch: 12 | Avg Train Loss: 0.955 | Avg Val Loss: 0.653


In [None]:
plt.hist(out_probs.detach().cpu().numpy(), 20)
plt.show()

In [None]:
plt.plot(range(epoch+1), train_epoch_losses, label='Train Loss')
plt.plot(range(epoch+1), val_epoch_losses, label='Val Loss')
plt.legend()

In [9]:
gat.eval()
test_loss = 0
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch.to(0)
        out_logits = gat(x=batch.x, edge_index=batch.edge_index.to(torch.int64), edge_attr=batch.edge_attr)
        out_probs = torch.nn.functional.sigmoid(out_logits)
        loss = compute_batch_loss(out_probs, batch.y.to(torch.float32), batch.ptr)
        test_loss += loss.item()
    test_loss /= (i+1)
print(f'Test Loss: {test_loss}')

Test Loss: 0.8008538877710383


In [94]:
summary(gat)

Layer (type:depth-idx)                   Param #
├─ReLU: 1-1                              --
├─ModuleList: 1-2                        --
|    └─GATConv: 2-1                      --
|    |    └─SumAggregation: 3-1          --
|    |    └─Linear: 3-2                  20
|    └─GATConv: 2-2                      --
|    |    └─SumAggregation: 3-3          --
|    |    └─Linear: 3-4                  4
Total params: 24
Trainable params: 24
Non-trainable params: 0


Layer (type:depth-idx)                   Param #
├─ReLU: 1-1                              --
├─ModuleList: 1-2                        --
|    └─GATConv: 2-1                      --
|    |    └─SumAggregation: 3-1          --
|    |    └─Linear: 3-2                  20
|    └─GATConv: 2-2                      --
|    |    └─SumAggregation: 3-3          --
|    |    └─Linear: 3-4                  4
Total params: 24
Trainable params: 24
Non-trainable params: 0