In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import itertools
import torch
from typing import List
from typing import Union
from numpy import ndarray
from torch_geometric.data import Data
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.nn import ChebConv
from scipy.spatial import distance
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from PIL import Image
from torch_geometric.nn.conv import GeneralConv
import math


## Grabbing Simulation Data From CSV

In [None]:
path_to_sim = '../data/simulation.csv'
sim_df = pd.read_csv(path_to_sim)
sim_df.drop(columns='Simulation', inplace=True)
sim_df.head(5)

In [2]:
path_to_sim_edges = '../data/simulation_edges.csv'
sim_edges_df = pd.read_csv(path_to_sim_edges)
sim_edges_df.drop(columns='Simulation', inplace=True)
sim_edges_df.head(5)

Unnamed: 0,Boid_i,Boid_j,Timestep
0,0,39,0
1,0,57,0
2,1,32,0
3,1,34,0
4,1,83,0


## Process Dataframes for all edge indices and all node features per timetstep

In [3]:

path_to_sim_edges = '../data/simulation_edges.csv'
sim_edges_df = pd.read_csv(path_to_sim_edges)

path_to_sim = '../data/simulation.csv'
sim_df = pd.read_csv(path_to_sim)


sim_edges_df.head(5)

frame_rate = 1

sim_edges_df= sim_edges_df[sim_edges_df['Timestep']% frame_rate== 0]
sim_df= sim_df[sim_df['Timestep']% frame_rate== 0]

sim_edges_df_final = sim_edges_df[sim_edges_df['Simulation'] == 0]
sim_df_final = sim_df[sim_df['Simulation'] == 0]

testing_edges = sim_edges_df[sim_edges_df['Simulation'] == 1]
testing = sim_df[sim_df['Simulation'] == 1]

total_length = len(testing)
total_length_edges = len(testing_edges)

subset_length = math.ceil(total_length * 0.2)  # Get 20% of the total data
subset_length_edges = math.ceil(total_length_edges * 0.2)  # Get 20% of the total data

testing_edges_20 = testing_edges[:subset_length_edges]
testing_20 = testing[:subset_length]

In [4]:
def process_dfs(sim_df, sim_edges_df):
    # Group the dataframes by 'Timestep'
    sim_grouped = sim_df.groupby('Timestep')
    edges_grouped = sim_edges_df.groupby('Timestep')
    
    # Initialize lists to store edge indices and node features
    edge_indices = []
    node_features = []
    
    # Iterate over each group
    for timestep, _ in sim_grouped:
        # Extract relevant columns for the current timestep
        timestep_df = sim_grouped.get_group(timestep)[['x', 'y', 'dx', 'dy']]
        timestep_edges_df = edges_grouped.get_group(timestep)[['Boid_i', 'Boid_j']]
        
        # Convert dataframes to numpy arrays
        node_array = timestep_df.to_numpy()
        edge_array = timestep_edges_df.to_numpy().T

        # Append the numpy arrays to the respective lists
        edge_indices.append(edge_array)
        node_features.append(node_array)

    # Return the lists of edge indices and node features
    return edge_indices, node_features
all_edge_indices_test, all_node_features_test = process_dfs(testing_20, testing_edges_20)
all_edge_indices, all_node_features = process_dfs(sim_df, sim_edges_df)
#all_node_features = [minmax_scale(all_node_features[i], axis=1) for i in range(len(all_node_features))]

## Get all possible combinations of edges given there are 100 nodes (boids)

In [5]:
combinations = list(itertools.combinations(range(100), 2))
all_possible_edge_indices = np.array(combinations).T
all_possible_edge_indices.shape

(2, 4950)

## Compute edge attributes per timestep (whether boids are close enough and distance)

In [6]:
def compute_edge_attributes(sim_df, distance_threshold=75):
    """
    Compute edge attributes for each timestep in the simulation DataFrame.

    Parameters:
    sim_df (pd.DataFrame): DataFrame containing simulation data with columns 'Timestep', 'x', 'y', and 'Boids'.
    distance_threshold (float): Distance threshold to determine if boids are close enough.

    Returns:
    list: A list of numpy arrays containing edge attributes for each timestep.
    """
    bruh = sim_df.groupby('Timestep')

    all_edge_attr = []
    for key, sim_df_t in tqdm(bruh, total=len(bruh)):
        # Extract coordinates
        coordinates = sim_df_t[['x', 'y']].values
        
        # Compute pairwise distances using scipy
        dist_matrix = distance.cdist(coordinates, coordinates, 'euclidean')
        
        # Get the indices of the combinations
        num_boids = len(coordinates)
        combinations_array = np.array(np.triu_indices(num_boids, k=1)).T

        # Filter distances and create edge attributes
        edge_attr_t = []
        for edge in combinations_array:
            dist = dist_matrix[edge[0], edge[1]]
            closeEnough = 1 if dist < distance_threshold else 0
            edge_attr_t.append([closeEnough, dist])
        
        edge_attr_t = np.array(edge_attr_t)
        all_edge_attr.append(edge_attr_t)
    
    return all_edge_attr

# Example usage:
all_edge_attr_test = compute_edge_attributes(testing_20)
all_edge_attr = compute_edge_attributes(sim_df)
#all_edge_attr = [minmax_scale(all_edge_attr[i], axis=1) for i in range(len(all_edge_attr))]

0it [00:00, ?it/s]
100%|██████████| 10000/10000 [00:44<00:00, 224.41it/s]


## Normalizing our data using min max scale

In [8]:
def compute_feature_min_max(all_features_list):
    """
    Compute the minimum and maximum of features across all node features.

    Parameters:
    all_node_features (list of np.ndarray): List of numpy arrays containing node features.

    Returns:
    tuple: A tuple containing two numpy arrays: (final_min, final_max).
    """
    if not all_features_list:
        raise ValueError("The input list 'all_node_features' is empty.")
    
    # Initialize final_min and final_max with appropriate dimensions
    feature_dim = all_features_list[0].shape[1]
    final_min = np.array([float('inf')] * feature_dim)
    final_max = np.array([float('-inf')] * feature_dim)
    
    # Iterate through all node features to compute final_min and final_max
    for features in all_features_list:
        curr_max = np.max(features, axis=0)
        curr_min = np.min(features, axis=0)
        final_max = np.max(np.array([final_max, curr_max]), axis=0)
        final_min = np.min(np.array([final_min, curr_min]), axis=0)
    
    return final_min, final_max

# Example usage:

# final_min_node_feat_test, final_max_node_feat_test = compute_feature_min_max(all_node_features_test)
# final_min_edge_attr_test, final_max_edge_attr_test= compute_feature_min_max(all_edge_attr_test)


final_min_node_feat, final_max_node_feat = compute_feature_min_max(all_node_features)
final_min_edge_attr, final_max_edge_attr = compute_feature_min_max(all_edge_attr)

In [9]:
def minmax_scale(all_features_list, final_min, final_max):
    normalized = []
    for i in range(len(all_features_list)):
        X = all_features_list[i]
        X_std = (X - final_min) / (final_max - final_min)
        normalized.append(X_std)
    return normalized

def undo_minmax_scale(all_normalized_feat_list, final_min, final_max):
    unnormalized = []
    for i in range(len(all_normalized_feat_list)):
        X_std = all_normalized_feat_list[i]
        X_scaled = X_std * (final_max - final_min) + final_min
        unnormalized.append(X_scaled)
    return unnormalized

In [10]:
all_node_features_normalized = minmax_scale(all_node_features, final_min_node_feat, final_max_node_feat)
all_edge_attr_normalized = minmax_scale(all_edge_attr, final_min_edge_attr, final_max_edge_attr)


# all_node_features_normalized_test = minmax_scale(all_node_features_test, final_min_node_feat_test, final_max_node_feat_test)
# all_edge_attr_normalized_test= minmax_scale(all_edge_attr_test, final_min_edge_attr_test, final_max_edge_attr_test)

## Split our dataset such that we have sequences and next sequence graph feature nodes

In [12]:
def split_dataset(temporal_data, window, delay=0, horizon=1, stride=1):
    # ONLY HORIZON=1 WORKS FOR PYTORCH TEMPORAL
    # Idea From torch spatio temporal https://torch-spatiotemporal.readthedocs.io/en/latest/_images/sliding_window.svg
    # Initialize lists to store input and target sequences
    input_sequences = []
    target_sequences = []
    
    # Calculate the total number of timesteps
    total_timesteps = len(temporal_data)
    sample_span = window + delay + horizon
    # Iterate over the list with the given stride
    for start in range(0, total_timesteps - sample_span + 1, stride):
        # Extract the input sequence
        input_seq = np.array(temporal_data[start:start + window])
        # Extract the target sequence
        target_seq = np.array(temporal_data[start + window + delay: start + window + delay + horizon])
        # Append the sequences to their respective lists
        input_sequences.append(input_seq)
        target_sequences.append(target_seq)
    
    # Return the lists of input and target sequences
    return input_sequences, target_sequences

node_feature_sequences, node_feature_targets = split_dataset(all_node_features_normalized, window=5, horizon=1)
edge_weights_sequences, edge_weights_targets = split_dataset(all_edge_attr_normalized, window=5, horizon=1)
edge_indices_sequence = [all_possible_edge_indices for _ in range(len(node_feature_sequences))]


# node_feature_sequence_test, node_feature_targets_test = split_dataset(all_node_features_normalized_test, window=5, horizon=1)
# edge_weights_sequences_test, edge_weights_targets_test = split_dataset(all_edge_attr_normalized_test, window=5, horizon=1)
# edge_indices_sequence_test = [all_possible_edge_indices for _ in range(len(node_feature_sequence_test))]

## Create CustomStaticGraphTemporalSignal that gives us a sample of our dataset

In [13]:
class CustomStaticGraphTemporalSignal(object):
    def __init__(self, edge_index: ndarray | None, 
                 edge_weight: ndarray | None, 
                 features: List[ndarray | None], 
                 targets: List[ndarray | None], 
                 **kwargs: List[ndarray]):
        
        self.edge_index = edge_index
        self.edge_weight = edge_weight
        self.features = features
        self.targets = targets
        self.additional_feature_keys = []
        for key, value in kwargs.items():
            setattr(self, key, value)
            self.additional_feature_keys.append(key)
        self._check_temporal_consistency()
        self._set_snapshot_count()

    def _check_temporal_consistency(self):
        assert len(self.features) == len(
            self.targets
        ), "Temporal dimension inconsistency."
        for key in self.additional_feature_keys:
            assert len(self.targets) == len(
                getattr(self, key)
            ), "Temporal dimension inconsistency."

    def _set_snapshot_count(self):
        self.snapshot_count = len(self.features)

    def _get_edge_index(self, time_index: int):
        if self.edge_index is None:
            return self.edge_index[time_index]
        else:
            return torch.LongTensor(self.edge_index[time_index])

    def _get_edge_weight(self, time_index: int):
        if self.edge_weight is None:
            return self.edge_weight[time_index]
        else:
            return torch.FloatTensor(self.edge_weight[time_index])


    def _get_features(self, time_index: int):
        if self.features[time_index] is None:
            return self.features[time_index]
        else:
            return torch.FloatTensor(self.features[time_index])

    def _get_target(self, time_index: int):
        if self.targets[time_index] is None:
            return self.targets[time_index]
        else:
            if self.targets[time_index].dtype.kind == "i":
                return torch.LongTensor(self.targets[time_index])
            elif self.targets[time_index].dtype.kind == "f":
                return torch.FloatTensor(self.targets[time_index])

    def _get_additional_feature(self, time_index: int, feature_key: str):
        feature = getattr(self, feature_key)[time_index]
        if feature.dtype.kind == "i":
            return torch.LongTensor(feature)
        elif feature.dtype.kind == "f":
            return torch.FloatTensor(feature)

    def _get_additional_features(self, time_index: int):
        additional_features = {
            key: self._get_additional_feature(time_index, key)
            for key in self.additional_feature_keys
        }
        return additional_features
    def __getitem__(self, time_index: Union[int, slice]):
        if isinstance(time_index, slice):
            snapshot = CustomStaticGraphTemporalSignal(
                self.edge_index[time_index],
                self.edge_weight[time_index],
                self.features[time_index],
                self.targets[time_index],
                **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
            )
        else:
            x = self._get_features(time_index)
            edge_index = self._get_edge_index(time_index)
            edge_weight = self._get_edge_weight(time_index)
            y = self._get_target(time_index)
            additional_features = self._get_additional_features(time_index)

            snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,
                            y=y, **additional_features)
        return snapshot

    def __next__(self):
        if self.t < len(self.features):
            snapshot = self[self.t]
            self.t = self.t + 1
            return snapshot
        else:
            self.t = 0
            raise StopIteration

    def __iter__(self):
        self.t = 0
        return self

In [14]:
train_data = CustomStaticGraphTemporalSignal(edge_index=edge_indices_sequence, 
                                       edge_weight=edge_weights_sequences, 
                                       features=node_feature_sequences, 
                                       targets=node_feature_targets
                                       )

# test_data = CustomStaticGraphTemporalSignal(edge_index=edge_indices_sequence_test, 
#                                        edge_weight=edge_weights_sequences_test, 
#                                        features=node_feature_sequence_test, 
#                                        targets=node_feature_targets_test
#                                        )

## Create GConvGRU for recurrent layer in our gnn

In [43]:
class CustomChebConv(ChebConv):
    def message(self, x_j, norm):
        # Handle multi-dimensional edge weights by using a weighted sum or mean
        if norm.dim() == 2:
            norm = norm.mean(dim=1)  # Example: mean across edge features
        return norm.view(-1, 1) * x_j


class GConvGRU(torch.nn.Module):
    def __init__(self, in_channels, out_channels, K, normalization="sym", bias=True):
        super(GConvGRU, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.normalization = normalization
        self.bias = bias
        self._create_parameters_and_layers()

    def _create_parameters_and_layers(self):
        self.conv_x_z = CustomChebConv(self.in_channels, self.out_channels, K=self.K, normalization=self.normalization, bias=self.bias)
        self.conv_h_z = CustomChebConv(self.out_channels, self.out_channels, K=self.K, normalization=self.normalization, bias=self.bias)
        self.conv_x_r = CustomChebConv(self.in_channels, self.out_channels, K=self.K, normalization=self.normalization, bias=self.bias)
        self.conv_h_r = CustomChebConv(self.out_channels, self.out_channels, K=self.K, normalization=self.normalization, bias=self.bias)
        self.conv_x_h = CustomChebConv(self.in_channels, self.out_channels, K=self.K, normalization=self.normalization, bias=self.bias)
        self.conv_h_h = CustomChebConv(self.out_channels, self.out_channels, K=self.K, normalization=self.normalization, bias=self.bias)

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H, lambda_max):
        Z = self.conv_x_z(X, edge_index, edge_weight, lambda_max=lambda_max)
        Z = Z + self.conv_h_z(H, edge_index, edge_weight, lambda_max=lambda_max)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H, lambda_max):
        R = self.conv_x_r(X, edge_index, edge_weight, lambda_max=lambda_max)
        R = R + self.conv_h_r(H, edge_index, edge_weight, lambda_max=lambda_max)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R, lambda_max):
        H_tilde = self.conv_x_h(X, edge_index, edge_weight, lambda_max=lambda_max)
        H_tilde = H_tilde + self.conv_h_h(H * R, edge_index, edge_weight, lambda_max=lambda_max)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(self, X, edge_index, edge_weight=None, H=None, lambda_max=None):
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H, lambda_max)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H, lambda_max)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R, lambda_max)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H

## Basic Graph Recurrent Neural Network

In [None]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 4)

    def forward(self, x, edge_index, edge_weight, H=None):
        h = self.recurrent(x, edge_index, edge_weight, H)
        x = F.relu(h)
        x = self.linear(x)
        return x, h


model = RecurrentGCN(node_features=4, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in range(3):
    for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
        h_t_prev = None
        for seq_num in range(train_data[0].x.shape[0]):
            y_hat, h_t = model(snapshot.x[seq_num], snapshot.edge_index, snapshot.edge_attr[seq_num], h_t_prev)
            h_t_prev = h_t

        cost = F.mse_loss(y_hat, snapshot.y[0])
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
# Assuming undo_minmax_scale, final_min_node_feat, final_max_node_feat, and train_data are defined
# RUN IF YOU WANT GIF (GIF COULD BE BETTER)

model.eval()
cost = 0
frames = []

for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
    h_t_prev = None
    for seq_num in range(train_data[0].x.shape[0]):
        y_hat, h_t = model(snapshot.x[seq_num], snapshot.edge_index, snapshot.edge_attr[seq_num], h_t_prev)
        h_t_prev = h_t
        
    y_hat_scaled = undo_minmax_scale([y_hat.numpy(force=True)], final_min_node_feat, final_max_node_feat)[0]
    y_actual_scaled = undo_minmax_scale([snapshot.y[0].numpy(force=True)], final_min_node_feat, final_max_node_feat)[0]

    y_hat_scaled_x = y_hat_scaled[:, 0]
    y_hat_scaled_y = y_hat_scaled[:, 1]
    y_actual_scaled_x = y_actual_scaled[:, 0]
    y_actual_scaled_y = y_actual_scaled[:, 1]
    
    # Plot the predicted vs actual values
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.scatter(list(y_hat_scaled_x), list(y_hat_scaled_y), label='Predicted', alpha=0.6)
    ax.scatter(list(y_actual_scaled_x), list(y_actual_scaled_y), label='Actual', alpha=0.6)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.legend()
    ax.set_title('Predicted vs Actual Values')
    
    # Save the current plot as a frame
    fig.canvas.draw()
    frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    frames.append(Image.fromarray(frame))
    plt.close(fig)
    
    cost = cost + torch.mean((y_hat - snapshot.y[0])**2)
    cost = cost / (time + 1)
    cost = cost.item()
    #print("MSE: {:.4f}".format(cost))

# Save frames as a GIF
frames[0].save('predicted_vs_actual.gif', save_all=True, append_images=frames[1:], duration=500, loop=0)

## Working on GAN [Generator: (Encoder, Decoder)] [Discriminator: (Encoder)]

In [50]:
class Encoder(torch.nn.Module):
    def __init__(self, input_recurrent_dim, output_recurrent_dim, hidden_dim, k=2):
        super(Encoder, self).__init__()
        # self.conv1 = GeneralConv(node_feature_dim, hidden_dim, in_edge_channels)
        # self.conv2 = GeneralConv(hidden_dim, input_recurrent_dim, in_edge_channels)
        self.recurrent = GConvGRU(input_recurrent_dim, output_recurrent_dim, k)
        self.linear = torch.nn.Linear(output_recurrent_dim, hidden_dim)
    
    def forward(self, x, edge_index, edge_attr, H=None):
        print(x.shape, edge_index.shape, edge_attr.shape)
        encoder_h = self.recurrent(X=x, edge_index=edge_index, edge_weight=edge_attr, H=H)
        encoder_h = F.relu(encoder_h)
        print("Bruh", encoder_h.shape)
        encoder_h = self.linear(encoder_h)
        print('Bruh2', encoder_h.shape)
        return encoder_h

class Decoder(torch.nn.Module):
    def __init__(self, node_feature_dim, input_recurrent_dim, output_recurrent_dim, k=2):
        super(Decoder, self).__init__()
        self.recurrent = GConvGRU(input_recurrent_dim, output_recurrent_dim, k)
        self.linear = torch.nn.Linear(output_recurrent_dim, node_feature_dim)

    def forward(self, h, edge_index, edge_attr, H=None):
        decoder_h = self.recurrent(X=h, edge_index=edge_index, edge_weight=edge_attr, H=H)
        x = F.relu(decoder_h)
        x = self.linear(x)
        return x, decoder_h

In [51]:
encoder = Encoder(4, 16, 64)
decoder = Decoder(4, 64, 16)

In [52]:
for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
        h_encoder_prev, h_decoder_prev = None, None
        for seq_num in range(train_data[0].x.shape[0]):
            #print(snapshot.x[seq_num].shape, snapshot.edge_index.shape, snapshot.edge_attr[seq_num][:, 1].shape,snapshot.y.shape)
            h_encoder = encoder(snapshot.x[seq_num], snapshot.edge_index, snapshot.edge_attr[seq_num], h_encoder_prev)
            h_t_prev = h_encoder
            print(h_encoder.shape)

            y_hat, h_decoder = decoder(h_encoder, snapshot.edge_index, snapshot.edge_attr[seq_num], h_decoder_prev)
            h_decoder_prev = h_decoder
            print('*******************')
            print(h_decoder.shape)
            print(y_hat.shape)
            break
        break

Train:   0%|          | 0/9995 [00:00<?, ?it/s]

torch.Size([100, 4]) torch.Size([2, 4950]) torch.Size([4950, 2])
Bruh torch.Size([100, 16])
Bruh2 torch.Size([100, 64])
torch.Size([100, 64])
*******************
torch.Size([100, 16])
torch.Size([100, 4])





In [53]:
class GraphSeqGenerator(torch.nn.Module):
    
    def __init__(self, node_features_dim, in_edge_channels, seq_length, 
                 hidden_dim_encoder, input_recurrent_dim_encoder, output_recurrent_dim_encoder,
                 input_recurrent_dim_decoder, output_recurrent_dim_decoder, k=2
                 ):
        super(GraphSeqGenerator, self).__init__()
        # Useful Parameters
        self.node_feature_dim = node_features_dim
        self.in_edge_channels = in_edge_channels
        self.seq_length = seq_length
        self.k = k

        # Encoder Parameters
        self.hidden_dim_encoder = hidden_dim_encoder
        self.input_recurrent_dim_encoder = input_recurrent_dim_encoder
        self.output_recurrent_dim_encoder = output_recurrent_dim_encoder

        # Decoder Parameters
        self.input_recurrent_dim_decoder = input_recurrent_dim_decoder
        self.output_recurrent_dim_decoder = output_recurrent_dim_decoder

        print("Bruh", self.input_recurrent_dim_encoder, self.output_recurrent_dim_encoder, self.hidden_dim_encoder)
        self.encoder = Encoder(
                                input_recurrent_dim=self.input_recurrent_dim_encoder,
                                output_recurrent_dim=self.output_recurrent_dim_encoder,
                                hidden_dim=self.hidden_dim_encoder,
                                k=self.k
                                )
        print("Brah", self.node_feature_dim, self.input_recurrent_dim_decoder, self.output_recurrent_dim_decoder )
        self.decoder = Decoder(
                                node_feature_dim=self.node_feature_dim,
                                input_recurrent_dim=self.input_recurrent_dim_decoder,
                                output_recurrent_dim=self.output_recurrent_dim_decoder,
                                k=self.k
                                )

    def forward(self, snapshot, device):
        prev_encoder_h, prev_decoder_h = None, None
        for seq_num in range(self.seq_length):
            #print(snapshot.x[seq_num].shape, snapshot.edge_index.shape, snapshot.edge_attr[seq_num][:, 1].shape,snapshot.y.shape)
            node_features = snapshot.x[seq_num].to(device)
            edge_index = snapshot.edge_index.to(device)
            edge_attr = snapshot.edge_attr[seq_num].to(device)

            curr_encoder_h = self.encoder(node_features, edge_index, edge_attr, prev_encoder_h)
            prev_encoder_h = curr_encoder_h
            print(curr_encoder_h.shape)
            
            y_hat, curr_decoder_h = self.decoder(curr_encoder_h, edge_index, edge_attr, prev_decoder_h)
            prev_decoder_h = curr_decoder_h
            # print('*******************')
            # print(curr_decoder_h.shape)
            # print(y_hat.shape)
        return y_hat

In [54]:
generator = GraphSeqGenerator(node_features_dim=4, in_edge_channels=2, seq_length=5, 
                  hidden_dim_encoder=64, input_recurrent_dim_encoder=4, output_recurrent_dim_encoder=16,
                  input_recurrent_dim_decoder=64, output_recurrent_dim_decoder=16)
device = torch.device('cpu')

generator.to(device)

Bruh 4 16 64
Brah 4 64 16


GraphSeqGenerator(
  (encoder): Encoder(
    (recurrent): GConvGRU(
      (conv_x_z): CustomChebConv(4, 16, K=2, normalization=sym)
      (conv_h_z): CustomChebConv(16, 16, K=2, normalization=sym)
      (conv_x_r): CustomChebConv(4, 16, K=2, normalization=sym)
      (conv_h_r): CustomChebConv(16, 16, K=2, normalization=sym)
      (conv_x_h): CustomChebConv(4, 16, K=2, normalization=sym)
      (conv_h_h): CustomChebConv(16, 16, K=2, normalization=sym)
    )
    (linear): Linear(in_features=16, out_features=64, bias=True)
  )
  (decoder): Decoder(
    (recurrent): GConvGRU(
      (conv_x_z): CustomChebConv(64, 16, K=2, normalization=sym)
      (conv_h_z): CustomChebConv(16, 16, K=2, normalization=sym)
      (conv_x_r): CustomChebConv(64, 16, K=2, normalization=sym)
      (conv_h_r): CustomChebConv(16, 16, K=2, normalization=sym)
      (conv_x_h): CustomChebConv(64, 16, K=2, normalization=sym)
      (conv_h_h): CustomChebConv(16, 16, K=2, normalization=sym)
    )
    (linear): Linear(in_f

In [55]:
optimizer = torch.optim.Adam(generator.parameters(), lr=0.01)

generator.train()

for epoch in range(3):
       for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
              y_hat = generator(snapshot, device)
              cost = F.mse_loss(y_hat, snapshot.y[0].to(device))
              cost.backward()
              optimizer.step()
              optimizer.zero_grad()

Train:   0%|          | 0/9995 [00:00<?, ?it/s]


torch.Size([100, 4]) torch.Size([2, 4950]) torch.Size([4950, 2])
Bruh torch.Size([100, 16])
Bruh2 torch.Size([100, 64])
torch.Size([100, 64])
torch.Size([100, 4]) torch.Size([2, 4950]) torch.Size([4950, 2])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (100x64 and 16x16)

In [None]:
train_data.snapshot_count

In [None]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 4)

    def forward(self, x, edge_index, edge_weight, H=None):
        #print(x.shape, edge_index.shape, edge_weight.shape)
        h = self.recurrent(x, edge_index, edge_weight, H)
        x = F.relu(h)
        x = self.linear(x)
        return x, h

In [None]:
model = RecurrentGCN(node_features=4, filters=32)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
train_data[0].x.shape[0]

In [None]:
train_data[0].x.shape

In [None]:
# note to self chebconv supports only one dimension edge_attr

model.train()

for epoch in range(3):
    for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
        h_t_prev = None
        for seq_num in range(train_data[0].x.shape[0]):
            #print(snapshot.x[seq_num].shape, snapshot.edge_index.shape, snapshot.edge_attr[seq_num][:, 1].shape,snapshot.y.shape)
            y_hat, h_t = model(snapshot.x[seq_num], snapshot.edge_index, snapshot.edge_attr[seq_num], h_t_prev)
            # print(y_hat.shape, h_t.shape)
            h_t_prev = h_t
            
        cost = torch.mean((y_hat-snapshot.y[0])**2)
        print('Loss', cost)
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:



class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 4)

    def forward(self, x, edge_index, edge_weight, H=None):
        h = self.recurrent(x, edge_index, edge_weight, H)
        x = F.relu(h)
        x = self.linear(x)
        return x, h


model = RecurrentGCN(node_features=4, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in range(3):
    for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
        h_t_prev = None
        for seq_num in range(train_data[0].x.shape[0]):
            y_hat, h_t = model(snapshot.x[seq_num], snapshot.edge_index, snapshot.edge_attr[seq_num], h_t_prev)
            h_t_prev = h_t

        cost = F.mse_loss(y_hat, snapshot.y[0])
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()


In [None]:
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from PIL import Image
import numpy as np
from tqdm import tqdm

# Assuming undo_minmax_scale, final_min_node_feat, final_max_node_feat, and train_data are defined

model.eval()
cost = 0
frames = []

for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
    h_t_prev = None
    for seq_num in range(train_data[0].x.shape[0]):
        y_hat, h_t = model(snapshot.x[seq_num], snapshot.edge_index, snapshot.edge_attr[seq_num], h_t_prev)
        h_t_prev = h_t
        
    y_hat_scaled = undo_minmax_scale([y_hat.numpy(force=True)], final_min_node_feat, final_max_node_feat)[0]
    y_actual_scaled = undo_minmax_scale([snapshot.y[0].numpy(force=True)], final_min_node_feat, final_max_node_feat)[0]

    y_hat_scaled_x = y_hat_scaled[:, 0]
    y_hat_scaled_y = y_hat_scaled[:, 1]
    y_actual_scaled_x = y_actual_scaled[:, 0]
    y_actual_scaled_y = y_actual_scaled[:, 1]
    
    # Plot the predicted vs actual values
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.scatter(list(y_hat_scaled_x), list(y_hat_scaled_y), label='Predicted', alpha=0.6)
    ax.scatter(list(y_actual_scaled_x), list(y_actual_scaled_y), label='Actual', alpha=0.6)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.legend()
    ax.set_title('Predicted vs Actual Values')
    
    # Save the current plot as a frame
    fig.canvas.draw()
    frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    frames.append(Image.fromarray(frame))
    plt.close(fig)
    
    cost = cost + torch.mean((y_hat - snapshot.y[0])**2)
    cost = cost / (time + 1)
    cost = cost.item()
    #print("MSE: {:.4f}".format(cost))

# Save frames as a GIF
frames[0].save('predicted_vs_actual.gif', save_all=True, append_images=frames[1:], duration=500, loop=0)

In [None]:
model.eval()
cost = 0
for time, snapshot in tqdm(enumerate(train_data), desc='Train', total=train_data.snapshot_count):
    h_t_prev = None
    for seq_num in range(train_data[0].x.shape[0]):
        y_hat, h_t = model(snapshot.x[seq_num], snapshot.edge_index, snapshot.edge_attr[seq_num], h_t_prev)
        # print(y_hat.shape, h_t.shape)
        h_t_prev = h_t
        
    # y_hat_x = y_hat[:, 0].numpy(force=True)
    # y_hat_y = y_hat[:, 1].numpy(force=True)
    # actual_x = snapshot.y[0][:, 0].numpy(force=True)
    # actual_y = snapshot.y[0][:, 1].numpy(force=True)
    y_hat_scaled = undo_minmax_scale([y_hat.numpy(force=True)], final_min_node_feat, final_max_node_feat)[0]
    y_actual_scaled = undo_minmax_scale([snapshot.y[0].numpy(force=True)], final_min_node_feat, final_max_node_feat)[0]

    y_hat_scaled_x = y_hat_scaled[:, 0]
    y_hat_scaled_y = y_hat_scaled[:, 1]
    y_actual_scaled_x = y_actual_scaled[:, 0]
    y_actual_scaled_y = y_actual_scaled[:, 1]
    #print(type(y_hat_x))
    
    # Plot the predicted vs actual values
    plt.figure(figsize=(10, 5))
    plt.scatter(list(y_hat_scaled_x), list(y_hat_scaled_y), label='Predicted', alpha=0.6)
    plt.scatter(list(y_actual_scaled_x), list(y_actual_scaled_y), label='Actual', alpha=0.6)
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.title('Predicted vs Actual Values')
    plt.show()
    cost = cost + torch.mean((y_hat-snapshot.y[0])**2)
    cost = cost / (time+1)
    cost = cost.item()
    print("MSE: {:.4f}".format(cost))