### Authors: Jennifer Gao, Helen Cai
### April 2025
The purpose of this notebook is to create a diffusion/decoder architecture that will be used in our final project.


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Linear
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
import tqdm

# NetworkX is a Python package used to create, manipulate, and mine graphs
import networkx as nx

# further libraries for working with graphs
import torch_geometric
from torch_geometric.nn import GCNConv, pool
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

# For visualization
import phate

## Hyperparameters

In [25]:
batch_size = 4

## 0. Import data

In [15]:
hidden_values = torch.load("hidden_values_from_insecticides.pt", weights_only=False)
insecticides = torch.load("./data/insecticides_graphs_small.pt", weights_only=False)

# note that logP values are in the y slot of the data

# Add the hidden values to the graphs
for i in range(len(insecticides)):
    graph = insecticides[i]
    graph.hidden_values = hidden_values[i]


In [26]:
# Create dataloaders 
# split into training and test
train_dataset, test_dataset = train_test_split(insecticides, test_size=0.2, random_state=2025)

all_data = DataLoader(insecticides, batch_size=1, shuffle = False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 1. Training a conditional diffusion model

Recall that the aim of our project is to sample from a well-constrained latent space that retains certain properties: e.g., molecules that retain a similar logP cluster together. We can take advantage of this well-organized space in our sampling process. 

For instance, in our sampling process, we want to specify a particular logP for our generated molecules to have, and be able to generate new diffused molecules from that region of the latent space. 

For our training process: our inputs are latent space points and logP values. The goal is to get new latent space points (e.g. 16x1 tensors) conditioned on a logP value (what might be traditionally thought of as a "class label."

## 2. Training a latent space decoder

Now that we have accomplished diffusion in the latent space, we want to be able to convert that latent space representation back to a graph representation. 

For our training process: our inputs are latent space points and the original graphs used to generate them. Refer to the `encoder` notebook for information on how these inputs are generated. 

We note that the graph data has the following (rough) dimensions: `x` $ \in R^{D\times3}$, `edge_index` $ \in R^{2 \times D}$, `edge_attr` $ \in R^{D \times 1}$. Thusly, these are the dimensions that our decoder will work to.

In [43]:
# What's the maximum number of edges we need to predict?
max([data.edge_index.shape[1] for data in insecticides])

88

In [66]:
class DecodeNet(nn.Module):
    """
    Initialize an MLP for re-creating graph representations from latent space.
    Input (x) will be latent space embeddings (tensor 16x1).
    We need to be able to predict x, edge_index, edge_attr.
    The number of nodes and edges will also be dynamically predicted.
    """
    def __init__(
        self,
        num_features: int = 16, # This is the size of the latent space representations.
        D: int = 88, # TODO figure this out
        p: float = 0.0,
    ):
        super().__init__()
        self.D = D

        # Use an MLP to expand dimensions of latent space
        self.mlp = nn.Sequential(
            Linear(num_features, 32),
            nn.ReLU(),
            Linear(32, 64),
            nn.ReLU(),
            Linear(64, 128),
            nn.ReLU()
        )
        
        # one head for predicting features x
        x_dim = int(D * 3 / 2)
        self.x_fc = Linear(128, x_dim)

        # one head for predicting edge_index
        self.edge_index_fc = Linear(128, 2*D)

        # one head for predicting edge_attr
        self.edge_attr_fc = Linear(128, 1*D)


    def forward(self, data):
        
        # Apply the MLP
        h = self.mlp(data)

        print(h.shape)
        # Predict x and reshape
        x = self.x_fc(h)
        x = x.view(-1, int(self.D / 2), 3)

        # Predict edge_index and reshape
        edge_index = self.edge_index_fc(h)
        edge_index = edge_index.view(-1, 2, self.D)

        # Predict edge_attr and reshape
        edge_attr = self.edge_attr_fc(h)
        edge_attr = edge_attr.view(-1, self.D, 1)

        return x, edge_index, edge_attr 

In [68]:
decoder = DecodeNet()

x, edge_index, edge_attr = decoder(hidden_values[2])

torch.Size([128])


## 3. Generation: conditional sampling in the latent space

We want to be able to walk around the latent space to get new molecules with a certain property, e.g. with logP < ***.

## 4. Generation: Going from latent space into graph space

Call the latent space decoder.

## 5. Export the graphs for evaluation 

This is where we pass things back to Tobias to work on for evaluation. 