In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from torch_geometric.nn import global_add_pool, GATConv, CGConv, VGAE,GCNConv
import pickle
from itertools import combinations
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader
from torch.utils.data import Dataset, DataLoader


In [49]:
#class VAEGeneratorEncoder(nn.Module):
#    def __init__(self, input_dim, hidden_dim, num_heads):
#        super().__init__()
#        self.conv_gat = GATConv(input_dim, hidden_dim, heads=num_heads)
#        self.conv1 = GCNConv(hidden_dim * num_heads, 2 * hidden_dim)
#        self.conv_mu = GCNConv(2 * hidden_dim, hidden_dim)
#        self.conv_logstd = GCNConv(2 * hidden_dim, hidden_dim)
#        self.relu = nn.ReLU()

#    def forward(self, x, edge_index, edge_attr):
#        x = x.float()
#        x = self.relu(self.conv_gat(x, edge_index, edge_attr))
#        x = self.relu(self.conv1(x, edge_index))
#        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

#model = VGAE(VAEGeneratorEncoder(input_dim, hidden_dim, num_heads))


In [71]:
def debug_nan(tensor, name="Tensor"):
    if torch.isnan(tensor).any():
        print(f"NaN detected in {name}")
        print(tensor)

In [92]:
input_dim = 11
hidden_dim_1 = 20
hidden_dim_2 = 10
num_bond_types = 5
num_atomic_features = 11

class VGAE(nn.Module):
    def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, num_bond_types, num_atomic_features):
        super(VGAE,self).__init__()
        self.input_dim = input_dim
        self.hidden_dim_1 = hidden_dim_1
        self.hidden_dim_2 = hidden_dim_2
        self.num_bond_types = num_bond_types
        self.base_gcn = GraphConvSparse(input_dim, hidden_dim_1, num_bond_types)
        self.gcn_mu = GraphConvSparse(hidden_dim_1, hidden_dim_2, num_bond_types, activation=lambda x:x)
        self.gcn_logstd = GraphConvSparse(hidden_dim_1, hidden_dim_2, num_bond_types, activation=lambda x:x)
        self.atomic_pred = nn.Linear(hidden_dim_2, num_atomic_features)
        self.decoder = nn.Linear(hidden_dim_2, 63 * num_bond_types)
        self.atomic_decoder = nn.Linear(hidden_dim_2, num_atomic_features)


    def reparameterize(self, mu, logstd):
        logstd = torch.clamp(logstd, min=-20, max=20)
        if self.training:
            std = torch.exp(logstd)
            eps = torch.randn_like(std)
            z = eps.mul(std).add_(mu)
            return z
        else:
            return mu

    def encode(self, x, adj):
        hidden = self.base_gcn(x, adj)
        #debug_nan(hidden, "hidden")
        self.mean = self.gcn_mu(hidden, adj)
        #debug_nan(self.mean, "mean")
        self.logstd = self.gcn_logstd(hidden, adj)
        #debug_nan(self.logstd, "logstd")
        return self.mean, self.logstd

    def forward(self, x, adj):
        mu, logstd = self.encode(x, adj)
        z = self.reparameterize(mu, logstd)
        #debug_nan(z, "Z")
        A_pred = self.decode_adj(z, x.size(1), self.num_bond_types)
        #debug_nan(A_pred, "A_pred")
        atomic_numbers_pred = self.atomic_decoder(z)
        atomic_numbers_pred = torch.clamp(atomic_numbers_pred, min=-1e10, max=1e10)
        #debug_nan(atomic_numbers_pred, "atomic number pred")
        atomic_numbers_pred = torch.sigmoid(atomic_numbers_pred)
        #debug_nan(atomic_numbers_pred, "atomic_numbers_pred_sigmoid")
        return A_pred, atomic_numbers_pred, mu, logstd

    def decode_adj(self, Z, num_nodes, num_bond_types):
        batch_size = Z.size(0)
        logits = self.decoder(Z)
        logits = torch.clamp(logits, min=-1e10, max=1e10)
        #debug_nan(logits, "logits_decode_adj")
        A_pred = logits.view(batch_size, num_nodes, num_nodes, num_bond_types)
        A_pred = torch.sigmoid(A_pred)
        return A_pred

class GraphConvSparse(nn.Module):
    def __init__(self, input_dim, output_dim, num_bond_types, activation = F.relu, **kwargs):
        super(GraphConvSparse, self).__init__(**kwargs)
        self.input_dim = input_dim
        self.weights = nn.Parameter(torch.randn(num_bond_types, input_dim, output_dim))
        self.num_bond_types = num_bond_types
        self.activation = activation

    def forward(self, x, adj):
        batch_size, num_nodes, _ = x.size()
        _, _, _, num_relations = adj.size()
        outputs = torch.zeros(batch_size, num_nodes, self.weights[0].size(1), device=x.device)
        for r in range(num_relations):

            weight = self.weights[r]

            support = torch.bmm(x, weight.unsqueeze(0).expand(batch_size, *weight.size()))

            adj_relation = adj[:, :, :, r]

            output = torch.bmm(adj_relation, support)
            outputs += output
        outputs = self.activation(outputs)
        return outputs




model = VGAE(input_dim, hidden_dim_1, hidden_dim_2, num_bond_types, num_atomic_features)

In [96]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


In [52]:
with open("../data/combined_data.pkl", "rb") as file:
    combined_data = pickle.load(file)

atomic_num_tensors = torch.load("../data/atomic_number_tensors.pt")
bond_tensors = torch.load("../data/bond_tensors.pt")

In [53]:
print(combined_data[0])
print(len(combined_data))

Data(x=[8, 10], edge_index=[2, 7], edge_attr=[7, 5], atom_data=[8, 35])
3208


In [54]:
print(atomic_num_tensors.shape)

torch.Size([3208, 63, 11])


In [55]:
print(bond_tensors.shape)

torch.Size([3208, 63, 63, 5])


In [56]:
class GraphDataset(Dataset):
    def __init__(self, atomic_numbers, bond_tensor):
        self.atomic_numbers = atomic_numbers
        self.bond_tensor = bond_tensor

    def __len__(self):
        return self.atomic_numbers.shape[0]

    def __getitem__(self, idx):
        node_features = self.atomic_numbers[idx]
        adj_matrix = self.bond_tensor[idx]
        return node_features, adj_matrix

In [57]:
dataset = GraphDataset(atomic_num_tensors, bond_tensors)

In [44]:
print(dataset.bond_tensor[0])

tensor([[[0., 0., 0., 0., 1.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]],

        [[1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         ...,
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]],

        [[1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         ...,
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]],

        ...,

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         ...,
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         ...,
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         

In [58]:
train_data, test_data = train_test_split(dataset, test_size=0.15, random_state=48)

batch_size = 32

train_data, valid_data = train_test_split(train_data, test_size=0.1, random_state=48)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size)


In [59]:
def preprocess_bonds(bond_tensor):
    ''' bond_tensor is of shape [num_nodes, num_nodes, 5]
    We split it to create a list of 5 [num_nodes, num_nodes] tensors'''
    return [bond_tensor[:, :, i] for i in range(bond_tensor.shape[2])]

In [31]:
def create_pos_edge_index(data):
    return data.edge_index

def create_neg_edge_index(num_nodes, pos_edge_index):
    all_possible_edges = list(combinations(range(num_nodes), 2))

    pos_edges = [(src.item(), dst.item()) for src, dst in pos_edge_index.t()]

    neg_edges = [edge for edge in all_possible_edges if edge not in pos_edges]

    neg_edge_index = torch.tensor(neg_edges).t()

    return neg_edge_index

In [60]:
def calculate_kl_divergence(mu, logvar):
    logvar = torch.clamp(logvar, min=-10, max=10)

    kl_divergence = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
    term1 = 1 + logvar
    term2 = mu.pow(2)
    term3 = logvar.exp()
    kl_divergence = -0.5 * torch.sum(term1 - term2 - term3)

    return kl_divergence

In [61]:
def loss_function(preds, labels, mu, logvar, preds_atomic, atomic_numbers):
    labels = torch.argmax(labels, dim=-1)
    BCE = torch.nn.functional.cross_entropy(preds.view(-1, preds.size(-1)), labels.view(-1), reduction='sum')
    KLD = calculate_kl_divergence(mu, logvar)
    MSE_atomic = F.mse_loss(preds_atomic, atomic_numbers, reduction='sum')
    return BCE + KLD + MSE_atomic

In [97]:
num_epochs = 50
for epoch in range(num_epochs):
    total_loss = 0
    for node_features, adj_matrix in train_loader:
        optimizer.zero_grad()

        # Ensure adj_matrix is in the correct format
        adj_matrix_dense = adj_matrix.to_dense() if adj_matrix.is_sparse else adj_matrix

        # Forward pass: compute the reconstructed matrix, mu, and logvar
        reconstructed, reconstructed_atoms, mu, logvar = model(node_features, adj_matrix_dense)

        # Compute the loss
        loss = loss_function(reconstructed, adj_matrix_dense, mu, logvar, reconstructed_atoms, node_features)
        total_loss += loss.item()

        # Backward pass: compute gradient and update params
        loss.backward()
        optimizer.step()

    # Calculate average loss for reporting
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

Epoch 1, Loss: 1353390620672.0
Epoch 2, Loss: 1211925004288.0
Epoch 3, Loss: 1291636834304.0
Epoch 4, Loss: 1130517233664.0
Epoch 5, Loss: 1069577666560.0
Epoch 6, Loss: 948908326912.0
Epoch 7, Loss: 946907512832.0
Epoch 8, Loss: 931140141056.0
Epoch 9, Loss: 923661369344.0
Epoch 10, Loss: 833265795072.0
Epoch 11, Loss: 904246394880.0
Epoch 12, Loss: 739094757376.0
Epoch 13, Loss: 858286194688.0
Epoch 14, Loss: 677460901888.0
Epoch 15, Loss: 798977097728.0
Epoch 16, Loss: 616941813760.0
Epoch 17, Loss: 653211140096.0
Epoch 18, Loss: 573961338880.0
Epoch 19, Loss: 630993453056.0
Epoch 20, Loss: 624450011136.0
Epoch 21, Loss: 532213825536.0
Epoch 22, Loss: 474545127424.0
Epoch 23, Loss: 502955540480.0
Epoch 24, Loss: 469013725184.0
Epoch 25, Loss: 418938421248.0
Epoch 26, Loss: 413619355648.0
Epoch 27, Loss: 410917535744.0
Epoch 28, Loss: 347372847104.0
Epoch 29, Loss: 347329560576.0
Epoch 30, Loss: 343248371712.0
Epoch 31, Loss: 334639398912.0
Epoch 32, Loss: 336853336064.0
Epoch 33, Lo

In [101]:
model.eval()
num_epochs = 1
for epoch in range(num_epochs):
    total_loss = 0
    for node_features, adj_matrix in train_loader:
        #optimizer.zero_grad()

        adj_matrix_dense = adj_matrix.to_dense() if adj_matrix.is_sparse else adj_matrix

        reconstructed, reconstructed_atoms, mu, logvar = model(node_features, adj_matrix_dense)

        loss = loss_function(reconstructed, adj_matrix_dense, mu, logvar, reconstructed_atoms, node_features)
        total_loss += loss.item()

        #optimizer.step()

    #avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

max_indices = torch.argmax(reconstructed_atoms, dim=-1)
one_hot_atoms = torch.nn.functional.one_hot(max_indices, num_classes=11)
max_indices_2 = torch.argmax(reconstructed, dim=-1)
one_hot_bond = torch.nn.functional.one_hot(max_indices_2, num_classes=5)
print(one_hot_atoms, one_hot_bond)

Epoch 1, Loss: 98040864768.0
tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         ...,
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         ...,
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0]],

        [[0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0]],

        ...,

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         ...,
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0]],

        [[0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,

In [69]:
## Train model
#num_epochs = 200
#losses = []
#valid_losses = []
#for epoch in range(num_epochs):
#    model.train()
#    for batch in train_loader:
#        optimizer.zero_grad()
#        output = model.encode(batch.x, batch.edge_index, batch.edge_attr)
#        pos_edge = create_pos_edge_index(batch)
#        neg_edge = create_neg_edge_index(len(batch.x), pos_edge)
#        loss = model.recon_loss(output, pos_edge, neg_edge)
#        loss = loss + (1 / batch.x.shape[0]) * model.kl_loss()
#        #loss.requires_grad = True
#        loss.backward()
#        optimizer.step()
#        #scheduler.step()
#    #model.eval()
#    #with torch.no_grad():
#    #    valid_loss = 0
#    #    for graph in valid_loader:
#    #        output_valid = model.encode(graph.x, graph.edge_index)
#    #        pos_edge = create_pos_edge_index(batch)
#    #        neg_edge = create_neg_edge_index(batch)
#    #        valid_loss_calc = model.recon_loss(output, pos_edge, neg_edge)
#    #        loss = loss + (1 / graph.x.shape[0]) * model.kl_loss()
#    #        valid_loss += valid_loss_calc.item()
#    #    valid_loss /= len(valid_loader)
#    #    valid_losses.append(valid_loss)

#    if epoch % 10 == 0:
#        losses.append(loss.item())
#    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')



IndexError: index 0 is out of bounds for dimension 0 with size 0