# Feature VAE
This notebook contains the code for generating the VAE responsible for compressing graph features (i.e., node and edge features) into a low dimension space to be sampled and to generate a new feature set to be passed to the aligner.

In [1]:
# 15 s
import torch
import torch.nn as nn
from torch_geometric.datasets import ZINC
from tqdm import tqdm

**Make the dataset**

In [None]:
dataset = ZINC(root = '', split='train') # valid, test

In [2]:
# Check which device we are currently using.
if torch.cuda.is_available():
  device = torch.cuda.current_device()
  print("Current device:", torch.cuda.get_device_name(device))
  
else:
  print("Using CPU")

Current device: NVIDIA GeForce GTX 1660 Ti with Max-Q Design


In [3]:
# Print out a sample of the dataset.
# dataset[0].to_namedtuple() # x, edge_index, edge_attr, y

In [4]:
"""
Converts a row of data from the pytorch geometric ZINC dataset into
a matrix representation by stacking edges and padding. 

Args:
    layer: a row of data from the pytorch geometric ZINC dataset. 

Returns:
    A graph of the zinc dataset represented as a matrix.
"""
def convert_row(row):
    x = row.x
    edge_index = row.edge_index
    edge_attr = row.edge_attr
    edge_reprs = [set(), set(), set()]

    # Unpack row and store result in edge_reprs.
    for edge_idx, (edge_i, edge_j) in enumerate(zip(edge_index[0], edge_index[1])):
        # Extract the attributes. 
        src_node_att = x[edge_i.item()].item()
        dst_node_att = x[edge_j.item()].item()
        edge_att = edge_attr[edge_idx].item()

        # Append to edge_reprs.
        edge_reprs[0].add(edge_att)
        edge_reprs[1].add(src_node_att)
        edge_reprs[2].add(dst_node_att)

    # Convert the edge_repr sets to lists. 
    edge_reprs = [list(edge_repr) for edge_repr in edge_reprs]

    # Add padding to make each list the same size.
    maxlen = max([len(s) for s in edge_reprs])
    for edge_repr in edge_reprs:
        while len(edge_repr) < maxlen:
            edge_repr.append(0)

    # Convert edge_reprs into a tensor.
    graph_repr = []
    for e, v1, v2 in zip(edge_reprs[0], edge_reprs[1], edge_reprs[2]):
        e_t = torch.tensor([e])
        v1_t = torch.tensor([v1])
        v2_t = torch.tensor([v2])
        graph_repr.append(torch.cat([torch.nn.functional.one_hot(e_t, num_classes=4).squeeze(),
                            torch.nn.functional.one_hot(v1_t, num_classes=28).squeeze(),
                            torch.nn.functional.one_hot(v2_t, num_classes=28).squeeze()]))

    return torch.stack(graph_repr)
    

In [None]:
# Print out what one converted row of data looks like.
# convert_row(dataset[0])

In [6]:
# 5:20 mins
# graph_lst = []
# for row in dataset:
#     graph_lst.append(convert_row(row))

In [7]:
# 10 s
# torch.save(graph_lst, "data/graph_lst.pth")

In [8]:
# 10 s
graph_lst = torch.load("data/graph_lst.pth")

In [9]:
max_len = max(graph.shape[0] for graph in graph_lst)
print(max_len)

10


In [10]:
# 22 s
# Pad all the graphs in the dataset to make them the same size.
padded_graph_lst = []
for graph in graph_lst:
    pad_len = max_len - graph.shape[0]
    # print(pad_len)
    pad_module = nn.ConstantPad2d((0, 0, 0, pad_len), value=0)
    padded_graph = pad_module(graph)
    padded_graph_lst.append(torch.flatten(padded_graph).float())
    # print(padded_graph.shape)

In [11]:
input_dim = padded_graph_lst[0].shape[0]
print(input_dim)

600


**Convert padded_graph_lst into a VAEDataset**

In [12]:
from typing import List, Optional, Sequence, Union
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset


# Add your custom dataset class here
class MyDataset(Dataset):
    def __init__(self, padded_graph_lst):
        self.padded_graph_lst = padded_graph_lst
    
    def __len__(self):
        return len(self.padded_graph_lst)
    
    def __getitem__(self, idx):
        return self.padded_graph_lst[idx]

class VAEDataset(LightningDataModule):
    """
    PyTorch Lightning data module 

    Args:
        data_dir: root directory of your dataset.
        train_batch_size: the batch size to use during training.
        val_batch_size: the batch size to use during validation.
        patch_size: the size of the crop to take from the original images.
        num_workers: the number of parallel workers to create to load data
            items (see PyTorch's Dataloader documentation for more details).
        pin_memory: whether prepared items should be loaded into pinned memory
            or not. This can improve performance on GPUs.
    """

    def __init__(
        self,
        data_path: str,
        train_batch_size: int = 8,
        val_batch_size: int = 8,
        patch_size: Union[int, Sequence[int]] = (256, 256),
        num_workers: int = 0,
        pin_memory: bool = False,
        **kwargs,
    ):
        super().__init__()

        self.data_dir = data_path
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.patch_size = patch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

    def setup(self, stage: Optional[str] = None) -> None:
        
        self.train_dataset = MyDataset(padded_graph_lst)
        
        # Replace CelebA with your dataset
        self.val_dataset = MyDataset(padded_graph_lst)
#       ===============================================================
        
    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=self.pin_memory,
        )

    def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=self.pin_memory,
        )
    
    def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(
            self.val_dataset,
            batch_size=144,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=self.pin_memory,
        )
     

In [13]:
vae_data = VAEDataset("")
vae_data.setup()

**Make the VAE**

In [15]:
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np

In [16]:
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dims):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, latent_dims)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        return self.linear2(x)

In [17]:
class Decoder(nn.Module):
    def __init__(self, input_dim, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims, 512)
        self.linear2 = nn.Linear(512, input_dim)

    def forward(self, z):
        z = F.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        return z

In [18]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dims):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(input_dim, latent_dims)
        self.decoder = Decoder(input_dim, latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [34]:
def revert_onehot(tensor):
    splits = torch.chunk(tensor.squeeze(), max_len)
    row_lst = []
    for split in splits:
        subsplits = torch.split(split, [4, 28, 28])
        # print(subsplits)
        e = torch.argmax(subsplits[0])
        v1 = torch.argmax(subsplits[1])
        v2 = torch.argmax(subsplits[2])
        row_lst.append(torch.tensor((e, v1, v2)))
    return torch.stack(row_lst)

**Train the VAE**

In [22]:
def train(autoencoder, data, epochs=20):
    opt = torch.optim.Adam(autoencoder.parameters())
    for epoch in tqdm(range(epochs)):
        for x in data:
            x = x.to(device) # GPU
            opt.zero_grad()
            x_hat = autoencoder(x)
            loss = ((x - x_hat)**2).sum()
            loss.backward()
            opt.step()
        latent_dim = 2
        z = torch.randn(1, latent_dim).to(device)
        print(revert_onehot(autoencoder.decoder(z)))
    return autoencoder

In [35]:
latent_dims = 2
autoencoder = Autoencoder(input_dim, latent_dims).to(device) # GPU
z = torch.randn(1, latent_dims).to(device)
print(revert_onehot(autoencoder.decoder(z)))
autoencoder = train(autoencoder, vae_data.train_dataloader())

tensor([[ 0,  1, 18],
        [ 3,  8,  5],
        [ 0, 11, 11],
        [ 0,  2,  2],
        [ 2, 11, 20],
        [ 1, 15, 20],
        [ 1,  1, 20],
        [ 1, 12, 27],
        [ 0, 15,  3],
        [ 0, 16, 12]])


  0%|          | 0/20 [00:00<?, ?it/s]

  5%|▌         | 1/20 [01:14<23:27, 74.10s/it]

tensor([[ 1,  0,  0],
        [ 2,  8,  8],
        [ 0,  1,  1],
        [ 0,  1,  1],
        [ 2,  9,  9],
        [ 1,  5, 10],
        [ 1, 12, 12],
        [ 3,  8, 16],
        [ 0, 27,  8],
        [ 1, 12,  2]])


 10%|█         | 2/20 [02:28<22:11, 74.00s/it]

tensor([[ 1,  0,  0],
        [ 2,  1,  1],
        [ 0,  4,  4],
        [ 0,  9,  9],
        [ 3, 10, 10],
        [ 2, 10, 10],
        [ 0, 12, 12],
        [ 0,  1, 25],
        [ 2,  9, 19],
        [ 1,  0, 18]])


 15%|█▌        | 3/20 [03:40<20:42, 73.11s/it]

tensor([[ 1,  0,  0],
        [ 0,  1,  1],
        [ 0,  4,  4],
        [ 0,  9,  9],
        [ 3, 10, 10],
        [ 3, 11, 11],
        [ 0, 12, 12],
        [ 0, 11,  1],
        [ 0,  4, 15],
        [ 0,  1,  3]])


 20%|██        | 4/20 [04:55<19:42, 73.88s/it]

tensor([[ 1,  0,  0],
        [ 2, 11, 11],
        [ 0,  4,  4],
        [ 0,  8,  8],
        [ 3, 10, 10],
        [ 3, 11, 11],
        [ 2, 23, 17],
        [ 2, 19,  7],
        [ 1, 25, 16],
        [ 0, 15,  3]])


 25%|██▌       | 5/20 [06:08<18:25, 73.69s/it]

tensor([[ 1,  0,  0],
        [ 2, 11, 11],
        [ 0,  4,  4],
        [ 0,  6,  6],
        [ 3, 10, 10],
        [ 3, 18, 11],
        [ 2, 23, 25],
        [ 0, 19, 13],
        [ 1, 25, 16],
        [ 0, 15, 13]])


 30%|███       | 6/20 [07:24<17:22, 74.49s/it]

tensor([[ 1,  0,  0],
        [ 2,  2,  2],
        [ 0,  4,  4],
        [ 0,  6,  6],
        [ 0, 10, 10],
        [ 3,  7,  5],
        [ 2,  2, 23],
        [ 0, 19, 14],
        [ 2, 27, 16],
        [ 0, 20, 13]])


 35%|███▌      | 7/20 [08:45<16:36, 76.67s/it]

tensor([[ 1,  0,  0],
        [ 0,  8,  8],
        [ 0,  4,  4],
        [ 0,  1,  1],
        [ 2,  9, 14],
        [ 1,  5,  5],
        [ 1,  1,  7],
        [ 0,  3, 16],
        [ 0, 27,  0],
        [ 1,  1,  2]])


 40%|████      | 8/20 [10:07<15:39, 78.30s/it]

tensor([[ 1,  0,  0],
        [ 2,  2,  2],
        [ 3,  5,  0],
        [ 0, 15, 15],
        [ 0, 14, 14],
        [ 1,  5,  5],
        [ 2, 14,  7],
        [ 0, 19, 14],
        [ 2, 27,  6],
        [ 1, 20, 13]])


 45%|████▌     | 9/20 [11:28<14:30, 79.10s/it]

tensor([[ 1,  0,  0],
        [ 2, 11, 11],
        [ 0,  4,  4],
        [ 0,  6,  6],
        [ 1, 14, 14],
        [ 3,  5,  5],
        [ 2, 14, 25],
        [ 0, 19, 13],
        [ 0, 25, 16],
        [ 0, 15,  3]])


 50%|█████     | 10/20 [12:48<13:14, 79.49s/it]

tensor([[ 1,  0,  0],
        [ 0,  1,  1],
        [ 0,  4,  4],
        [ 0,  9,  9],
        [ 1, 10, 10],
        [ 3, 11,  5],
        [ 1, 14,  2],
        [ 2, 26,  2],
        [ 0, 25, 16],
        [ 0,  1,  3]])


 55%|█████▌    | 11/20 [14:09<11:59, 79.97s/it]

tensor([[ 1,  0,  0],
        [ 2,  1,  1],
        [ 0,  4,  4],
        [ 0,  9,  9],
        [ 1, 15, 15],
        [ 3, 12, 12],
        [ 0, 12, 12],
        [ 0, 17,  3],
        [ 0,  7, 26],
        [ 1,  1,  3]])


 60%|██████    | 12/20 [15:30<10:41, 80.16s/it]

tensor([[ 1,  0,  0],
        [ 0,  1,  1],
        [ 0,  4,  4],
        [ 0, 10, 10],
        [ 1, 14, 14],
        [ 3,  5,  5],
        [ 1, 14,  2],
        [ 2, 26,  2],
        [ 0, 25, 16],
        [ 0, 14,  3]])


 65%|██████▌   | 13/20 [16:50<09:22, 80.29s/it]

tensor([[ 1,  0,  0],
        [ 2, 11, 11],
        [ 0,  4,  4],
        [ 0, 15, 15],
        [ 1, 14, 14],
        [ 3,  5,  5],
        [ 2, 14, 23],
        [ 0, 19, 13],
        [ 2, 27, 16],
        [ 1, 20, 13]])


 70%|███████   | 14/20 [18:09<07:58, 79.67s/it]

tensor([[ 1,  0,  0],
        [ 2,  2,  2],
        [ 0,  4,  6],
        [ 0, 15, 15],
        [ 1, 14, 14],
        [ 3,  5,  5],
        [ 2, 14, 23],
        [ 0, 19, 13],
        [ 2, 27, 16],
        [ 1, 20, 13]])


 75%|███████▌  | 15/20 [19:27<06:36, 79.28s/it]

tensor([[ 1,  0,  0],
        [ 0, 11, 11],
        [ 0,  6,  6],
        [ 0, 15,  1],
        [ 2, 14, 14],
        [ 1,  5,  5],
        [ 1,  5,  7],
        [ 0,  3, 16],
        [ 0, 27,  0],
        [ 1, 25, 20]])


 80%|████████  | 16/20 [20:45<05:15, 78.77s/it]

tensor([[ 1,  0,  0],
        [ 0, 11, 11],
        [ 0,  4,  4],
        [ 1,  1,  1],
        [ 1, 14, 14],
        [ 1,  5,  5],
        [ 1,  5,  7],
        [ 0,  3, 16],
        [ 0, 27,  0],
        [ 1, 25, 20]])


 85%|████████▌ | 17/20 [22:07<03:59, 79.85s/it]

tensor([[ 1,  0,  0],
        [ 2, 11, 11],
        [ 0,  4,  4],
        [ 0, 15, 15],
        [ 1, 14, 14],
        [ 3,  5,  5],
        [ 2, 14, 25],
        [ 0, 19, 13],
        [ 0, 25, 16],
        [ 0, 20, 13]])


 90%|█████████ | 18/20 [23:34<02:43, 81.93s/it]

tensor([[ 1,  0,  0],
        [ 2,  1,  1],
        [ 0,  4,  4],
        [ 0,  6,  6],
        [ 1, 15, 15],
        [ 3, 11,  9],
        [ 1,  8, 13],
        [ 0,  8,  3],
        [ 0, 25,  0],
        [ 1,  1, 20]])


 95%|█████████▌| 19/20 [24:53<01:20, 80.98s/it]

tensor([[ 1,  0,  0],
        [ 2,  2,  2],
        [ 3,  6,  6],
        [ 0, 15, 15],
        [ 3, 14, 14],
        [ 1,  5,  5],
        [ 2, 14,  7],
        [ 0, 19, 14],
        [ 2, 27, 12],
        [ 1, 20, 13]])


100%|██████████| 20/20 [26:11<00:00, 78.59s/it]

tensor([[ 1,  0,  0],
        [ 2, 10, 10],
        [ 0, 15, 15],
        [ 0, 16, 16],
        [ 1, 16, 16],
        [ 3, 11,  5],
        [ 1, 20,  7],
        [ 0,  8, 16],
        [ 2, 21,  0],
        [ 1, 20, 20]])





In [36]:
torch.save(autoencoder, "autoencoder.pth")

**Observe the VAE Performance 📈**

In [37]:
# After training the VAE

# Sample a random point in the latent space
z = torch.randn(1, latent_dims).to(device)

# Decode the sampled latent vector to get a new vector
print(revert_onehot(autoencoder.decoder(z)))


tensor([[ 1,  0,  0],
        [ 2,  1,  1],
        [ 0,  4,  4],
        [ 0,  9,  9],
        [ 1, 15, 15],
        [ 3, 11,  5],
        [ 3, 15,  2],
        [ 0,  8,  3],
        [ 0, 25, 22],
        [ 1, 20, 20]])


In [None]:
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200

def plot_latent(autoencoder, data, num_batches=100):
    for i, (x, y) in enumerate(data):
        z = autoencoder.encoder(x.to(device))
        z = z.to('cpu').detach().numpy()
        plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
        if i > num_batches:
            plt.colorbar()
            break