In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GIN, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.loader import DataLoader

from torch.utils.data import DataLoader as TorchDataLoader
import numpy as np
from pathlib import Path


from libraries.dataloader import GraphDataset

In [39]:
class GINEncoder(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_features, dim_h, dim_encode):
        super(GINEncoder, self).__init__()
        self.conv1 = GINConv(   nn.Sequential(nn.Linear(dim_features, dim_h),
                                nn.BatchNorm1d(dim_h), nn.ReLU(),
                                nn.Linear(dim_h, dim_h), nn.ReLU()))
        self.conv2 = GINConv(nn.Sequential(nn.Linear(dim_h, dim_h),
                                nn.BatchNorm1d(dim_h), nn.ReLU(),
                                nn.Linear(dim_h, dim_h), nn.ReLU()))
        self.conv3 = GINConv(
            nn.Sequential(nn.Linear(dim_h, dim_h), nn.BatchNorm1d(dim_h), nn.ReLU(),
                       nn.Linear(dim_h, dim_h), nn.ReLU()))
        self.lin1 = nn.Linear(dim_h*3, dim_h*3)
        self.lin2 = nn.Linear(dim_h*3, dim_encode)

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        
        # Node embeddings 
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)
        # Graph-level readout
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)
        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)
        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)
        
        return h

class GINDecoder(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_features, dim_h, dim_encode):
        super(GIN, self).__init__()
        self.conv1 = GINConv(   nn.Sequential(nn.Linear(dim_encode, dim_h),
                                nn.BatchNorm1d(dim_h), nn.ReLU(),
                                nn.Linear(dim_h, dim_h), nn.ReLU()))
        self.conv2 = GINConv(nn.Sequential(nn.Linear(dim_h, dim_h),
                                nn.BatchNorm1d(dim_h), nn.ReLU(),
                                nn.Linear(dim_h, dim_h), nn.ReLU()))
        self.conv3 = GINConv(
            nn.Sequential(nn.Linear(dim_h, dim_h), nn.BatchNorm1d(dim_h), nn.ReLU(),
                       nn.Linear(dim_h, dim_features), nn.ReLU()))
        self.lin1 = nn.Linear(dim_encode, dim_h*3)
        self.lin2 = nn.Linear(dim_h*3, dim_h*3)

    def forward(self, x):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch
        
        # Node embeddings 
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)
        # Graph-level readout
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)
        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)
        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)
        
        return h

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_size)
        self.fc22 = nn.Linear(hidden_size, latent_size)

        # Decoder
        self.fc3 = nn.Linear(latent_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, input_size)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Instantiate the VAE model
input_size = 784  # Adjust according to your input size
hidden_size = 400
latent_size = 20
vae = VAE(input_size, hidden_size, latent_size)

# Define the loss function (BCE loss for reconstruction and KLD loss for regularization)
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # Kullback-Leibler Divergence loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD



In [40]:
batch_size = 32

In [41]:
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphDataset(root=Path(Path.cwd(), 'data'), gt=gt, split='val')
train_dataset = GraphDataset(root=Path(Path.cwd(), 'data'), gt=gt, split='train')

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [42]:
model = GINEncoder(300, 300, 50)

In [50]:
maxi = 0
for batch in train_loader:
    print(batch)
    input_ids = batch.input_ids
    batch.pop('input_ids')
    attention_mask = batch.attention_mask
    batch.pop('attention_mask')
    graph_batch = batch
    out = model(graph_batch)
    
    break

print(maxi)

DataBatch(x=[1171, 300], edge_index=[2, 2472], input_ids=[32, 256], attention_mask=[32, 256], batch=[1171], ptr=[33])
0


In [52]:
batch.edge_index

tensor([[   0,    1,    1,  ..., 1146, 1166, 1157],
        [   1,    0,    2,  ..., 1168, 1157, 1166]])

In [44]:
out.shape

torch.Size([32, 50])