In [None]:
from typing import Tuple, Optional, Dict, List, Union, Any

import math 
import random

import numpy as np 

import JAX

import torch 
import torch.nn as nn 
from torch.nn.utils.rnn import pack_sequence, pad_sequence, pack_padded_sequence, pad_packed_sequence 
import torch.nn.functional as F 

from einops import rearrange 

from flash_attn.flash_attention import FlashMHA, FlashAttention 
from flash_attn.bert_padding import unpad_input, pad_input 
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 

In [None]:
class PositionalEncoding(nn.Module):
    '''
    Positional encoding for transformer models.
    '''

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1) -> None: 
        '''
        Initialize the positional encoding layer. 

        Args:
            d_model (int): The dimension of the model embeddings
            max_len (int): Maximum sequence length for positional encoding
            dropout (float): Dropout probability for the positional encoding

        Returns:
            None
        '''
        super().__init__()

        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(
            max_len, 
            d_model
        )
        position = torch.arange(
            start=0,
            end=max_len,
            dtype=torch.float
        ).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        '''
        Apply positional encoding to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (seq_len, batch_size, d_model)

        Returns:
            torch.Tensor: Tensor with the positional encoding added
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
class CustomFlashAttention(nn.Module):
    '''
    Custom Flash Attention optimized to work with Tesla GPUs.
    '''
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1) -> None:
        '''
        Initialize the custom flash attention module.

        Args:
            embed_dim (int): The embedding dimension 
            num_heads (int): Number of attention heads 
            dropout (float): Dropout probability for attention weights

        Returns: 
            None
        '''
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads 
        self.dropout_p = dropout
        self.softmax_scale = embed_dim ** -0.5

    def forward(
        self,
        q: torch.Tensor, 
        k: torch.Tensor, 
        v: torch.Tensor,
        key_padding_mask: Optional[torch.Tensor],
        casual: bool = False,
        cu_seqlens: Optional[torch.Tensor] = None,
        max_s: Optional[int] = None,
        need_weights: bool = False
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        '''
        Implements the multihead softmax attention with separate query, key and value tensors.

        Args:
            q (torch.Tensor): Query tensor of shape (B, S_q, H, D)
            k (torch.Tensor): Key tensor of shape (B, S_k, H, D)
            v (torch.Tensor): Value tensor of shape (B, S_v, H, D)
            key_padding_mask (Optional[torch.Tensor]): Boolean mask for padding, shape (B, S_k)
            casual (bool): Whether to use causal attention
            cu_seqlens (Optional[torch.Tensor]): Cumulative sequence lengths for packed sequences
            max_s (Optional[int]): Maximum sequence length
            need_weights (bool): Whether to return attention weights

        Returns:
            Tuple[torch.Tensor, Optional[torch.Tensor]]: Output tensor and optional attention weights
        '''
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.dtype == k.dtype == v.dtype 
        assert q.is_cuda and k.is_cuda and v.is_cuda 

        batch_size = q.shape[0]
        seqlen_q = q.shape[1]
        seqlen_k = k.shape[1]