# Modular TRM Training for Community Detection

This notebook demonstrates a modular approach to training and evaluating a TRM neural network for community detection on synthetic graphs using PyTorch. The code is organized for easy adaptation to other graph-based problems.

## 1. Import Libraries and Set Up Environment
Import all required libraries, set random seeds, and configure device (CPU/GPU).

In [75]:
import os, math, random
from dataclasses import dataclass
from typing import Tuple, Any, List
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sys
sys.path.append(os.path.join("..", "src"))
from exploretinyrm.trm import TRM, TRMConfig
def set_seed(seed: int = 123):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

import networkx as nx

Device: cuda


## 2. AMP and EMA Utilities
Define automatic mixed precision (AMP) and exponential moving average (EMA) utility functions and classes.

In [76]:
try:
    from torch.amp import autocast as _autocast, GradScaler as _GradScaler
    _USE_TORCH_AMP = True
except ImportError:
    from torch.cuda.amp import autocast as _autocast, GradScaler as _GradScaler
    _USE_TORCH_AMP = False

def make_grad_scaler(is_cuda: bool):
    if _USE_TORCH_AMP:
        try:
            return _GradScaler("cuda", enabled=is_cuda)
        except TypeError:
            return _GradScaler(enabled=is_cuda)
    else:
        return _GradScaler(enabled=is_cuda)

def amp_autocast(is_cuda: bool, use_amp: bool):
    if _USE_TORCH_AMP:
        try:
            return _autocast(device_type="cuda", enabled=(is_cuda and use_amp))
        except TypeError:
            return _autocast(enabled=(is_cuda and use_amp))
    else:
        return _autocast(enabled=(is_cuda and use_amp))

class EMA:
    def __init__(self, model: torch.nn.Module, decay: float = 0.999):
        self.decay = decay
        self.shadow = {
            name: param.detach().clone()
            for name, param in model.named_parameters()
            if param.requires_grad
        }

    def update(self, model: torch.nn.Module) -> None:
        d = self.decay
        with torch.no_grad():
            for name, param in model.named_parameters():
                if not param.requires_grad:
                    continue
                self.shadow[name].mul_(d).add_(param.detach(), alpha=1.0 - d)

    def copy_to(self, model: torch.nn.Module) -> None:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in self.shadow:
                    param.copy_(self.shadow[name])

from contextlib import contextmanager

@contextmanager
def use_ema_weights(model: torch.nn.Module, ema: EMA):
    backup = {
        name: param.detach().clone()
        for name, param in model.named_parameters()
        if param.requires_grad
    }
    ema.copy_to(model)
    try:
        yield
    finally:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in backup:
                    param.copy_(backup[name])

## 3. Community Detection Dataset Preparation
Synthetic dataset for community detection

In [77]:
class GameDataset(Dataset):
    """Base class for game datasets. Subclass and implement _generate_sample."""
    def __init__(self, n_samples: int, seed: int = 0):
        self.rng = np.random.default_rng(seed)
        self.samples = [self._generate_sample() for _ in range(n_samples)]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]
    def _generate_sample(self): raise NotImplementedError()

class CommunityDetectionDataset(GameDataset):
    """Synthetic SBM community detection puzzles."""
    def __init__(self, n_samples: int, n_nodes: int = 30, n_communities: int = 3, p_in: float = 0.6, p_out: float = 0.05, seed: int = 0):
        self.n_nodes = n_nodes
        self.n_communities = n_communities
        self.p_in = p_in
        self.p_out = p_out
        self.rng = np.random.default_rng(seed)
        self.samples = [self._generate_sample() for _ in range(n_samples)]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]
    def _generate_sample(self):
        sizes = [self.n_nodes // self.n_communities] * self.n_communities
        for i in range(self.n_nodes % self.n_communities):
            sizes[i] += 1
        probs = np.full((self.n_communities, self.n_communities), self.p_out)
        np.fill_diagonal(probs, self.p_in)
        G = nx.stochastic_block_model(sizes, probs, seed=int(self.rng.integers(1e9)))
        labels = []
        for idx, size in enumerate(sizes):
            labels.extend([idx] * size)
        labels = np.array(labels)
        adj = nx.to_numpy_array(G)
        x_tokens = torch.from_numpy(adj.astype(np.int64))  # [n_nodes, n_nodes]
        y_tokens = torch.from_numpy(labels.astype(np.int64))  # [n_nodes]
        return x_tokens, y_tokens

def get_gc_loaders(n_train=512, n_val=128, batch_size=16, n_nodes=30, n_communities=3, p_in=0.6, p_out=0.05, seed=42):
    ds_tr = CommunityDetectionDataset(n_samples=n_train, n_nodes=n_nodes, n_communities=n_communities, p_in=p_in, p_out=p_out, seed=seed)
    ds_va = CommunityDetectionDataset(n_samples=n_val, n_nodes=n_nodes, n_communities=n_communities, p_in=p_in, p_out=p_out, seed=seed+1)
    return (
        DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True),
        DataLoader(ds_va, batch_size=batch_size, shuffle=False, pin_memory=True)
    )

train_loader, val_loader = get_gc_loaders(
    n_train=2048,
    n_val=512,
    batch_size=16,
    n_nodes=4, #N_NODES
    n_communities=3,
    p_in=0.6,
    p_out=0.05,
    seed=123
)

In [78]:
# show some examples of the dataset
for i in range(2):
    x, y = train_loader.dataset[i]
    print(f"Example {i}:")
    print(x.shape, y.shape)

Example 0:
torch.Size([4, 4]) torch.Size([4])
Example 1:
torch.Size([4, 4]) torch.Size([4])


In [79]:
N_NODES = 4
N_COMMUNITIES = 3
INPUT_TOKENS = 2  # adjacency values: 0 or 1 (float)
OUTPUT_TOKENS = N_COMMUNITIES
SEQ_LEN = N_NODES

D_MODEL = 128
N_SUP = 16
N = 6
T = 3
USE_ATT = False

In [80]:
# show some examples of the dataset
for i in range(2):
    x, y = train_loader.dataset[i]
    print(f"Example {i}:")
    print("Input Adjacency Matrix:")
    adj_matrix = x.numpy().reshape(N_NODES, N_NODES)
    print(adj_matrix)
    print("Node Colors:")
    print(y.numpy())
    print()

Example 0:
Input Adjacency Matrix:
[[0 1 0 0]
 [1 0 0 0]
 [0 0 0 1]
 [0 0 1 0]]
Node Colors:
[0 0 1 2]

Example 1:
Input Adjacency Matrix:
[[0 1 0 0]
 [1 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]
Node Colors:
[0 0 1 2]



## 4. Model Configuration for Community Detection
Define the input/output vocabularies, sequence encoding, and instantiate the TRM model for node classification.

In [81]:

cfg = TRMConfig(
    input_vocab_size=INPUT_TOKENS,
    output_vocab_size=OUTPUT_TOKENS,
    seq_len=SEQ_LEN,
    d_model=D_MODEL,
    n_layers=2,
    use_attention=USE_ATT,
    n_heads=8,
    dropout=0.0,
    mlp_ratio=4.0,
    token_mlp_ratio=2.0,
    n=N,
    T=T,
    k_last_ops=None,
    stabilize_input_sums=True
)

model = TRM(cfg).to(device)
print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=3e-4, weight_decay=0.0, betas=(0.9, 0.95)
)

scaler = make_grad_scaler(device.type == "cuda")
ema = EMA(model, decay=0.999)

Params (M): 0.39488


## 5. Training Loop
Train the TRM model to predict community labels from the adjacency matrix, using permutation-invariant loss and accuracy metrics.

In [None]:
# --- Training Loop ---
def token_ce_loss(logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    B, L, V = logits.shape
    return F.cross_entropy(logits.reshape(B*L, V), y_true.reshape(B*L))

def train_one_epoch(model, loader, optimizer, scaler, epoch, use_amp=True, ema=None):
    model.train()
    total_ce, total_acc, total_steps = 0.0, 0.0, 0
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device, non_blocking=True).long()
        y_true   = y_true.to(device,   non_blocking=True)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(N_SUP):
            optimizer.zero_grad(set_to_none=True)
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
            loss_ce = token_ce_loss(logits.float(), y_true)
            with torch.no_grad():
                preds = logits.argmax(dim=-1)
                acc = (preds == y_true).float().mean().item()
            loss = loss_ce
            if use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                if ema is not None:
                    ema.update(model)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                if ema is not None:
                    ema.update(model)
            total_ce   += loss_ce.detach().item()
            total_acc  += acc
            total_steps += 1
    print(f"Epoch {epoch:02d} | CE {total_ce/max(1,total_steps):.4f} | Accuracy {total_acc/max(1,total_steps):.3f}")

EPOCHS = 2
node_acc_history = []
for epoch in range(1, EPOCHS+1):
    train_one_epoch(model, train_loader, optimizer, scaler, epoch, use_amp=False)
    # Optionally add validation here

RuntimeError: The size of tensor a (4) must match the size of tensor b (16) at non-singleton dimension 1

## 5. Evaluation and Visualization
Evaluate the trained TRM model on synthetic graphs and visualize the detected communities.

In [None]:
# --- Evaluation and Visualization ---
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    acc_list = []
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device).long()
        y_true   = y_true.to(device)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(N_SUP):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
        preds = logits.argmax(dim=-1)
        acc = (preds == y_true).float().mean().item()
        acc_list.append(acc)
    avg_acc = np.mean(acc_list)
    print(f"Validation | Accuracy {avg_acc:.3f}")
    return avg_acc

avg_acc = evaluate(model, val_loader)
print("Node accuracy history:", node_acc_history)

# Visualization
import matplotlib.pyplot as plt
G, labels = CommunityDetectionDataset(n_samples=1, n_nodes=N_NODES, n_communities=N_COMMUNITIES, seed=999)[0]
adj = G.numpy().reshape(N_NODES, N_NODES)
G_nx = nx.from_numpy_array(adj)
pos = nx.spring_layout(G_nx, seed=42)
preds = model(torch.from_numpy(adj.flatten()).unsqueeze(0).to(device).long()).argmax(dim=-1).cpu().numpy()[0]
nx.draw_networkx_nodes(G_nx, pos, node_color=preds, cmap=plt.cm.Set1, node_size=100)
nx.draw_networkx_edges(G_nx, pos, alpha=0.5)
plt.title("TRM Detected Communities")
plt.axis('off')
plt.show()