In [1]:
# Standard library imports
from pathlib import Path
import os
import sys
import math
import itertools
from typing import List, Union, Sequence


# Third-party library imports
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from numpy import ndarray
from tqdm import tqdm
from scipy.spatial import distance
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from PIL import Image
from torch_geometric.data import Data
from torch_geometric.nn import ChebConv
from torch_geometric.nn.conv import GeneralConv

# Project-specific import
from Utils_GAN.classes import DynamicGraphTemporalSignal, GConvGRU, BoidDatasetLoader,RecurrentGCN
from Utils_GAN.classes import Encoder,Decoder,GraphSeqGenerator
from Utils_GAN.classes import GraphSeqDiscriminator,GraphSeqGenerator


from Utils_GAN.subfunctions import temporal_signal_split, train,plot_predictions,test,create_video_from_images
from Utils_GAN.subfunctions import train_generator,test_generator,test_generator_plot

### For Now: Grabbing Parts of Pytorch Geometric Temporal As Having Difficulty Setting Up The Environment For It

## Set Up BoidDatasetLoader To Load Entire Dataset
Following Example of EnglandCovidDatasetLoader from Pytorch Geometric Temporal

In [2]:
# reset kernel if not working
loader = BoidDatasetLoader()
dataset = loader.get_dataset()
dataset.snapshot_count

9999

In [None]:
# reset kernel if not working
loader = BoidDatasetLoader()
dataset = loader.get_dataset()
dataset.snapshot_count

In [None]:
loader.min_features[0:2], loader.max_features[0:2]

# Split Dataset Into Training and Testing

In [4]:
train_dataset, test_dataset = temporal_signal_split(dataset)

train_dataset.snapshot_count, test_dataset.snapshot_count

(7999, 2000)

In [None]:
train_dataset, test_dataset = temporal_signal_split(dataset)

train_dataset.snapshot_count, test_dataset.snapshot_count

## Create GConvGRU For Recurrent Layer In Our GNN

## Basic Graph Recurrent Neural Network

In [None]:
#  Training a basic model for tesiting purpose
model = RecurrentGCN(node_features=4, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(3):
    for seq_start in tqdm(range(0, train_dataset.snapshot_count - 5, 5)):
        h_t_prev = None
        for i in range(5):
            snapshot = train_dataset[seq_start+i]
            y_hat, h_t = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h_t_prev)
            h_t_prev = h_t
        cost = F.mse_loss(y_hat, snapshot.y)
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()


## Training

In [None]:
model = RecurrentGCN(node_features=4, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train(train_dataset, 10, model, optimizer)

In [None]:
test(test_dataset, model,loader=loader)

In [38]:
import cv2
import os
import re

# Example usage
create_video_from_images('../generator_test_plots_at_epoch_10', 'generator_test_plots_at_epoch_10.2.mp4', fps=30)

In [6]:
#create_video_from_images('../generator2_test_plots_at_epoch_30', 'generator2_test_plots_at_epoch_30.mp4', fps=30)

## Working on GAN [Generator: (Encoder, Decoder)] [Discriminator: (Encoder)]

In [9]:
class GraphSeqGenerator(torch.nn.Module):
    def __init__(self, node_feat_dim, enc_hidden_dim, enc_latent_dim, dec_hidden_dim, pred_horizon, min_max_x, min_max_y, min_max_edge_weight, visualRange):
        super(GraphSeqGenerator, self).__init__()
        self.encoder = Encoder(node_feat_dim, enc_hidden_dim, enc_latent_dim)
        self.decoder = Decoder(enc_latent_dim, dec_hidden_dim, node_feat_dim)
        self.out_steps = pred_horizon
        self.min_x, self.max_x = min_max_x
        self.min_y, self.max_y = min_max_y
        self.min_edge_weight, self.max_edge_weight = min_max_edge_weight
        self.visualRange = visualRange

    def _compute_edge_index_and_weight(self, y_hat):
        # Not designed for batches :/
        # Grab x and y features
        y_hat_x = y_hat[:, 0].detach().numpy()
        y_hat_y = y_hat[:, 1].detach().numpy()

        # Undo normalization
        y_hat_x = y_hat_x * (self.max_x - self.min_x) + self.min_x
        y_hat_y = y_hat_y * (self.max_y - self.min_y) + self.min_y

        # Compute the distance of all points and include that edge if its less than visualRange
        coords = np.stack((y_hat_x, y_hat_y), axis=1)
        dist_matrix = np.linalg.norm(coords[:, np.newaxis, :] - coords[np.newaxis, :, :], axis=2)
        
        # Get indices where distance is less than visualRange
        edge_indices = np.where((dist_matrix < self.visualRange) & (dist_matrix > 0))
        
        # Create edge_index and edge_attr
        edge_index = np.vstack((edge_indices[0], edge_indices[1]))
        edge_weight = dist_matrix[edge_indices]

        #Normalize edge_weight
        edge_weight = (edge_weight - self.min_edge_weight) / (self.max_edge_weight - self.min_edge_weight)
        
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        edge_weight = torch.tensor(edge_weight, dtype=torch.float)
        return edge_index, edge_weight


        
    def forward(self, sequence, h_enc, h_dec):
        # Warmup Section
        for i in range(sequence.snapshot_count):
            snapshot = sequence[i]
            z, h_enc_0 = self.encoder(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h_enc)
            y_hat, h_dec_0 = self.decoder(z, snapshot.edge_index, snapshot.edge_weight, h_dec)

            h_enc = h_enc_0
            h_dec = h_dec_0

        predictions = []
        predictions.append(y_hat)

        # Prediction Section
        for _ in range(self.out_steps-1):
            # TODO: Compute edge index and edge_attr of y_hat :()
            y_hat_edge_index, y_hat_edge_attr = self._compute_edge_index_and_weight(y_hat)

            z, h_enc_0 = self.encoder(y_hat, y_hat_edge_index, y_hat_edge_attr, h_enc)
            y_hat, h_dec_0 = self.decoder(z, y_hat_edge_index, y_hat_edge_attr, h_dec)

            predictions.append(y_hat)
        return predictions

In [42]:
class GraphSeqDiscriminator(torch.nn.Module):
    def __init__(self, node_feat_dim, enc_hidden_dim, enc_latent_dim, pred_horizon, device):
        super(GraphSeqDiscriminator, self).__init__()

        self.encoder = Encoder(node_feat_dim, enc_hidden_dim, enc_latent_dim)
        self.linear = torch.nn.Linear(enc_latent_dim, 1)
        self.out_steps = pred_horizon
        self.device = device

    
    def forward(self, sequence, h_enc, shouldDetach=False):
        for i in range(sequence.snapshot_count):
            snapshot = sequence[i]
            snapshot_x = snapshot.x.to(self.device)
            snapshot_edge_index = snapshot.edge_index.to(self.device)
            snapshot_edge_attr = snapshot.edge_attr.to(self.device)

            if shouldDetach:
                snapshot_x = snapshot_x.detach()
                snapshot_edge_index = snapshot_edge_index.detach()
                snapshot_edge_attr = snapshot_edge_attr.detach()

            z, h_enc_0 = self.encoder(snapshot_x, snapshot_edge_index, snapshot_edge_attr, h_enc)
            h_enc = h_enc_0
        
        z = F.relu(z)

        # Apply global mean pooling across the node dimension (dim=0) to aggregate node features
        z_pooled = z.mean(dim=0)
        out = self.linear(z_pooled)
        out = torch.sigmoid(out)

        return out, h_enc_0

GraphSeqGenerator(
  (encoder): Encoder(
    (recurrent): GConvGRU(
      (conv_x_z): ChebConv(4, 32, K=2, normalization=sym)
      (conv_h_z): ChebConv(32, 32, K=2, normalization=sym)
      (conv_x_r): ChebConv(4, 32, K=2, normalization=sym)
      (conv_h_r): ChebConv(32, 32, K=2, normalization=sym)
      (conv_x_h): ChebConv(4, 32, K=2, normalization=sym)
      (conv_h_h): ChebConv(32, 32, K=2, normalization=sym)
    )
    (linear): Linear(in_features=32, out_features=16, bias=True)
  )
  (decoder): Decoder(
    (recurrent): GConvGRU(
      (conv_x_z): ChebConv(16, 32, K=2, normalization=sym)
      (conv_h_z): ChebConv(32, 32, K=2, normalization=sym)
      (conv_x_r): ChebConv(16, 32, K=2, normalization=sym)
      (conv_h_r): ChebConv(32, 32, K=2, normalization=sym)
      (conv_x_h): ChebConv(16, 32, K=2, normalization=sym)
      (conv_h_h): ChebConv(32, 32, K=2, normalization=sym)
    )
    (linear): Linear(in_features=32, out_features=4, bias=True)
  )
)

In [11]:
def train_generator(train_data, num_epochs, generator, optimizer, window=8, delay=0, horizon=1, stride=1):
    """
    Trains the given model using the provided training data.

    Args:
        train_data (Dataset): The dataset containing the training data.
        num_epochs (int): The number of epochs to train the model.
        model (nn.Module): The model to be trained.
        optimizer (Optimizer): The optimizer used for training the model.
        window (int, optional): The size of the input sequence window. Defaults to 8.
        delay (int, optional): The delay between the input sequence and the target sequence. Defaults to 0.
        horizon (int, optional): The prediction horizon. Defaults to 1.
        stride (int, optional): The stride for iterating over the training data. Defaults to 1.

    Returns:
        None
    """
    total_timesteps = train_data.snapshot_count
    sample_span = window + delay + horizon

    generator.train()
    for epoch in range(num_epochs):
        print(f'Epoch: {epoch+1}/{num_epochs}')
        epoch_cost = 0
        for start in tqdm(range(0, total_timesteps - sample_span + 1, stride), desc='Training'):
            input_seq = train_data[start:start + window]
            target_seq = train_data[start + window + delay: start + window + delay + horizon]
            predictions = generator(input_seq, None, None)
            predictions = torch.stack(predictions, dim=0)
            target_seq = torch.stack([target_seq[i].x for i in range(target_seq.snapshot_count)], dim=0)
            cost = torch.mean((predictions - target_seq) ** 2)
            epoch_cost += cost.item()
            cost.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f'Cost after epoch {epoch+1}: {epoch_cost}')

Epoch: 1/10


Training:   0%|          | 7/7984 [00:03<1:09:32,  1.91it/s]


KeyboardInterrupt: 

In [16]:
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)

train_generator(train_dataset, 10, generator, optimizer, horizon=generator.out_steps)

Epoch: 1/10


Training:  53%|█████▎    | 4268/7984 [06:40<05:48, 10.66it/s]


KeyboardInterrupt: 

In [11]:
def test_generator(test_data, generator, window=8, delay=0, horizon=1, stride=1):
    """
    Tests the given generator model using the provided test data.

            y_hat_seq_feats, y_hat_seq = generator(input_seq, None, None)


            # Discriminator Step
            discriminator.zero_grad()

            output_of_real, _ = discriminator(y_seq, h_enc=None)

            errD_real = criterion(output_of_real, torch.ones(1, device=device))

            errD_real.backward()

            output_of_fake, _ = discriminator(y_hat_seq, h_enc=None, shouldDetach=True)
            errD_fake = criterion(output_of_fake, torch.ones(1, device=device))

            errD_fake.backward()

            errD = errD_real + errD_fake

            optimizerD.step()

            # Generator Step
            generator.zero_grad()
            min_mse_loss = float('inf')
            best_y_hat_seq = None

            for _ in range(k):
                y_hat_seq_feats, y_hat_seq = generator(input_seq, None, None)
                output_of_fake2, _ = discriminator(y_hat_seq, h_enc=None)

                y_hat = torch.stack([y_hat_seq[i].x for i in range(y_hat_seq.snapshot_count)], dim=0).to(device=device)
                y_actual = torch.stack([y_seq[i].x for i in range(y_seq.snapshot_count)], dim=0).to(device=device)

                mse_loss = F.mse_loss(y_hat, y_actual)
                if mse_loss < min_mse_loss:
                    min_mse_loss = mse_loss
                    best_y_hat_seq = y_hat_seq

            output_of_fake2, _ = discriminator(best_y_hat_seq, h_enc=None)

            errG = criterion(output_of_fake2, torch.ones(1, device=device))

            errG += min_mse_loss

            errG.backward()

            optimizerG.step()         

In [None]:
device = torch.device("cpu")
generator = GraphSeqGenerator(node_feat_dim=4,
                              enc_hidden_dim=32,
                              enc_latent_dim=16,
                              dec_hidden_dim=32,
                              pred_horizon=8,
                              min_max_x=(loader.min_features[0], loader.max_features[0]),
                              min_max_y=(loader.min_features[1], loader.max_features[1]),
                              min_max_edge_weight=(loader.min_edge_weight, loader.max_edge_weight),
                              visualRange=75,
                              device=device,
                            )
generator.to(device)

discriminator = GraphSeqDiscriminator(node_feat_dim=4,
                                      enc_hidden_dim=32,
                                      enc_latent_dim=16,
                                      pred_horizon=8,
                                      device=device)
discriminator.to(device)

criterion = torch.nn.BCELoss()
optimizerG = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=0.001)
train_gan(train_dataset, 1, generator, discriminator, criterion, optimizerG, optimizerD, device, k=6, window=8, horizon=8 )