In [None]:
import torch
import torch.nn as nn


class ModalityEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.encoder = nn.Linear(input_dim, output_dim)
        self.relu = nn.ReLU()
        self.output_dim = output_dim
    
    def forward(self, features):
        return self.relu(self.encoder(features))
       



class MRRF(nn.Module):
    
    def __init__(self, dim1, dim2, dim3, output_dim, rank1, rank2, rank3, dropout=0.0):
        super(MRRF, self).__init__()
        
        self.ranks = [rank1, rank2, rank3]

        self.factor1 = nn.Linear(dim1 + 1, rank1, bias=False)
        self.factor2 = nn.Linear(dim2 + 1, rank2, bias=False)
        self.factor3 = nn.Linear(dim3 + 1, rank3, bias=False)

        self.dropout = nn.Dropout(dropout)

        fused_dim = rank1 * rank2 * rank3
        self.core_layer = nn.Linear(fused_dim, output_dim)

        nn.init.xavier_uniform_(self.factor1.weight)
        nn.init.xavier_uniform_(self.factor2.weight)
        nn.init.xavier_uniform_(self.factor3.weight)
        nn.init.xavier_uniform_(self.core_layer.weight)

    def forward(self, z1, z2, z3):
        batch_size = z1.size(0)

        ones = torch.ones(batch_size, 1, device=z1.device)
        z1 = torch.cat([ones, z1], dim=1)
        z2 = torch.cat([ones, z2], dim=1)
        z3 = torch.cat([ones, z3], dim=1)

        p1 = self.factor1(z1)  # Shape: (Batch, rank1)
        p2 = self.factor2(z2)  # Shape: (Batch, rank2)
        p3 = self.factor3(z3)  # Shape: (Batch, rank3)

        # 3. Tensor Fusion (Outer Product)
        # We compute Z = p1 (x) p2 (x) p3 efficiently
        
        # Outer product of Modality 1 and 2
        # (B, r1, 1) * (B, 1, r2) -> (B, r1, r2) -> Flatten -> (B, r1*r2)
        z_12 = torch.bmm(p1.unsqueeze(2), p2.unsqueeze(1)).view(batch_size, -1)

        # Outer product with Modality 3
        # (B, r1*r2, 1) * (B, 1, r3) -> (B, r1*r2, r3) -> Flatten -> (B, r1*r2*r3)
        z_fused = torch.bmm(z_12.unsqueeze(2), p3.unsqueeze(1)).view(batch_size, -1)

        z_fused = self.dropout(z_fused)
        output = self.core_layer(z_fused)

        return output