In [4]:
import sys

sys.path.append("../..")

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim 
import math
from model.long import chunked_parallel_scan, recurrent_scan


In [6]:
class LongAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.d_model = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.d_mode // self.num_heads 

        # Projections 
        self.q_proj = nn.Linear(self.d_model, self.d_model, bias = False)
        self.k_proj = nn.Linear(self.d_model, self.d_model, bias = False)
        self.v_proj = nn.Linear(self.d_model, self.d_model, bias = False)

        # Local Convolution
        self.conv = nn.Conv1d(self.d_model, 
                              self.d_model,
                              kernel_size = config.conv_kernel,
                              groups = self.d_model,
                              padding = config.conv_kernel - 1)

        # Gates 
        self.input_gate_proj = nn.Linear(self.d_model, self.d_model, bias = True)
        self.output_gate_proj = nn.Linear(self.d_model, self.d_model, bias = True)
        self.gamma_proj = nn.Linear(self.d_model, self.num_heads, bias = True)

        self.o_proj = nn.Linear(self.d_model, self.d_model)

        self.v_norm = nn.LayerNorm(self.head_dim)
        self.grp_norm = nn.GroupNorm(self.num_heads, self.d_model)
        self.mem_norm = nn.LayerNorm(self.head_dim)

        
        self.reset_parameters()

    def init_decays(self):
        """
        Initialize Gamma decays for training
        """
        with torch.no_grad():
            min_decay = 0.5 
            max_decay = 0.9999

            target_decays = 1 - torch.exp(
                torch.linspace(
                    math.log(1 - min_decay),
                    math.log(1 - max_decay),
                    self.num_heads
                )
            )

            gamma_bias_init = torch.log(target_decays / (1 - target_decays))
            self.gamma_proj.bias.copy_(gamma_bias_init)
            nn.init.zeros_(self.gamma_proj.weight)
            
        
    def reset_parameters(self):
        """
        Custom initialization logic.
        """
        # 1. Projections
        nn.init.normal_(self.q_proj.weight, std = 0.02)
        nn.init.normal_(self.k_proj.weight, std = 0.02)
        nn.init.normal_(self.v_proj.weight, std = 0.02)
        nn.init.normal_(self.o_proj.weight, std = 0.02)

        # Takes the initial gate bias to the initialize value
        # Prevents initial instability
        nn.init.zeros_(self.input_gate_proj.weight)
        nn.init.constant_(self.input_gate_proj.bias, self.config.gate_init_bias)

        # Gamma Initialization
        self.init_decays()

        # Output Gate: Start NEUTRAL (Bias 0.0)
        nn.init.constant_(self.output_gate_proj.bias, 0.0)

        # Scales the contribution of each layer relative to the depth of 
        # the model
        nn.init.normal_(self.o_proj.weight, 
                        std = 0.02 / math.sqrt(self.config.num_hidden_layers))

    def forward(self, x, state = None):
        B, T, C = x.shape

        # --- 1. Local Convolution (Merged Logic) ---
        if state is not None:
            rnn_state, conv_cache = state 
            x_t = x.transpose(1, 2).contiguous()

            # Concatenate history [old_cache, new_input]
            conv_window = torch.cat([conv_cache, x_t], dim = 2)
            conv_window = conv_window.contiguous()

            # If (Cache + Input) < Kernel Size, we must pad left to run the convolution
            # This happens when processing the very first few tokens of a generation
            window_len = conv_window.shape[-1]

            if window_len < self.config.conv_kernel:
                pad_amt = self.config.conv_kernel - window_len

                # Pad let (zeros) to reach the kernel size
                conv_window = F.pad(conv_window, (pad_amt, 0)).contiguous()

            x_conv = F.conv1d(
                conv_window, 
                self.conv.weight,
                bias = self.conv.bias, 
                padding = 0, 
                groups = self.d_model
            )

            x_conv = x_conv[:, :, :T].tranpose(1, 2).contiguous()

            # Update cache: Keep the last (kernel_size - 1) elements
            if self.config.conv_kernel > 1:
                new_conv_cache = conv_window[:, :, 1:].contiguous()
            else:
                # Should ideally empty if kernel = 1
                new_conv_cache = conv_cache

        else:
            x_input = x.transpose(1, 2).contiguous()

            pad_amt = self.config.conv_kernel -1
            x_padded = F.pad(x_input, (pad_amt, 0))

            x_conv = F.conv1d(
                x_padded,
                self.conv.weight,
                bias = self.conv.bias,
                padding = 0,
                grouops = self.d_model
            )
            
            x_conv = x_conv[:, :, :T].transpose(1, 2).contiguous()

            rnn_state = None
            if self.config.conv_kernel > 1:
                new_conv_cache = x.transpose(1, 2)[:, :, -(self.config.conv_kernel-1):].contiguous()
            else:
                new_conv_cache = None


        x_conv = F.silu(x_conv)

        # --- 2. Projections & Stability ---
        q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim)

        # Normalize Q, K for stable matching (Cosine Attention style)
        # Normalize V for stable accumulation
        q = F.normalize(q, p = 2, dim = -1)
        k = F.normalize(k, p = 2, dim = -1)
        v = self.v_norm(v)

        # --- 3. Gating (Depends on the Conv1d Input ???) ---
        # Gating mechanism all depends on the Convolution
        i_project = self.input_gate_proj(x_conv)
        i_gate = torch.sigmoid(i_project).view(B, T, self.num_heads, self.head_dim)

        # Gamma decay
        gamma = torch.sigmoid(self.gamma_proj(x_conv)).view(B, T, self.num_heads, self.head_dim)

        # --- 4. Scan Logic ---
        if rnn_state is None:
            mem = chunked_parallel_scan(k, v, i_gate, gamma)

            # Save the final state for potential continuation
            next_rnn_state = mem[:, -1].detach().clone()
        else:
            mem, next_rnn_state = recurrent_scan(k, v, i_gate,
                                                 gamma, rnn_state)

        # --- 5. Output Projection & Gating ---
        # Normalize memory state before combining with Query
        mem_out = self.mem_norm(mem)

        # Attention: Q * Memory
        scale = 1.0 / math.sqrt(self.head_dim)
        out = (mem_out * q * scale).reshape(B, T, C)

        # Group Normalization on the combined output
        out = self.grp_norm(out.reshape(B*T, C)).view(B, T, C)

        # Outut gating
        out = out * torch.sigmoid(self.output_gate_proj(x_conv))

        return self.o_proj(out), (next_rnn_state, new_conv_cache)