<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/multihop_reasoner_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# multihop_reasoner.py
import math
import os
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class MultiHopReasoner(nn.Module):
    """
    Differentiable multi-hop symbolic reasoner.

    - Facts: learnable embeddings of size (N, D)
    - Relations: learnable bank of matrices (R, D, D)
    - Reasoning: K hops; at hop t, compute relation mixture from context,
      transform facts, attend with cosine similarity, update context via residual read.

    Inputs:
      query: (B, D) query embeddings
      hops: int, number of hops (default: self.hops)
      return_intermediates: bool, if True returns dict with per-hop beliefs and relation weights

    Outputs:
      probs: (B, N) probability over facts
      extras (optional): {
        "beliefs": List[(B, N)],
        "rel_weights": List[(B, R)],
        "reads": List[(B, D)]
      }
    """

    def __init__(
        self,
        num_facts: int,
        embedding_dim: int,
        num_relations: int = 4,
        hops: int = 2,
        temperature: float = 0.2,
        train_facts: bool = True,
        layer_norm: bool = True,
        dropout_p: float = 0.0,
    ):
        super().__init__()
        self.N = num_facts
        self.D = embedding_dim
        self.R = num_relations
        self.K = hops
        self.tau = temperature
        self.use_ln = layer_norm
        self.dropout = nn.Dropout(dropout_p) if dropout_p > 0.0 else nn.Identity()

        # Fact embeddings
        facts = torch.randn(self.N, self.D) / math.sqrt(self.D)
        self.facts = nn.Parameter(facts, requires_grad=train_facts)

        # Relation matrices
        # Xavier init for stable transforms
        self.relations = nn.Parameter(torch.empty(self.R, self.D, self.D))
        nn.init.xavier_uniform_(self.relations)

        # Query-conditioned relation scorer
        hidden = max(64, self.D)
        self.rel_scorer = nn.Sequential(
            nn.Linear(self.D, hidden),
            nn.ReLU(),
            nn.Linear(hidden, self.R),
        )

        # Optional context mixer to update query with the read vector
        self.context_mixer = nn.Sequential(
            nn.Linear(2 * self.D, self.D),
            nn.ReLU(),
        )

        # Layer norms for stability
        self.ln_facts = nn.LayerNorm(self.D) if self.use_ln else nn.Identity()
        self.ln_ctx = nn.LayerNorm(self.D) if self.use_ln else nn.Identity()

    def _normalize(self, x: torch.Tensor, dim: int = -1, eps: float = 1e-8):
        return x / (x.norm(dim=dim, keepdim=True) + eps)

    def _apply_relation_mixture(self, F_emb: torch.Tensor, rel_logits: torch.Tensor):
        """
        F_emb: (B, N, D)
        rel_logits: (B, R)
        returns:
          F_trans: (B, N, D) where each batch uses a mixture of relation matrices
        """
        B, N, D = F_emb.shape
        w = F.softmax(rel_logits, dim=-1)  # (B, R)
        # Weighted sum of relation matrices per batch
        # R_mix[b] = sum_r w[b,r] * R[r]  => (B, D, D)
        R_mix = torch.einsum("br,rde->bde", w, self.relations)
        # Apply to all facts in batch
        F_trans = torch.einsum("bnd,bde->bne", F_emb, R_mix)
        return F_trans, w

    def forward(
        self,
        query: torch.Tensor,
        hops: Optional[int] = None,
        return_intermediates: bool = False,
    ):
        """
        query: (B, D)
        """
        B, D = query.shape
        K = self.K if hops is None else hops

        # Prepare facts for batch: (B, N, D)
        F_emb = self.ln_facts(self.facts)  # (N, D)
        F_emb = F_emb.unsqueeze(0).expand(B, -1, -1)

        # Initialize context and containers
        ctx = self.ln_ctx(query)
        beliefs: List[torch.Tensor] = []
        rel_weights: List[torch.Tensor] = []
        reads: List[torch.Tensor] = []

        # Multi-hop reasoning loop
        for _ in range(K):
            # Relation mixture from context
            rel_logits = self.rel_scorer(ctx)  # (B, R)
            F_trans, w_rel = self._apply_relation_mixture(F_emb, rel_logits)  # (B,N,D), (B,R)

            # Optionally include residual facts (keeps identity paths available)
            F_hop = self.ln_facts(F_trans + F_emb)

            # Similarity to current context and belief update
            F_norm = self._normalize(F_hop, dim=-1)
            ctx_norm = self._normalize(ctx, dim=-1)
            scores = torch.einsum("bnd,bd->bn", F_norm, ctx_norm)  # cosine sim
            scores = scores / max(self.tau, 1e-6)
            belief = F.softmax(scores, dim=-1)  # (B, N)

            # Read vector as belief-weighted sum of hop facts
            read = torch.einsum("bn,bnd->bd", belief, F_hop)  # (B, D)
            read = self.dropout(read)

            # Update context with residual mix
            ctx = self.context_mixer(torch.cat([ctx, read], dim=-1))
            ctx = self.ln_ctx(ctx)

            beliefs.append(belief)
            rel_weights.append(w_rel)
            reads.append(read)

        # Final scores and probabilities over facts from last hop
        probs = beliefs[-1]
        if return_intermediates:
            return probs, {"beliefs": beliefs, "rel_weights": rel_weights, "reads": reads}
        return probs