In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Unimodal_Feature_Extraction_Network(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, k):
        super(UFEN, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, k)
        
    def forward(self, x):
        self.lstm.flatten_parameters()
        output, (final_hidden_state, final_cell_state) = self.lstm(x)
        
        avg_pool = torch.mean(output, dim=1)
        projected = self.fc(avg_pool)
        return projected # Shape: (batch, k)

class ABSLSTMCell(nn.Module):
    """
    RIA - Recurrent Interaction Attention
    GIA - Global Interation Attention
    """
    def __init__(self, input_dim, hidden_dim, attention_window_t_RIA):
        super(ABSLSTMCell, self).__init__()
        self.hidden_dim = hidden_dim
        self.t = attention_window_t_RIA # Window size for RIA
      
        self.W_f1 = nn.Linear(input_dim, hidden_dim)
        self.W_f2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        
        self.W_i1 = nn.Linear(input_dim, hidden_dim)
        self.W_i2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        
        self.W_m1 = nn.Linear(input_dim, hidden_dim)
        self.W_m2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        
        self.W_o1 = nn.Linear(input_dim, hidden_dim)
        self.W_o2 = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.W_c = nn.Linear(hidden_dim + input_dim, hidden_dim) 
        self.W_h = nn.Linear(hidden_dim + input_dim, hidden_dim)

        self.W_h_prime = nn.Linear(hidden_dim, hidden_dim)
        self.W_x_prime = nn.Linear(input_dim, hidden_dim) 
        
        self.W_h2 = nn.Linear(hidden_dim, 1, bias=False) 
        self.W_x2 = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, x_f, h_prev, c_prev, h_history, c_history):
        """
        x_f: Current input chunk (Batch, Input_Dim)
        h_prev, c_prev: Previous step hidden/cell (Batch, Hidden_Dim)
        h_history, c_history: Lists of previous tensors for RIA
        """
        batch_size = x_f.size(0)
        
        history_len = len(h_history)
        if history_len == 0:
            # If no history, context is just previous state
            c_tilde = c_prev
            h_tilde = h_prev
        else:
            # Calculate scores for history
            # We iterate through history items to compute attention scores
            s_c_list = []
            s_h_list = []
            
            # Using recent history up to self.t
            start_idx = max(0, history_len - self.t)
            relevant_h = h_history[start_idx:]
            relevant_c = c_history[start_idx:]
            
            for h_past, c_past in zip(relevant_h, relevant_c):
                # Eq 5: Concat past state with CURRENT input x_f
                cat_c = torch.cat([c_past, x_f], dim=1)
                cat_h = torch.cat([h_past, x_f], dim=1)
                
                # Compute score vectors (tanh activation per Eq 5)
                s_c_vec = torch.tanh(self.W_c(cat_c))
                s_h_vec = torch.tanh(self.W_h(cat_h))
                
                # Eq 6: L2 norm to get scalar score per history item
                s_c_list.append(torch.norm(s_c_vec, p=2, dim=1, keepdim=True))
                s_h_list.append(torch.norm(s_h_vec, p=2, dim=1, keepdim=True))
            
            # Stack scores: (Batch, History_Len)
            s_c_stack = torch.cat(s_c_list, dim=1)
            s_h_stack = torch.cat(s_h_list, dim=1)
            
            # Eq 7: Softmax
            gamma_c = F.softmax(s_c_stack, dim=1)
            gamma_h = F.softmax(s_h_stack, dim=1)
            
            # Eq 8: Weighted Sum
            c_tilde_sum = torch.zeros_like(c_prev)
            h_tilde_sum = torch.zeros_like(h_prev)
            
            for idx, (h_past, c_past) in enumerate(zip(relevant_h, relevant_c)):
                w_c = gamma_c[:, idx].unsqueeze(1)
                w_h = gamma_h[:, idx].unsqueeze(1)
                c_tilde_sum += w_c * c_past
                h_tilde_sum += w_h * h_past
                
            c_tilde = F.relu(c_tilde_sum) # 'a' is ReLU per text
            h_tilde = F.relu(h_tilde_sum)

        # --- 2. Standard LSTM Gates Updates ---
        # Eq 9, 10
        f_l = torch.sigmoid(self.W_f1(x_f) + self.W_f2(h_tilde))
        i_l = torch.sigmoid(self.W_i1(x_f) + self.W_i2(h_tilde))
        
        # Eq 11: Candidate Cell
        c_candidate = torch.tanh(self.W_m1(x_f) + self.W_m2(h_tilde))
        
        # Update Cell State
        c_l = f_l * c_tilde + i_l * c_candidate

        # Output Gate (Eq 12)
        o_l = torch.sigmoid(self.W_o1(x_f) + self.W_o2(h_tilde))
        # Initial hidden state (before GIA)
        h_l_initial = o_l * torch.tanh(c_l)

        # --- 3. Global Interaction Attention (GIA) ---
        # Eq 13: Omega_h
        omega_h = F.relu(self.W_h_prime(h_l_initial))
        # Eq 14: Omega_x
        omega_x = F.relu(self.W_x_prime(x_f))
        
        # Eq 15: Attended State Calculation
        # The text says: omega_x is pre-multiplied by W_x2 and added as BIAS
        # The text says: W_h2 and omega_h form a scalar weight for h_l
        
        # Scalar weight for hidden state: (Batch, 1)
        weight_h = torch.matmul(omega_h, self.W_h2.weight.t()) 
        # Scalar bias from input: (Batch, 1)
        bias_x = torch.matmul(omega_x, self.W_x2.weight.t())
        
        # Apply attention
        # Note: Dimensions in text are tricky. W_h2 is 1xO. 
        # We implement it as element-wise scaling or scalar scaling based on text description.
        # "W_h2 and omega_h first form a scalar... which reflects importance"
        # "pre-multiply omega_x by W_x2 and obtain a scalar... added as bias"
        
        h_l_attended = torch.tanh(weight_h * h_l_initial + bias_x)
        
        return h_l_attended, c_l

class HFFN(nn.Module):
    def __init__(self, k, d, s, hidden_dim_o, num_classes):
        super(HFFN, self).__init__()
        self.k = k # Feature length
        self.d = d # Window size
        self.s = s # Stride
        self.hidden_dim = hidden_dim_o
        
        # Divide Stage: Calculate number of chunks
        # n = (k - d) // s + 2 (because of padding logic in text)
        # We will handle padding dynamically in forward
        
        # Conquer Stage: Outer Product Dimension
        # Each vector is padded with 1, so size is d+1
        # Tensor size is (d+1)^3
        self.fusion_dim = (d + 1) ** 3
        
        # Combine Stage: ABS-LSTM
        self.abs_lstm_cell = ABSLSTMCell(self.fusion_dim, hidden_dim_o, attention_window_t=5)
        self.dropout = nn.Dropout(0.5)
        
        # Inference Stage
        # Input is concatenation of all hidden states (Bidirectional)
        # Output of ABS-LSTM is sequence of size 2*o
        # Text says "f contains a tanh... and dropout... We1 in R^{50 x n*2o}"
        # It seems they flatten the WHOLE sequence of states for the final classifier?
        # "I = softmax(We2 E)... where We1... E = f(We1 X^g + b)"
        # X^g is R^{n x 2o}. If We1 is 50 x (n*2o), they flatten X^g.
        
        # We defer the definition of the linear layer until we know 'n' (runtime)
        # or we define it based on max 'n'.
        self.inference_fc1 = None # Will init in forward or need fixed k
        self.inference_fc2 = nn.Linear(50, num_classes)
        
    def _get_local_chunks(self, vec, modality_name):
        """
        Implements Eq 1: Sliding Window
        Input: vec (Batch, k)
        Output: (Batch, n, d)
        """
        batch_size = vec.size(0)
        # Pad with zeros if necessary to make divisible
        # Text says: "feature vectors are padded with 0s to guarantee divisibility"
        # The logic: n = floor((k-d)/s) + 2
        # We use PyTorch unfold.
        
        # Reshape to (Batch, 1, 1, k) for unfold, or just use simple loop
        # Unfold on dimension 1
        # Input to unfold: (Batch, Channel, Length) -> (Batch, 1, k)
        x = vec.unsqueeze(1)
        
        # We need to pad 'x' to ensure we cover everything + the extra pad mentioned
        # Simple approach: Manual slicing to match Eq 1 exactly
        chunks = []
        i = 1
        while True:
            start_idx = self.s * (i - 1)
            end_idx = start_idx + self.d
            
            if end_idx > self.k:
                # Padding case (Last chunk)
                # Text says "padded with 0s"
                remainder = vec[:, start_idx:]
                pad_len = self.d - remainder.size(1)
                if pad_len > 0:
                    padding = torch.zeros(batch_size, pad_len).to(vec.device)
                    chunk = torch.cat([remainder, padding], dim=1)
                    chunks.append(chunk)
                break
            else:
                chunk = vec[:, start_idx:end_idx]
                chunks.append(chunk)
            
            i += 1
            
        return torch.stack(chunks, dim=1) # (Batch, n, d)

    def forward(self, l_vec, v_vec, a_vec):
        """
        l_vec, v_vec, a_vec: (Batch, k)
        """
        batch_size = l_vec.size(0)
        
        # --- 1. Divide Stage (Sliding Window) ---
        l_chunks = self._get_local_chunks(l_vec, 'l') # (Batch, n, d)
        v_chunks = self._get_local_chunks(v_vec, 'v')
        a_chunks = self._get_local_chunks(a_vec, 'a')
        
        n = l_chunks.size(1) # Number of chunks
        
        # --- 2. Conquer Stage (Tensor Fusion) ---
        # Eq 2: Pad with 1
        ones = torch.ones(batch_size, n, 1).to(l_vec.device)
        l_pad = torch.cat([l_chunks, ones], dim=2) # (Batch, n, d+1)
        v_pad = torch.cat([v_chunks, ones], dim=2)
        a_pad = torch.cat([a_chunks, ones], dim=2)
        
        # Eq 3: Outer Product
        # We process each chunk.
        # Result shape: (Batch, n, (d+1)^3)
        # Einstein summation is easiest: b=batch, n=seq, i,j,k = dims
        fused_seq = []
        for t in range(n):
            l_t = l_pad[:, t, :]
            v_t = v_pad[:, t, :]
            a_t = a_pad[:, t, :]
            
            # Outer product l (x) v (x) a
            # (Batch, d+1) -> (Batch, d+1, d+1, d+1)
            # using einsum: bi, bj, bk -> bijk
            out_prod = torch.einsum('bi,bj,bk->bijk', l_t, v_t, a_t)
            
            # Flatten to vector
            flat = out_prod.reshape(batch_size, -1) # (Batch, (d+1)^3)
            fused_seq.append(flat)
            
        X_f = torch.stack(fused_seq, dim=1) # (Batch, n, input_dim)
        
        # --- 3. Combine Stage (ABS-LSTM) ---
        # Bidirectional: Forward pass and Backward pass
        
        # -- Forward Pass --
        h_fwd_list = []
        c_prev = torch.zeros(batch_size, self.hidden_dim).to(l_vec.device)
        h_prev = torch.zeros(batch_size, self.hidden_dim).to(l_vec.device)
        h_history = []
        c_history = []
        
        for t in range(n):
            x_curr = X_f[:, t, :]
            h_curr, c_curr = self.abs_lstm_cell(x_curr, h_prev, c_prev, h_history, c_history)
            
            h_fwd_list.append(h_curr)
            
            # Update history (detach to prevent massive graph retention if needed, 
            # though standard BPTT keeps it. We keep it attached for gradients)
            h_history.append(h_curr)
            c_history.append(c_curr)
            h_prev, c_prev = h_curr, c_curr
            
        # -- Backward Pass --
        # Reverse input
        X_f_rev = torch.flip(X_f, [1])
        h_bwd_list = []
        c_prev = torch.zeros(batch_size, self.hidden_dim).to(l_vec.device)
        h_prev = torch.zeros(batch_size, self.hidden_dim).to(l_vec.device)
        h_history = []
        c_history = []
        
        for t in range(n):
            x_curr = X_f_rev[:, t, :]
            h_curr, c_curr = self.abs_lstm_cell(x_curr, h_prev, c_prev, h_history, c_history)
            h_bwd_list.append(h_curr)
            h_history.append(h_curr)
            c_history.append(c_curr)
            h_prev, c_prev = h_curr, c_curr
            
        # Reverse backward results to match time order
        h_bwd_list.reverse()
        
        # Concatenate: X^g = [h_fwd; h_bwd] per step
        # Shape: (Batch, n, 2*o)
        X_g_list = [torch.cat([f, b], dim=1) for f, b in zip(h_fwd_list, h_bwd_list)]
        X_g = torch.cat(X_g_list, dim=1) # Concatenate along feature dimension? 
        # Wait, Eq 4 says X^g in R^{n x 2o}. 
        # The inference Eq 17 uses W_e1 in R^{50 x n*2o}.
        # This implies we concatenate ALL steps into one giant vector.
        X_g_flat = torch.cat(X_g_list, dim=1) # (Batch, n * 2 * o)
        
        # --- 4. Inference Module ---
        if self.inference_fc1 is None:
            # Initialize dynamically on first run based on flattened size
            flattened_dim = X_g_flat.size(1)
            self.inference_fc1 = nn.Linear(flattened_dim, 50).to(l_vec.device)