In [99]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from torch_geometric.nn import global_add_pool, GATConv, CGConv, VGAE,GCNConv
import pickle
from itertools import combinations
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader
from torch.utils.data import Dataset, DataLoader


In [32]:
input_dim = 10
hidden_dim = 5
num_heads = 3

In [66]:
#class VAEGeneratorEncoder(nn.Module):
#    def __init__(self, input_dim, hidden_dim, num_heads):
#        super().__init__()
#        self.conv_gat = GATConv(input_dim, hidden_dim, heads=num_heads)
#        self.conv1 = GCNConv(hidden_dim * num_heads, 2 * hidden_dim)
#        self.conv_mu = GCNConv(2 * hidden_dim, hidden_dim)
#        self.conv_logstd = GCNConv(2 * hidden_dim, hidden_dim)
#        self.relu = nn.ReLU()

#    def forward(self, x, edge_index, edge_attr):
#        x = x.float()
#        x = self.relu(self.conv_gat(x, edge_index, edge_attr))
#        x = self.relu(self.conv1(x, edge_index))
#        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

#model = VGAE(VAEGeneratorEncoder(input_dim, hidden_dim, num_heads))


In [176]:
input_dim = 11
hidden_dim_1 = 20
hidden_dim_2 = 10
num_bond_types = 5
num_atomic_features = 11

class VGAE(nn.Module):
    def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, num_bond_types, num_atomic_features):
        super(VGAE,self).__init__()
        self.input_dim = input_dim
        self.hidden_dim_1 = hidden_dim_1
        self.hidden_dim_2 = hidden_dim_2
        self.num_bond_types = num_bond_types
        self.base_gcn = GraphConvSparse(input_dim, hidden_dim_1, num_bond_types)
        self.gcn_mu = GraphConvSparse(hidden_dim_1, hidden_dim_2, num_bond_types, activation=lambda x:x)
        self.gcn_logstd = GraphConvSparse(hidden_dim_1, hidden_dim_2, num_bond_types, activation=lambda x:x)
        self.atomic_pred = nn.Linear(hidden_dim_2, num_atomic_features)
        self.decoder = nn.Linear(hidden_dim_2, 63 * num_bond_types)
        self.atomic_decoder = nn.Linear(hidden_dim_2, num_atomic_features)


    def reparameterize(self, mu, logstd):
        if self.training:
            std = torch.exp(logstd)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def encode(self, x, adj):
        hidden = self.base_gcn(x, adj)
        self.mean = self.gcn_mu(hidden, adj)
        self.logstd = self.gcn_logstd(hidden, adj)
        return self.mean, self.logstd

    def forward(self, x, adj):
        mu, logstd = self.encode(x, adj)
        z = self.reparameterize(mu, logstd)
        A_pred = self.decode_adj(z, x.size(1), self.num_bond_types)
        atomic_numbers_pred = self.atomic_decoder(z)
        atomic_numbers_pred = torch.sigmoid(atomic_numbers_pred)
        return A_pred, atomic_numbers_pred, mu, logstd

    def decode_adj(self, Z, num_nodes, num_bond_types):
        batch_size = Z.size(0)
        logits = self.decoder(Z)
        A_pred = logits.view(batch_size, num_nodes, num_nodes, num_bond_types)
        A_pred = torch.sigmoid(A_pred)
        return A_pred

class GraphConvSparse(nn.Module):
    def __init__(self, input_dim, output_dim, num_bond_types, activation = F.relu, **kwargs):
        super(GraphConvSparse, self).__init__(**kwargs)
        self.input_dim = input_dim
        self.weights = nn.Parameter(torch.randn(num_bond_types, input_dim, output_dim))
        self.num_bond_types = num_bond_types
        self.activation = activation

    def forward(self, x, adj):
        batch_size, num_nodes, _ = x.size()
        _, _, _, num_relations = adj.size()
        outputs = torch.zeros(batch_size, num_nodes, self.weights[0].size(1), device=x.device)
        for r in range(num_relations):

            weight = self.weights[r]

            support = torch.bmm(x, weight.unsqueeze(0).expand(batch_size, *weight.size()))

            adj_relation = adj[:, :, :, r]

            output = torch.bmm(adj_relation, support)
            outputs += output
        outputs = self.activation(outputs)
        return outputs




model = VGAE(input_dim, hidden_dim_1, hidden_dim_2, num_bond_types, num_atomic_features)

In [174]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)


In [85]:
with open("../data/combined_data.pkl", "rb") as file:
    combined_data = pickle.load(file)

atomic_num_tensors = torch.load("../data/atomic_number_tensors.pt")
bond_tensors = torch.load("../data/bond_tensors.pt")

In [10]:
print(combined_data[0])
print(len(combined_data))

Data(x=[8, 10], edge_index=[2, 7], edge_attr=[7, 5], atom_data=[8, 35])
3208


In [86]:
print(atomic_num_tensors.shape)

torch.Size([3208, 63, 11])


In [87]:
print(bond_tensors.shape)

torch.Size([3208, 63, 63, 5])


In [100]:
class GraphDataset(Dataset):
    def __init__(self, atomic_numbers, bond_tensor):
        self.atomic_numbers = atomic_numbers
        self.bond_tensor = bond_tensor

    def __len__(self):
        return self.atomic_numbers.shape[0]

    def __getitem__(self, idx):
        node_features = self.atomic_numbers[idx]
        adj_matrix = self.bond_tensor[idx]
        return node_features, adj_matrix

In [101]:
dataset = GraphDataset(atomic_num_tensors, bond_tensors)

In [102]:
train_data, test_data = train_test_split(dataset, test_size=0.15, random_state=48)

batch_size = 32

train_data, valid_data = train_test_split(train_data, test_size=0.1, random_state=48)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size)


In [94]:
def preprocess_bonds(bond_tensor):
    ''' bond_tensor is of shape [num_nodes, num_nodes, 5]
    We split it to create a list of 5 [num_nodes, num_nodes] tensors'''
    return [bond_tensor[:, :, i] for i in range(bond_tensor.shape[2])]

In [47]:
def create_pos_edge_index(data):
    return data.edge_index

def create_neg_edge_index(num_nodes, pos_edge_index):
    all_possible_edges = list(combinations(range(num_nodes), 2))

    pos_edges = [(src.item(), dst.item()) for src, dst in pos_edge_index.t()]

    neg_edges = [edge for edge in all_possible_edges if edge not in pos_edges]

    neg_edge_index = torch.tensor(neg_edges).t()

    return neg_edge_index

In [180]:
def calculate_kl_divergence(mu, logvar):
    logvar = torch.clamp(logvar, min=-10, max=10)

    kl_divergence = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
    term1 = 1 + logvar
    term2 = mu.pow(2)
    term3 = logvar.exp()
    kl_divergence = -0.5 * torch.sum(term1 - term2 - term3)

    return kl_divergence

In [181]:
def loss_function(preds, labels, mu, logvar, preds_atomic, atomic_numbers):
    labels = torch.argmax(labels, dim=-1)
    BCE = torch.nn.functional.cross_entropy(preds.view(-1, preds.size(-1)), labels.view(-1), reduction='sum')
    KLD = calculate_kl_divergence(mu, logvar)
    MSE_atomic = F.mse_loss(preds_atomic, atomic_numbers, reduction='sum')
    return BCE + KLD + MSE_atomic

In [182]:
num_epochs = 50
for epoch in range(num_epochs):
    total_loss = 0
    for node_features, adj_matrix in train_loader:
        optimizer.zero_grad()

        # Ensure adj_matrix is in the correct format
        adj_matrix_dense = adj_matrix.to_dense() if adj_matrix.is_sparse else adj_matrix

        # Forward pass: compute the reconstructed matrix, mu, and logvar
        reconstructed, reconstructed_atoms, mu, logvar = model(node_features, adj_matrix_dense)

        # Compute the loss
        loss = loss_function(reconstructed, adj_matrix_dense, mu, logvar, reconstructed_atoms, node_features)
        total_loss += loss.item()

        # Backward pass: compute gradient and update params
        loss.backward()
        optimizer.step()

    # Calculate average loss for reporting
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

Epoch 1, Loss: 9586603.0
Epoch 2, Loss: 11914190.0
Epoch 3, Loss: 9378263.0
Epoch 4, Loss: 8511828.0
Epoch 5, Loss: 7259391.5
Epoch 6, Loss: 8440001.0
Epoch 7, Loss: 9503491.0
Epoch 8, Loss: 9329538.0
Epoch 9, Loss: 9527532.0
Epoch 10, Loss: 9338478.0
Epoch 11, Loss: 9302072.0
Epoch 12, Loss: 10026549.0
Epoch 13, Loss: 7777640.5
Epoch 14, Loss: 10640034.0
Epoch 15, Loss: 7919554.5
Epoch 16, Loss: 9311752.0
Epoch 17, Loss: 11546408.0
Epoch 18, Loss: 10369405.0
Epoch 19, Loss: 12746552.0
Epoch 20, Loss: 7528205.0
Epoch 21, Loss: 10939868.0
Epoch 22, Loss: 8023830.0
Epoch 23, Loss: 9061849.0
Epoch 24, Loss: 11810005.0
Epoch 25, Loss: 8558274.0
Epoch 26, Loss: 8935355.0
Epoch 27, Loss: 8115506.5
Epoch 28, Loss: 10183830.0
Epoch 29, Loss: 10590085.0
Epoch 30, Loss: 9525789.0
Epoch 31, Loss: 10794018.0
Epoch 32, Loss: 11314034.0
Epoch 33, Loss: 8560840.0
Epoch 34, Loss: 10057752.0
Epoch 35, Loss: 9003507.0
Epoch 36, Loss: 8306642.5
Epoch 37, Loss: 10779516.0
Epoch 38, Loss: 7522270.0
Epoch 3

In [192]:
model.eval()
num_epochs = 1
for epoch in range(num_epochs):
    total_loss = 0
    for node_features, adj_matrix in train_loader:
        #optimizer.zero_grad()

        adj_matrix_dense = adj_matrix.to_dense() if adj_matrix.is_sparse else adj_matrix

        reconstructed, reconstructed_atoms, mu, logvar = model(node_features, adj_matrix_dense)

        loss = loss_function(reconstructed, adj_matrix_dense, mu, logvar, reconstructed_atoms, node_features)
        total_loss += loss.item()

        #optimizer.step()

    #avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
print(torch.argmax(reconstructed, dim=-1), torch.argmax(reconstructed_atoms, dim=-1))

Epoch 1, Loss: 9867673.0
tensor([[[4, 4, 3,  ..., 2, 2, 1],
         [0, 4, 3,  ..., 0, 2, 1],
         [0, 4, 3,  ..., 0, 2, 1],
         ...,
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4]],

        [[0, 4, 3,  ..., 0, 2, 4],
         [0, 4, 3,  ..., 0, 2, 4],
         [0, 4, 3,  ..., 0, 2, 1],
         ...,
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4]],

        [[0, 4, 4,  ..., 0, 3, 0],
         [0, 4, 3,  ..., 0, 2, 4],
         [4, 4, 3,  ..., 0, 2, 4],
         ...,
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4]],

        ...,

        [[3, 2, 1,  ..., 1, 2, 3],
         [3, 2, 1,  ..., 1, 2, 3],
         [0, 4, 3,  ..., 0, 2, 4],
         ...,
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4],
         [0, 4, 3,  ..., 1, 0, 4]],

        [[4, 4, 3,  ..., 0, 2, 3],
         [0, 0, 1,  ..

In [69]:
## Train model
#num_epochs = 200
#losses = []
#valid_losses = []
#for epoch in range(num_epochs):
#    model.train()
#    for batch in train_loader:
#        optimizer.zero_grad()
#        output = model.encode(batch.x, batch.edge_index, batch.edge_attr)
#        pos_edge = create_pos_edge_index(batch)
#        neg_edge = create_neg_edge_index(len(batch.x), pos_edge)
#        loss = model.recon_loss(output, pos_edge, neg_edge)
#        loss = loss + (1 / batch.x.shape[0]) * model.kl_loss()
#        #loss.requires_grad = True
#        loss.backward()
#        optimizer.step()
#        #scheduler.step()
#    #model.eval()
#    #with torch.no_grad():
#    #    valid_loss = 0
#    #    for graph in valid_loader:
#    #        output_valid = model.encode(graph.x, graph.edge_index)
#    #        pos_edge = create_pos_edge_index(batch)
#    #        neg_edge = create_neg_edge_index(batch)
#    #        valid_loss_calc = model.recon_loss(output, pos_edge, neg_edge)
#    #        loss = loss + (1 / graph.x.shape[0]) * model.kl_loss()
#    #        valid_loss += valid_loss_calc.item()
#    #    valid_loss /= len(valid_loader)
#    #    valid_losses.append(valid_loss)

#    if epoch % 10 == 0:
#        losses.append(loss.item())
#    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')



IndexError: index 0 is out of bounds for dimension 0 with size 0