# GPT Implementation in DGL

### Problem 1: Attention Score Computation does not involve node features:
Mr GPT says to: <br>
Option A: Incorporate into Attention Mechanism
Concatenate or combine edge features with the node features when computing the attention scores. For example, you could modify the edge_attention method to include the trade volume or political score (after appropriate transformation).


Option B: Use a Separate Aggregation for Trade Volumes
Since your final output is the total outgoing trade volume per country, you can leave the TGAT message passing largely as is and perform an extra aggregation step on the edge features. For example:

Aggregate the trade volumes of outgoing edges for each country. <br>
g.update_all(fn.copy_e('trade_vol', 'm'), fn.sum('m', 'total_trade_vol'))
total_trade_vol = g.ndata['total_trade_vol']

In [None]:
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
import dgl.function as fn
import numpy as np

# Define the Temporal Attention Layer
class TGATLayer(nn.Module):
    def __init__(self, in_feats, out_feats, num_heads, dropout=0.1):
        super(TGATLayer, self).__init__()
        self.num_heads = num_heads
        self.out_feats = out_feats
        # Change attention layer to output a scalar per head
        self.attn_fc = nn.Linear(2 * out_feats, 1, bias=False)
        self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
        self.dropout = nn.Dropout(dropout)

    def edge_attention(self, edges):
        # Concatenate source and destination features along the last dimension.
        # Both edges.src['z'] and edges.dst['z'] have shape (E, num_heads, out_feats),
        # so concatenating along dim=-1 yields shape (E, num_heads, 2*out_feats)
        z_cat = torch.cat([edges.src['z'], edges.dst['z']], dim=-1)
        a = self.attn_fc(z_cat)  # Now shape: (E, num_heads, 1)
        a = torch.nn.functional.leaky_relu(a)
        return {'e': a}

    def forward(self, g, h):
        with g.local_scope():
            # Linear transformation and reshape to (N, num_heads, out_feats)
            z = self.fc(h)  # shape: (N, out_feats * num_heads)
            z = z.view(z.shape[0], self.num_heads, self.out_feats)
            g.ndata['z'] = z

            # Compute attention scores on edges
            g.apply_edges(self.edge_attention)

            # Message passing: multiply each source feature with its attention coefficient; e
            g.update_all(fn.u_mul_e('z', 'e', 'm'), fn.mean('m', 'h_new'))
            # Flatten the aggregated features to (N, num_heads * out_feats)
            h_new = g.ndata['h_new'].reshape(g.ndata['h_new'].shape[0], self.num_heads * self.out_feats)
            return h_new

# Define the TGAT Model
class TGAT(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, num_heads, num_layers):
        super(TGAT, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(TGATLayer(in_feats, hidden_feats, num_heads))
        for _ in range(num_layers - 1):
            self.layers.append(TGATLayer(hidden_feats * num_heads, hidden_feats, num_heads))
        # Change output to one neuron for regression (or more if predicting per sector)
        self.fc_out = nn.Linear(hidden_feats * num_heads, out_feats)

    def forward(self, g, h):
        for layer in self.layers:
            h = layer(g, h)
        return self.fc_out(h)

# Create a small temporal graph dataset
def create_tg():
    src_nodes = [0, 1, 2, 3, 4]
    dst_nodes = [1, 2, 3, 4, 0]
    timestamps = [1, 2, 3, 4, 5]  # Simulated time edges

    g = dgl.graph((src_nodes, dst_nodes))
    g.edata['timestamp'] = torch.tensor(timestamps, dtype=torch.float32)
    g.ndata['feat'] = torch.randn(len(src_nodes), 10)  # Random node features
    return g

# Train the TGAT Model
def train_tgat():
    g = create_tg()  # Make sure to update create_tg() to use your country and edge features
    model = TGAT(in_feats=5, hidden_feats=16, out_feats=1, num_heads=2, num_layers=2)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    # Assuming you have ground truth total trade volumes for each country
    labels = torch.tensor([200, 300, 250, 400, 150], dtype=torch.float32).unsqueeze(1)

    for epoch in range(100):
        logits = model(g, g.ndata['feat'])
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

    # Optionally, after model training, perform edge aggregation to get trade volumes directly:
    # g.update_all(fn.copy_e('trade_vol', 'm'), fn.sum('m', 'total_trade_vol'))
    # print("Aggregated Trade Volumes:", g.ndata['total_trade_vol'])

# Run the training process
train_tgat()



Epoch 0 | Loss: 0.6800
Epoch 10 | Loss: 0.4489
Epoch 20 | Loss: 0.1027
Epoch 30 | Loss: 0.0203
Epoch 40 | Loss: 0.0000
Epoch 50 | Loss: 0.0004
Epoch 60 | Loss: 0.0009
Epoch 70 | Loss: 0.0012
Epoch 80 | Loss: 0.0014
Epoch 90 | Loss: 0.0014


### Masked TGAT Implementation 
Given we want to shut off certain nodes while doing the prediction, (for eg, the trade analyst wants to only account for the interactions in a certain region), our training might have to account for random shutting off during inference. This is one way we can do it.

In [None]:
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
import dgl.function as fn

# Define a Masked TGAT Layer where we apply masks on edges and nodes
class MaskedTGATLayer(nn.Module):
    def __init__(self, in_feats, out_feats, num_heads, dropout=0.1):
        super(MaskedTGATLayer, self).__init__()
        self.num_heads = num_heads
        self.out_feats = out_feats
        # Attention function outputs a scalar per head
        self.attn_fc = nn.Linear(2 * out_feats, 1, bias=False)
        self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
        self.dropout = nn.Dropout(dropout)

    def edge_attention(self, edges):
        # Concatenate source and destination transformed features
        z_cat = torch.cat([edges.src['z'], edges.dst['z']], dim=-1)
        a = self.attn_fc(z_cat)
        a = torch.nn.functional.leaky_relu(a)
        # Apply edge mask if available (assumed shape: (E, num_heads, 1))
        if 'mask' in edges.data:
            a = a * edges.data['mask']
        return {'e': a}

    def forward(self, g, h, node_mask=None):
        with g.local_scope():
            # Linear transformation and reshape to (N, num_heads, out_feats)
            z = self.fc(h)
            z = z.view(z.shape[0], self.num_heads, self.out_feats)
            g.ndata['z'] = z

            # Ensure each edge has a mask; if not, default to ones
            if 'mask' not in g.edata:
                g.edata['mask'] = torch.ones(g.number_of_edges(), self.num_heads, 1).to(z.device)

            # Compute edge attention values, incorporating the mask
            g.apply_edges(self.edge_attention)
            # Multiply node features with attention scores and aggregate messages
            g.update_all(fn.u_mul_e('z', 'e', 'm'), fn.mean('m', 'h_new'))
            # Flatten aggregated features
            h_new = g.ndata['h_new'].reshape(g.ndata['h_new'].shape[0], self.num_heads * self.out_feats)
            # Apply node mask if provided (mask should be broadcastable to h_new's shape)
            if node_mask is not None:
                h_new = h_new * node_mask
            return h_new

# Define a simple TGAT model using the masked layer
class MaskedTGAT(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, num_heads, num_layers):
        super(MaskedTGAT, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(MaskedTGATLayer(in_feats, hidden_feats, num_heads))
        for _ in range(num_layers - 1):
            # The next layer takes flattened node features as input
            self.layers.append(MaskedTGATLayer(hidden_feats * num_heads, hidden_feats, num_heads))
        self.fc_out = nn.Linear(hidden_feats * num_heads, out_feats)

    def forward(self, g, h, node_mask=None):
        for layer in self.layers:
            h = layer(g, h, node_mask)
        return self.fc_out(h)

# Create a graph and set up masks for nodes and edges
def create_masked_graph():
    # Create a simple graph with 5 nodes and 5 directed edges
    src_nodes = [0, 1, 2, 3, 4]
    dst_nodes = [1, 2, 3, 4, 0]
    g = dgl.graph((src_nodes, dst_nodes))
    
    # Example node features (e.g., representing country indicators)
    g.ndata['feat'] = torch.randn(len(src_nodes), 10)
    
    # Define an edge mask: 1 indicates the edge is active, 0 means it's switched off.
    # Here we switch off the second edge (index 1).
    edge_mask = torch.tensor([[1], [0], [1], [1], [1]], dtype=torch.float32)
    # If using multiple heads (e.g., num_heads=2), expand the mask shape to (E, num_heads, 1)
    edge_mask = edge_mask.unsqueeze(1).repeat(1, 2, 1)
    g.edata['mask'] = edge_mask
    
    # Define a node mask: here, node 2 is switched off (mask=0) and others remain active (mask=1).
    # For a node feature output dimension of D (here, num_heads*out_feats), the mask can be broadcast.
    node_mask = torch.tensor([[1], [1], [0], [1], [1]], dtype=torch.float32)
    
    return g, g.ndata['feat'], node_mask

# Example training loop using the masked TGAT model
def train_masked_tgat():
    g, features, node_mask = create_masked_graph()
    model = MaskedTGAT(in_feats=10, hidden_feats=16, out_feats=1, num_heads=2, num_layers=2)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    # Fake regression labels for demonstration (e.g., total trade volume per country)
    labels = torch.tensor([[0.5], [0.7], [0.3], [0.9], [0.4]], dtype=torch.float32)
    
    for epoch in range(100):
        # Forward pass: pass the node features and the node mask into the model
        logits = model(g, features, node_mask)
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

train_masked_tgat()

## Pytorch Geometric Temporal Implementation (using TGN, but I think we should deprioritise this for now)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric_temporal.nn.tgn import TGN  # Ensure you have the correct import for TGN
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal
import numpy as np

# Define a TGN-based model that accepts edge features
class TradeTGN(nn.Module):
    def __init__(self, node_in_feats, edge_in_feats, memory_dim, hidden_dim, out_feats, num_heads):
        super(TradeTGN, self).__init__()
        # TGN module that accepts node features, edge features, and timestamps.
        self.tgn = TGN(
            node_in_channels=node_in_feats,
            edge_in_channels=edge_in_feats,
            memory_dim=memory_dim,
            out_channels=hidden_dim,
            num_heads=num_heads
        )
        self.fc = nn.Linear(hidden_dim, out_feats)
    
    def forward(self, x, edge_index, edge_attr, timestamps):
        # The TGN module computes node embeddings using the provided edge features.
        h = self.tgn(x, edge_index, edge_attr, timestamps)
        return self.fc(h)

# Create a temporal graph dataset for your trade use case.
def create_trade_temporal_graph():
    num_nodes = 10          # For example, 10 countries
    num_snapshots = 5       # e.g., quarterly snapshots
    edge_indices = []
    edge_times = []
    node_features = []
    edge_features = []      # We'll generate edge features here
    labels = []

    for t in range(num_snapshots):
        # Simulate 20 bilateral trade relationships at time step t
        src_nodes = torch.randint(0, num_nodes, (20,))
        dst_nodes = torch.randint(0, num_nodes, (20,))
        edge_time = torch.tensor([t] * 20, dtype=torch.float32)
        edge_index = torch.stack([src_nodes, dst_nodes], dim=0)

        # Node features: [need_diversion, impediment_diversion, GDP_agriculture, GDP_industry, GDP_services]
        node_feat = torch.randn(num_nodes, 5)  # Replace with your actual data

        # Edge features: [trade_agriculture, trade_industry, trade_services, political_score]
        edge_feat = torch.randn(20, 4)  # Replace with your real bilateral trade & political data

        # Compute label: total outgoing trade volume per country (sum of trade volumes across sectors)
        trade_volumes = torch.sum(edge_feat[:, :3], dim=1)
        country_trade_volume = torch.zeros(num_nodes)
        for i in range(20):
            country_trade_volume[src_nodes[i]] += trade_volumes[i]

        edge_indices.append(edge_index)
        edge_times.append(edge_time)
        node_features.append(node_feat)
        edge_features.append(edge_feat)
        labels.append(country_trade_volume)

    # IMPORTANT: If your dataset class does not support edge features, you might need to extend it.
    # For simplicity, here we return the DynamicGraphTemporalSignal without edge features.
    # In practice, you’d modify the snapshot objects to include the corresponding edge_attr.
    return DynamicGraphTemporalSignal(edge_indices, edge_times, node_features, labels)

# Train the TGN model
def train_trade_tgn():
    dataset = create_trade_temporal_graph()
    model = TradeTGN(node_in_feats=5, edge_in_feats=4, memory_dim=16, hidden_dim=32, out_feats=1, num_heads=2)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()
    
    model.train()
    for epoch in range(100):
        loss_sum = 0
        # Here, we assume that each snapshot now includes an edge_attr attribute.
        # If not, you can manage edge features using a parallel list or extend your dataset.
        for snapshot in dataset:
            x, edge_index, edge_time = snapshot.x, snapshot.edge_index, snapshot.edge_time
            # For demonstration, we simulate edge_attr here; replace with actual snapshot.edge_attr if available.
            edge_attr = torch.randn(edge_index.size(1), 4)
            y_true = snapshot.y.unsqueeze(1)
            
            optimizer.zero_grad()
            y_pred = model(x, edge_index, edge_attr, edge_time)
            loss = loss_fn(y_pred, y_true)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss_sum / len(dataset):.4f}")

# Run training
train_trade_tgn()


## FROM THE PAPER
Link: https://github.com/dmlc/dgl/tree/0.9.x/examples/pytorch/tgn

TGN, i think for DGL, is if we need more advanced memory mechanisms, which we might not need for now.

In [None]:
import copy
import torch.nn as nn
import dgl
from modules import MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode #find modules from the github link


class TGN(nn.Module):
    def __init__(self,
                 edge_feat_dim,
                 memory_dim,
                 temporal_dim,
                 embedding_dim,
                 num_heads,
                 num_nodes,
                 n_neighbors=10,
                 memory_updater_type='gru',
                 layers=1):
        super(TGN, self).__init__()
        self.memory_dim = memory_dim
        self.edge_feat_dim = edge_feat_dim
        self.temporal_dim = temporal_dim
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.n_neighbors = n_neighbors
        self.memory_updater_type = memory_updater_type
        self.num_nodes = num_nodes
        self.layers = layers

        self.temporal_encoder = TimeEncode(self.temporal_dim)

        self.memory = MemoryModule(self.num_nodes,
                                   self.memory_dim)

        self.memory_ops = MemoryOperation(self.memory_updater_type,
                                          self.memory,
                                          self.edge_feat_dim,
                                          self.temporal_encoder)

        self.embedding_attn = TemporalTransformerConv(self.edge_feat_dim,
                                                      self.memory_dim,
                                                      self.temporal_encoder,
                                                      self.embedding_dim,
                                                      self.num_heads,
                                                      layers=self.layers,
                                                      allow_zero_in_degree=True)

        self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)

    def embed(self, postive_graph, negative_graph, blocks):
        emb_graph = blocks[0]
        emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
        emb_t = emb_graph.ndata['timestamp']
        embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
        emb2pred = dict(
            zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist()))
        # Since postive graph and negative graph has same is mapping
        feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
        feat = embedding[feat_id]
        pred_pos, pred_neg = self.msg_linkpredictor(
            feat, postive_graph, negative_graph)
        return pred_pos, pred_neg

    def update_memory(self, subg):
        new_g = self.memory_ops(subg)
        self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata['memory'])
        self.memory.set_last_update_t(
            new_g.ndata[dgl.NID], new_g.ndata['timestamp'])

    # Some memory operation wrappers
    def detach_memory(self):
        self.memory.detach_memory()

    def reset_memory(self):
        self.memory.reset_memory()

    def store_memory(self):
        memory_checkpoint = {}
        memory_checkpoint['memory'] = copy.deepcopy(self.memory.memory)
        memory_checkpoint['last_t'] = copy.deepcopy(self.memory.last_update_t)
        return memory_checkpoint

    def restore_memory(self, memory_checkpoint):
        self.memory.memory = memory_checkpoint['memory']
        self.memory.last_update_time = memory_checkpoint['last_t']

Package                   Version
------------------------- --------------
anyio                     4.8.0
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 3.0.0
async-lru                 2.0.4
attrs                     25.1.0
babel                     2.17.0
beautifulsoup4            4.13.3
bleach                    6.2.0
blinker                   1.9.0
certifi                   2025.1.31
cffi                      1.17.1
charset-normalizer        3.4.1
click                     8.1.8
colorama                  0.4.6
comm                      0.2.2
contourpy                 1.3.1
cycler                    0.12.1
dash                      2.18.2
dash-core-components      2.0.0
dash-html-components      2.0.0
dash-table                5.0.0
debugpy                   1.8.12
decorator                 5.2.1
defusedxml                0.7.1
dgl                       2.2.1
executing                 2.2.0
fastjsonschema  


[notice] A new release of pip is available: 23.2.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip
