In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LocalFusionModule(nn.Module):
    
    def __init__(self, window_size, stride):
        super(LocalFusionModule, self).__init__()
        self.d = window_size
        self.s = stride

    def forward(self, l, v, a):
        batch_size, k = l.shape
        
        # --- Divide (Sliding Window) ---
        # Handle padding if k-d is not divisible by s
        if (k - self.d) % self.s != 0:
            pad_len = self.s - ((k - self.d) % self.s)
            l = F.pad(l, (0, pad_len))
            v = F.pad(v, (0, pad_len))
            a = F.pad(a, (0, pad_len))
        
        # Create windows: (Batch, Num_Chunks, Window_Size)
        # unfold dimension 1 (time/feature dim)
        l_chunks = l.unfold(1, self.d, self.s)
        v_chunks = v.unfold(1, self.d, self.s)
        a_chunks = a.unfold(1, self.d, self.s)

        # --- Conquer (Fusion) ---
        # Pad local portion with 1s -> Shape (Batch, n, d+1)
        ones = torch.ones(batch_size, l_chunks.size(1), 1).to(l.device)
        l_pad = torch.cat([l_chunks, ones], dim=2)
        v_pad = torch.cat([v_chunks, ones], dim=2)
        a_pad = torch.cat([a_chunks, ones], dim=2)
        
        # Outer Product: (d+1) x (d+1) x (d+1)
        # Output: (Batch, n, d+1, d+1, d+1)
        fusion_tensor = torch.einsum('bni,bnj,bnk->bnijk', l_pad, v_pad, a_pad)
        
        # Flatten feature dims: (Batch, n, (d+1)^3)
        X_f = fusion_tensor.reshape(batch_size, l_chunks.size(1), -1)
        
        return X_f


class ABSLSTMCell(nn.Module):
  
    def __init__(self, input_dim, hidden_dim, lookback_t=3):
        super(ABSLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.t = lookback_t

        self.W_c = nn.Linear(hidden_dim + input_dim, hidden_dim, bias=True)
        self.W_h = nn.Linear(hidden_dim + input_dim, hidden_dim, bias=True)
        self.act_a = nn.ReLU() 
       
        self.W_f1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.W_f2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_i1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.W_i2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_m1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.W_m2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_o1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.W_o2 = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x_curr, past_h_list, past_c_list):
        # 1. RIA (Look-back Attention)
        s_c_scores, s_h_scores = [], []
        
        for i in range(len(past_h_list)):
            cat_c = torch.cat([past_c_list[i], x_curr], dim=1)
            cat_h = torch.cat([past_h_list[i], x_curr], dim=1)
            
            s_c_vec = torch.tanh(self.W_c(cat_c))
            s_h_vec = torch.tanh(self.W_h(cat_h))
            
            s_c_scores.append(torch.norm(s_c_vec, p=2, dim=1, keepdim=True))
            s_h_scores.append(torch.norm(s_h_vec, p=2, dim=1, keepdim=True))
            
        # Softmax over time dimension (Eq 7)
        gamma_c = F.softmax(torch.cat(s_c_scores, dim=1), dim=1).unsqueeze(2)
        gamma_h = F.softmax(torch.cat(s_h_scores, dim=1), dim=1).unsqueeze(2)
        
        # Weighted sum of past states (Eq 8)
        c_tilde = torch.sum(gamma_c * torch.stack(past_c_list, dim=1), dim=1)
        h_tilde = torch.sum(gamma_h * torch.stack(past_h_list, dim=1), dim=1)
        
        # Apply activation 'a'
        c_tilde = self.act_a(c_tilde)
        h_tilde = self.act_a(h_tilde)
        
        # 2. LSTM Logic (Eq 9-12)
        f_l = torch.sigmoid(self.W_f1(x_curr) + self.W_f2(h_tilde))
        i_l = torch.sigmoid(self.W_i1(x_curr) + self.W_i2(h_tilde))
        m_l = torch.tanh(self.W_m1(x_curr) + self.W_m2(h_tilde))
        o_l = torch.sigmoid(self.W_o1(x_curr) + self.W_o2(h_tilde))
        
        c_l = f_l * c_tilde + i_l * m_l
        h_l = o_l * torch.tanh(c_l)
        
        return h_l, c_l

class GlobalFusionModule(nn.Module):
    def __init__(self, input_dim, hidden_dim_o, lookback_t=3):
        """
        Bidirectional ABS-LSTM with Global Interaction Attention (GIA).
        """
        super(GlobalFusionModule, self).__init__()
        self.hidden_dim = hidden_dim_o
        self.t = lookback_t
        
        # Shared Cell for both directions
        self.abs_lstm_cell = ABSLSTMCell(input_dim, hidden_dim_o, lookback_t)
        
        # GIA Weights (Eq 13-15)
        self.bi_hidden_dim = 2 * hidden_dim_o
        self.W_h_prime = nn.Linear(self.bi_hidden_dim, hidden_dim_o, bias=True)
        self.W_x = nn.Linear(input_dim, hidden_dim_o, bias=True)
        self.W_h2 = nn.Linear(hidden_dim_o, 1, bias=False)
        self.W_x2 = nn.Linear(hidden_dim_o, 1, bias=False)

    def _run_direction(self, x_seq):
        batch_size, seq_len, _ = x_seq.size()
        
        # Init buffers for 't' past steps
        past_h = [torch.zeros(batch_size, self.hidden_dim).to(x_seq.device) for _ in range(self.t)]
        past_c = [torch.zeros(batch_size, self.hidden_dim).to(x_seq.device) for _ in range(self.t)]
        
        h_outputs = []
        for l in range(seq_len):
            x_curr = x_seq[:, l, :]
            h_l, c_l = self.abs_lstm_cell(x_curr, past_h, past_c)
            h_outputs.append(h_l)
            
            # Slide window (pop old, push new)
            past_h.pop(0); past_h.append(h_l)
            past_c.pop(0); past_c.append(c_l)
            
        return torch.stack(h_outputs, dim=1)

    def forward(self, x_f):
        # Bidirectional Pass
        h_fwd = self._run_direction(x_f)
        
        x_rev = torch.flip(x_f, dims=[1])
        h_bwd = self._run_direction(x_rev)
        h_bwd = torch.flip(h_bwd, dims=[1]) # Restore order
        
        # Concatenate: (Batch, n, 2o)
        h_l = torch.cat([h_bwd, h_fwd], dim=2)
        
        # Global Interaction Attention (GIA)
        omega_h = F.relu(self.W_h_prime(h_l))
        omega_x = F.relu(self.W_x(x_f))
        
        scalar_h = self.W_h2(omega_h)
        scalar_x = self.W_x2(omega_x)
        
        # Apply GIA weights (Eq 15)
        h_l_att = torch.tanh(scalar_h * h_l + scalar_x)
        
        return h_l_att


class EmotionInferenceModule(nn.Module):
    def __init__(self, num_chunks, hidden_dim_o, num_classes, intermediate_dim=50):
        super(EmotionInferenceModule, self).__init__()
        
        # Input size: n * 2o (all timesteps concatenated)
        self.input_dim = num_chunks * (2 * hidden_dim_o)
        
        self.W_e1 = nn.Linear(self.input_dim, intermediate_dim, bias=True)
        self.dropout = nn.Dropout(p=0.5)
        self.W_e2 = nn.Linear(intermediate_dim, num_classes, bias=False)

    def forward(self, x_g):
        # Flatten: (Batch, n * 2o)
        x_flat = x_g.view(x_g.size(0), -1)
        
        # Eq 16
        e = torch.tanh(self.W_e1(x_flat))
        e = self.dropout(e)
        
        # Eq 17
        logits = self.W_e2(e)
        return F.softmax(logits, dim=1)


class HFFN(nn.Module):
    def __init__(self, k, d, s, o, num_classes):
        """
        Args:
            k: Feature vector length per modality.
            d: Window size.
            s: Stride.
            o: LSTM hidden dim.
            num_classes: Output classes.
        """
        super(HFFN, self).__init__()
        
        # Calculate expected number of chunks (n)
        # Since LFM pads dynamically, we calculate n based on the padded length logic.
        # k - d must be divisible by s.
        remainder = (k - d) % s
        if remainder == 0:
            pad = 0
        else:
            pad = s - remainder
            
        k_padded = k + pad
        self.n = ((k_padded - d) // s) + 1
        
        self.lfm = LocalFusionModule(d, s)
        
        # GFM input dim = (d+1)^3
        lfm_out_dim = (d + 1) ** 3
        self.gfm = GlobalFusionModule(lfm_out_dim, o, lookback_t=3)
        
        self.eim = EmotionInferenceModule(self.n, o, num_classes)

    def forward(self, l, v, a):
        x_f = self.lfm(l, v, a)
        x_g = self.gfm(x_f)
        output = self.eim(x_g)
        return output


In [None]:


if __name__ == "__main__":
    # --- Configuration ---
    BATCH_SIZE = 8
    K_LEN = 50       # Feature length (k)
    D_WIN = 5        # Window size (d)
    S_STRIDE = 2     # Stride (s)
    O_HIDDEN = 32    # LSTM Hidden (o)
    N_CLASSES = 6    # Emotion classes
    
    # --- Instantiate Model ---
    model = HFFN(K_LEN, D_WIN, S_STRIDE, O_HIDDEN, N_CLASSES)
    print("Model initialized successfully.")
    
    # --- Dummy Input Data ---
    # Random Tensors for Language, Visual, Acoustic
    l_in = torch.randn(BATCH_SIZE, K_LEN)
    v_in = torch.randn(BATCH_SIZE, K_LEN)
    a_in = torch.randn(BATCH_SIZE, K_LEN)
    
    # --- Forward Pass ---
    try:
        preds = model(l_in, v_in, a_in)
        print("\n--- Forward Pass Successful ---")
        print(f"Input Shape (per modality): {(BATCH_SIZE, K_LEN)}")
        print(f"Output Shape (Probabilities): {preds.shape}")
        print(f"Sample Prediction (First Item): \n{preds[0]}")
    except Exception as e:
        print(f"\nError during forward pass: {e}")