In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3


In [8]:
# Install required packages.
import os
import torch
from torch import nn
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
#!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt


def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()


def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()

import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool
from torch_geometric.utils import degree

2.3.0+cu121


Load the Data

In [3]:
!pip install rdkit
import pandas as pd
from rdkit import Chem

def load_qm9_smiles(csv_file):
    # Read the CSV file containing the QM9 dataset
    df = pd.read_csv(csv_file)

    # Extract SMILES strings
    smiles_list = df['smiles'].tolist()

    return smiles_list

# Example usage
csv_file = "qm9.csv"  # Replace with the path to your QM9 CSV file
qm9_smiles = load_qm9_smiles(csv_file)

print("Number of SMILES in QM9 dataset:", len(qm9_smiles))
print("Example SMILES:", qm9_smiles[1])

Collecting rdkit
  Downloading rdkit-2023.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.9.6
Number of SMILES in QM9 dataset: 84780
Example SMILES: N


In [4]:
def remove_hydrogen_from_smiles(smiles_list):
    modified_smiles = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print("Invalid SMILES:", smiles)
            continue
        mol = Chem.RemoveHs(mol)
        modified_smiles.append(Chem.MolToSmiles(mol))
    return modified_smiles

modified_smiles = remove_hydrogen_from_smiles(qm9_smiles)

In [5]:
def smiles_to_graph(smiles):
    # Parse the SMILES string
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None, None

    # Get node features (atomic numbers)
    atomic_numbers = [atom.GetAtomicNum() for atom in mol.GetAtoms()]

    # Get edge indices (connectivity)
    edge_index = []
    for bond in mol.GetBonds():
        start_idx = bond.GetBeginAtomIdx()
        end_idx = bond.GetEndAtomIdx()
        edge_index.append([start_idx, end_idx])

    # Convert edge indices to PyTorch tensor
    edge_index = torch.tensor(edge_index).t().contiguous()

    # Convert node features to PyTorch tensor
    node_features = torch.tensor(atomic_numbers, dtype=torch.float).unsqueeze(1)

    return node_features, edge_index

In [6]:
import torch
from torch_geometric.data import Data
filtered_dataset = []

# Define encoding mappings
encoding_mappings = {
    7: [0, 0, 1, 0, 0],
    8: [0, 0, 0, 1, 0],
    6: [0, 1, 0, 0, 0],
    9: [0, 0, 0, 0, 1]
}

# Iterate over modified SMILES
for smile in modified_smiles:
    try:
        # Convert SMILES to graph representation
        node_features, edge_index1 = smiles_to_graph(smile)

        # Check if the graph has more than one node
        num_nodes = node_features.shape[0]
        if num_nodes > 1:
            # Convert node features to one-hot encoding
            one_hot_encoded = torch.tensor([encoding_mappings[num.item()] for num in node_features], dtype=torch.float32)

            # Create Data object and add it to the filtered dataset
            graph = Data(x=one_hot_encoded, edge_index=edge_index1, num_nodes=num_nodes)
            filtered_dataset.append(graph)
    except Exception as e:
        print(f"Error processing SMILES: {smile}. {e}")


Defining the forward heat equation

In [29]:
"""Taken from https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py
Some modifications have been made to work with newer versions of Pytorch"""

import numpy as np
import torch
import torch.nn as nn


def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    #Vc = torch.fft.rfft(v, 1)
    Vc = torch.view_as_real(torch.fft.fft(v, dim=1))

    k = - torch.arange(N, dtype=x.dtype,
                       device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V


def idct(X, norm=None):
    """
    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
    Our definition of idct is that idct(dct(x)) == x
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param X: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the inverse DCT-II of the signal over the last dimension
    """

    x_shape = X.shape
    N = x_shape[-1]

    X_v = X.contiguous().view(-1, x_shape[-1]) / 2

    if norm == 'ortho':
        X_v[:, 0] *= np.sqrt(N) * 2
        X_v[:, 1:] *= np.sqrt(N / 2) * 2

    k = torch.arange(x_shape[-1], dtype=X.dtype,
                     device=X.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V_t_r = X_v
    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

    V_r = V_t_r * W_r - V_t_i * W_i
    V_i = V_t_r * W_i + V_t_i * W_r

    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)

    #v = torch.fft.irfft(V, 1)
    v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
    x = v.new_zeros(v.shape)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip([1])[:, :N // 2]

    return x.view(*x_shape)


def dct_2d(x, norm=None):
    """
    2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last 2 dimensions
    """
    X1 = dct(x, norm=norm)
    X2 = dct(X1.transpose(-1, -2), norm=norm)
    return X2.transpose(-1, -2)


def idct_2d(X, norm=None):
    """
    The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
    Our definition of idct is that idct_2d(dct_2d(x)) == x
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param X: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last 2 dimensions
    """
    x1 = idct(X, norm=norm)
    x2 = idct(x1.transpose(-1, -2), norm=norm)
    return x2.transpose(-1, -2)


def dct_3d(x, norm=None):
    """
    3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last 3 dimensions
    """
    X1 = dct(x, norm=norm)
    X2 = dct(X1.transpose(-1, -2), norm=norm)
    X3 = dct(X2.transpose(-1, -3), norm=norm)
    return X3.transpose(-1, -3).transpose(-1, -2)


def idct_3d(X, norm=None):
    """
    The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
    Our definition of idct is that idct_3d(dct_3d(x)) == x
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param X: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last 3 dimensions
    """
    x1 = idct(X, norm=norm)
    x2 = idct(x1.transpose(-1, -2), norm=norm)
    x3 = idct(x2.transpose(-1, -3), norm=norm)
    return x3.transpose(-1, -3).transpose(-1, -2)

class DCTBlur1D(nn.Module):

    def __init__(self, blur_sigmas, device):
        super().__init__()
        self.blur_sigmas = torch.tensor(blur_sigmas).to(device)


    def forward(self, x,img, fwd_steps):
        freqs = np.pi*torch.linspace(0, img-1, img)/img
        frequencies_squared = freqs[None, :]**2
        if len(x.shape) == 4:
            sigmas = self.blur_sigmas[fwd_steps][:, None, None, None]
        elif len(x.shape) == 3:
            sigmas = self.blur_sigmas[fwd_steps][:, None, None]
        elif len(x.shape) == 2:
            sigmas = self.blur_sigmas[fwd_steps][:, None]
        t = sigmas**2/2
        dct_coefs = dct(x, norm='ortho')
        dct_coefs = dct_coefs * torch.exp(- frequencies_squared * t)
        return idct(dct_coefs, norm='ortho')



GNN model helpers

In [11]:
from torch import nn
import torch
import math

def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str):
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
        Normalization: 'sum' or 'mean'.
    """
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    if aggregation_method == 'sum':
        result = result / normalization_factor

    if aggregation_method == 'mean':
        norm = data.new_zeros(result.shape)
        norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
        norm[norm == 0] = 1
        result = result / norm
    return result

class GCL(nn.Module):
    def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False):
        super(GCL, self).__init__()
        input_edge = input_nf * 2
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method
        self.attention = attention

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, edge_attr, edge_mask):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target], dim=1)
        else:
            out = torch.cat([source, target, edge_attr], dim=1)

        mij = self.edge_mlp(out)

        if self.attention:
            att_val = self.att_mlp(mij)
            out = mij * att_val
        else:
            out = mij

        if edge_mask is not None:
            out = out * edge_mask
        return out, mij

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = x + self.node_mlp(agg)
        return out, agg

    def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
        row, col = edge_index
        edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        if node_mask is not None:
            h = h * node_mask
        return h, mij


In [12]:
class GNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, out_node_nf,aggregation_method='sum', device='cpu',
                 act_fn=nn.SiLU(), n_layers=4, attention=False,
                 normalization_factor=100, ):
        super(GNN, self).__init__()
        if out_node_nf is None:
            out_node_nf = in_node_nf
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        ### Encoder
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, GCL(
                self.hidden_nf, self.hidden_nf, self.hidden_nf,
                normalization_factor=normalization_factor,
                aggregation_method=aggregation_method,
                edges_in_d=in_edge_nf, act_fn=act_fn,
                attention=attention))
        self.to(self.device)

    def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None):
        # Edit Emiel: Remove velocity as input
        h = self.embedding(h)
        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
        h = self.embedding_out(h)

        # Important, the bias of the last linear might be non-zero
        if node_mask is not None:
            h = h * node_mask
        return h


In [17]:

def coord2diff(x, edge_index, norm_constant=1):
    row, col = edge_index
    coord_diff = x[row] - x[col]
    radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)
    norm = torch.sqrt(radial + 1e-8)
    coord_diff = coord_diff/(norm + norm_constant)
    return radial, coord_diff
class GCL(nn.Module):
    def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False):
        super(GCL, self).__init__()
        input_edge = input_nf * 2
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method
        self.attention = attention

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, edge_attr, edge_mask):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target], dim=1)
        else:
            out = torch.cat([source, target, edge_attr], dim=1)
        mij = self.edge_mlp(out)

        if self.attention:
            att_val = self.att_mlp(mij)
            out = mij * att_val
        else:
            out = mij

        if edge_mask is not None:
            out = out * edge_mask
        return out, mij

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = x + self.node_mlp(agg)
        return out, agg

    def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
        row, col = edge_index
        edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        if node_mask is not None:
            h = h * node_mask
        return h, mij


class EquivariantUpdate(nn.Module):
    def __init__(self, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=1, act_fn=nn.SiLU(), tanh=False, coords_range=10.0):
        super(EquivariantUpdate, self).__init__()
        self.tanh = tanh
        self.coords_range = coords_range
        input_edge = hidden_nf * 2 + edges_in_d
        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
        self.coord_mlp = nn.Sequential(
            nn.Linear(input_edge, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn,
            layer)
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method

    def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask):
        row, col = edge_index
        input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1)
        if self.tanh:
            trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range
        else:
            trans = coord_diff * self.coord_mlp(input_tensor)
        if edge_mask is not None:
            trans = trans * edge_mask
        agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        coord = coord + agg
        return coord

    def forward(self, h, coord, edge_index, coord_diff, edge_attr=None, node_mask=None, edge_mask=None):
        coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask)
        if node_mask is not None:
            coord = coord * node_mask
        return coord


class EquivariantBlock(nn.Module):
    def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', act_fn=nn.SiLU(), n_layers=2, attention=True,
                 norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None,
                 normalization_factor=100, aggregation_method='sum'):
        super(EquivariantBlock, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.coords_range_layer = float(coords_range)
        self.norm_diff = norm_diff
        self.norm_constant = norm_constant
        self.sin_embedding = sin_embedding
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method

        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf,
                                              act_fn=act_fn, attention=attention,
                                              normalization_factor=self.normalization_factor,
                                              aggregation_method=self.aggregation_method))
        self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, act_fn=nn.SiLU(), tanh=tanh,
                                                       coords_range=self.coords_range_layer,
                                                       normalization_factor=self.normalization_factor,
                                                       aggregation_method=self.aggregation_method))
        self.to(self.device)

    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None, edge_attr=None):
        # Edit Emiel: Remove velocity as input
        distances, coord_diff = coord2diff(x, edge_index, self.norm_constant)
        if self.sin_embedding is not None:
            distances = self.sin_embedding(distances)
        edge_attr = torch.cat([distances, edge_attr], dim=1)
        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
        x = self._modules["gcl_equiv"](h, x, edge_index, coord_diff, edge_attr, node_mask, edge_mask)

        # Important, the bias of the last linear might be non-zero
        if node_mask is not None:
            h = h * node_mask
        return h, x


class EGNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=3, attention=False,
                 norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2,
                 sin_embedding=False, normalization_factor=100, aggregation_method='sum'):
        super(EGNN, self).__init__()
        if out_node_nf is None:
            out_node_nf = in_node_nf
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.coords_range_layer = float(coords_range/n_layers) if n_layers > 0 else float(coords_range)
        self.norm_diff = norm_diff
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method

        if sin_embedding:
            self.sin_embedding = SinusoidsEmbeddingNew()
            edge_feat_nf = self.sin_embedding.dim * 2
        else:
            self.sin_embedding = None
            edge_feat_nf = 2

        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
                                                               act_fn=act_fn, n_layers=inv_sublayers,
                                                               attention=attention, norm_diff=norm_diff, tanh=tanh,
                                                               coords_range=coords_range, norm_constant=norm_constant,
                                                               sin_embedding=self.sin_embedding,
                                                               normalization_factor=self.normalization_factor,
                                                               aggregation_method=self.aggregation_method))
        self.to(self.device)

    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None):
        # Edit Emiel: Remove velocity as input
        distances, _ = coord2diff(x, edge_index)
        if self.sin_embedding is not None:
            distances = self.sin_embedding(distances)
        h = self.embedding(h)
        for i in range(0, self.n_layers):
            h, x = self._modules["e_block_%d" % i](h, x, edge_index, node_mask=node_mask, edge_mask=edge_mask, edge_attr=distances)

        # Important, the bias of the last linear might be non-zero
        h = self.embedding_out(h)
        if node_mask is not None:
            h = h * node_mask
        return h, x

def fully_connected_graph_with_self_loops(num_nodes):
    """
    Generates edge indices for a fully connected graph with self-loops.

    Args:
        num_nodes (int): Number of nodes in the graph.

    Returns:
        torch.Tensor: Edge indices of the fully connected graph with self-loops.
    """
    # Create edge indices for a fully connected graph with self-loops
    edge_index = torch.tensor([[i, j] for i in range(num_nodes) for j in range(num_nodes)])

    return edge_index.t().contiguous()


def edge_index_to_adj(edge_index, num_nodes):
    # Create an empty adjacency matrix
    adj = torch.zeros(num_nodes, num_nodes)

    # Fill the adjacency matrix using the edge indices
    adj[edge_index[0], edge_index[1]] = 1
    adj[edge_index[1], edge_index[0]] = 1  # For undirected graphs, if edge (i,j) exists, edge (j,i) also exists

    return adj



Autoencoder model


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

from torch.nn import Linear
from torch_geometric.nn import PNAConv

class GraphAutoencoder(nn.Module):
    def __init__(self):
        super(GraphAutoencoder, self).__init__()

        # Define the parameters for the autoencoder

        in_edge_nf = 0  # Replace with the actual input edge feature dimension
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        act_fn = torch.nn.SiLU()
        n_layers = 4  # Replace with the desired number of layers

        # Define the encoder and decoder parts
        self.conv1 = GNN(
            in_node_nf=5,
            in_edge_nf=in_edge_nf,
            hidden_nf=8,
            out_node_nf=3,
            device=device,
            act_fn=act_fn,
            n_layers=n_layers
        )


        self.conv2 = GNN(
            in_node_nf=3,
            in_edge_nf=in_edge_nf,
            hidden_nf=5,
            out_node_nf=1,
            device=device,
            act_fn=act_fn,
            n_layers=n_layers
        )




        self.dconv1 = GNN(
            in_node_nf=1,
            in_edge_nf=in_edge_nf,
            hidden_nf=5,
            out_node_nf=3,
            device=device,
            act_fn=act_fn,
            n_layers=n_layers
        )


        self.dconv2 = GNN(
            in_node_nf=3,
            in_edge_nf=in_edge_nf,
            hidden_nf=8,
            out_node_nf=5,
            device=device,
            act_fn=act_fn,
            n_layers=n_layers
        )


        self.edge_classifier = nn.Sequential(
            nn.Linear(10, 32),  # Adjust the input and hidden layers as needed
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()  # Output probability between 0 and 1
        )


        # Define the edge classifier MLP


    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_index):
        z = self.dconv1(z, edge_index)
        z = self.dconv2(z, edge_index)
        return z

    def edge_prediction(self, recon, num_rows2):
         edge=torch.empty(0)
         for i in range(num_rows2):
            for j in range(i + 1, num_rows2):  # Start from i+1 to avoid duplicate pairs
                # Select the pair of rows from recon tensor
                pair_features = recon[[i, j], :]

                # Concatenate the pair features along dimension 1
                con_tensor = torch.cat([pair_features[0], pair_features[1]], dim=0)


                # Call edge_prediction method for the pair of rows
                edge_prob = self.edge_classifier(con_tensor)
                edge= torch.cat((edge, edge_prob), dim=0)
         return edge

# Define other necessary components and hyperparameters
epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model= GraphAutoencoder()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
criterion = nn.MSELoss()

Training the latent space autoencoder

In [None]:
data1=filtered_dataset[:1000]
for epoch in range(10):

    total_loss = 0

    print(epoch)
    for data in data1:








        edge_index_1 = data.edge_index.to(device)

        node_features_1 = data.x.to(device)[:,:5]


        num_nodes_1 = node_features_1.size(0)




        adjacency_matrix_1 = edge_index_to_adj(edge_index_1, num_nodes_1)

        num_vectors_1 = node_features_1.size(0)
        num_upper_triangle_terms_1 = int((num_vectors_1 * (num_vectors_1 - 1)) / 2)

        pairwise_distances_1 = torch.zeros(num_upper_triangle_terms_1)

        k= 0

        for i in range(num_vectors_1):
            for j in range(i + 1, num_vectors_1):


                pairwise_distances_1[k] = adjacency_matrix_1[i][j]

                k=k+ 1



        pairwise_distances_1 = pairwise_distances_1.view(-1, 1)
        pairwise_distances_1= torch.where(pairwise_distances_1 == 0, 0, pairwise_distances_1)




        column_tensor_1 = pairwise_distances_1


        num_repeats_1 =5


        row_tensor_1 = column_tensor_1.repeat(1, num_repeats_1).to(device)




        updated_node_features_1 = torch.cat([node_features_1, row_tensor_1], dim=0)



        optimizer.zero_grad()

        z =model.encode(node_features_1,edge_index_1.to(device)) #encode
        z= torch.exp(z)
        edge_index1 = fully_connected_graph_with_self_loops(num_nodes_1).to(device)


        recon = model.decode(z.to(device), edge_index1.to(device))

        num_rows2 = recon.size(0)


        edge= model.edge_prediction(recon, num_rows2).unsqueeze(1)


        loss = criterion(edge,column_tensor_1)


        loss.backward()
        optimizer.step()


        total_loss += loss.item()



    average_loss = total_loss / len(data1)

    print(f"Epoch [{epoch + 1}/{epochs}] Loss: {average_loss:.4f}")




Save and load the model

In [22]:
import pickle
filename = f"autoencoder.pickle"
with open(filename, 'wb') as f:
    pickle.dump(model, f)

In [23]:
from sys import path
import pickle

path=  'autoencoder.pickle'
with open(path, 'rb') as file:
      loaded_auto = pickle.load(file)

Reverse Heat model


In [24]:
torch.set_default_dtype(torch.float32)
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops

in_node_nf = 1
out_node_nf = 1
in_edge_nf = 0
hidden_nf = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
act_fn = torch.nn.SiLU()
n_layers = 4



# Instantiate the GNN model
model = GNN(
    in_node_nf=in_node_nf,
    in_edge_nf=in_edge_nf,
    hidden_nf=hidden_nf,
    out_node_nf= out_node_nf,
    device=device,
    act_fn=act_fn,
    n_layers=n_layers
)


optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
!pip install einops
import os
from os.path import join as pjoin
from pathlib import Path

import numpy as np
import torch
from einops import repeat, rearrange


import matplotlib.pyplot as plt
from IPython.display import HTML, Image
from matplotlib.animation import FuncAnimation

Blur Parameters

In [26]:
import numpy as np
blur_sigma_max =1

blur_sigma_min = 0.2
blur_schedule = np.exp(np.linspace(np.log(blur_sigma_min), np.log(blur_sigma_max),100))

device='cpu'
blur_schedule

array([0.2       , 0.20327796, 0.20660965, 0.20999594, 0.21343774,
       0.21693594, 0.22049148, 0.22410529, 0.22777834, 0.23151158,
       0.23530601, 0.23916263, 0.24308247, 0.24706654, 0.25111592,
       0.25523166, 0.25941486, 0.26366662, 0.26798806, 0.27238034,
       0.2768446 , 0.28138203, 0.28599383, 0.29068121, 0.29544543,
       0.30028772, 0.30520938, 0.3102117 , 0.31529601, 0.32046366,
       0.325716  , 0.33105442, 0.33648034, 0.34199519, 0.34760043,
       0.35329753, 0.35908801, 0.3649734 , 0.37095524, 0.37703513,
       0.38321466, 0.38949548, 0.39587924, 0.40236762, 0.40896235,
       0.41566517, 0.42247784, 0.42940218, 0.43644   , 0.44359317,
       0.45086357, 0.45825314, 0.46576383, 0.47339761, 0.48115651,
       0.48904257, 0.49705789, 0.50520457, 0.51348478, 0.5219007 ,
       0.53045455, 0.5391486 , 0.54798515, 0.55696652, 0.5660951 ,
       0.57537329, 0.58480355, 0.59438837, 0.60413028, 0.61403186,
       0.62409573, 0.63432454, 0.644721  , 0.65528786, 0.66602

In [27]:
mod = DCTBlur1D(blur_schedule, device)

Training


In [None]:
import random
epochs =10
data1= filtered_dataset[:1000]
for epoch in range(epochs):
    total_loss = 0


    print(epoch)

    for data in data1:







        optimizer.zero_grad()

        random_integer = random.randint(1, 98)
        num_nodes=data.x.size(0)
        x=loaded_auto.encode(data.x,data.edge_index)
        x= torch.exp(x)






        dat = torch.squeeze(x)




        fwd_steps = torch.linspace(1, 99,99, dtype=torch.long, device=device)

        blurred_batch =  mod(repeat(dat, 'd -> N d', N=99),num_nodes,fwd_steps).float()
        blurred= blurred_batch[random_integer]

        less_blurred= blurred_batch[random_integer-1]
        sigma=0.01
        noise = torch.randn_like(blurred) * sigma
        perturbed_data = noise + blurred
        pert =perturbed_data .unsqueeze(1)
        edge_index = data.edge_index
        h=pert[:,:1]
        h1= torch.ones_like(h)
        h_time = torch.empty_like(h1[:, 0:1]).fill_(random_integer)
        edge_index1 = fully_connected_graph_with_self_loops(num_nodes).to(device)

        output = model(pert, edge_index1)


        diff= torch.squeeze(output)
        prediction = perturbed_data + diff

        loss = F.mse_loss(less_blurred,prediction)*100






        loss.backward()


        optimizer.step()



        total_loss += loss.item()

    average_loss = total_loss /len(data1)

    print(f"Epoch [{epoch + 1}/{epochs}] Loss: {average_loss:.4f}")

In [35]:
import pickle
filename = f"reverse_heat.pickle"
with open(filename, 'wb') as f:
    pickle.dump(model, f)

In [36]:
from sys import path
import pickle

path=  'reverse_heat.pickle'
with open(path, 'rb') as file:
      loaded_heat = pickle.load(file)

Generation

In [39]:
deblur=[]
data= filtered_dataset[901]
num_nodes= data.x.size(0)
K=90




x = loaded_auto.encode(data.x,data.edge_index)
x= torch.exp(x)







dat = torch.squeeze(x)



fwd_steps = torch.linspace(1, 99,99, dtype=torch.long, device=device)
intial_batch =  mod(repeat(dat, 'd -> N d', N=99),num_nodes,fwd_steps).float()
initial_sample= intial_batch[90]
noises = [torch.randn_like(initial_sample[0], dtype=torch.float)[None] for i in range(K)]
intermediate_samples_out = []
u = initial_sample.to(device).float()
intermediate_samples_out = []


for i in range(K, 0, -1):

     edge_index = fully_connected_graph_with_self_loops(num_nodes)
     pert =u .unsqueeze(1)


     output = loaded_heat(pert, edge_index)
     u_mean= torch.squeeze(output)+u
     noise = noises[i-1]
     u = u_mean + noise*0.0125

     deblur.append(u)




The generated latent space is stored in deblur which can be decoded using the autoencoder