In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SimpleEncoder(nn.Module):
    """Small MLP encoder used for each modality. Replace with CNN/RNN/Transformer as needed."""
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # x: (batch, features)
        return self.net(x)  # (batch, out_dim)


class CrossAttentionBlock(nn.Module):
    """Cross-attention block between two feature sets A and B.
    We compute attention from A -> B and return updated A and B (residual style).
    This block is symmetric and lightweight (linear projections + scaled dot-product).
    """
    def __init__(self, dim, num_heads=4, dropout=0.1):
        super().__init__()
        assert dim % num_heads == 0
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        # projections for queries/keys/values for both directions
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        self.out = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dim)

    def _split_heads(self, x):
        # x: (batch, seq_len, dim) or (batch, dim) -> treat as seq_len=1
        if x.dim() == 2:
            x = x.unsqueeze(1)
        b, s, d = x.shape
        x = x.view(b, s, self.num_heads, self.head_dim).transpose(1, 2)  # (b, heads, s, head_dim)
        return x

    def _merge_heads(self, x):
        # x: (b, heads, s, head_dim)
        x = x.transpose(1, 2).contiguous()  # (b, s, heads, head_dim)
        b, s, h, hd = x.shape
        return x.view(b, s, h * hd)

    def forward(self, A, B):
        """A, B: tensors of shape (batch, dim) or (batch, seq_len, dim)
        Returns updated (A', B') of the same shapes.
        """
        # normalize inputs
        A_ln = self.norm(A)
        B_ln = self.norm(B)

        # project
        qA = self.q_proj(A_ln)
        kB = self.k_proj(B_ln)
        vB = self.v_proj(B_ln)

        qB = self.q_proj(B_ln)
        kA = self.k_proj(A_ln)
        vA = self.v_proj(A_ln)

        # split heads
        qA_h = self._split_heads(qA)
        kB_h = self._split_heads(kB)
        vB_h = self._split_heads(vB)

        qB_h = self._split_heads(qB)
        kA_h = self._split_heads(kA)
        vA_h = self._split_heads(vA)

        # scaled dot-product attention A->B
        scale = (self.head_dim) ** -0.5
        attn_logits = torch.matmul(qA_h, kB_h.transpose(-2, -1)) * scale  # (b, heads, sA, sB)
        attn = torch.softmax(attn_logits, dim=-1)
        attn = self.dropout(attn)
        attendedB = torch.matmul(attn, vB_h)  # (b, heads, sA, head_dim)
        attendedB = self._merge_heads(attendedB)  # (b, sA, dim)

        # scaled dot-product attention B->A
        attn_logits2 = torch.matmul(qB_h, kA_h.transpose(-2, -1)) * scale
        attn2 = torch.softmax(attn_logits2, dim=-1)
        attn2 = self.dropout(attn2)
        attendedA = torch.matmul(attn2, vA_h)
        attendedA = self._merge_heads(attendedA)

        # linear + residual
        outA = self.out(attendedB)
        outB = self.out(attendedA)

        # If original inputs were 2D (batch, dim), squeeze
        if A.dim() == 2:
            outA = outA.squeeze(1)
            outB = outB.squeeze(1)

        A_up = A + self.dropout(outA)
        B_up = B + self.dropout(outB)

        return A_up, B_up


class HierarchicalFeatureFusionNetwork(nn.Module):
    """A compact HFFN example supporting N modalities.

    Design:
      1) Per-modality encoder -> produce modality embeddings (batch, dim)
      2) Early fusion: pairwise concatenation/project to shared dim
      3) Mid-level: hierarchical cross-attention between modality pairs (multiple stages)
      4) Pooling + late fusion classifier

    This is intentionally modular so you can replace encoder/attention with your own blocks.
    """

    def __init__(self, modalities_dims, hidden_dim=256, shared_dim=128, num_heads=4, depth_per_stage=1, stages=2, num_classes=3, dropout=0.1):
        """
        modalities_dims: dict -> {"text": dim_text, "audio": dim_audio, ...}
        hidden_dim: internal encoder hidden size
        shared_dim: dimension to project all modalities into (embedding dim)
        depth_per_stage: number of cross-attention blocks per pair in each stage
        stages: number of hierarchical stages
        """
        super().__init__()
        self.modalities = list(modalities_dims.keys())
        self.num_modalities = len(self.modalities)
        self.shared_dim = shared_dim

        # encoders per modality
        self.encoders = nn.ModuleDict()
        for name, d in modalities_dims.items():
            self.encoders[name] = SimpleEncoder(d, hidden_dim, shared_dim, dropout=dropout)

        # early fusion projection (projects concatenated features to shared_dim)
        self.early_proj = nn.Sequential(
            nn.Linear(shared_dim * self.num_modalities, shared_dim),
            nn.ReLU(inplace=True),
            nn.LayerNorm(shared_dim)
        )

        # mid-level hierarchical cross-attention blocks.
        # We'll create cross-attention for each pair and reuse blocks across stages if desired.
        self.stages = stages
        self.depth_per_stage = depth_per_stage

        # build a list of cross-attention blocks for each stage
        self.hierarchy = nn.ModuleList()
        for s in range(stages):
            # a small ModuleDict mapping pair->stack of cross-attention blocks
            pair_blocks = nn.ModuleDict()
            for i in range(self.num_modalities):
                for j in range(i + 1, self.num_modalities):
                    pair_name = f"{self.modalities[i]}__{self.modalities[j]}"
                    # stack of blocks for this pair at this stage
                    blocks = nn.ModuleList([CrossAttentionBlock(shared_dim, num_heads=num_heads, dropout=dropout) for _ in range(depth_per_stage)])
                    pair_blocks[pair_name] = blocks
            self.hierarchy.append(pair_blocks)

        # fusion head (after hierarchical stages)
        self.fusion_norm = nn.LayerNorm(shared_dim)
        self.classifier = nn.Sequential(
            nn.Linear(shared_dim * 2, shared_dim),  # use global pooled + early fused
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(shared_dim, num_classes)
        )

    def forward(self, inputs):
        """
        inputs: dict of modality_name -> tensor (batch, features)
        returns logits (batch, num_classes)
        """
        # 1) per-modality encoding
        embeddings = {}
        for name in self.modalities:
            x = inputs[name]
            emb = self.encoders[name](x)  # (batch, shared_dim)
            embeddings[name] = emb

        batch = next(iter(embeddings.values())).shape[0]

        # 2) early fusion (concatenate all embeddings)
        concat = torch.cat([embeddings[m] for m in self.modalities], dim=-1)  # (batch, shared_dim * M)
        early = self.early_proj(concat)  # (batch, shared_dim)

        # 3) hierarchical cross-attention stages
        # We'll keep a working dict for embeddings that gets updated each stage
        working = {k: v.unsqueeze(1) for k, v in embeddings.items()}  # (batch, seq=1, dim)

        for stage_idx, pair_blocks in enumerate(self.hierarchy):
            # iterate through modality pairs and apply their stacked blocks
            # after processing all pairs, we can optionally aggregate or keep updated for next stage
            updates = {}
            # init updates with current values
            for k in working:
                updates[k] = working[k]

            for pair_name, blocks in pair_blocks.items():
                a_name, b_name = pair_name.split("__")
                A = updates[a_name]
                B = updates[b_name]
                # apply stacked blocks sequentially
                for block in blocks:
                    A, B = block(A, B)  # each may be (batch, seq, dim)
                updates[a_name] = A
                updates[b_name] = B

            # optionally perform a pooling or cross-modal aggregation here
            # for simplicity we set working=updates
            working = updates

        # squeeze sequence dim
        final_embeddings = {k: v.squeeze(1) for k, v in working.items()}  # (batch, dim)

        # 4) global pooling/aggregation and classification
        # simple strategy: mean of modality embeddings
        stacked = torch.stack([final_embeddings[m] for m in self.modalities], dim=1)  # (batch, M, dim)
        global_pool = stacked.mean(dim=1)  # (batch, dim)

        # combine early fusion + hierarchical global
        combined = torch.cat([self.fusion_norm(global_pool), early], dim=-1)  # (batch, 2*dim)
        logits = self.classifier(combined)
        return logits


if __name__ == "__main__":
    # quick synthetic example
    modalities = {"text": 300, "audio": 74, "video": 512}
    model = HierarchicalFeatureFusionNetwork(modalities_dims=modalities, hidden_dim=256, shared_dim=128, stages=2, depth_per_stage=1, num_classes=5)

    B = 8
    x_text = torch.randn(B, 300)
    x_audio = torch.randn(B, 74)
    x_video = torch.randn(B, 512)

    inputs = {"text": x_text, "audio": x_audio, "video": x_video}
    logits = model(inputs)
    print("logits shape:", logits.shape)  # (B, num_classes)

    # simple loss to test backward
    labels = torch.randint(0, 5, (B,))
    loss = F.cross_entropy(logits, labels)
    loss.backward()
    print("forward+backward OK, loss:", loss.item())
