# GPT Implementation in DGL

### Problem 1: Attention Score Computation does not involve node features:
Mr GPT says to: <br>
Option A: Incorporate into Attention Mechanism
Concatenate or combine edge features with the node features when computing the attention scores. For example, you could modify the edge_attention method to include the trade volume or political score (after appropriate transformation).


Option B: Use a Separate Aggregation for Trade Volumes
Since your final output is the total outgoing trade volume per country, you can leave the TGAT message passing largely as is and perform an extra aggregation step on the edge features. For example:

Aggregate the trade volumes of outgoing edges for each country. <br>
g.update_all(fn.copy_e('trade_vol', 'm'), fn.sum('m', 'total_trade_vol'))
total_trade_vol = g.ndata['total_trade_vol']

### Masked TGAT Implementation 
Given we want to shut off certain nodes while doing the prediction, (for eg, the trade analyst wants to only account for the interactions in a certain region), our training might have to account for random shutting off during inference. This is one way we can do it.

In [None]:
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
import dgl.function as fn

# Define a Masked TGAT Layer where we apply masks on edges and nodes
class MaskedTGATLayer(nn.Module):
    def __init__(self, in_feats, out_feats, num_heads, dropout=0.1):
        super(MaskedTGATLayer, self).__init__()
        self.num_heads = num_heads
        self.out_feats = out_feats
        # Attention function outputs a scalar per head
        self.attn_fc = nn.Linear(2 * out_feats, 1, bias=False)
        self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
        self.dropout = nn.Dropout(dropout)

    def edge_attention(self, edges):
        # Concatenate source and destination transformed features
        z_cat = torch.cat([edges.src['z'], edges.dst['z']], dim=-1)
        a = self.attn_fc(z_cat)
        a = torch.nn.functional.leaky_relu(a)
        # Apply edge mask if available (assumed shape: (E, num_heads, 1))
        if 'mask' in edges.data:
            a = a * edges.data['mask']
        return {'e': a}

    def forward(self, g, h, node_mask=None):
        with g.local_scope():
            # Linear transformation and reshape to (N, num_heads, out_feats)
            z = self.fc(h)
            z = z.view(z.shape[0], self.num_heads, self.out_feats)
            g.ndata['z'] = z

            # Ensure each edge has a mask; if not, default to ones
            if 'mask' not in g.edata:
                g.edata['mask'] = torch.ones(g.number_of_edges(), self.num_heads, 1).to(z.device)

            # Compute edge attention values, incorporating the mask
            g.apply_edges(self.edge_attention)
            # Multiply node features with attention scores and aggregate messages
            g.update_all(fn.u_mul_e('z', 'e', 'm'), fn.mean('m', 'h_new'))
            # Flatten aggregated features
            h_new = g.ndata['h_new'].reshape(g.ndata['h_new'].shape[0], self.num_heads * self.out_feats)
            # Apply node mask if provided (mask should be broadcastable to h_new's shape)
            if node_mask is not None:
                h_new = h_new * node_mask
            return h_new

# Define a simple TGAT model using the masked layer
class MaskedTGAT(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, num_heads, num_layers):
        super(MaskedTGAT, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(MaskedTGATLayer(in_feats, hidden_feats, num_heads))
        for _ in range(num_layers - 1):
            # The next layer takes flattened node features as input
            self.layers.append(MaskedTGATLayer(hidden_feats * num_heads, hidden_feats, num_heads))
        self.fc_out = nn.Linear(hidden_feats * num_heads, out_feats)

    def forward(self, g, h, node_mask=None):
        for layer in self.layers:
            h = layer(g, h, node_mask)
        return self.fc_out(h)

# Create a graph and set up masks for nodes and edges
def create_masked_graph():
    # Create a simple graph with 5 nodes and 5 directed edges
    src_nodes = [0, 1, 2, 3, 4]
    dst_nodes = [1, 2, 3, 4, 0]
    g = dgl.graph((src_nodes, dst_nodes))
    
    # Example node features (e.g., representing country indicators)
    g.ndata['feat'] = torch.randn(len(src_nodes), 10)
    
    # Define an edge mask: 1 indicates the edge is active, 0 means it's switched off.
    # Here we switch off the second edge (index 1).
    edge_mask = torch.tensor([[1], [0], [1], [1], [1]], dtype=torch.float32)
    # If using multiple heads (e.g., num_heads=2), expand the mask shape to (E, num_heads, 1)
    edge_mask = edge_mask.unsqueeze(1).repeat(1, 2, 1)
    g.edata['mask'] = edge_mask
    
    # Define a node mask: here, node 2 is switched off (mask=0) and others remain active (mask=1).
    # For a node feature output dimension of D (here, num_heads*out_feats), the mask can be broadcast.
    node_mask = torch.tensor([[1], [1], [0], [1], [1]], dtype=torch.float32)
    
    return g, g.ndata['feat'], node_mask

# Example training loop using the masked TGAT model
def train_masked_tgat():
    g, features, node_mask = create_masked_graph()
    model = MaskedTGAT(in_feats=10, hidden_feats=16, out_feats=1, num_heads=2, num_layers=2)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    # Fake regression labels for demonstration (e.g., total trade volume per country)
    labels = torch.tensor([[0.5], [0.7], [0.3], [0.9], [0.4]], dtype=torch.float32)
    
    for epoch in range(100):
        # Forward pass: pass the node features and the node mask into the model
        logits = model(g, features, node_mask)
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

train_masked_tgat()

included edge features 

supposedly included time periods through timestamp -- but not fully sure how its supposed to be done in terms of including in data. 
GPT mentions that the timestamp embedding time_enc will be added to the edge features which will serve as the time part

also worth noting TGAT does not do layers -- so all the time periods / timestamps is supposedly all included and processed within this one layer model 

[also for note possibly explosive amount of edges]

-> next action should probably be to figure out how to include time-based data that will be conceptually similar to what we want to do & try training it on a small scale set of (1) time period (2) node feature (3) edge feature (4) ?output but probably not since training without predicting won't involve it yet. 
[by Wed night]

not sure if i missed anything but this setup should be sufficient at the super base level of what the model needs data-wise. 

In [None]:
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
import dgl.function as fn
import numpy as np

# Temporal Encoding Function
def temporal_encoding(timestamps, dim):
    freqs = torch.arange(0, dim // 2, dtype=torch.float32)
    freqs = 1.0 / (10000 ** (2 * freqs / dim))
    encodings = torch.cat([torch.sin(timestamps * freqs), torch.cos(timestamps * freqs)], dim=-1)
    return encodings

# Define the Temporal Attention Layer
class TGATLayer(nn.Module):
    def __init__(self, in_feats, out_feats, num_heads, time_dim=4, dropout=0.1):
        super(TGATLayer, self).__init__()
        self.num_heads = num_heads
        self.out_feats = out_feats
        self.time_dim = time_dim
        self.attn_fc = nn.Linear(2 * out_feats + 4 + time_dim, 1, bias=False)  # Added 4 edge features and time encoding
        self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
        self.dropout = nn.Dropout(dropout)

    def edge_attention(self, edges):
        time_enc = temporal_encoding(edges.data['timestamp'].unsqueeze(-1), self.time_dim)
        z_cat = torch.cat([
            edges.src['z'], edges.dst['z'], 
            edges.data['dummy1'].unsqueeze(-1), 
            edges.data['dummy2'].unsqueeze(-1),
            edges.data['dummy3'].unsqueeze(-1),
            edges.data['dummy4'].unsqueeze(-1),
            time_enc
        ], dim=-1)
        a = self.attn_fc(z_cat)
        a = torch.nn.functional.leaky_relu(a)
        return {'e': a}

    def forward(self, g, h):
        with g.local_scope():
            z = self.fc(h)
            z = z.view(z.shape[0], self.num_heads, self.out_feats)
            g.ndata['z'] = z
            g.apply_edges(self.edge_attention)
            g.update_all(fn.u_mul_e('z', 'e', 'm'), fn.mean('m', 'h_new'))
            h_new = g.ndata['h_new'].reshape(g.ndata['h_new'].shape[0], self.num_heads * self.out_feats)
            return h_new

# Define the TGAT Model
class TGAT(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, num_heads, num_layers, time_dim=4):
        super(TGAT, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(TGATLayer(in_feats, hidden_feats, num_heads, time_dim))
        for _ in range(num_layers - 1):
            self.layers.append(TGATLayer(hidden_feats * num_heads, hidden_feats, num_heads, time_dim))
        self.fc_out = nn.Linear(hidden_feats * num_heads, out_feats)

    def forward(self, g, h):
        for layer in self.layers:
            h = layer(g, h)
        return self.fc_out(h)

# Create a small temporal graph dataset with edge features]
def create_tg():
    src_nodes = [0, 1, 2, 3, 4, 1, 2, 3, 4, 0]  # Bidirectional edges
    dst_nodes = [1, 2, 3, 4, 0, 0, 1, 2, 3, 4]
    timestamps = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.float32)
    trade_volumes = [100, 200, 150, 300, 250, 180, 220, 170, 280, 240]
    dummy1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    dummy2 = [5, 4, 3, 2, 1, 0, 1, 2, 3, 4]
    dummy3 = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
    dummy4 = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    g = dgl.graph((src_nodes, dst_nodes))
    g.edata['timestamp'] = timestamps
    g.edata['trade_vol'] = torch.tensor(trade_volumes, dtype=torch.float32)
    g.edata['dummy1'] = torch.tensor(dummy1, dtype=torch.float32)
    g.edata['dummy2'] = torch.tensor(dummy2, dtype=torch.float32)
    g.edata['dummy3'] = torch.tensor(dummy3, dtype=torch.float32)
    g.edata['dummy4'] = torch.tensor(dummy4, dtype=torch.float32)
    g.ndata['feat'] = torch.randn(g.num_nodes(), 10)
    return g

# Train the TGAT Model
def train_tgat():
    g = create_tg()
    model = TGAT(in_feats=10, hidden_feats=16, out_feats=1, num_heads=2, num_layers=2, time_dim=4)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    labels = torch.tensor([200, 300, 250, 400, 150], dtype=torch.float32).unsqueeze(1)

    for epoch in range(100):
        logits = model(g, g.ndata['feat'])
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

train_tgat()


RuntimeError: Tensors must have same number of dimensions: got 3 and 2

## Revised Implementation

Todo: reference the following url to see how to implement temporal component. 
https://www.kaggle.com/code/dipanjandas96/temporal-graph-attention-network-tgan/notebook#INDUCTIVE-REPRESENTATION-LEARNING-ON-TEMPORAL-GRAPHS

Also, see if the stacking temporal implementation works/need to modify aggregation mechanism.

In [4]:
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
import dgl.function as fn

# Temporal Encoding Function inspired by the TGAT paper
def temporal_encoding(timestamps, dim):
    # timestamps: shape (..., 1)
    freqs = torch.arange(0, dim // 2, dtype=torch.float32, device=timestamps.device)
    freqs = 1.0 / (10000 ** (2 * freqs / dim))
    encodings = torch.cat([torch.sin(timestamps * freqs), torch.cos(timestamps * freqs)], dim=-1)
    return encodings

# Define the Temporal Attention Layer (TGATLayer)
class TGATLayer(nn.Module):
    def __init__(self, in_feats, out_feats, num_heads, time_dim=4, dropout=0.1):
        super(TGATLayer, self).__init__()
        self.num_heads = num_heads
        self.out_feats = out_feats
        self.time_dim = time_dim
        # Input dimension for the attention FC: 
        # 2 * out_feats (source and destination) + 4 dummy edge features (each of dim 1) + time_dim.
        self.attn_fc = nn.Linear(2 * out_feats + 4 + time_dim, 1, bias=False)
        self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
        self.dropout = nn.Dropout(dropout)

    def edge_attention(self, edges):
        # edges.src['z'] and edges.dst['z'] have shape: (E, num_heads, out_feats)

        # Reshape dummy features to (E, 1, 1) and then repeat along head dimension.
        dummy1 = edges.data['dummy1'].unsqueeze(1).unsqueeze(-1).repeat(1, self.num_heads, 1)
        dummy2 = edges.data['dummy2'].unsqueeze(1).unsqueeze(-1).repeat(1, self.num_heads, 1)
        dummy3 = edges.data['dummy3'].unsqueeze(1).unsqueeze(-1).repeat(1, self.num_heads, 1)
        dummy4 = edges.data['dummy4'].unsqueeze(1).unsqueeze(-1).repeat(1, self.num_heads, 1)

        # Compute temporal encoding and repeat along head dimension.
        time_enc = temporal_encoding(edges.data['timestamp'].unsqueeze(-1), self.time_dim)\
                     .unsqueeze(1).repeat(1, self.num_heads, 1)

        # Concatenate source, destination, the 4 dummy features, and the temporal encoding along the last dimension.
        z_cat = torch.cat([
            edges.src['z'],            # shape: (E, num_heads, out_feats)
            edges.dst['z'],            # shape: (E, num_heads, out_feats)
            dummy1,                    # shape: (E, num_heads, 1)
            dummy2,                    # shape: (E, num_heads, 1)
            dummy3,                    # shape: (E, num_heads, 1)
            dummy4,                    # shape: (E, num_heads, 1)
            time_enc                   # shape: (E, num_heads, time_dim)
        ], dim=-1)  # Final shape: (E, num_heads, 2*out_feats + 4 + time_dim)

        a = self.attn_fc(z_cat)       # shape: (E, num_heads, 1)
        a = torch.nn.functional.leaky_relu(a)
        return {'e': a}

    def forward(self, g, h):
        with g.local_scope():
            # Project node features and reshape for multi-head attention
            z = self.fc(h)  # shape: (N, out_feats * num_heads)
            z = z.view(z.shape[0], self.num_heads, self.out_feats)  # shape: (N, num_heads, out_feats)
            g.ndata['z'] = z
            # Compute edge attention values incorporating temporal encodings
            g.apply_edges(self.edge_attention)
            # Message passing: multiply source node representations with attention coefficients
            g.update_all(fn.u_mul_e('z', 'e', 'm'), fn.mean('m', 'h_new'))
            # Reshape aggregated messages from multi-head outputs
            h_new = g.ndata['h_new'].reshape(g.ndata['h_new'].shape[0], self.num_heads * self.out_feats)
            return self.dropout(h_new)

# Define the stacked TGAT model
class TGAT(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, num_heads, num_layers, time_dim=4):
        super(TGAT, self).__init__()
        self.layers = nn.ModuleList()
        # First layer projects from input features to hidden space
        self.layers.append(TGATLayer(in_feats, hidden_feats, num_heads, time_dim))
        # Additional layers to capture multi-hop temporal interactions
        for _ in range(num_layers - 1):
            self.layers.append(TGATLayer(hidden_feats * num_heads, hidden_feats, num_heads, time_dim))
        # Final linear layer for output prediction
        self.fc_out = nn.Linear(hidden_feats * num_heads, out_feats)

    def forward(self, g, h):
        for layer in self.layers:
            h = layer(g, h)
        return self.fc_out(h)

# Create a small temporal graph dataset with edge features
def create_tg():
    src_nodes = [0, 1, 2, 3, 4, 1, 2, 3, 4, 0]  # Bidirectional edges
    dst_nodes = [1, 2, 3, 4, 0, 0, 1, 2, 3, 4]
    timestamps = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.float32)
    trade_volumes = [100, 200, 150, 300, 250, 180, 220, 170, 280, 240]
    dummy1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    dummy2 = [5, 4, 3, 2, 1, 0, 1, 2, 3, 4]
    dummy3 = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
    dummy4 = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    g = dgl.graph((src_nodes, dst_nodes))
    g.edata['timestamp'] = timestamps
    g.edata['trade_vol'] = torch.tensor(trade_volumes, dtype=torch.float32)
    g.edata['dummy1'] = torch.tensor(dummy1, dtype=torch.float32)
    g.edata['dummy2'] = torch.tensor(dummy2, dtype=torch.float32)
    g.edata['dummy3'] = torch.tensor(dummy3, dtype=torch.float32)
    g.edata['dummy4'] = torch.tensor(dummy4, dtype=torch.float32)
    g.ndata['feat'] = torch.randn(g.num_nodes(), 10)
    return g

# Train the stacked TGAT Model
def train_tgat():
    g = create_tg()
    # Stack 2 TGAT layers (experiment with additional layers if desired)
    model = TGAT(in_feats=10, hidden_feats=16, out_feats=1, num_heads=2, num_layers=2, time_dim=4)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    labels = torch.tensor([200, 300, 250, 400, 150], dtype=torch.float32).unsqueeze(1)

    for epoch in range(100):
        logits = model(g, g.ndata['feat'])
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

if __name__ == "__main__":
    train_tgat()


Epoch 0 | Loss: 74965.0703
Epoch 10 | Loss: 73102.9844
Epoch 20 | Loss: 26777.5215
Epoch 30 | Loss: 17917.1855
Epoch 40 | Loss: 18315.6816
Epoch 50 | Loss: 8199.2002
Epoch 60 | Loss: 4419.8477
Epoch 70 | Loss: 472.5968
Epoch 80 | Loss: 1095.7872
Epoch 90 | Loss: 1613.3792


## FROM THE PAPER
Link: https://github.com/dmlc/dgl/tree/0.9.x/examples/pytorch/tgn
Can refer to data_preprocessing.py to see how we need to pre-process the data and fit model.


In [None]:
import argparse
import traceback
import time
import copy

import numpy as np
import dgl
import torch

from tgn import TGN
# from data_preprocess import TemporalWikipediaDataset, TemporalRedditDataset, TemporalDataset
from dataloading_tgnn import (FastTemporalEdgeCollator, FastTemporalSampler,
                         SimpleTemporalEdgeCollator, SimpleTemporalSampler,
                         TemporalEdgeDataLoader, TemporalSampler, TemporalEdgeCollator)

from sklearn.metrics import average_precision_score, roc_auc_score


TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.85

# set random Seed
np.random.seed(2021)
torch.manual_seed(2021)


def train(model, dataloader, sampler, criterion, optimizer, args):
    model.train()
    total_loss = 0
    batch_cnt = 0
    last_t = time.time()
    
    # Assuming it returns (batch_id, positive_pair_g, None, blocks) where None is for the negative graph
    for _, positive_pair_g, _, blocks in dataloader: #check the dataloader structure
        optimizer.zero_grad()
        
        # Just use the positive graph for edge feature prediction
        pred_edge_features = model.embed(positive_pair_g, blocks)
        
        # Get ground truth edge features from the graph
        true_edge_features = positive_pair_g.edata['feat']  # Adjust based on your actual edge feature field name
        
        # Compute MSE loss between predicted and ground truth edge features
        loss = criterion(pred_edge_features, true_edge_features)
        
        total_loss += float(loss) * args.batch_size
        
        retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
        loss.backward(retain_graph=retain_graph)
        optimizer.step()
        
        model.detach_memory()
        
        if not args.not_use_memory:
            # Update memory based on the batch graph
            model.update_memory(positive_pair_g)
            
        if args.fast_mode:
            sampler.attach_last_update(model.memory.last_update_t)
            
        print("Batch: ", batch_cnt, "Time: ", time.time()-last_t)
        last_t = time.time()
        batch_cnt += 1
        
    return total_loss


def test_val(model, dataloader, sampler, criterion, args):
    model.eval()
    batch_size = args.batch_size
    total_loss = 0
    mse_scores = []
    batch_cnt = 0
    
    with torch.no_grad():
        for _, positive_pair_g, _, blocks in dataloader:

            pred_edge_features = model.embed(positive_pair_g, blocks)
            
            # Get the ground truth edge features
            true_edge_features = positive_pair_g.edata['feat']  # Adjust field name if needed
            
            # Calculate MSE loss
            loss = criterion(pred_edge_features, true_edge_features)
            total_loss += float(loss) * batch_size
            
            # Calculate per-batch MSE for reporting
            batch_mse = ((pred_edge_features - true_edge_features) ** 2).mean().item()
            mse_scores.append(batch_mse)
            
            # Update memory if needed
            if not args.not_use_memory:
                model.update_memory(positive_pair_g)
                
            if args.fast_mode:
                sampler.attach_last_update(model.memory.last_update_t)
                
            batch_cnt += 1
    
    # Return the average MSE across all batches
    avg_mse = float(torch.tensor(mse_scores).mean())
    return avg_mse
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()

#     parser.add_argument("--epochs", type=int, default=50,
#                         help='epochs for training on entire dataset')
#     parser.add_argument("--batch_size", type=int,
#                         default=200, help="Size of each batch")
#     parser.add_argument("--embedding_dim", type=int, default=100,
#                         help="Embedding dim for link prediction")
#     parser.add_argument("--memory_dim", type=int, default=100,
#                         help="dimension of memory")
#     parser.add_argument("--temporal_dim", type=int, default=100,
#                         help="Temporal dimension for time encoding")
#     parser.add_argument("--memory_updater", type=str, default='gru',
#                         help="Recurrent unit for memory update")
#     parser.add_argument("--aggregator", type=str, default='last',
#                         help="Aggregation method for memory update")
#     parser.add_argument("--n_neighbors", type=int, default=10,
#                         help="number of neighbors while doing embedding")
#     parser.add_argument("--sampling_method", type=str, default='topk',
#                         help="In embedding how node aggregate from its neighor")
#     parser.add_argument("--num_heads", type=int, default=8,
#                         help="Number of heads for multihead attention mechanism")
#     parser.add_argument("--fast_mode", action="store_true", default=False,
#                         help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained")
#     parser.add_argument("--simple_mode", action="store_true", default=False,
#                         help="Simple Mode directly delete the temporal edges from the original static graph")
#     parser.add_argument("--num_negative_samples", type=int, default=1,
#                         help="number of negative samplers per positive samples")
#     parser.add_argument("--dataset", type=str, default="wikipedia",
#                         help="dataset selection wikipedia/reddit")
#     parser.add_argument("--k_hop", type=int, default=1,
#                         help="sampling k-hop neighborhood")
#     parser.add_argument("--not_use_memory", action="store_true", default=False,
#                         help="Enable memory for TGN Model disable memory for TGN Model")
class Inputs: #specify arguments here
    def __init__(self):
        self.epochs = 50
        self.batch_size = 200 #refers to number of edges processed together in a single batch
        self.embedding_dim = 100 #number of out-feat from each layer
        self.memory_dim = 100 #more memory, more complexity; try to play around using CV but don't increase it too much because it can cause overfitting and very heavy computational costs (i think can think of it like a lag?)
        self.temporal_dim = 100 #dimension for time encoding; can lower since we are doing by quarterly/yearly
        self.memory_updater = 'gru' #gru or lstm
        self.aggregator = 'last'
        self.n_neighbors = 10 #number of neighbours to consider while embedding
        self.sampling_method = 'topk' # or 'uniform' can see in dataloading_tgnn.py
        self.num_heads = 8 # this essentially acts as the feedforward network between time periods
        self.fast_mode = False
        self.simple_mode = False #dont care this first
        self.num_negative_samples = 1 #this also
        self.k_hop = 1 # basically we can control indirect effects through k-hop
        self.not_use_memory = False
args = Inputs()

assert not (
    args.fast_mode and args.simple_mode), "you can only choose one sampling mode"
if args.k_hop != 1:
    assert args.simple_mode, "this k-hop parameter only support simple mode"

# Load Dataset
data=None

# Pre-process data, mask new node in test set from original graph
num_nodes = data.num_nodes()
num_edges = data.num_edges()

trainval_div = int(VALID_SPLIT*num_edges)

# Select new node from test set and remove them from entire graph
test_split_ts = data.edata['timestamp'][trainval_div]
test_nodes = torch.cat([data.edges()[0][trainval_div:], data.edges()[
                        1][trainval_div:]]).unique().numpy()
test_new_nodes = np.random.choice(
    test_nodes, int(0.1*len(test_nodes)), replace=False)

in_subg = dgl.in_subgraph(data, test_new_nodes)
out_subg = dgl.out_subgraph(data, test_new_nodes)
# Remove edge who happen before the test set to prevent from learning the connection info
new_node_in_eid_delete = in_subg.edata[dgl.EID][in_subg.edata['timestamp'] < test_split_ts]
new_node_out_eid_delete = out_subg.edata[dgl.EID][out_subg.edata['timestamp'] < test_split_ts]
new_node_eid_delete = torch.cat(
    [new_node_in_eid_delete, new_node_out_eid_delete]).unique()

graph_new_node = copy.deepcopy(data)
# relative order preseved
graph_new_node.remove_edges(new_node_eid_delete)

# Now for no new node graph, all edge id need to be removed
in_eid_delete = in_subg.edata[dgl.EID]
out_eid_delete = out_subg.edata[dgl.EID]
eid_delete = torch.cat([in_eid_delete, out_eid_delete]).unique()

graph_no_new_node = copy.deepcopy(data)
graph_no_new_node.remove_edges(eid_delete)

# graph_no_new_node and graph_new_node should have same set of nid

# Sampler Initialization
if args.simple_mode:
    fan_out = [args.n_neighbors for _ in range(args.k_hop)]
    sampler = SimpleTemporalSampler(graph_no_new_node, fan_out)
    new_node_sampler = SimpleTemporalSampler(data, fan_out)
    edge_collator = SimpleTemporalEdgeCollator
elif args.fast_mode:
    sampler = FastTemporalSampler(graph_no_new_node, k=args.n_neighbors)
    new_node_sampler = FastTemporalSampler(data, k=args.n_neighbors)
    edge_collator = FastTemporalEdgeCollator
else:
    sampler = TemporalSampler(k=args.n_neighbors)
    edge_collator = TemporalEdgeCollator

neg_sampler = None #negative sampler is not used

# Set Train, validation, test and new node test id
train_seed = torch.arange(int(TRAIN_SPLIT*graph_no_new_node.num_edges()))
valid_seed = torch.arange(int(
    TRAIN_SPLIT*graph_no_new_node.num_edges()), trainval_div-new_node_eid_delete.size(0))
test_seed = torch.arange(
    trainval_div-new_node_eid_delete.size(0), graph_no_new_node.num_edges())
test_new_node_seed = torch.arange(
    trainval_div-new_node_eid_delete.size(0), graph_new_node.num_edges())

g_sampling = None if args.fast_mode else dgl.add_reverse_edges(
    graph_no_new_node, copy_edata=True)
new_node_g_sampling = None if args.fast_mode else dgl.add_reverse_edges(
    graph_new_node, copy_edata=True)
if not args.fast_mode:
    new_node_g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
    g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()

# we highly recommend that you always set the num_workers=0, otherwise the sampled subgraph may not be correct.
train_dataloader = TemporalEdgeDataLoader(graph_no_new_node,
                                            train_seed,
                                            sampler,
                                            batch_size=args.batch_size,
                                            negative_sampler=neg_sampler,
                                            shuffle=False,
                                            drop_last=False,
                                            num_workers=0,
                                            collator=edge_collator,
                                            g_sampling=g_sampling)

valid_dataloader = TemporalEdgeDataLoader(graph_no_new_node,
                                            valid_seed,
                                            sampler,
                                            batch_size=args.batch_size,
                                            negative_sampler=neg_sampler,
                                            shuffle=False,
                                            drop_last=False,
                                            num_workers=0,
                                            collator=edge_collator,
                                            g_sampling=g_sampling)

test_dataloader = TemporalEdgeDataLoader(graph_no_new_node,
                                            test_seed,
                                            sampler,
                                            batch_size=args.batch_size,
                                            negative_sampler=neg_sampler,
                                            shuffle=False,
                                            drop_last=False,
                                            num_workers=0,
                                            collator=edge_collator,
                                            g_sampling=g_sampling)

test_new_node_dataloader = TemporalEdgeDataLoader(graph_new_node,
                                                    test_new_node_seed,
                                                    new_node_sampler if args.fast_mode else sampler,
                                                    batch_size=args.batch_size,
                                                    negative_sampler=neg_sampler,
                                                    shuffle=False,
                                                    drop_last=False,
                                                    num_workers=0,
                                                    collator=edge_collator,
                                                    g_sampling=new_node_g_sampling)

edge_dim = data.edata['feats'].shape[1]
num_node = data.num_nodes()

model = TGN(edge_feat_dim=edge_dim,
            memory_dim=args.memory_dim,
            temporal_dim=args.temporal_dim,
            embedding_dim=args.embedding_dim,
            num_heads=args.num_heads,
            num_nodes=num_node,
            n_neighbors=args.n_neighbors,
            memory_updater_type=args.memory_updater,
            layers=args.k_hop)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# Implement Logging mechanism
f = open("logging.txt", 'w')
if args.fast_mode:
    sampler.reset()
try:
    for i in range(args.epochs):
        train_loss = train(model, train_dataloader, sampler,
                            criterion, optimizer, args)

        val_mse = test_val(
            model, valid_dataloader, sampler, criterion, args)
        memory_checkpoint = model.store_memory()
        if args.fast_mode:
            new_node_sampler.sync(sampler)
        test_mse = test_val(
            model, test_dataloader, sampler, criterion, args)
        model.restore_memory(memory_checkpoint)
        if args.fast_mode:
            sample_nn = new_node_sampler
        else:
            sample_nn = sampler
        nn_test_mse = test_val(
            model, test_new_node_dataloader, sample_nn, criterion, args)
        log_content = []
        log_content.append(f"Epoch: {i}; Training Loss: {train_loss} | "+ f"Validation MSE: {val_mse}")
        log_content.append(
            f"Epoch: {i}; Test MSE: {test_mse}")
        log_content.append(f"Epoch: {i}; New Node Test MSE: {nn_test_mse}")

        f.writelines(log_content)
        model.reset_memory()
        if i < args.epochs-1 and args.fast_mode:
            sampler.reset()
        print(log_content[0], log_content[1], log_content[2])
except KeyboardInterrupt:
    traceback.print_exc()
    error_content = "Training Interreputed!"
    f.writelines(error_content)
    f.close()
print("========Training is Done========")

In [None]:
## Creating Sample Graph 


import torch
import dgl
import random

# Define nodes (countries)
nodes = ["USA", "China", "Germany", "India", "Brazil"]
node_ids = {name: i for i, name in enumerate(nodes)}

# Generate node features (e.g., GDP, population, trade openness)
#node_features = torch.rand(len(nodes), 5)  # 5D feature vector per node

# Define edges and timestamps
edges_src = []
edges_dst = []
timestamps = []
edge_features = []

time_step = 0
for year in range(2000, 2005):  
    for i in range(len(nodes)):
        for j in range(len(nodes)):
            if i != j:
                edges_src.append(i)
                edges_dst.append(j)
                timestamps.append(time_step)
                edge_features.append(torch.rand(3))  # 3D edge feature
    time_step += 1

# Convert to tensors
edges_src = torch.tensor(edges_src)
edges_dst = torch.tensor(edges_dst)
timestamps = torch.tensor(timestamps, dtype=torch.float32)
edge_features = torch.stack(edge_features)

# Create DGL Graph
g = dgl.graph((edges_src, edges_dst))
g.edata['timestamp'] = timestamps
g.edata['feat'] = edge_features
#g.ndata['feat'] = node_features  # ✅ Add node features



In [None]:
def train_tgn(g, epochs=50, batch_size=200, n_neighbors=10):
    """ Train a Temporal GNN model on a given DGL graph g """
    from tgn import TGN
    from dataloading_tgnn import TemporalEdgeDataLoader, TemporalSampler, TemporalEdgeCollator
    import torch.nn as nn
    import torch.optim as optim
    
    # Model Parameters
    num_nodes = g.num_nodes()
    edge_dim = g.edata['feat'].shape[1]

    model = TGN(
        edge_feat_dim=edge_dim,
        memory_dim=100,
        temporal_dim=100,
        embedding_dim=100,
        num_heads=8,
        num_nodes=num_nodes,
        n_neighbors=n_neighbors, # need to review this value 
        memory_updater_type='gru', # or lstm iirc
        layers=1
    )

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # Create DataLoaders
    TRAIN_SPLIT = 0.7
    VALID_SPLIT = 0.85
    num_edges = g.num_edges()
    trainval_div = int(VALID_SPLIT * num_edges)

    sampler = TemporalSampler(k=n_neighbors)
    edge_collator = TemporalEdgeCollator

    train_seed = torch.arange(int(TRAIN_SPLIT * num_edges))
    valid_seed = torch.arange(int(TRAIN_SPLIT * num_edges), trainval_div)
    test_seed = torch.arange(trainval_div, num_edges)

    train_dataloader = TemporalEdgeDataLoader(g, train_seed, sampler, batch_size=batch_size, collator=edge_collator)
    valid_dataloader = TemporalEdgeDataLoader(g, valid_seed, sampler, batch_size=batch_size, collator=edge_collator)
    test_dataloader = TemporalEdgeDataLoader(g, test_seed, sampler, batch_size=batch_size, collator=edge_collator)

    # Training Loop
    from train import train, test_val
    for epoch in range(epochs):
        train_loss = train(model, train_dataloader, sampler, criterion, optimizer, args)
        val_mse = test_val(model, valid_dataloader, sampler, criterion, args)
        test_mse = test_val(model, test_dataloader, sampler, criterion, args)
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val MSE = {val_mse:.4f}, Test MSE = {test_mse:.4f}")

    # Save the model
    torch.save(model.state_dict(), "tgn_model.pth")

    return model  # Return trained model
