In [None]:
import re
import sys
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import os
import time
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import warnings
import gc
from accelerate import Accelerator

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.simplefilter("ignore")
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoConfig
from tqdm.notebook import tqdm

In [2]:
!pip install einops



In [None]:
from einops import rearrange, reduce
from typing import Optional, Tuple, Union, List
from dataclasses import dataclass


class AbsoluteEncoding(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.pos_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )
        self.max_size = config.max_position_embeddings

    def forward(self, size: int) -> torch.Tensor:
        if self.max_size < size:
            raise ValueError(
                f"The hidden size ({size }) is more than the config max_position_embeddings {self.max_size}"
            )
        return self.pos_embeddings(self.position_ids[:, :size])


class SinusoidalEncoding(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        if config.hidden_size % 2 != 0:
            raise ValueError(
                f"Cannot use SinusoidalEncoding with "
                "odd hidden dim got dim {config.hidden_size}"
            )
        self.positional_encoding = torch.zeros(
            1, config.max_position_embeddings, config.hidden_size
        )
        self.position = torch.arange(0, config.max_position_embeddings).unsqueeze(1)
        self.div_term = torch.exp(
            (
                torch.arange(0, config.hidden_size, 2, dtype=torch.float)
                * -(torch.log(torch.tensor(10000.0)) / config.hidden_size)
            )
        )

        self.positional_encoding[:, :, 0::2] = torch.sin(
            self.position.float() * self.div_term
        )
        self.positional_encoding[:, :, 1::2] = torch.cos(
            self.position.float() * self.div_term
        )

    def forward(self, seq_len: int) -> torch.Tensor:

        return self.positional_encoding[:, :seq_len]


class RotaryEmbedding(nn.Module):
    """
    RotaryEmbedding is a PyTorch module that implements rotary positional embeddings for attention mechanisms.
    Args:
        config (object): Configuration object containing the following attributes:
            hidden_size (int): The hidden size of the model.
            num_attention_heads (int): The number of attention heads.
    Attributes:
        inv_freq (torch.Tensor): A tensor containing the inverse frequencies for the rotary embeddings.
    Methods:
        forward(seq_len):
            Computes the rotary positional embeddings for a given sequence length.
            Args:
                seq_len (int): The length of the input sequence.
            Returns:
                torch.Tensor: A tensor containing the rotary positional embeddings with shape (1, seq_len, dim).
    """

    def __init__(self, config):
        super().__init__()
        dim = int(config.hidden_size // config.num_attention_heads)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
        freqs = torch.einsum("i, j -> i j", t, self.inv_freq)

        return freqs[None, :, :]


def rotate_half(x):
    """
    Rotates half the hidden dimensions of the input tensor.

    Args:
        x (torch.Tensor): The input tensor to be rotated.

    Returns:
        torch.Tensor: The tensor with half of its hidden dimensions rotated.
    """
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(
    q, k, freqs, only_q: bool = False, unsqueeze_dim=1
) -> Tuple[torch.Tensor]:
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        freqs: precalculated frqs for sin cos
        only_q: bool = False for encoder decoder
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = emb.cos()
    sin = emb.sin()
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    #     print(cos.size(),sin.size(),q.size(),k.size())
    if only_q:
        q_embed = (q * cos) + (rotate_half(q) * sin)
    else:

        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed


# To do :  Alibi

In [None]:
class AttentionSelfOutput(nn.Module):
    def __init__(
        self, config, bias: Optional[bool] = True, out_features: Optional[int] = None
    ):
        super().__init__()
        self.dense = nn.Linear(
            config.hidden_size,
            config.hidden_size if out_features is None else out_features,
            bias=bias,
        )
        self.layernorm = nn.LayerNorm(
            config.hidden_size, eps=getattr(config, "layer_norm_eps", 1e-6)
        )
        self.dropout = nn.Dropout(getattr(config, "attention_dropout", 0.0))

    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: torch.FloatTensor of shape (batch, seq_len, embed_dim)`
            input_tensor: torch.FloatTensor of shape (batch, seq_len, embed_dim)`

        return:
               hidden_states: torch.FloatTensor of shape (batch, seq_len, embed_dim)

        """
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states + input_tensor)
        return hidden_states


class DecoderAttention(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.attention_bias = getattr(config, "attention_bias", True)
        self.layer_idx = layer_idx
        # self.qkv = nn.Linear(config.hidden_size,3*config.hidden_size)
        self.query = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.key = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.value = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.out = AttentionSelfOutput(config=config, bias=self.attention_bias)
        self.num_attention_heads = config.num_attention_heads
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if not self.flash and self.layer_idx == 0:  # avoid to print m times:
            print("WARNING: Flash Attention requires PyTorch >= 2.0")

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        kv_cache: List[torch.FloatTensor] = None,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)`
            Attention_mask: torch.Tensor of shape (batch,1, seq_len, seqlen)`
            freqs: Positional freqs in case of RoPE embedding
            use_cace: Optional to use kvCache
            start_pos: in case of kvCache to get store kv-cache at start_pos
        return:
               hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)

        """
        q = self.query(hidden_state)
        k = self.key(hidden_state)
        v = self.value(hidden_state)
        # transform it into batch_size x no_of_heads x seqlen x head_dim for Multihead Attention
        q = rearrange(q, "b l (h d) -> b h l d", h=self.num_attention_heads)
        k = rearrange(k, "b l (h d) -> b h l d", h=self.num_attention_heads)
        v = rearrange(v, "b l (h d) -> b h l d", h=self.num_attention_heads)

        if freqs is not None:
            q, k = apply_rotary_pos_emb(q, k, freqs)  # apply RoPE if freqs is available

        if use_cache:
            if kv_cache is None:
                raise ValueError("you need to pass kv_cache")
            k, v = kv_cache.update(self.layer_idx, k, v, start_pos)

        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, attn_mask=attention_mask
        )
        # transform it back into batch_size x seqlen x hidden_dim
        out = rearrange(out, "b h l d -> b l (h d)")

        return self.out(out, hidden_state), kv_cache


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class DecoderAttentionGqa(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if not self.flash and self.layer_idx == 0:  # avoid to print m times
            print("WARNING: Flash Attention requires PyTorch >= 2.0")
        self.layer_idx = layer_idx
        self.head_dim = int(config.hidden_size // config.num_attention_heads)
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = getattr(config, "num_key_value_heads", 4)
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
        if (
            self.num_attention_heads % self.num_key_value_heads != 0
            or self.num_attention_heads < self.num_key_value_heads
        ):
            raise ValueError(
                f"num_key_value_heads {self.num_key_value_heads }  should be less than equal num_attention_heads {config.num_attention_heads} and  multiple of num_attention_heads {config.num_attention_heads} "
            )
        self.attention_bias = getattr(config, "attention_bias", True)
        self.out = AttentionSelfOutput(config=config, bias=self.attention_bias)
        self.query = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.key = nn.Linear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=self.attention_bias,
        )
        self.value = nn.Linear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=self.attention_bias,
        )

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        kv_cache: List[torch.FloatTensor] = None,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)`
            Attention_mask: torch.Tensor of shape (batch,1, seq_len, seqlen)`
            freqs: Positional freqs in case of RoPE embedding
            use_cace: Optional to use kvCache
            start_pos: in case of kvCache to get store kv-cache at start_pos
        return:
               hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)

        """
        q = self.query(hidden_state)
        k = self.key(hidden_state)
        v = self.value(hidden_state)
        # transform it into batch_size x no_of_heads x seqlen x head_dim for Multihead Attention
        q = rearrange(q, "b l (h d) -> b h l d", d=self.head_dim)
        k = rearrange(k, "b l (h d) -> b h l d", d=self.head_dim)
        v = rearrange(v, "b l (h d) -> b h l d", d=self.head_dim)
        if freqs is not None:
            q, k = apply_rotary_pos_emb(q, k, freqs)  # apply RoPE if freqs is available

        if use_cache:
            if kv_cache is None:
                raise ValueError("you need to pass kv_cache")
            k, v = kv_cache.update(self.layer_idx, k, v, start_pos)

        k = repeat_kv(
            k, n_rep=self.num_key_value_groups
        )  # in case of GQA repeat k,v to make it same as q
        v = repeat_kv(v, n_rep=self.num_key_value_groups)

        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, attn_mask=attention_mask
        )
        # transform it back into batch_size x seqlen x hidden_dim
        out = rearrange(out, "b h l d -> b l (h d)")

        return self.out(out, hidden_state), kv_cache

In [None]:
_ACT_ = {
    "gelu": nn.GELU(),
    "leaky_relu": nn.LeakyReLU(),
    "relu6": nn.ReLU6(),
    "sigmoid": nn.Sigmoid(),
    "silu": nn.SiLU(),
    "swish": nn.SiLU(),
    "tanh": nn.Tanh(),
}


class FeedForward(nn.Module):
    def __init__(self, config, multiplier: Union[int, float] = 4) -> None:
        super().__init__()
        self.intermediate = nn.Linear(
            config.hidden_size, int(multiplier) * config.hidden_size
        )
        self.dropout = nn.Dropout(getattr(config, "attention_dropout", 0.0))
        self.layerNorm = nn.LayerNorm(
            config.hidden_size, eps=getattr(config, "layer_norm_eps", 1e-6)
        )
        if _ACT_.get(getattr(config, "hidden_act", None), None):
            self.act_fn = _ACT_[config.hidden_act]
        else:
            self.act_fn = nn.GELU()
        self.out = nn.Linear(int(multiplier) * config.hidden_size, config.hidden_size)

    def forward(
        self, hidden_state: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        output = self.intermediate(hidden_state)
        output = self.act_fn(output)
        output = self.out(output)
        output = self.dropout(output)
        output = self.layerNorm(output + input_tensor)
        return output

In [None]:
# mainly 2way to do one keep it into the model init like llama https://github.com/meta-llama/llama/blob/main/llama/model.py
# every attention layer have its own kv-cache storage
# or keep all attention layer kv-cache into single storage like Huggingface Transformer


from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Tuple
import torch


class DynamicCache:
    """
    A cache that grows dynamically as more tokens are generated.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self, config, is_gqa: bool = False) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self._seen_tokens = False

        self.layers = config.num_hidden_layers
        for _ in range(self.layers):
            self.key_cache.append([])
            self.value_cache.append([])

    def __len__(self) -> int:
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        if len(self.key_cache) == 0:
            return 0
        return self.key_cache[0].shape[-2]

    def update(
        self,
        index: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        start_pos: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """

        # Update the cache first iteration'

        if len(self.key_cache[index]) == 0:
            self._seen_tokens = True
            self.key_cache[index] = key_states.clone()
            self.value_cache[index] = value_states.clone()
        else:
            self.key_cache[index] = torch.cat(
                [self.key_cache[index], key_states], dim=-2
            )
            self.value_cache[index] = torch.cat(
                [self.value_cache[index], value_states], dim=-2
            )

        return self.key_cache[index], self.value_cache[index]

    def get(self, index: int) -> Tuple[torch.Tensor]:
        if self._seen_tokens:
            return self.key_cache[index], self.value_cache[index]
        else:
            raise ValueError("there is no token available in kv-cache")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if self.key_cache is None:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return self.max_cache_len


class StaticCache:
    """
    A cache that grows dynamically as more tokens are generated.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(
        self,
        config,
        max_cache_len: int = None,
        dtype: torch.dtype = torch.float32,
        batch_size: int = 1,
        is_gqa: bool = False,
    ) -> None:
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.heads = None
        self.batch_size = batch_size
        # if is_gqa:
        self.heads = getattr(config, "num_key_value_heads", None)
        # if self.heads is None:
        #     raise ValueError(
        #         "you are using is_gqa=True and config.num_key_value_heads is not available"
        #     )
        if self.heads is None:

            self.heads = config.num_attention_heads

        self.max_cache_len = (
            config.max_position_embeddings if max_cache_len is None else max_cache_len
        )

        self.dtype = dtype

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []

        self.cache_shape = (
            self.batch_size,
            self.heads,
            self.max_cache_len,
            self.head_size,
        )

        self._seen_tokens = False
        self.layers = config.num_hidden_layers
        for _ in range(self.layers):
            blank_key_cache = torch.zeros(
                self.cache_shape, dtype=self.dtype, device=self.device
            )
            blank_value_cache = torch.zeros(
                self.cache_shape, dtype=self.dtype, device=self.device
            )
            self.key_cache.append(blank_key_cache)
            self.value_cache.append(blank_value_cache)

    def __len__(self) -> int:
        if self.key_cache is None:
            return 0
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return self.key_cache.shape[-2]

    def update(
        self,
        index: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        start_pos: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """

        # Update the cache first iteration'

        bsz, head, seqlen, _ = key_states.shape
        if seqlen > self.key_cache[index].size()[2]:
            raise ValueError(
                f"{k.shape} is more than init k_cache size {self.key_cache}"
            )

        self.key_cache[index][:bsz, :, start_pos : start_pos + seqlen] = key_states
        self.value_cache[index][:bsz, :, start_pos : start_pos + seqlen] = value_states

        k = self.key_cache[index][:bsz, :, : start_pos + seqlen]
        v = self.value_cache[index][:bsz, :, : start_pos + seqlen]

        return k, v

    def get(self, index: int) -> Tuple[torch.Tensor]:
        if self._seen_tokens:
            return self.key_cache[index], self.value_cache[index]
        else:
            raise ValueError("there is no token available in kv-cache")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if self.key_cache is None:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return None

In [None]:
# GPT style model Casual language modeling

_position_embeddings = {
    "absolute": AbsoluteEncoding,
    "sinusoidal": SinusoidalEncoding,
}


@dataclass
class DecoderOutput(object):
    logits: torch.Tensor
    past_key_value: Optional[object]


@dataclass
class CLMOutput(object):
    hidden_state: torch.Tensor
    logits: torch.Tensor
    kv_cache: List[torch.FloatTensor] = None


class DecoderLayer(nn.Module):

    def __init__(self, config, layer_idx: int, attention_type: str = None) -> None:
        super().__init__()
        self.attention = (
            DecoderAttentionGqa(config, layer_idx=layer_idx)
            if attention_type == "gqa"
            else DecoderAttention(config, layer_idx=layer_idx)
        )
        if attention_type == "gqa" and layer_idx == 0:  # avoid to print m times
            print("Decoder Using GQA Attention")
        self.feed_forward = FeedForward(config)
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        kv_cache: List[torch.FloatTensor] = None,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        out, kv_cache = self.attention(
            hidden_state=hidden_state,
            attention_mask=attention_mask,
            freqs=freqs,
            use_cache=use_cache,
            kv_cache=kv_cache,
            start_pos=start_pos,
        )
        out = self.feed_forward(out, hidden_state)
        return out, kv_cache


class LMHead(nn.Module):
    """Head for masked language modelling"""

    def __init__(self, config) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layerNorm = nn.LayerNorm(
            config.hidden_size, eps=getattr(config, "layer_norm_eps", 1e-6)
        )

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_state)
        x = nn.GELU()(x)
        x = self.layerNorm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x


class DecoderModel(nn.Module):

    def __init__(
        self,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> None:
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            padding_idx=getattr(config, "pad_token_id", None),
        )
        if _position_embeddings.get(pos_embedding_type, None) is not None:
            self.position_embeddings = _position_embeddings.get(pos_embedding_type)(
                config
            )
        else:
            self.position_embeddings = None
        if pos_embedding_type == "rope":
            self.emb_freq = RotaryEmbedding(config)(config.max_position_embeddings)
            print(
                "Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable"
            )
        self.all_layer = nn.ModuleList(
            [
                DecoderLayer(config, layer_idx, attention_type)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.lm_head = LMHead(config=config)
        self.config = config

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        kv_cache: List[torch.FloatTensor] = None,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        _bsz, seqlen = input_ids.shape
        hidden_state = self.word_embeddings(input_ids)
        freqs = None
        if self.position_embeddings is not None:
            pos_info = self.position_embeddings(start_pos + seqlen)[
                :, start_pos : start_pos + seqlen, :
            ].to(input_ids.device)
            hidden_state = hidden_state + pos_info
        else:
            freqs = self.emb_freq[:, start_pos : start_pos + seqlen].to(
                input_ids.device
            )
        mask = None
        if seqlen > 1:
            mask = self.create_mask_for_decoder(
                input_ids=input_ids, attention_mask=attention_mask, start_pos=start_pos
            )
            mask = (1.0 - mask) * torch.finfo(
                hidden_state.dtype
            ).min  # invert it to to add directly to attention score

        for layer in self.all_layer:
            hidden_state, kv_cache = layer(
                hidden_state,
                mask,
                freqs=freqs,
                use_cache=use_cache,
                kv_cache=kv_cache,
                start_pos=start_pos,
            )
        logits = self.lm_head(hidden_state)
        return CLMOutput(hidden_state=hidden_state, logits=logits, kv_cache=kv_cache)

    def create_mask_for_decoder(
        self,
        input_ids,
        attention_mask: Optional[torch.Tensor] = None,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        device = input_ids.device
        batch_size, seq_length = input_ids.shape
        if attention_mask is None:
            attention_mask = (
                torch.ones(seq_length + start_pos).repeat(batch_size, 1).to(device)
            )
        seq_ids = torch.arange(seq_length).to(device)
        causal_mask = (
            seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
            <= seq_ids[None, :, None]
        )  # 1x1xl repeat bxlxl compare to 1xlx1

        causal_mask = causal_mask.to(attention_mask.dtype)

        if start_pos > 0:  # correct the attention mask  for kv-cache operation
            causal_mask = torch.cat(
                [
                    torch.ones(
                        (batch_size, seq_length, start_pos),
                        device=device,
                        dtype=causal_mask.dtype,
                    ),
                    causal_mask,
                ],
                axis=-1,
            )

        extended_attention_mask = (
            causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
        )  # # this is mainly if batch contains <PAD> tokens. stop casual procees before <PAD>
        return extended_attention_mask

    @classmethod
    def from_config(
        cls,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: Optional[str] = None,
    ) -> nn.Module:
        return cls(config, pos_embedding_type, attention_type)

    def _setup_cache(self, config, cls: Optional[object] = StaticCache) -> None:
        for layer in self.all_layer:
            layer.attention.cache = cls(config)

    def _clean_cache(self) -> None:
        for layer in self.all_layer:
            layer.attention.cache = None

In [8]:
from einops import rearrange

In [None]:
m = AutoModelForCausalLM.from_pretrained(
    "../input/transformer-distilation-gpt-2/gpt2_6L"
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [11]:
tokenizer.model_max_length

1024

In [None]:
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Set reasonable default for models without max length
if tokenizer.model_max_length > 512:
    tokenizer.model_max_length = 512

In [13]:
state_dict = m.state_dict()

In [None]:
string = open(
    "/kaggle/input/mark-twain-books/Combine.txt", encoding="utf8", errors="ignore"
).read()

new_str = re.sub("�", "", string)

open("Train.txt", "w").write(new_str)

6588596

In [15]:
model_ckpt = "roberta-base"
config = AutoConfig.from_pretrained(model_ckpt)

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [16]:
config

RobertaConfig {
  "_name_or_path": "roberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.47.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

In [None]:
from types import SimpleNamespace
from collections import namedtuple

In [18]:
config = SimpleNamespace(**config.__dict__)
config.vocab_size = len(tokenizer)
config.num_hidden_layers = 6

In [None]:
model = DecoderModel.from_config(config, pos_embedding_type="rope")

Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable


In [20]:
model.word_embeddings = nn.Embedding.from_pretrained(
    state_dict["transformer.wte.weight"], freeze=False
)

**Data Source**

https://www.kaggle.com/datasets/msinger007/mark-twain-books

In [None]:
train_path = "Train.txt"

In [None]:
class TextDataset(Dataset):

    def __init__(self, tokenizer, file_path: str, block_size: int):
        if os.path.isfile(file_path) is False:
            raise ValueError(f"Input file path {file_path} not found")

        block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
        saved = False
        cache_dir = None
        directory, filename = os.path.split(file_path)
        cached_features_file = os.path.join(
            cache_dir if cache_dir is not None else directory,
            f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
        )

        if os.path.exists(cached_features_file) and saved:
            start = time.time()
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        #                 logger.info(
        #                     f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
        #                 )

        else:
            #                 logger.info(f"Creating features from dataset file at {directory}")

            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                text = f.read()

            tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))

            for i in range(
                0, len(tokenized_text) - block_size + 1, block_size
            ):  # Truncate in block of block_size
                self.examples.append(
                    tokenizer.build_inputs_with_special_tokens(
                        tokenized_text[i : i + block_size]
                    )
                )
                # )
            # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
            # If your dataset is small, first you should look for a bigger one :-) and second you
            # can change this behavior by adding (model specific) padding.

            start = time.time()
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
                saved = True

    #                 logger.info(
    #                     f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
    #                 )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return {"input_ids": torch.tensor(self.examples[i], dtype=torch.long)}

In [23]:
def collate(batch):
    labels = batch["input_ids"].clone()
    if tokenizer.pad_token_id is not None:
        labels[labels == tokenizer.pad_token_id] = -100
    batch["labels"] = labels
    return batch

In [None]:
train_loader = torch.utils.data.DataLoader(
    TextDataset(tokenizer, train_path, 128), batch_size=16, shuffle=True, num_workers=2
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1580900 > 512). Running this sequence through the model will result in indexing errors


In [25]:
# for data in train_loader:
#     break

In [26]:
# out = model(**data)

In [None]:
lr = 5e-5
# train it fully
no_decay = ["bias", "layerNorm.weight", "layerNorm.bias"]
optimizer_grouped_parameters = [
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.01,
    },
    {
        "params": [
            p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

In [None]:
EPOCHS = 5
accumulation_steps = 2
num_train_optimization_steps = int(EPOCHS * len(train_loader) / accumulation_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0.05 * num_train_optimization_steps,
    num_training_steps=num_train_optimization_steps,
)

In [None]:
def loss_fn(labels, prediction_scores):
    shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    lm_loss = loss_fct(
        shifted_prediction_scores.view(-1, config.vocab_size), labels.view(-1)
    )
    return lm_loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def train(train_loader=train_loader, model=model):
    best_epoch_loss = np.inf
    model.to(device)
    accelerator = Accelerator()
    Config = {
        "num_epoch": EPOCHS,
        "learning_rate": lr,
        "loss_function": str(torch.nn.CrossEntropyLoss),
    }
    model.train()

    for epoch in range(EPOCHS):
        start_time = time.time()
        avg_loss = 0.0
        model.train()
        tbar = tqdm(train_loader, file=sys.stdout)
        loss_list = []
        tbar.set_description(f"Epoch {epoch+1}")
        for step, data in enumerate(tbar):
            data = collate(data)
            x = data["input_ids"].to(device)
            y = data["labels"].to(device)
            optimizer.zero_grad()
            pred = model(input_ids=x).logits
            loss = loss_fn(y, pred)
            loss.backward()
            optimizer.step()
            scheduler.step()

            tbar.set_postfix(loss=loss.item())
            loss_list.append(loss.detach().cpu().item())

        avg_loss = np.round(np.mean(loss_list), 4)

        print(f"Epoch--{epoch+1} ### Train loss---{avg_loss}")

    PATH = f"decoder__{epoch}.pth"
    torch.save(model.state_dict(), PATH)

    del train_loader
    gc.collect()

In [32]:
train()

  0%|          | 0/772 [00:00<?, ?it/s]

Epoch--1 ### Train loss---6.3507


  0%|          | 0/772 [00:00<?, ?it/s]

Epoch--2 ### Train loss---5.2408


  0%|          | 0/772 [00:00<?, ?it/s]

Epoch--3 ### Train loss---5.0446


  0%|          | 0/772 [00:00<?, ?it/s]

Epoch--4 ### Train loss---5.0258


  0%|          | 0/772 [00:00<?, ?it/s]

Epoch--5 ### Train loss---5.0258


In [33]:
model.eval()

DecoderModel(
  (word_embeddings): Embedding(50257, 768)
  (all_layer): ModuleList(
    (0-5): 6 x DecoderLayer(
      (attention): DecoderAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (out): AttentionSelfOutput(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (feed_forward): FeedForward(
        (intermediate): Linear(in_features=768, out_features=3072, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (layerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (act_fn): GELU(approximate='none')
        (out): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (lm_head): LMHead(


In [None]:
@torch.no_grad()
def generate(
    model,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    max_len: int = 20,
    temperature: float = 1.0,
    use_cache: bool = True,
    do_sample: bool = False,
    use_static_cache: bool = False,
) -> torch.Tensor:

    device = input_ids.device

    all_prompt_size = [t.size()[0] for t in input_ids]

    min_prompt_len = min(all_prompt_size)
    max_prompt_len = max(all_prompt_size)

    max_len = (
        max_len + max_prompt_len
    )  # get  max len (prompt + to be generated token combined)

    pad_id = getattr(model.config, "pad_token_id", 50256)
    bsz, _ = input_ids.size()
    tokens = torch.full((bsz, max_len), pad_id, dtype=torch.long, device=device)

    kv_cache = None
    if use_cache:
        if use_static_cache:
            kv_cache = StaticCache(model.config, max_cache_len=max_len, batch_size=bsz)
        else:
            kv_cache = DynamicCache(model.config)

    for k, t in enumerate(input_ids):
        tokens[k, : t.size()[0]] = t

    prev_pos = torch.tensor(0, device=device)
    eos_reached = torch.tensor([False] * bsz, device=device)
    # to break generation if eos reached for all  prompt

    input_text_mask = tokens != pad_id  # mask to fill generated values into batch

    stop_tokens = torch.tensor(
        getattr(model.config, "eos_token_id", 50256), device=device
    )
    for cur_pos in range(min_prompt_len, max_len):

        # Get the model output
        with torch.no_grad():
            outputs = model(
                input_ids=tokens[:, prev_pos:cur_pos],
                attention_mask=attention_mask,
                use_cache=use_cache,
                kv_cache=kv_cache,
                start_pos=prev_pos,
            )
        kv_cache = outputs.kv_cache
        next_token_logits = outputs.logits[:, -1] / temperature

        if do_sample:
            next_token = torch.multinomial(next_token_logits, num_samples=1)
        else:
            _, next_token = torch.topk(next_token_logits, k=1, dim=-1)

        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        tokens[:, cur_pos] = next_token
        eos_reached |= (~input_text_mask[:, cur_pos]) & (
            torch.isin(next_token, stop_tokens)
        )

        if use_cache:
            prev_pos = cur_pos

        attention_mask = torch.cat(
            [attention_mask, torch.ones((bsz, 1), device=device)], dim=-1
        )
        if all(eos_reached):
            break
    return tokens

In [None]:
text = tokenizer(
    ["this is a test, blue", "Well, sir, you could"], return_tensors="pt", padding=True
)

In [43]:
text

{'input_ids': tensor([[ 5661,   318,   257,  1332,    11,  4171],
        [ 5779,    11, 15967,    11,   345,   714]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1]])}

In [44]:
tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)

In [None]:
input_ids, attention_mask = text["input_ids"], text["attention_mask"]
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)

In [None]:
out = generate(
    model, input_ids=input_ids, attention_mask=attention_mask, use_cache=False
)

tokenizer.batch_decode(out)

['this is a test, blue-day, and the  first time the same.  The boys were not a good deal.',
 'Well, sir, you couldnt  be aint no, but I was a good deal. I was a good deal of']

In [None]:
out = generate(
    model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
)

tokenizer.batch_decode(out)

['this is a test, blue-day, and the  first time the same.  The boys were not a good deal.',
 'Well, sir, you couldnt  be aint no, but I was a good deal. I was a good deal of']

In [None]:
out = generate(
    model,
    input_ids=input_ids,
    attention_mask=attention_mask,
    use_cache=True,
    use_static_cache=True,
)

tokenizer.batch_decode(out)

['this is a test, blue-day, and the  first time the same.  The boys were not a good deal.',
 'Well, sir, you couldnt  be aint no, but I was a good deal. I was a good deal of']