<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/train_multihop_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# train_multihop.py
import argparse
import math
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from multihop_reasoner import MultiHopReasoner, set_seed


@torch.no_grad()
def generate_batch(
    true_facts: torch.Tensor,         # (N, D)
    true_relations: torch.Tensor,     # (R, D, D)
    hops: int,
    batch_size: int,
    noise_std: float = 0.0,
    device: str = "cpu",
):
    """
    Sample batch of (query, target_fact_index) pairs:
      - Choose target index t
      - Choose a chain of relations r_1..r_K
      - Query = facts[t] @ R_{r1} @ ... @ R_{rK} + noise
    """
    N, D = true_facts.shape
    R = true_relations.size(0)

    t_idx = torch.randint(low=0, high=N, size=(batch_size,), device=device)  # (B,)
    # Sample relation chains
    r_chain = torch.randint(low=0, high=R, size=(batch_size, hops), device=device)  # (B, K)

    # Build queries
    q = torch.zeros(batch_size, D, device=device)
    for b in range(batch_size):
        v = true_facts[t_idx[b]]  # (D,)
        # Apply chain
        for k in range(hops):
            r = r_chain[b, k]
            v = v @ true_relations[r]  # (D,)
        q[b] = v

    if noise_std > 0:
        q = q + noise_std * torch.randn_like(q)

    # Normalize queries (optional, matches cosine-based scoring)
    q = q / (q.norm(dim=-1, keepdim=True) + 1e-8)

    return q, t_idx


def train(
    device="cuda" if torch.cuda.is_available() else "cpu",
    num_facts=64,
    embedding_dim=64,
    num_relations=4,
    hops=2,
    batch_size=128,
    steps=5000,
    lr=3e-4,
    noise_std=0.0,
    seed=42,
    log_every=100,
):
    set_seed(seed)

    # Hidden ground-truth world to generate supervision
    true_facts = torch.randn(num_facts, embedding_dim, device=device)
    true_facts = true_facts / (true_facts.norm(dim=-1, keepdim=True) + 1e-8)
    true_relations = torch.empty(num_relations, embedding_dim, embedding_dim, device=device)
    nn.init.orthogonal_(true_relations)  # helpful for stable chaining

    # Model (can choose to train facts or treat them as a separate memory)
    model = MultiHopReasoner(
        num_facts=num_facts,
        embedding_dim=embedding_dim,
        num_relations=num_relations,
        hops=hops,
        temperature=0.2,
        train_facts=True,
        layer_norm=True,
        dropout_p=0.0,
    ).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    best_acc = 0.0

    for step in range(1, steps + 1):
        model.train()
        q, targets = generate_batch(true_facts, true_relations, hops, batch_size, noise_std, device)

        probs, extras = model(q, return_intermediates=True)
        loss = F.nll_loss((probs + 1e-8).log(), targets)

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        opt.step()

        if step % log_every == 0 or step == 1:
            with torch.no_grad():
                preds = probs.argmax(dim=-1)
                acc = (preds == targets).float().mean().item()
            best_acc = max(best_acc, acc)
            # Inspect latest hop relation weights (mean over batch)
            rel_mean = torch.stack(extras["rel_weights"])[-1].mean(dim=0)  # (R,)
            print(
                f"[{step:05d}] loss={loss.item():.4f} acc={acc:.3f} best={best_acc:.3f} "
                f"rel_w={rel_mean.cpu().numpy()}"
            )

    # Quick qualitative check
    with torch.no_grad():
        q, targets = generate_batch(true_facts, true_relations, hops, batch_size=4, noise_std=noise_std, device=device)
        probs, extras = model(q, return_intermediates=True)
        preds = probs.argmax(dim=-1)
        print("\nSample predictions:")
        for i in range(q.size(0)):
            print(f"target={targets[i].item():3d} pred={preds[i].item():3d} top5={probs[i].topk(5).indices.tolist()}")


if __name__ == "__main__":
    import argparse

    p = argparse.ArgumentParser(description="Train a multi-hop symbolic reasoner on synthetic chained-relations.")
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--num-facts", type=int, default=64)
    p.add_argument("--embedding-dim", type=int, default=64)
    p.add_argument("--num-relations", type=int, default=4)
    p.add_argument("--hops", type=int, default=2)
    p.add_argument("--batch-size", type=int, default=128)
    p.add_argument("--steps", type=int, default=5000)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--noise-std", type=float, default=0.0)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--log-every", type=int, default=100)
    args = p.parse_args()

    train(
        device=args.device,
        num_facts=args.num_facts,
        embedding_dim=args.embedding_dim,
        num_relations=args.num_relations,
        hops=args.hops,
        batch_size=args.batch_size,
        steps=args.steps,
        lr=args.lr,
        noise_std=args.noise_std,
        seed=args.seed,
        log_every=args.log_every,
    )