## simple implementation from gpt

In [1]:
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
import dgl.function as fn
import numpy as np

# Define the Temporal Attention Layer
class TGATLayer(nn.Module):
    def __init__(self, in_feats, out_feats, num_heads, dropout=0.1):
        super(TGATLayer, self).__init__()
        self.num_heads = num_heads
        self.out_feats = out_feats
        # Change attention layer to output a scalar per head
        self.attn_fc = nn.Linear(2 * out_feats, 1, bias=False)
        self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
        self.dropout = nn.Dropout(dropout)

    def edge_attention(self, edges):
        # Concatenate source and destination features along the last dimension.
        # Both edges.src['z'] and edges.dst['z'] have shape (E, num_heads, out_feats),
        # so concatenating along dim=-1 yields shape (E, num_heads, 2*out_feats)
        z_cat = torch.cat([edges.src['z'], edges.dst['z']], dim=-1)
        a = self.attn_fc(z_cat)  # Now shape: (E, num_heads, 1)
        a = torch.nn.functional.leaky_relu(a)
        return {'e': a}

    def forward(self, g, h):
        with g.local_scope():
            # Linear transformation and reshape to (N, num_heads, out_feats)
            z = self.fc(h)  # shape: (N, out_feats * num_heads)
            z = z.view(z.shape[0], self.num_heads, self.out_feats)
            g.ndata['z'] = z

            # Compute attention scores on edges
            g.apply_edges(self.edge_attention)

            # Message passing: multiply each source feature with its attention coefficient
            g.update_all(fn.u_mul_e('z', 'e', 'm'), fn.mean('m', 'h_new'))
            # Flatten the aggregated features to (N, num_heads * out_feats)
            h_new = g.ndata['h_new'].reshape(g.ndata['h_new'].shape[0], self.num_heads * self.out_feats)
            return h_new

# Define the TGAT Model
class TGAT(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, num_heads, num_layers):
        super(TGAT, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(TGATLayer(in_feats, hidden_feats, num_heads))
        for _ in range(num_layers - 1):
            # The input for the next layer is flattened from (num_heads, hidden_feats)
            self.layers.append(TGATLayer(hidden_feats * num_heads, hidden_feats, num_heads))
        self.fc_out = nn.Linear(hidden_feats * num_heads, out_feats)

    def forward(self, g, h):
        for layer in self.layers:
            h = layer(g, h)
        return self.fc_out(h)

# Create a small temporal graph dataset
def create_tg():
    src_nodes = [0, 1, 2, 3, 4]
    dst_nodes = [1, 2, 3, 4, 0]
    timestamps = [1, 2, 3, 4, 5]  # Simulated time edges

    g = dgl.graph((src_nodes, dst_nodes))
    g.edata['timestamp'] = torch.tensor(timestamps, dtype=torch.float32)
    g.ndata['feat'] = torch.randn(len(src_nodes), 10)  # Random node features
    return g

# Train the TGAT Model
def train_tgat():
    g = create_tg()
    model = TGAT(in_feats=10, hidden_feats=16, out_feats=2, num_heads=2, num_layers=2)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()

    # Fake labels for demonstration
    labels = torch.tensor([0, 1, 0, 1, 0], dtype=torch.long)

    for epoch in range(100):
        logits = model(g, g.ndata['feat'])
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

# Run the training process
train_tgat()


Epoch 0 | Loss: 0.6800
Epoch 10 | Loss: 0.4489
Epoch 20 | Loss: 0.1027
Epoch 30 | Loss: 0.0203
Epoch 40 | Loss: 0.0000
Epoch 50 | Loss: 0.0004
Epoch 60 | Loss: 0.0009
Epoch 70 | Loss: 0.0012
Epoch 80 | Loss: 0.0014
Epoch 90 | Loss: 0.0014


## FROM THE PAPER

In [None]:
import copy
import torch.nn as nn
import dgl
from modules import MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode


class TGN(nn.Module):
    def __init__(self,
                 edge_feat_dim,
                 memory_dim,
                 temporal_dim,
                 embedding_dim,
                 num_heads,
                 num_nodes,
                 n_neighbors=10,
                 memory_updater_type='gru',
                 layers=1):
        super(TGN, self).__init__()
        self.memory_dim = memory_dim
        self.edge_feat_dim = edge_feat_dim
        self.temporal_dim = temporal_dim
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.n_neighbors = n_neighbors
        self.memory_updater_type = memory_updater_type
        self.num_nodes = num_nodes
        self.layers = layers

        self.temporal_encoder = TimeEncode(self.temporal_dim)

        self.memory = MemoryModule(self.num_nodes,
                                   self.memory_dim)

        self.memory_ops = MemoryOperation(self.memory_updater_type,
                                          self.memory,
                                          self.edge_feat_dim,
                                          self.temporal_encoder)

        self.embedding_attn = TemporalTransformerConv(self.edge_feat_dim,
                                                      self.memory_dim,
                                                      self.temporal_encoder,
                                                      self.embedding_dim,
                                                      self.num_heads,
                                                      layers=self.layers,
                                                      allow_zero_in_degree=True)

        self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)

    def embed(self, postive_graph, negative_graph, blocks):
        emb_graph = blocks[0]
        emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
        emb_t = emb_graph.ndata['timestamp']
        embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
        emb2pred = dict(
            zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist()))
        # Since postive graph and negative graph has same is mapping
        feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
        feat = embedding[feat_id]
        pred_pos, pred_neg = self.msg_linkpredictor(
            feat, postive_graph, negative_graph)
        return pred_pos, pred_neg

    def update_memory(self, subg):
        new_g = self.memory_ops(subg)
        self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata['memory'])
        self.memory.set_last_update_t(
            new_g.ndata[dgl.NID], new_g.ndata['timestamp'])

    # Some memory operation wrappers
    def detach_memory(self):
        self.memory.detach_memory()

    def reset_memory(self):
        self.memory.reset_memory()

    def store_memory(self):
        memory_checkpoint = {}
        memory_checkpoint['memory'] = copy.deepcopy(self.memory.memory)
        memory_checkpoint['last_t'] = copy.deepcopy(self.memory.last_update_t)
        return memory_checkpoint

    def restore_memory(self, memory_checkpoint):
        self.memory.memory = memory_checkpoint['memory']
        self.memory.last_update_time = memory_checkpoint['last_t']

Package                   Version
------------------------- --------------
anyio                     4.8.0
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 3.0.0
async-lru                 2.0.4
attrs                     25.1.0
babel                     2.17.0
beautifulsoup4            4.13.3
bleach                    6.2.0
blinker                   1.9.0
certifi                   2025.1.31
cffi                      1.17.1
charset-normalizer        3.4.1
click                     8.1.8
colorama                  0.4.6
comm                      0.2.2
contourpy                 1.3.1
cycler                    0.12.1
dash                      2.18.2
dash-core-components      2.0.0
dash-html-components      2.0.0
dash-table                5.0.0
debugpy                   1.8.12
decorator                 5.2.1
defusedxml                0.7.1
dgl                       2.2.1
executing                 2.2.0
fastjsonschema  


[notice] A new release of pip is available: 23.2.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip
