In [4]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [5]:
import torch
from torch import Tensor
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.transforms import ToDense


class ToGraphEBM(BaseTransform):
    def __init__(self,max_nodes = 38, n_atoms = 28, n_edge_types = 3):
        self.max_nodes = max_nodes
        self.n_atoms = n_atoms
        
    def forward(self, data : Data):
        nodes, _ = data.x.shape
        
        
        # adding virtual nodes if max nodes not present
        if data.x.shape[0] < self.max_nodes:
            virt_rows = torch.ones([self.max_nodes - data.x.shape[0], 1])
            virt_rows = virt_rows * self.n_atoms
            data.x = torch.cat([data.x, virt_rows], dim=0)
        data.x = F.one_hot(data.x.squeeze().long(), num_classes = self.n_atoms + 1)
        
        return data
        
class ToDenseAdj(BaseTransform):
    def __init__(self,max_nodes = 38, n_atoms = 28, n_edge_types = 3):
        self.max_nodes = max_nodes
        self.n_atoms = n_atoms
        self.trans = ToDense(num_nodes=38)
        self.n_edge_types = n_edge_types 

    def forward(self, data : Data):
        
        # working with dense dataset and adding virtual edges
        d = self.trans(data)
        adj = d.adj
        adj_mask = (adj > 0) * 1
        adj = F.one_hot(adj.long(), num_classes = self.n_edge_types + 1)
        adj[:, :, 0] = adj[:, :, 0] * adj_mask
        adj[:,:,-1] = torch.logical_not(torch.any(adj,dim=-1, keepdim=False)) * 1
        for i in range(adj.shape[0]):
            adj[i,i,:] = 0
        d.adj = adj.permute(2,0,1)
        
        return d
    

In [6]:
import torch_geometric as pyg
from torch_geometric.transforms import ToDense

dataset = pyg.datasets.ZINC(root='/kaggle/working/', transform = ToDenseAdj(),pre_transform=ToGraphEBM())

Downloading https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1
Extracting /kaggle/working/molecules.zip
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/train.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/val.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/test.index
Processing...
Processing train dataset: 100%|██████████| 220011/220011 [00:43<00:00, 5065.17it/s]
Processing val dataset: 100%|██████████| 24445/24445 [00:05<00:00, 4274.72it/s]
Processing test dataset: 100%|██████████| 5000/5000 [00:01<00:00, 3303.89it/s]
Done!


In [7]:
from torch_geometric.loader import DenseDataLoader

loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)

In [8]:
import os
import torch_geometric as pyg
from torch_geometric.transforms import ToDense
from torch_geometric.loader import DenseDataLoader
from torch.nn.utils import spectral_norm
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [9]:
class GraphConv(nn.Module):

    def __init__(self, in_channels, out_channels, num_edge_type=4, add_self=False):
        super(GraphConv, self).__init__()

        self.add_self = add_self
        if self.add_self:
            self.linear_node = spectral_norm(nn.Linear(in_channels, out_channels))
        self.linear_edge = spectral_norm(nn.Linear(in_channels, out_channels * num_edge_type))
        self.num_edge_type = num_edge_type
        self.in_ch = in_channels
        self.out_ch = out_channels

    def forward(self, adj, h):
        mb, node, _ = h.shape
        if self.add_self:
            h_node = self.linear_node(h)
        m = self.linear_edge(h)
        m = m.reshape(mb, node, self.out_ch, self.num_edge_type)
        m = m.permute(0, 3, 1, 2) # m: (batchsize, edge_type, node, ch)
        hr = torch.matmul(adj, m)  # hr: (batchsize, edge_type, node, ch)
        hr = hr.sum(dim=1)   # hr: (batchsize, node, ch)
        if self.add_self:
            return hr+h_node  #
        else:
            return hr

In [10]:
class EnergyFunction(nn.Module):

  def __init__(self, in_channels, num_edge_type, latent_channels, layers):
    super(EnergyFunction, self).__init__()
    self.conv1 = GraphConv(in_channels, latent_channels, num_edge_type, add_self=True)
    self.conv_list = nn.ModuleList([GraphConv(latent_channels, latent_channels) for i in range(layers-1)])
    self.linear = nn.Linear(latent_channels, 1)
    self.layers = layers
    self.latent_channels = latent_channels

  def forward(self, adj, h):
    h = self.conv1(adj, h)
    h = F.relu(h)
    for i in range(self.layers-1):
      h = self.conv_list[i](adj, h)
      h = F.relu(h)
    h = torch.sum(h, dim=1)
    energy = self.linear(h)
    return energy


In [11]:
def normalize_adj(adj):
    degree = adj.sum(dim=(1,2))
    degree_inv = degree.pow(-1) 
    diags_inv = [torch.diag(degree_inv[i, :]) for i in range(adj.shape[0])]
    D_inv = torch.stack(diags_inv, dim=0)
    s = [1 for i in range(len(D_inv.shape))]
    s.append(adj.shape[1])
    D_inv = D_inv.unsqueeze(-1).repeat(s).permute(0,3,1,2)
    adj_pos = torch.matmul(D_inv, adj)
    return adj_pos
    
def requires_grad(parameters, flag):
    for p in parameters:
        p.requires_grad = flag

In [12]:
class GraphEBM():
  def __init__(self, num_nodes, num_atoms, num_edges, latent_size, layers, device):
      self.energy_function = EnergyFunction(num_atoms, num_edges, latent_size, layers).to(device)
      self.num_nodes = num_nodes
      self.num_atoms = num_atoms
      self.num_edges = num_edges
      self.device = device

  def train_rand(self, loader, lr, wd, max_epochs, c, ld_step, ld_noise_std, ld_step_size, clamp, alpha, save_dir):
     
     parameters = self.energy_function.parameters()
     optimizer = torch.optim.Adam(parameters, lr=lr, betas=(0.0, 0.999), weight_decay=wd)
     for epoch in range(max_epochs):
         l_loss = []
         l_loss_energy= []
         l_loss_reg = []
         for _, batch in enumerate(tqdm(loader)):
         
             pos_x = batch.x.to(self.device).to(dtype=torch.float32)
             pos_adj = batch.adj.to(self.device).to(dtype=torch.float32)

             # dequantization
             pos_x += c * torch.rand_like(pos_x, device=self.device)
             pos_adj += c * torch.rand_like(pos_adj, device=self.device)
             pos_adj = normalize_adj(pos_adj)

             neg_x = torch.rand_like(pos_x, device=self.device) * (1 + c)
             neg_adj = torch.rand_like(pos_adj, device=self.device)

             noise_x = torch.rand_like(pos_x, device=self.device)
             noise_adj = torch.rand_like(pos_adj, device=self.device)
             neg_x.requires_grad =True
             neg_adj.requires_grad= True
             for _ in range(ld_step):

                 noise_x.normal_(0, ld_noise_std)
                 noise_adj.normal_(0, ld_noise_std)
             
                 # calculate energy gradient wrt negative data
                 neg_energy = self.energy_function(neg_adj, neg_x)
                 neg_energy.sum().backward()

                 if (clamp):
                     neg_x.grad.clamp_(-0.01, 0.01)
                     neg_adj.grad.clamp_(-0.01, 0.01)
                 
                 # langevin steps
                 neg_x.data.add_(noise_x.data)
                 neg_adj.data.add_(noise_adj.data)
                 neg_x.data.add_(neg_x.grad.data, alpha=-ld_step_size)
                 neg_adj.data.add_(neg_adj.grad.data, alpha=-ld_step_size)
             
                 # manual zero step since optims are not used
                 neg_x.grad.detach_()
                 neg_x.grad.zero_()
                 neg_adj.grad.detach_()
                 neg_adj.grad.zero_()

                 # ensuring value ranges
                 neg_x.data.clamp_(0, 1 + c)
                 neg_adj.data.clamp_(0, 1)


             neg_adj.requires_grad = False
             neg_x.requires_grad = False
             requires_grad(parameters, True)
             neg_energy = self.energy_function(neg_adj, neg_x)
             pos_energy = self.energy_function(pos_adj, pos_x)
             loss_energy = (pos_energy - neg_energy).mean()
             loss_reg = (pos_energy ** 2 + neg_energy ** 2).mean()
             loss = loss_energy + alpha * loss_reg 
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             l_loss_energy.append(loss_energy)
             l_loss_reg.append(loss_reg)
             l_loss.append(loss)
         print('Epoch: {:03d}, Loss: {:.6f}, Energy Loss: {:.6f}, Regularizer Loss: {:.6f}'.format(epoch+1, (sum(l_loss)/len(l_loss)).item(), (sum(l_loss_energy)/len(l_loss_energy)).item(), (sum(l_loss_reg)/len(l_loss_reg)).item()))
         if((epoch + 1 % 5) == 0):
             torch.save(self.energy_function, save_dir)
  def gen_rand(self, n_molecules, c, ld_step, ld_noise_std, ld_step_size, clamp):
         
         gen_x = torch.rand(n_molecules, self.num_nodes, self.num_atoms, device=self.device) * (1 + c)
         gen_adj = torch.rand(n_molecules, self.num_nodes, self.num_nodes, self.num_edges, device=self.device)

         noise_x = torch.rand_like(gen_x, device=self.device)
         noise_adj = torch.rand_like(gen_adj, device=self.device)
         gen_x.requires_grad = True
         gen_adj.requires_grad = True
         for _ in range(ld_step):

             noise_x.normal_(0, ld_noise_std)
             noise_adj.normal_(0, ld_noise_std)

             # calculate energy gradient wrt generated data
             gen_energy = self.energy_function(gen_adj, gen_x)
             gen_energy = gen_energy.sum()
             gen_energy.backward()

             if (clamp):
                 gen_x.grad.clamp_(-0.01, 0.01)
                 gen_adj.grad.clamp_(-0.01, 0.01)
                 
             # langevin steps
             gen_x.data.add_(noise_x.data)
             gen_adj.data.add_(noise_adj.data)
             gen_x.data.add_(gen_x.grad.data, alpha=-ld_step_size)
             gen_adj.data.add_(gen_adj.grad.data, alpha=-ld_step_size)
             
             # manual zero step since optims are not used
             gen_x.grad.detach_()
             gen_x.grad.zero_()
             gen_adj.grad.detach_()
             gen_adj.grad.zero_()

             # ensuring value ranges
             gen_x.data.clamp_(0, 1 + c)
             gen_adj.data.clamp_(0, 1)

         gen_adj += gen_adj.permute(0, 2, 1, 3)
         gen_adj  = gen_adj * 0.5
         return gen_adj        

In [13]:
model = GraphEBM( num_nodes=38, num_atoms = 29, num_edges = 4, latent_size = 64, layers = 2, device=torch.device('cuda'))

In [14]:
model.train_rand(loader, lr=1e-4, wd=0, max_epochs=5, c=0, ld_step=150, ld_noise_std=0.005, ld_step_size=30, clamp=True, alpha=1,save_dir='/kaggle/working/model.pth')

100%|██████████| 6876/6876 [59:53<00:00,  1.91it/s] 


Epoch: 001, Loss: 360551.843750, Energy Loss: 81.415825, Regularizer Loss: 360470.906250


100%|██████████| 6876/6876 [1:00:53<00:00,  1.88it/s]


Epoch: 002, Loss: 747.655762, Energy Loss: -0.591725, Regularizer Loss: 748.250366


100%|██████████| 6876/6876 [1:00:45<00:00,  1.89it/s]


Epoch: 003, Loss: 25.795780, Energy Loss: -1.015117, Regularizer Loss: 26.810881


100%|██████████| 6876/6876 [1:00:56<00:00,  1.88it/s]


Epoch: 004, Loss: 0.729441, Energy Loss: -0.998932, Regularizer Loss: 1.728377


100%|██████████| 6876/6876 [1:00:54<00:00,  1.88it/s]


Epoch: 005, Loss: 0.005470, Energy Loss: -0.995745, Regularizer Loss: 1.001217


In [16]:
torch.save(model.energy_function, '/kaggle/working/model.pth')