### Author: Helen Cai
### April 2025
The purpose of this notebook is to create a decoder architecture that will be used in our final project.


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Linear
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
import tqdm

# NetworkX is a Python package used to create, manipulate, and mine graphs
import networkx as nx

# further libraries for working with graphs
import torch_geometric
from torch_geometric.nn import GCNConv, pool
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

# For visualization
import phate

# Graph scattering functionality
from LEGS_module import *

# Home-grown functions
from utils import *

## Hyperparameters

In [2]:
batch_size = 1
learning_rate = 1e-4
epochs = 100
p = 0.2     # probability for dropout

## 0. Import data

In [3]:
hidden_values = torch.load("hidden_values_from_insecticides.pt", weights_only=False)
insecticides = torch.load("./data/insecticides_graphs_small.pt", weights_only=False)

# note that logP values are in the y slot of the data


Because the value of D changes for each of the graphs in our dataset, we need some way to account for this. I will allow the decoding MLP to project to a size of fixed dimension, (e.g. the dimension that corresponds to the largest possible graph). Then, in my training dataset, the graphs will be padded with 0's to match the resulting dimension. 

In [4]:
# Function to add padding to target matrices to match the output dimensions
def pad_graphs(input_x, input_edge_index, input_edge_attr, D = 88):
    extra_x = int((D/2) - input_x.shape[0])
    padded_x = F.pad(input_x, (0, 0, 0, extra_x), value=0)

    extra_edge_index = int(D - input_edge_index.shape[1])
    padded_edge_index = F.pad(input_edge_index, (0, extra_edge_index, 0, 0), value=0)

    extra_edge_attr = int(88 - input_edge_attr.shape[0])
    padded_edge_attr = F.pad(input_edge_attr, (0, 0, 0, extra_edge_attr), value=0)

    return padded_x, padded_edge_index, padded_edge_attr


In [5]:
# Add the hidden values to the graphs & add padding
for i in range(len(insecticides)):
    graph = insecticides[i]
    graph.hidden_values = hidden_values[i]
    graph.x, graph.edge_index, graph.edge_attr = pad_graphs(graph.x, graph.edge_index, graph.edge_attr)

In [6]:
# Create dataloaders 
# split into training and test
train_dataset, test_dataset = train_test_split(insecticides, test_size=0.2, random_state=2025)

all_data = DataLoader(insecticides, batch_size=1, shuffle = False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 1. Training a latent space decoder

Now that we have accomplished diffusion in the latent space, we want to be able to convert that latent space representation back to a graph representation. 

For our training process: our inputs are latent space points and the original graphs used to generate them. Refer to the `encoder` notebook for information on how these inputs are generated. 

We note that the graph data has the following (rough) dimensions: `x` $ \in R^{D\times3}$, `edge_index` $ \in R^{2 \times D}$, `edge_attr` $ \in R^{D \times 1}$. Thusly, these are the dimensions that our decoder will work to.


In [7]:
# What's the maximum number of edges we need to predict?
max([data.edge_index.shape[1] for data in insecticides])

88

In [8]:
class DecodeNet(nn.Module):
    """
    Initialize an MLP for re-creating graph representations from latent space.
    Input (x) will be latent space embeddings (tensor 16x1).
    We need to be able to predict x, edge_index, edge_attr.
    The number of nodes and edges will also be dynamically predicted.
    """
    def __init__(
        self,
        num_features: int = 16, # This is the size of the latent space representations.
        D: int = 88, # TODO figure this out
        p: float = 0.0,
    ):
        super().__init__()
        self.D = D

        # Use an MLP to expand dimensions of latent space
        self.mlp = nn.Sequential(
            Linear(num_features, 32),
            nn.ReLU(),
            Linear(32, 64),
            nn.Dropout(p),
            nn.ReLU(),
            Linear(64, 128),
            nn.ReLU()
        )
        
        # one head for predicting features x
        x_dim = int(D * 3 / 2)
        self.x_fc = nn.Sequential(
            Linear(128, 128),
            nn.Dropout(p),
            nn.ReLU(),
            Linear(128, x_dim))

        # one head for predicting edge_index
        self.edge_index_fc = nn.Sequential(
            Linear(128, 128),
            nn.Dropout(p),
            nn.ReLU(),
            Linear(128, 2*D))

        # one head for predicting edge_attr
        self.edge_attr_fc = nn.Sequential(
            Linear(128, 128),
            nn.Dropout(p),
            nn.ReLU(),
            Linear(128, 1*D))


    def forward(self, data):
        
        # Apply the MLP
        h = self.mlp(data)
        
        # Predict x and reshape
        x = self.x_fc(h)
        x = x.view(-1, int(self.D / 2), 3)
        x = torch.squeeze(x, 0)
        

        # Predict edge_index and reshape
        edge_index = self.edge_index_fc(h)
        edge_index = edge_index.view(-1, 2, self.D)
        edge_index = torch.squeeze(edge_index, 0)
        edge_index = edge_index.to(torch.int64)

        # Predict edge_attr and reshape
        edge_attr = self.edge_attr_fc(h)
        edge_attr = edge_attr.view(-1, self.D, 1)
        edge_attr = torch.squeeze(edge_attr, 0)

        return x, edge_index, edge_attr 

We should be thoughtful about how we calculate loss for each of these 3 different things. 

* for x (node features): these are floats, so we can compare them using MSE loss
* for edge_index: these are categorical, so we should use a categorical-type loss such as  cross entropy
* for edge_attr: these are also floats, so we can use MSE loss

In [9]:
def train_decoder_epoch(model, optimizer, train_loader):
    """Train the model for one epoch.
    Args:
        model: the model
        optimizer: the optimizer
        train_loader: contains all information needed for training, including graphs and their latent space representations.
    Returns:
        train_loss: the loss of the epoch
    """
    
    model.train()
    optimizer.zero_grad()
    loss_epoch = 0

    # what loss functions are used for each of our three targets?
    x_criterion = nn.MSELoss()
    edge_index_criterion = nn.MSELoss()  # TODO: think about cross entropy loss here?
    edge_attr_criterion = nn.MSELoss()
    
    # evaluate on the train nodes
    for data in train_loader:
        target_x = data.x
        target_edge_index = data.edge_index
        target_edge_attr = data.edge_attr
        
        batch_size = data.num_graphs

        # get the outputs
        x, edge_index, edge_attr  = model(data.hidden_values)

        # calculate loss for each of the three outputs
        x_loss = x_criterion(x, target_x)
        edge_index_loss = edge_index_criterion(edge_index.to(torch.float), target_edge_index.to(torch.float))  ## something strange w data types here
        edge_attr_loss = edge_attr_criterion(edge_attr, target_edge_attr)
        
        # aggregate the total loss
        loss = x_loss + edge_index_loss + edge_attr_loss

        loss.backward()
        loss_epoch += loss.detach().numpy() * batch_size

        optimizer.step()

    # calculate training loss for the epoch
    loss_epoch = loss_epoch / len(train_loader.dataset)

    return loss_epoch 
    

In [10]:
def test_decoder_epoch(model, train_loader):
    """Test the model for one epoch.
    Args:
        model: the model
        train_loader: contains all information needed for training, including graphs and their latent space representations.
    Returns:
        train_loss: the loss of the epoch
    """
    model.eval()  # set model to evaluation mode

    loss_epoch = 0
    
    # what loss functions are used for each of our three targets?
    x_criterion = nn.MSELoss()
    edge_index_criterion = nn.MSELoss()  # TODO: think about cross entropy loss here?
    edge_attr_criterion = nn.MSELoss()
    
    with torch.no_grad():  # disable gradient calculation

        for data in train_loader:
            target_x = data.x
            target_edge_index = data.edge_index
            target_edge_attr = data.edge_attr
            
            batch_size = data.num_graphs
    
            # get the outputs
            x, edge_index, edge_attr  = model(data.hidden_values)
    
            # calculate loss for each of the three outputs
            x_loss = x_criterion(x, target_x)
            edge_index_loss = edge_index_criterion(edge_index.to(torch.float), target_edge_index.to(torch.float))  ## something strange w data types here
            edge_attr_loss = edge_attr_criterion(edge_attr, target_edge_attr)
            
            # aggregate the total loss
            loss = x_loss + edge_index_loss + edge_attr_loss
            
            loss_epoch += loss.detach().numpy() * batch_size

        # calculate test loss for the epoch
        loss_epoch = loss_epoch / len(test_loader.dataset)
                
    return loss_epoch

In [11]:
def train_decoder(model, train_loader, test_loader, optimizer, epochs=5):
    """Train the model.
    Args:
        model: the model
        loss_fn: the loss function
        train_loader: the training data loader
        test_loader: the testing data loader
        optimizer: the optimizer
        epochs: the number of epochs to train
    Returns:
        train_losses: the training losses
        test_losses: the testing losses
    """
    train_losses = []
    
    test_losses = []
    
    loop = tqdm.tqdm(range(1, epochs + 1))

    for epoch in loop:

        # train the model for one epoch
        train_loss_epoch = train_decoder_epoch(model, optimizer, train_loader)
        
        # test the model for one epoch        
        test_loss_epoch = test_decoder_epoch(model, train_loader)

        # put into our storage vectors
        train_losses.append(train_loss_epoch)
        test_losses.append(test_loss_epoch)
        
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(train_loss=train_loss_epoch, test_loss=test_loss_epoch)
    
    return train_losses, test_losses

In [None]:
decoder = DecodeNet(p = p) 

# Define the optimizer
optimizer = torch.optim.RMSprop(decoder.parameters(), lr = learning_rate)


# Call the functions
decoder_train_losses, decoder_test_losses  = train_decoder(decoder, train_loader, test_loader, optimizer, epochs=epochs)



Epoch 5:   5%|▌         | 5/100 [01:29<25:21, 16.01s/it, test_loss=353, train_loss=88.7]

In [None]:
plot_metrics(decoder_train_losses, decoder_test_losses, 
             xlabel="Epoch", ylabel="Loss", title="Decoder training loss", 
             fname="Loss.png", subdir="./training-figs/decoder/")

## 2. Generation: Going from latent space into graph space

Call the latent space decoder.

In [None]:
torch.save(decoder.state_dict(), "decoder-trained.pth")

In [None]:
decoder = DecodeNet()
decoder.load_state_dict(torch.load("decoder-trained.pth"))

generated_x, generated_edge_index, generated_edge_attr = decoder(insecticides[2].hidden_values)  # TODO: replace this with diffused hidden values

## 3. Export the graphs for evaluation 

This is where we pass things back to Tobias to work on for evaluation. 