In [73]:
import math
import random
import pygame
import sys
import numpy as np
import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data, InMemoryDataset, download_url, TemporalData
import torch.nn as nn
from torch_geometric.nn import GCNConv
from tqdm import tqdm
from torch_geometric.nn import ChebConv
import torch.optim as optim


## Converting CSV To Input For Our Model

### Grabbing Our CSV And Converting To DataFrame

In [74]:
path_to_sim = '../data/simulation.csv'
sim_df = pd.read_csv(path_to_sim)

sim_df.head(5)

Unnamed: 0,x,y,dx,dy,Boids,Simulation,Timestep
0,192.032076,413.277323,-1.768924,-0.825558,0,0,0
1,266.236092,98.829753,1.131848,2.60699,1,0,0
2,62.612704,129.962456,-4.415965,-2.354997,2,0,0
3,536.605412,33.303273,-3.42423,4.601926,3,0,0
4,679.053022,882.292871,1.094493,0.031586,4,0,0


In [75]:
path_to_sim_edges = '../data/simulation_edges.csv'
sim_edges_df = pd.read_csv(path_to_sim_edges)

sim_edges_df.head(5)

Unnamed: 0,Boid_i,Boid_j,Timestep,Simulation
0,0,39,0,0
1,0,57,0,0
2,1,32,0,0
3,1,34,0,0
4,1,83,0,0


### EDA Of Dataset

In [76]:
# TODO

### Converting DataFrame To Data Object From Pytorch Geometric

In [77]:
def toDataGraph(sim_df, sim_edges_df, node_features_names):
    """
    Converts simulation data into a PyTorch Geometric Data object.

    Parameters:
    - sim_df (DataFrame): DataFrame containing node features for a specific simulation and timestep.
    - sim_edges_df (DataFrame): DataFrame containing edge information for the simulation.
    - node_features_names (list of str): Names of the columns in sim_df that are node features.

    Returns:
    - Data: A PyTorch Geometric Data object representing the graph for the simulation.
    """
    # Convert node features and edge information into tensors
    node_features = torch.tensor(sim_df[node_features_names].to_numpy(), dtype=torch.float)
    edge_index = torch.tensor(sim_edges_df[['Boid_i', 'Boid_j']].to_numpy().T, dtype=torch.long)
    edge_attributes = torch.tensor(np.ones((sim_edges_df.shape[0], 1)), dtype=torch.float)

    # Create and return the Data object
    graph = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attributes)
    return graph

def allDataGraph(sim_df, sim_edges_df):
    """
    Generates a list of PyTorch Geometric Data objects for each simulation and timestep.

    Parameters:
    - sim_df (DataFrame): DataFrame containing node features for all simulations and timesteps.
    - sim_edges_df (DataFrame): DataFrame containing edge information for all simulations and timesteps.

    Returns:
    - list of Data: A list of PyTorch Geometric Data objects, one for each simulation and timestep.
    """
    # Group the data by simulation and timestep
    sim_gb_df = sim_df.groupby(['Timestep', 'Simulation'])
    sim_edges_gb_df = sim_edges_df.groupby(['Timestep', 'Simulation'])

    graphs = []
    # Iterate over each group and convert to a Data object
    for key, _ in sim_gb_df:
        curr_sim_df = sim_gb_df.get_group(key)
        curr_sim_edges_df = sim_edges_gb_df.get_group(key)
        curr_graph = toDataGraph(curr_sim_df, curr_sim_edges_df, ['x', 'y', 'dx', 'dy'])
        graphs.append(curr_graph)

    return graphs

# Example usage
graphs = allDataGraph(sim_df, sim_edges_df)

In [78]:
## TODO: NEXT STEP MAKES CLASS THAT GIVEN THE SIMULATION DATAFRAME AND SIMULATION EDGES DATAFRAME CREATES A DATASET OBJECT

class CustomDataset(Dataset):
    def __init__(self, sim_df, sim_edges_df):
        super(CustomDataset).__init__()
        self.all_graphs = allDataGraph(sim_df, sim_edges_df)
        self.sequences = [graphs[i-5:i-1] for i in range(5, len(self.all_graphs)+1)]
        self.labels = [graphs[i-1] for i in range(5, len(self.all_graphs)+1)]
        self.len = len(self.labels)
    def __getitem__(self, index):
        return self.sequences[index], self.labels[index]
    def __len__(self):
        return self.len

dataset = CustomDataset(sim_df, sim_edges_df)

In [79]:
class GConvGRU(torch.nn.Module):
    r"""An implementation of the Chebyshev Graph Convolutional Gated Recurrent Unit
    Cell. For details see this paper: `"Structured Sequence Modeling with Graph
    Convolutional Recurrent Networks." <https://arxiv.org/abs/1612.07659>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Chebyshev filter size :math:`K`.
        normalization (str, optional): The normalization scheme for the graph
            Laplacian (default: :obj:`"sym"`):

            1. :obj:`None`: No normalization
            :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`

            2. :obj:`"sym"`: Symmetric normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
            \mathbf{D}^{-1/2}`

            3. :obj:`"rw"`: Random-walk normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`

            You need to pass :obj:`lambda_max` to the :meth:`forward` method of
            this operator in case the normalization is non-symmetric.
            :obj:`\lambda_max` should be a :class:`torch.Tensor` of size
            :obj:`[num_graphs]` in a mini-batch scenario and a
            scalar/zero-dimensional tensor when operating on single graphs.
            You can pre-compute :obj:`lambda_max` via the
            :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        K: int,
        normalization: str = "sym",
        bias: bool = True,
    ):
        super(GConvGRU, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.normalization = normalization
        self.bias = bias
        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):

        self.conv_x_z = ChebConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.conv_h_z = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

    def _create_reset_gate_parameters_and_layers(self):

        self.conv_x_r = ChebConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.conv_h_r = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

    def _create_candidate_state_parameters_and_layers(self):

        self.conv_x_h = ChebConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.conv_h_h = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H, lambda_max):
        Z = self.conv_x_z(X, edge_index, edge_weight, lambda_max=lambda_max)
        Z = Z + self.conv_h_z(H, edge_index, edge_weight, lambda_max=lambda_max)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H, lambda_max):
        R = self.conv_x_r(X, edge_index, edge_weight, lambda_max=lambda_max)
        R = R + self.conv_h_r(H, edge_index, edge_weight, lambda_max=lambda_max)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R, lambda_max):
        H_tilde = self.conv_x_h(X, edge_index, edge_weight, lambda_max=lambda_max)
        H_tilde = H_tilde + self.conv_h_h(H * R, edge_index, edge_weight, lambda_max=lambda_max)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
        lambda_max: torch.Tensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** *(PyTorch Float Tensor)* - Node features.
            * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
            * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
            * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
            * **lambda_max** *(PyTorch Tensor, optional but mandatory if normalization is not sym)* - Largest eigenvalue of Laplacian.


        Return types:
            * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H, lambda_max)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H, lambda_max)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R, lambda_max)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H

In [80]:
class Encoder(nn.Module):
    def __init__(self, node_feature_dim, hidden_dim, recurrent_dim, output_dim, k=2):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(node_feature_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, recurrent_dim)
        self.reccurent = GConvGRU(recurrent_dim, output_dim, k)
    
    def forward(self,x,edge_index,edge_weight,H=None):
        x = self.conv1(x, edge_index, edge_weight)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = torch.relu(x)
        encoder_h = self.reccurent(x,edge_index,edge_weight,H)
        return encoder_h

class Decoder(nn.Module):
    def __init__(self, node_feature_dim, hidden_dim, recurrent_dim, output_dim, k=2):
        super(Decoder, self).__init__()
        self.recurrent = GConvGRU(output_dim, recurrent_dim, k)
        self.conv2 = GCNConv(recurrent_dim, hidden_dim)
        self.conv1 = GCNConv(hidden_dim, node_feature_dim)

    def forward(self, h, edge_index, edge_weight, H=None):
        decoder_h = self.recurrent(h, edge_index, edge_weight, H)
        x = torch.relu(decoder_h)
        x = self.conv2(x, edge_index, edge_weight)
        x = torch.relu(x)
        x = self.conv1(x, edge_index, edge_weight)
        return x, decoder_h

In [83]:
class GraphSeqGenerator(nn.Module):
    def __init__(self, obs_len, pred_len, 
                 node_feature_dim, 
                 encoder_hidden_dim, encoder_recurrent_dim, enconder_output_dim,
                 decoder_hidden_dim, decoder_recurrent_dim, decoder_output_dim,
                 device,k,
                 noise_dim=(0, ), noise_type='gaussian', noise_mix_type='ped', 
                ):
        
        self.obs_len = obs_len
        self.pred_len = pred_len
        self.node_feature_dim = node_feature_dim

        self.encoder_hidden_dim = encoder_hidden_dim
        self.encoder_recurrent_dim = encoder_recurrent_dim
        self.enconder_output_dim = enconder_output_dim

        self.decoder_hidden_dim = decoder_hidden_dim
        self.decoder_recurrent_dim = decoder_recurrent_dim
        self.decoder_output_dim = decoder_output_dim

        self.noise_dim = noise_dim
        self.noise_type = noise_type
        self.noise_mix_type = noise_mix_type

        self.k=k

        self.encoder = Encoder(
                                node_feature_dim=self.node_feature_dim,
                                hidden_dim=self.encoder_hidden_dim,
                                recurrent_dim=self.encoder_recurrent_dim,
                                output_dim=self.enconder_output_dim,
                                k=self.k
                                )
        
        self.decoder = Decoder(
                                node_feature_dim=self.node_feature_dim,
                                hidden_dim=self.decoder_hidden_dim,
                                recurrent_dim=self.decoder_recurrent_dim,
                                output_dim=self.decoder_output_dim,
                                k=self.k
                                )
        
        if self.noise_dim[0] == 0:
            self.noise_dim = None
        else:
            self.noise_first_dim = noise_dim[0]
        
        def get_noise(self, shape, noise_type):
            if noise_type == 'gaussian':
                return torch.randn(*shape).to(device)
            elif noise_type == 'uniform':
                return torch.rand(*shape).sub_(0.5).mul_(2.0).to(device)
            raise ValueError('Unrecognized noise type "%s"' % noise_type)

        def add_noise(self, _input, user_noise=None):
            """
            Inputs:
            - _input: Tensor of shape (_, decoder_h_dim - noise_first_dim)
            - user_noise: Generally used for inference when you want to see
            relation between different types of noise and outputs.
            Outputs:
            - decoder_h: Tensor of shape (_, decoder_h_dim)
            Example:

            Here _input.size(0) is the number of boids (weren't doing batches yet).
            Let's say 100. Lets say self.noise_dim is '(64,)', the noise shape
            will be (100, 64). So then we concat, (100, num_feat) with (100,64) tensor
            along dim=1, then we get a resulting vector, (100, num_feat + 64)

            """

            noise_shape = (_input.size(0), ) + self.noise_dim

            if user_noise is not None:
                z_decoder = user_noise
            else:
                z_decoder = self.get_noise(noise_shape, self.noise_type)

            decoder_h = torch.cat([_input, z_decoder], dim=1)

            return decoder_h

    def forward(self, seq):
        encoder_hidden_states = []
        prev_encoder_H = None

        # First get the hidden states from encoder
        for graph in seq:
            curr_encoder_h = self.encoder(graph.x, graph.edge_index, graph.edge_weight, prev_encoder_H)
            encoder_hidden_states.append(curr_encoder_h)
            prev_encoder_H = curr_encoder_h
            
        # Second add noise to the hidden states from encoder to feed it to decoder
        encoder_hidden_states = [self.add_noise(h) for h in encoder_hidden_states]
        
        # Third pass in noisy hidden states from encoder to decoder and get last output of decoder
        prev_decoder_H = None
        for i, graph in enumerate(seq):
            x, curr_decoder_h = self.decoder(encoder_hidden_states[i], graph.x, graph.edge_index, prev_decoder_H)
            prev_decoder_H = curr_decoder_h
        return x

In [89]:

class GraphSeqDiscriminator(nn.Module):
    def __init__(self, node_feature_dim, hidden_dim, recurrent_dim, enc_output_dim,k):
        self.node_feature_dim = node_feature_dim
        self.encoder_hidden_dim = hidden_dim
        self.encoder_recurrent_dim = recurrent_dim
        self.enconder_output_dim = enc_output_dim,
        self.k=k

        self.encoder = Encoder(
                                node_feature_dim=self.node_feature_dim,
                                hidden_dim=self.encoder_hidden_dim,
                                recurrent_dim=self.encoder_recurrent_dim,
                                output_dim=self.enconder_output_dim,
                                k=self.k
                                )
        self.linear = nn.Linear(self.enconder_output_dim, 1)
        self.relu = nn.ReLU()

    def forward(self, seq):
        encoder_hidden_states = []
        prev_encoder_H = None

        # First get the hidden states from encoder
        for graph in seq:
            curr_encoder_h = self.encoder(graph.x, graph.edge_index, graph.edge_weight, prev_encoder_H)
            encoder_hidden_states.append(curr_encoder_h)
            prev_encoder_H = curr_encoder_h
            
        x = self.relu(self.linear(curr_encoder_h))
        return x

In [85]:
def get_edges_tensor(data, threshold):
    """
    Calculates edges based on a distance threshold for a tensor where the first two columns represent 'x' and 'y' coordinates,
    and formats them in COO (Coordinate List) format with shape [2, num_edges].

    Parameters:
    - data: A tensor where each row is a point in 2D space, with the first two columns being 'x' and 'y' coordinates.
    - threshold: The distance threshold to consider two points as connected.

    Returns:
    - edges_coo: A tensor in COO format with shape [2, num_edges], where the first row contains the source nodes and the second row contains the target nodes.
    """
    
    # Calculate pairwise distances only for 'x' and 'y'
    x_y = data[:, :2]  # Extract 'x' and 'y' columns
    distances = torch.cdist(x_y, x_y)  # Compute pairwise distances

    # Identify pairs within the threshold distance
    close_pairs = distances < threshold

    # Extract indices of close pairs
    edges = torch.nonzero(close_pairs, as_tuple=False).type(torch.long)

    # Filter out upper triangle including diagonal to avoid duplicates and self-connections
    edges_filtered = edges[edges[:, 0] < edges[:, 1]]

    # Transpose to get shape [2, num_edges]
    edges_coo = edges_filtered.t()

    return edges_coo

In [86]:
def train_loop(dataset, generator_model, discriminator_model, criterion_g, criterion_d, optimizer_g, optimizer_d, device):
    '''
    Loops through the entire dataset for training.

    Parameters:
    - dataset: The dataset to train on.
    - criterion_g: The loss function for the generator.
    - criterion_d: The loss function for the discriminator.
    - optimizer_g: The optimizer for the generator.
    - optimizer_d: The optimizer for the discriminator.

    '''
    g_losses = []
    d_losses = []
    err_d_total, err_g_total = 0, 0
    for i, (seq, next_graph_of_seq) in tqdm(enumerate(dataset), desc='Train'):
        # Putting sequences in device
        seq, real_next_graph_of_seq = [graph.to(device) for graph in seq], next_graph_of_seq.to(device)
        # Get output of generator which is just node features
        fake_graph_node_feats = generator_model(seq)
        # Get the edges of the fake_graph_node_feats
        edge_index = get_edges_tensor(fake_next_graph_of_seq, threshold=75)
        # Set up edge attributes too
        edge_attr = torch.ones((edge_index.size(dim=1), 1))
        # Set up a Data Object from Pytorch Geometric
        fake_next_graph_of_seq = Data(fake_graph_node_feats, edge_index, edge_attr).to(device)
        # Creating fake sequence and real sequence where the first couple are real and the last is either predictied or real
        real_seq = seq
        fake_seq = seq

        real_seq.append(real_next_graph_of_seq)
        fake_seq.append(fake_next_graph_of_seq)
        discriminator_model(real_seq)

        """         
        Part 1 Train Discriminator
        1. Pass in real sequence to discriminator. Calculate loss: loss(log(D(x))) (backward pass) | loss(prob. its real, itsreal) ex. loss(0.7, 1)
        2. Pass in fake sequence from the current generator to discriminator. Calculate loss: loss(log(1-D(G(z)))) (backward pass) | loss(prob its. real, its fake) ex. loss(0.2, 0)
        3. Step for optimizer of discriminator 
        """
    
        """  
        Part 2 Train Generator
        1. Pass in fake sequence from current generator to discriminator. Calculate loss (using real labels for loss) | loss(its real, its real)
        2. Update using backward pass and step into optimizer
        """
        # Part 1 
        discriminator_model.zero_grad()
        output_d_real = discriminator_model(real_seq)
        err_d_real = criterion_d(output_d_real, 1)
        err_d_real.backward()

        output_d_fake = discriminator_model(fake_seq)
        err_d_fake = criterion_d(output_d_fake, 0)
        err_d_fake.backward()

        err_d_total += (err_d_real.item() + err_d_fake.item())

        optimizer_d.step()

        # Part 2
        generator_model.zero_grad()
        output_d_fake = discriminator_model(fake_seq)
        err_g = criterion_g(output_d_fake, 1)
        err_g.backward()

        err_g_total += err_g

        optimizer_g.step()

        # For plotting
        
        # For user
        if i % 32 == 0:
            err_d_total /= 32
            err_g_total /= 32
            print(f'Discriminator Loss: {err_d_total} | Generator Loss: {err_g_total}')
            d_losses.append(err_d_total)
            g_losses.append(err_g_total)
            err_d_total = 0
            err_g_total = 0 

    return g_losses, d_losses     
        

In [92]:
obs_len = 4
pred_len= 5
node_feature_dim = 4
encoder_hidden_dim =32
encoder_recurrent_dim =32
enconder_output_dim = 32
decoder_hidden_dim = 32
decoder_recurrent_dim =32
decoder_output_dim =4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dim = 32
recurrent_dim = 32
output_dim = 4
k=2


Discriminator = GraphSeqDiscriminator(node_feature_dim=node_feature_dim, 
                                      hidden_dim=hidden_dim, 
                                      recurrent_dim=recurrent_dim, 
                                      enc_output_dim=enconder_output_dim, 
                                      k=k
                                      )

generator_model = GraphSeqGenerator(obs_len=obs_len, 
                                    pred_len=pred_len,
                                    node_feature_dim=node_feature_dim,
                                    encoder_hidden_dim=encoder_hidden_dim, 
                                    encoder_recurrent_dim=encoder_recurrent_dim, 
                                    enconder_output_dim=enconder_output_dim,
                                    decoder_hidden_dim=decoder_hidden_dim, 
                                    decoder_recurrent_dim=decoder_recurrent_dim, 
                                    decoder_output_dim=decoder_output_dim,
                                    device=device,k=k
                                    )
optimizer_G = optim.Adam(generator_model.parameters(), lr=0.001)
optimizer_D = optim.Adam(Discriminator.parameters(), lr=0.001)
criterion_g = nn.BCELoss()
criterion_d = nn.BCELoss()
Discriminator.train()
generator_model.train()
epochs = 100
for i in range(epochs):
    train_loop(dataset,generator_model,Discriminator,criterion_g,criterion_d,
               optimizer_G,optimizer_D,device
               )

TypeError: empty() received an invalid combination of arguments - got (tuple, int), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [None]:

# # Version 2 easier to read?

# def train_generator(generator_model, discriminator_model, criterion_g, optimizer_g, fake_seq):
#     """Train the generator model."""
#     generator_model.zero_grad()
#     output_d_fake = discriminator_model(fake_seq)
#     err_g = criterion_g(output_d_fake, torch.ones_like(output_d_fake))
#     err_g.backward()
#     optimizer_g.step()
#     return err_g.item()

# def train_discriminator(discriminator_model, criterion_d, optimizer_d, real_seq, fake_seq):
#     """Train the discriminator model."""
#     discriminator_model.zero_grad()

#     # Real sequence
#     output_d_real = discriminator_model(real_seq)
#     err_d_real = criterion_d(output_d_real, torch.ones_like(output_d_real))
#     err_d_real.backward()

#     # Fake sequence
#     output_d_fake = discriminator_model(fake_seq)
#     err_d_fake = criterion_d(output_d_fake, torch.zeros_like(output_d_fake))
#     err_d_fake.backward()

#     optimizer_d.step()
#     err_d_total = err_d_real.item() + err_d_fake.item()
#     return err_d_total

# def create_fake_sequence(generator_model, seq, device):
#     """Generate a fake sequence using the generator model."""
#     fake_graph_node_feats = generator_model(seq)
#     edge_index = get_edges_tensor(fake_graph_node_feats, threshold=75)
#     edge_attr = torch.ones((edge_index.size(dim=1), 1))
#     fake_next_graph_of_seq = Data(fake_graph_node_feats, edge_index, edge_attr).to(device)
#     fake_seq = seq + [fake_next_graph_of_seq]
#     return fake_seq

# def prepare_sequences(seq, next_graph_of_seq, device):
#     """Prepare real and fake sequences for training."""
#     seq = [graph.to(device) for graph in seq]
#     real_next_graph_of_seq = next_graph_of_seq.to(device)
#     real_seq = seq + [real_next_graph_of_seq]
#     return seq, real_seq

# def train_loop(dataset, generator_model, discriminator_model, criterion_g, criterion_d, optimizer_g, optimizer_d, device):
#     """Loop through the entire dataset for training."""
#     g_losses = []
#     d_losses = []
#     err_d_total, err_g_total = 0, 0

#     for i, (seq, next_graph_of_seq) in tqdm(enumerate(dataset), desc='Train'):
#         seq, real_seq = prepare_sequences(seq, next_graph_of_seq, device)
#         fake_seq = create_fake_sequence(generator_model, seq, device)

#         # Train discriminator
#         err_d_total += train_discriminator(discriminator_model, criterion_d, optimizer_d, real_seq, fake_seq)

#         # Train generator
#         err_g_total += train_generator(generator_model, discriminator_model, criterion_g, optimizer_g, fake_seq)

#         # Log losses for user
#         if (i + 1) % 32 == 0:
#             avg_err_d = err_d_total / 32
#             avg_err_g = err_g_total / 32
#             print(f'Iteration {i+1}, Discriminator Loss: {avg_err_d:.4f}, Generator Loss: {avg_err_g:.4f}')
#             d_losses.append(avg_err_d)
#             g_losses.append(avg_err_g)
#             err_d_total, err_g_total = 0, 0

#     return g_losses, d_losses
