In [None]:
#Load model

gat_path = ''

gat = torch_geometric.nn.GAT(-1, 20, num_layers=2, out_channels=1, dropout=0.5)
gat.load_state_dict(torch.load(gat_path))
gat.eval()

gcn_path = ''

gcn = torch_geometric.nn.GCN(-1, 10, num_layers=2, out_channels=1, dropout=0.5)
gcn.load_state_dict(torch.load(gcn_path))
gcn.eval()

# Loss

In [None]:
#functions to calculate loss on test

def predict(model, batch, device):
    """ Predice node-wise probabilities of being a coordinating atom given a batch. """
    model.to(device)
    
    batch = batch.to(device)
    
    out_logits = model(x=batch.x, edge_index=batch.edge_index.to(torch.int64), edge_attr=batch.edge_attr)
    out_probs = torch.nn.functional.sigmoid(out_logits)
    return out_probs

def compute_batch_loss(preds: torch.Tensor, labels: torch.Tensor, inds: torch.Tensor):
    """ 
    Computes a cross-entropy loss for each atom.
    Then, computes the mean of that loss for each ligand, and then across all ligands in the batch
    Parameters
    ----------
    preds : torch.Tensor (N,1)
        Atom-wise predicted logits for not-being or being a coordinating atom
    labels : torch.Tensor (N,1)
        Atom-wise labels for whether it isn't or is a coordinating atom
    inds : torch.Tensor (batch_size+1)
        The indices defining the ligands within each batch. Uses the batch.ptr generated by the torch_geometric dataloader.
    Return
    ------
    torch.Tensor (1,)
        Mean batch loss
    """
    # this is (N,2) for some reason
    loss_per_node = torch.nn.functional.binary_cross_entropy(preds, labels, reduction='none',
                                                             weight=torch.Tensor([1]).to(0))
    # Compute the mean cross-entropy across each individual graph, then the mean across the entire batch
    # graph_sizes = torch.diff(inds)
    # segment_ids = torch.repeat_interleave(torch.arange(len(graph_sizes), device=preds.device), graph_sizes)
    # graph_losses = torch_scatter.scatter_mean(loss_per_node, segment_ids, dim=0)
    # return graph_losses.mean()
    
    # Experimenting with averaging negative and positive losses
    # Note: this does not average over graphs
    # neg_loss = loss_per_node[labels[:,0].nonzero()]
    # pos_loss = loss_per_node[labels[:,1].nonzero()]
    # return neg_loss.mean() + pos_loss.mean()
    
    ## Average negative and positive losses separately per graph
    graph_sizes = torch.diff(inds)
    # Get how many ones/zeros are in each individual graph
    num_ones_per_graph = torch.Tensor([len(labels[inds[i-1]:inds[i]].nonzero()) for i in range(1,len(inds))],
                                     ).to(torch.long)
    num_zeros_per_graph = torch.Tensor([len(torch.where(labels[inds[i-1]:inds[i]]==0)[0]) for i in range(1,len(inds))],
                                     ).to(torch.long)
    ones_seg_ids = torch.repeat_interleave(torch.arange(len(num_ones_per_graph)), num_ones_per_graph).to(preds.device)
    zeros_seg_ids = torch.repeat_interleave(torch.arange(len(num_zeros_per_graph)), num_zeros_per_graph).to(preds.device)
    # compute mean loss for each pos/neg for each graph
    pos_loss = torch_scatter.scatter_mean(loss_per_node[labels.flatten().nonzero().flatten()], ones_seg_ids, dim=0)
    neg_loss = torch_scatter.scatter_mean(loss_per_node[torch.where(labels==0)[0]], zeros_seg_ids, dim=0)
    combined_loss_per_graph = pos_loss + neg_loss # element-wise for each graph

    # pred_num_one = torch_scatter.scatter_add(loss_per_node[labels.flatten().nonzero().flatten()], ones_seg_ids, dim=0)
    # pred_num_zero = torch_scatter.scatter_add(loss_per_node[torch.where(labels==0)[0]], zeros_seg_ids, dim=0)
    
    return (combined_loss_per_graph.mean())
            # + 0.5*torch.mean(torch.square(pred_num_one/denticities - 1))
            # + 0.01*torch.mean(torch.square(pred_num_zero/(natoms-denticities))))

test_data = torch.load('data/test_dataset.pt')
test_loader = DataLoader(test_data, batch_size=100, shuffle=False)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pred_loss = 0

with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch.to(device)
        out_probs = 1-batch.y.to(torch.float64).to(device)
        loss = compute_batch_loss(out_probs, batch.y.to(torch.float64), batch.ptr)
        pred_opp_loss += loss.item()
pred_loss = pred_loss / (i+1)

pred_loss

# Denticity

# Identity