In [1]:
import torch as th 
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from model import Generator, Discriminator, gradient_penalty

In [None]:
""" 
Import Data
"""


In [None]:
""" 
Hyperparameters
"""

# Optimizer params
g_lr = 0.001 
d_lr = 0.001
b1 = 0.5 
b2 = 0.999  

# WGAN params
N_critic = 5            # nr of times to train discriminator more
lambda_gp = 10          # gradient penalty hyperpraram

# Training params
MAX_EPOCHS = 500
BATCH_SIZE = 16

In [None]:
""" 
Model Definitions
"""
generator = Generator()
discriminator = Discriminator()

# Optimizers
optimizer_G = th.optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2)) 
optimizer_D = th.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(b1, b2))

In [None]:
# Main training loop
def train(generator, discriminator, optimizer_g, optimizer_d, data_loader):
    for epoch in range(MAX_EPOCHS):
        # real == batch (confusing naming I know...)
        for real in data_loader:
            # Create new data object with noise and same edge_index 
            for i in range(N_critic):
                noise_batch = []

                # Create n_batch number of noise_vectors and batch it
                # We must create a noise vector with the corresponding graph in the actual data
                for batch_data in real:
                    noise = th.randn(batch_data.x.shape)
                    noise = Data(x=noise, edge_index=batch_data.edge_indices)
                    noise_batch.append(noise)

                noise_batch = Batch.from_data_list(noise_batch)

                # Input noise_data into generator
                global fake 
                fake = Generator(noise_batch)

                # Generator output is a tensor of dimensionality: (sum of all nodes in batch, output_features)
                # We must turn this into appropriate (batch) input for the discriminator
                splits = [batch_data.shape[0] for batch_data in real]       # nr of nodes per graph
                fake_batch = th.split(fake, splits, dim=0)                  # split stacked tensor fake into appropriate batches
                fake = Batch.from_data_list(fake_batch)

                discriminator_fake = discriminator(fake).reshape(-1)        # discriminator scores for fakes
                discriminator_real = discriminator(real).reshape(-1)        # discriminator scores for reals
                gp = gradient_penalty(discriminator, real, fake)

                # Discriminator loss and train
                loss_discriminator = -(th.mean(discriminator_real) - th.mean(discriminator_fake)) + lambda_gp * gp
                discriminator.zero_grad() 
                loss_discriminator.backward() 
                optimizer_d.step()

            # Generator loss and train
            output = discriminator(fake).reshape(-1)        # discriminator scores for fake
            loss_generator = -th.mean(output)               # loss for genereator = the discriminators' judgement
                                                            # higher score = better
            generator.zero_grad()
            loss_generator.backward()
            optimizer_g.step()
    
        # TODO: Evaluation and logging code??

In [10]:
""" 
Code for testing dimensionality of batches
"""

import torch
from torch_geometric.data import Data, Batch
from torch_geometric.nn import TAGConv

class SomeModel(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layer = TAGConv(in_channels, out_channels)

    def forward(self, data):
        x = self.layer(data.x, data.edge_index)
        return x  


# Create individual graphs with different number of nodes
graph1 = Data(x=torch.randn(5, 16), edge_index=torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]), y=torch.randn(5, 1))
graph2 = Data(x=torch.randn(7, 16), edge_index=torch.tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 0]]), y=torch.randn(7, 1))
graph3 = Data(x=torch.randn(8, 16), edge_index=torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 0]]), y=torch.randn(8, 1))

# Create a batch of graphs
batch_data = Batch.from_data_list([graph1, graph2, graph3])

# Initialize and forward pass through the model
model = SomeModel(in_channels=16, out_channels=1)
output = model(batch_data)

# print(output)
print(output.shape)  # Shape of the output tensor

new_output = torch.split(output, [5, 7, 8], dim=0)
# print(new_output)

for tens in new_output:
    print(tens.shape)

torch.Size([20, 1])
torch.Size([5, 1])
torch.Size([7, 1])
torch.Size([8, 1])
