In [None]:
# mr_rf_tensor_fusion.py
# PyTorch implementation of a simple MRRF + Tensor Fusion prototype.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import math

# -------------------------
# Helper: projection utils
# -------------------------
def batch_project_onto(u, v):
    """
    Project vector u onto vector v (batched).
    u: (B, D)
    v: (B, D)  -- can be the same vector repeated for each batch
    returns: projection of u onto v  (B, D)
    formula: proj_v(u) = (u·v / v·v) * v
    """
    # add tiny eps for numeric stability
    eps = 1e-8
    dot = torch.sum(u * v, dim=1, keepdim=True)        # (B,1)
    norm2 = torch.sum(v * v, dim=1, keepdim=True) + eps  # (B,1)
    coef = dot / norm2                                 # (B,1)
    proj = coef * v                                    # (B,D)
    return proj

# -------------------------
# Modality encoders
# -------------------------
class ModalityEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    def forward(self, x):
        return self.net(x)  # (B, hidden_dim)

# -------------------------
# MRRF + Tensor Fusion Model
# -------------------------
class MRRF_TensorFusion(nn.Module):
    def __init__(self, input_dims, modality_hidden=128, fused_hidden=128, out_dim=2):
        """
        input_dims: list of int, input dimension for each modality (e.g. [300, 50, 128])
        modality_hidden: encoded feature dim for each modality after encoder
        fused_hidden: hidden dim after fusion
        out_dim: number of classes (or 1 for regression)
        """
        super().__init__()
        self.num_modalities = len(input_dims)
        self.encoders = nn.ModuleList([
            ModalityEncoder(d, modality_hidden) for d in input_dims
        ])
        # small module to compute a 'shared vector' (redundant component)
        # We produce one shared vector per sample by averaging encoders' outputs and passing through a small MLP
        self.shared_mlp = nn.Sequential(
            nn.Linear(modality_hidden, modality_hidden),
            nn.ReLU(),
            nn.Linear(modality_hidden, modality_hidden)
        )
        # Optionally compress each residual before tensor fusion
        self.residual_projectors = nn.ModuleList([
            nn.Linear(modality_hidden, fused_hidden) for _ in input_dims
        ])
        # final classifier after tensor fusion
        # We'll use outer products: (r1 ⊗ r2 ⊗ ...), but full outer grows quickly.
        # To keep size manageable, we fuse pairwise and then flatten: (r1 ⊗ r2) concat (r2 ⊗ r3) ...
        self.fused_hidden = fused_hidden
        pairwise_count = max(1, self.num_modalities - 1)
        # classifier on flattened pairwise outer-products
        fused_vector_len = pairwise_count * (fused_hidden * fused_hidden)
        self.classifier = nn.Sequential(
            nn.Linear(fused_vector_len, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, out_dim)
        )
    
    def forward(self, *modal_inputs):
        """
        modal_inputs: list of tensors, each (B, input_dim_i)
        """
        B = modal_inputs[0].shape[0]
        # 1) encode modalities
        encoded = [enc(x) for enc, x in zip(self.encoders, modal_inputs)]  # list of (B, H)
        # 2) compute shared/ redundant component estimate
        # simple estimate: mean of encodings, then refine with MLP
        mean_enc = torch.stack(encoded, dim=0).mean(dim=0)  # (B, H)
        shared = self.shared_mlp(mean_enc)                  # (B, H)
        # 3) compute residuals by removing projection onto shared direction
        residuals = []
        for e in encoded:
            proj = batch_project_onto(e, shared)  # (B,H)
            residual = e - proj
            residuals.append(residual)            # (B,H)
        # 4) optionally project residuals to smaller fused dim
        rp = [proj(res) for proj, res in zip(self.residual_projectors, residuals)]  # list of (B, F)
        # 5) tensor fusion using pairwise outer-products (keeps dims manageable)
        # Example: for 3 modalities, produce outer(r1,r2), outer(r2,r3)
        pairwise = []
        for i in range(self.num_modalities - 1):
            a = rp[i].unsqueeze(2)  # (B, F, 1)
            b = rp[i+1].unsqueeze(1)  # (B, 1, F)
            outer = torch.matmul(a, b)  # (B, F, F)
            pairwise.append(outer.view(B, -1))  # flatten (B, F*F)
        fused_vec = torch.cat(pairwise, dim=1)  # (B, pairwise_count * F * F)
        # 6) classification
        out = self.classifier(fused_vec)  # (B, out_dim)
        return out, {
            "encoded": encoded,
            "shared": shared,
            "residuals": residuals,
            "rp": rp,
            "fused_vec": fused_vec
        }

# -------------------------
# Toy training loop (synthetic data)
# -------------------------
def synthetic_data(num_samples=2000, dims=[300, 50, 128], num_classes=2):
    torch.manual_seed(0)
    Xs = [torch.randn(num_samples, d) for d in dims]
    # create artificial label correlated to sum of first modality's mean and small noise
    y = (Xs[0].mean(dim=1) + 0.1 * torch.randn(num_samples) > 0).long()
    return Xs, y

def train_example():
    # config
    input_dims = [300, 50, 128]
    model = MRRF_TensorFusion(input_dims, modality_hidden=128, fused_hidden=64, out_dim=2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    Xs, y = synthetic_data(num_samples=2000, dims=input_dims, num_classes=2)
    dataset = TensorDataset(Xs[0], Xs[1], Xs[2], y)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(8):
        total_loss = 0.0
        correct = 0
        total = 0
        for xb0, xb1, xb2, lbl in loader:
            xb0 = xb0.to(device); xb1 = xb1.to(device); xb2 = xb2.to(device); lbl = lbl.to(device)
            opt.zero_grad()
            logits, aux = model(xb0, xb1, xb2)
            loss = loss_fn(logits, lbl)
            loss.backward()
            opt.step()
            total_loss += loss.item() * xb0.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == lbl).sum().item()
            total += xb0.size(0)
        print(f"Epoch {epoch+1}: loss={total_loss/total:.4f} acc={correct/total:.4f}")

    # example inference
    model.eval()
    with torch.no_grad():
        sample = [Xs[i][:5].to(device) for i in range(len(Xs))]
        logits, aux = model(*sample)
        print("Logits (sample):", logits)
        print("Shared vector shape:", aux["shared"].shape)
        print("Residual shape (per modality):", aux["residuals"][0].shape)

if __name__ == "__main__":
    train_example()
