In [1]:
import re
import sys
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import os
import time
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import warnings
import gc
from accelerate import Accelerator

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.simplefilter("ignore")
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import AutoModel, AutoTokenizer, AutoModelWithLMHead, AutoConfig
from tqdm.notebook import tqdm



In [2]:
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
[0m

In [3]:
from einops import rearrange, reduce
from typing import Optional, Tuple, Union, List
from dataclasses import dataclass


class AbsoluteEncoding(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.pos_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )
        self.max_size = config.max_position_embeddings

    def forward(self, size: int) -> torch.Tensor:
        if self.max_size < size:
            raise ValueError(
                f"The hidden size ({size }) is more than the config max_position_embeddings {self.max_size}"
            )
        return self.pos_embeddings(self.position_ids[:, :size])


class SinusoidalEncoding(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        if config.hidden_size % 2 != 0:
            raise ValueError(
                f"Cannot use SinusoidalEncoding with "
                "odd hidden dim got dim {config.hidden_size}"
            )
        self.positional_encoding = torch.zeros(
            1, config.max_position_embeddings, config.hidden_size
        )
        self.position = torch.arange(0, config.max_position_embeddings).unsqueeze(1)
        self.div_term = torch.exp(
            (
                torch.arange(0, config.hidden_size, 2, dtype=torch.float)
                * -(torch.log(torch.tensor(10000.0)) / config.hidden_size)
            )
        )

        self.positional_encoding[:, :, 0::2] = torch.sin(
            self.position.float() * self.div_term
        )
        self.positional_encoding[:, :, 1::2] = torch.cos(
            self.position.float() * self.div_term
        )

    def forward(self, seq_len: int) -> torch.Tensor:

        return self.positional_encoding[:, :seq_len]


# copied from transformer/models/gemma
class RotaryEmbedding(nn.Module):
    def __init__(self, config, base=10000, device=None):
        super().__init__()

        self.dim = int(config.hidden_size // config.num_attention_heads)
        self.max_position_embeddings = config.max_position_embeddings
        self.base = base
        self.register_buffer(
            "inv_freq",
            1.0
            / (
                self.base
                ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
            ),
            persistent=False,
        )
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )

    @torch.no_grad()
    def forward(self, seq_len: int = None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # size = x.size()[2]
        position_ids = torch.arange(seq_len).unsqueeze(0)
        # position_ids = self.position_ids[:, :size].float()

        inv_freq_expanded = (
            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        )
        position_ids_expanded = position_ids[:, None, :].float()

        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
            1, 2
        )
        return freqs


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# def rotate_half(x):
#     x1, x2 = x.chunk(2, dim=-1)
#     return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(
    q, k, freqs, only_q: bool = False, unsqueeze_dim=1
) -> Tuple[torch.Tensor]:
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        freqs: precalculated frqs for sin cos
        only_q: bool = False for encoder decoder
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = emb.cos()
    sin = emb.sin()
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    #     print(cos.size(),sin.size(),q.size(),k.size())
    if only_q:
        q_embed = (q * cos) + (rotate_half(q) * sin)
    else:

        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed


# To do :  Alibi

In [4]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
     num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def repeat_kv_einops(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = repeat(
        hidden_states,
        "batch num_key_value_heads slen head_dim -> batch num_key_value_heads n_rep slen head_dim",
        n_rep=n_rep,
    )  # hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    # return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
    return rearrange(
        hidden_states,
        "batch num_key_value_heads n_rep slen head_dim -> batch (num_key_value_heads n_rep) slen head_dim",
    )


class DecoderAttention(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.attention_bias = getattr(config, "attention_bias", True)
        self.layer_idx = layer_idx
        # self.qkv = nn.Linear(config.hidden_size,3*config.hidden_size)
        self.q = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.k = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.v = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.out = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.num_attention_heads = config.num_attention_heads
        self.rotary_emb = (
            RotaryEmbedding(config=config) if getattr(config, "is_rope", None) else None
        )
        if self.rotary_emb != None and self.layer_idx == 0:  # avoid to print m times:
            print("Decoder Using Rotatry Embedding")
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if not self.flash and self.layer_idx == 0:  # avoid to print m times:
            print("WARNING: Flash Attention requires PyTorch >= 2.0")

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ) -> Tuple[torch.Tensor, object]:
        q = self.q(hidden_state)
        k = self.k(hidden_state)
        v = self.v(hidden_state)
        # q,k,v = self.qkv(hidden_state).chunk(3, dim = -1) #b X l X d dim =-1 or 2
        # place holder for RoPe operation
        q = rearrange(q, "b l (h d) -> b h l d", h=self.num_attention_heads)
        k = rearrange(k, "b l (h d) -> b h l d", h=self.num_attention_heads)
        v = rearrange(v, "b l (h d) -> b h l d", h=self.num_attention_heads)

        if freqs is not None:
            q, k = apply_rotary_pos_emb(q, k, freqs)

        if use_cache:
            cache = getattr(self, "cache", None)
            if cache is None:
                raise ValueError(
                    "you need to setup cache for every attention layer with model.setup_cache()"
                )
            k, v = cache.update(k, v, start_pos)

        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, attn_mask=attention_mask
        )
        out = rearrange(out, "b h l d -> b l (h d)")
        out = self.out(out)
        return out


class DecoderAttentionGqa(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if not self.flash and self.layer_idx == 0:  # avoid to print m times
            print("WARNING: Flash Attention requires PyTorch >= 2.0")
        self.layer_idx = layer_idx
        self.head_dim = int(config.hidden_size // config.num_attention_heads)
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = getattr(config, "num_key_value_heads", 4)
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
        if (
            self.num_attention_heads % self.num_key_value_heads != 0
            or self.num_attention_heads < self.num_key_value_heads
        ):
            raise ValueError(
                f"num_key_value_heads {self.num_key_value_heads }  should be less than equal num_attention_heads {config.num_attention_heads} and  multiple of num_attention_heads {config.num_attention_heads} "
            )
        self.attention_bias = getattr(config, "attention_bias", True)
        self.out = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.q = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.k = nn.Linear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=self.attention_bias,
        )
        self.v = nn.Linear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=self.attention_bias,
        )

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        q = self.q(hidden_state)
        k = self.k(hidden_state)
        v = self.v(hidden_state)
        q = rearrange(q, "b l (h d) -> b h l d", d=self.head_dim)
        k = rearrange(k, "b l (h d) -> b h l d", d=self.head_dim)
        v = rearrange(v, "b l (h d) -> b h l d", d=self.head_dim)
        if freqs is not None:
            q, k = apply_rotary_pos_emb(q, k, freqs)

        if use_cache:
            cache = getattr(self, "cache", None)
            if cache is None:
                raise ValueError(
                    "you need to setup cache for every attention layer with model._setup_cache() before using it"
                )
            k, v = cache.update(k, v, start_pos)

        k = repeat_kv(k, n_rep=self.num_key_value_groups)
        v = repeat_kv(v, n_rep=self.num_key_value_groups)

        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, attn_mask=attention_mask
        )
        out = rearrange(out, "b h l d -> b l (h d)")
        out = self.out(out)
        return out

In [5]:
_ACT_ = {
    "gelu": nn.GELU(),
    "leaky_relu": nn.LeakyReLU(),
    "relu6": nn.ReLU6(),
    "sigmoid": nn.Sigmoid(),
    "silu": nn.SiLU(),
    "swish": nn.SiLU(),
    "tanh": nn.Tanh(),
}


class FeedForward(nn.Module):
    def __init__(self, config, multiplier: Union[int, float] = 4) -> None:
        super().__init__()
        self.intermediate = nn.Linear(
            config.hidden_size, int(multiplier) * config.hidden_size
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.layerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        if _ACT_.get(getattr(config, "hidden_act", None), None):
            self.act_fn = _ACT_[config.hidden_act]
        else:
            self.act_fn = nn.GELU()
        self.out = nn.Linear(int(multiplier) * config.hidden_size, config.hidden_size)

    def forward(
        self, hidden_state: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        output = self.intermediate(hidden_state)
        output = self.act_fn(output)
        output = self.out(output)
        output = self.dropout(output)
        output = self.layerNorm(output + input_tensor)
        return output

In [6]:
class DynamicCache:
    """
    A cache that grows dynamically as more tokens are generated.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self, config) -> None:
        self.key_cache: torch.Tensor = None
        self.value_cache: torch.Tensor = None
        self._seen_tokens = False

    def __len__(self) -> int:
        if self.key_cache is None:
            return 0
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return self.key_cache.shape[-2]

    def update(
        self, key_states: torch.Tensor, value_states: torch.Tensor, start_pos: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """

        # Update the cache first iteration
        if self.key_cache is None:
            self._seen_tokens = True
            self.key_cache = key_states.clone()
            self.value_cache = value_states.clone()
        else:
            self.key_cache = torch.cat([self.key_cache, key_states], dim=-2)
            self.value_cache = torch.cat([self.value_cache, value_states], dim=-2)

        return self.key_cache, self.value_cache

    def get(self) -> Tuple[torch.Tensor]:
        if self._seen_tokens:
            return self.key_cache, self.value_cache
        else:
            raise ValueError("there is no token available in kv-cache")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if self.key_cache is None:
            return 0
        return self.key_cache.shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return None


class StaticCache:
    """
    A cache that is size fixed suitable for torch.compile

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self, config) -> None:
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
        self.key_cache: torch.Tensor = torch.zeros(
            1,
            self.heads,
            config.max_position_embeddings,
            self.head_size,
        )
        self.value_cache: torch.Tensor = torch.zeros(
            1,
            self.heads,
            config.max_position_embeddings,
            self.head_size,
        )
        self._seen_tokens = False

    def update(
        self, k: torch.Tensor, v: torch.Tensor, start_pos: int = 0
    ) -> Tuple[torch.Tensor]:
        self._seen_tokens = True
        bsz, head, seqlen, _ = k.shape
        assert bsz == 1, "Only support batch size 1"

        self.key_cache = self.key_cache.to(k)
        self.value_cache = self.value_cache.to(v)

        self.key_cache[:bsz, :, start_pos : start_pos + seqlen] = k
        self.value_cache[:bsz, :, start_pos : start_pos + seqlen] = v

        k = self.key_cache[:bsz, :, : start_pos + seqlen]
        v = self.value_cache[:bsz, :, : start_pos + seqlen]

        return k, v

    def get(self) -> Tuple[torch.Tensor]:
        if self._seen_tokens:
            return self.key_cache, self.value_cache
        else:
            raise ValueError("there is no token available in kv-cache")

    def __len__(self) -> int:
        if self._seen_tokens == False:
            return 0
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return self.key_cache.shape[2]

In [7]:
# GPT style model Casual language modeling

_position_embeddings = {
    "absolute": AbsoluteEncoding,
    "sinusoidal": SinusoidalEncoding,
}


@dataclass
class DecoderOutput(object):
    logits: torch.Tensor
    past_key_value: Optional[object]


@dataclass
class CLMOutput(object):
    hidden_state: torch.Tensor
    logits: torch.Tensor


class DecoderLayer(nn.Module):

    def __init__(self, config, layer_idx: int, attention_type: str = None) -> None:
        super().__init__()
        self.attention = (
            DecoderAttentionGqa(config, layer_idx=layer_idx)
            if attention_type == "gqa"
            else DecoderAttention(config, layer_idx=layer_idx)
        )
        if attention_type == "gqa" and layer_idx == 0:  # avoid to print m times
            print("Decoder Using GQA Attention")
        self.feed_forward = FeedForward(config)
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        out = self.attention(
            hidden_state=hidden_state,
            attention_mask=attention_mask,
            freqs=freqs,
            use_cache=use_cache,
            start_pos=start_pos,
        )
        out = self.feed_forward(out, hidden_state)
        return out


class LMHead(nn.Module):
    """Head for masked language modelling"""

    def __init__(self, config) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_state)
        x = nn.GELU()(x)
        x = self.layerNorm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x


class DecoderModel(nn.Module):

    def __init__(
        self,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> None:
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            padding_idx=getattr(config, "pad_token_id", None),
        )
        if _position_embeddings.get(pos_embedding_type, None) is not None:
            self.position_embeddings = _position_embeddings.get(pos_embedding_type)(
                config
            )
        else:
            self.position_embeddings = None
        if pos_embedding_type == "rope":
            self.emb_freq = RotaryEmbedding(config)(config.max_position_embeddings)
            print(
                "Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable"
            )
        self.all_layer = nn.ModuleList(
            [
                DecoderLayer(config, layer_idx, attention_type)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.lm_head = LMHead(config=config)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = False,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        _bsz, seqlen = input_ids.shape
        hidden_state = self.word_embeddings(input_ids)
        freqs = None
        if self.position_embeddings is not None:
            pos_info = self.position_embeddings(start_pos + seqlen)[
                :, start_pos : start_pos + seqlen, :
            ].to(input_ids.device)
            hidden_state = hidden_state + pos_info
        else:
            freqs = self.emb_freq[:, start_pos : start_pos + seqlen].to(
                input_ids.device
            )
        mask = None
        if seqlen > 1:
            mask = self.create_mask_for_decoder(
                input_ids=input_ids, attention_mask=attention_mask, start_pos=start_pos
            )
            mask = (1.0 - mask) * torch.finfo(
                hidden_state.dtype
            ).min  # invert it to to add directly to attention score

        for layer in self.all_layer:
            hidden_state = layer(
                hidden_state,
                mask,
                freqs=freqs,
                use_cache=use_cache,
                start_pos=start_pos,
            )
        logits = self.lm_head(hidden_state)
        return CLMOutput(hidden_state=hidden_state, logits=logits)

    def create_mask_for_decoder(
        self,
        input_ids,
        attention_mask: Optional[torch.Tensor] = None,
        start_pos: Optional[int] = 0,
    ) -> torch.Tensor:
        device = input_ids.device
        batch_size, seq_length = input_ids.shape
        if attention_mask is None:
            attention_mask = (
                torch.ones(seq_length + start_pos).repeat(batch_size, 1).to(device)
            )
        seq_ids = torch.arange(seq_length).to(device)
        causal_mask = (
            seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
            <= seq_ids[None, :, None]
        )  # 1x1xl repeat bxlxl compare to 1xlx1

        causal_mask = causal_mask.to(attention_mask.dtype)

        if start_pos > 0:  # correct the attention mask  for kv-cache operation
            causal_mask = torch.cat(
                [
                    torch.ones(
                        (batch_size, seq_length, start_pos),
                        device=device,
                        dtype=causal_mask.dtype,
                    ),
                    causal_mask,
                ],
                axis=-1,
            )

        extended_attention_mask = (
            causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
        )  # # this is mainly if batch contains <PAD> tokens. stop casual procees before <PAD>
        return extended_attention_mask

    @classmethod
    def from_config(
        cls,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: Optional[str] = None,
    ) -> nn.Module:
        return cls(config, pos_embedding_type, attention_type)

    def _setup_cache(self, config, cls: Optional[object] = StaticCache) -> None:
        for layer in self.all_layer:
            layer.attention.cache = cls(config)

    def _clean_cache(self) -> None:
        for layer in self.all_layer:
            layer.attention.cache = None

In [8]:
from einops import rearrange

model_ckpt = "roberta-base"
config = AutoConfig.from_pretrained(model_ckpt)

Downloading config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [9]:
m = AutoModelWithLMHead.from_pretrained(
    "../input/transformer-distilation-gpt-2/gpt2_6L"
)
# config = AutoConfig.from_pretrained('../input/transformer-distilation-gpt-2/gpt2_6L')

In [10]:
state_dict = m.state_dict()

In [11]:
string = open(
    "/kaggle/input/mark-twain-books/Combine.txt", encoding="utf8", errors="ignore"
).read()
new_str = re.sub("�", "", string)
open("Train.txt", "w").write(new_str)

6588596

In [12]:
config.vocab_size = 50257
config.num_hidden_layers = 6
# config.is_rope = True

In [13]:
model = DecoderModel.from_config(config, pos_embedding_type="rope")

Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable


In [14]:
model.word_embeddings = nn.Embedding.from_pretrained(
    state_dict["transformer.wte.weight"], freeze=False
)

**Copy Embedding for faster convergence**

**Data Source**

https://www.kaggle.com/datasets/msinger007/mark-twain-books

In [15]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

train_path = "Train.txt"

Downloading tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [16]:
class TextDataset(Dataset):

    def __init__(self, tokenizer, file_path: str, block_size: int):
        if os.path.isfile(file_path) is False:
            raise ValueError(f"Input file path {file_path} not found")

        block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
        saved = False
        cache_dir = None
        directory, filename = os.path.split(file_path)
        cached_features_file = os.path.join(
            cache_dir if cache_dir is not None else directory,
            f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
        )

        if os.path.exists(cached_features_file) and saved:
            start = time.time()
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        #                 logger.info(
        #                     f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
        #                 )

        else:
            #                 logger.info(f"Creating features from dataset file at {directory}")

            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                text = f.read()

            tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))

            for i in range(
                0, len(tokenized_text) - block_size + 1, block_size
            ):  # Truncate in block of block_size
                self.examples.append(
                    tokenizer.build_inputs_with_special_tokens(
                        tokenized_text[i : i + block_size]
                    )
                )
            # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
            # If your dataset is small, first you should look for a bigger one :-) and second you
            # can change this behavior by adding (model specific) padding.

            start = time.time()
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
                saved = True

    #                 logger.info(
    #                     f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
    #                 )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return {"input_ids": torch.tensor(self.examples[i], dtype=torch.long)}

In [17]:
def collate(batch):
    labels = batch["input_ids"].clone()
    if tokenizer.pad_token_id is not None:
        labels[labels == tokenizer.pad_token_id] = -100
    batch["labels"] = labels
    return batch

In [18]:
train_loader = torch.utils.data.DataLoader(
    TextDataset(tokenizer, train_path, 128), batch_size=24, shuffle=True, num_workers=2
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1580900 > 1024). Running this sequence through the model will result in indexing errors


In [20]:
lr = 5e-5
# train it fully
no_decay = ["bias", "layerNorm.weight", "layerNorm.bias"]
optimizer_grouped_parameters = [
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.01,
    },
    {
        "params": [
            p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

In [21]:
EPOCHS = 7
accumulation_steps = 2
num_train_optimization_steps = int(EPOCHS * len(train_loader) / accumulation_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0.05 * num_train_optimization_steps,
    num_training_steps=num_train_optimization_steps,
)

In [22]:
def loss_fn(labels, prediction_scores):
    shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    lm_loss = loss_fct(
        shifted_prediction_scores.view(-1, config.vocab_size), labels.view(-1)
    )
    return lm_loss

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
# for step, data in enumerate(train_loader):
#     data =  collate(data)
#     x = data["input_ids"].to(device)
#     y = data['labels'].to(device)
#     break

In [25]:
# model()

In [26]:
# out = model(input_ids = x)

In [27]:
# out.logits.size()

In [28]:
def train(train_loader=train_loader, model=model):
    best_epoch_loss = np.inf
    model.to(device)
    accelerator = accelerator = Accelerator(log_with="tensorboard", logging_dir=".")
    Config = {
        "num_epoch": EPOCHS,
        "learning_rate": lr,
        "loss_function": str(torch.nn.CrossEntropyLoss),
    }
    model.train()
    accelerator.init_trackers(f"CLM_project", config=Config)
    for epoch in range(EPOCHS):
        start_time = time.time()
        avg_loss = 0.0
        model.train()
        tbar = tqdm(train_loader, file=sys.stdout)
        loss_list = []
        tbar.set_description(f"Epoch {epoch+1}")
        for step, data in enumerate(tbar):
            data = collate(data)
            x = data["input_ids"].to(device)
            y = data["labels"].to(device)
            optimizer.zero_grad()
            pred = model(input_ids=x).logits
            loss = loss_fn(y, pred)
            loss.backward()
            optimizer.step()
            scheduler.step()
            accelerator.log({"train_step": loss.item()}, step=step)
            tbar.set_postfix(loss=loss.item())
            loss_list.append(loss.detach().cpu().item())

        avg_loss = np.round(np.mean(loss_list), 4)
        accelerator.log({"train_step_epoch": avg_loss}, step=epoch + 1)

        print(f"Epoch--{epoch+1} ### Train loss---{avg_loss}")

    PATH = f"decoder__{epoch}.pth"
    torch.save(model.state_dict(), PATH)
    accelerator.end_training()
    del train_loader
    gc.collect()

In [29]:
train()

  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--1 ### Train loss---5.8968


  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--2 ### Train loss---4.5792


  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--3 ### Train loss---4.1407


  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--4 ### Train loss---3.8287


  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--5 ### Train loss---3.6911


  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--6 ### Train loss---3.6913


  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--7 ### Train loss---3.6913


  0%|          | 0/515 [00:00<?, ?it/s]

Epoch--8 ### Train loss---3.6911


In [None]:
# model = DecoderModel.from_config(config)
# model.load_state_dict(torch.load('/kaggle/input/update-gpt2-scratch-kv-cache-sdpa/gpt2_epoch__9.pth'))

In [30]:
model.eval()

DecoderModel(
  (word_embeddings): Embedding(50257, 768)
  (all_layer): ModuleList(
    (0-5): 6 x DecoderLayer(
      (attention): DecoderAttention(
        (q): Linear(in_features=768, out_features=768, bias=True)
        (k): Linear(in_features=768, out_features=768, bias=True)
        (v): Linear(in_features=768, out_features=768, bias=True)
        (out): Linear(in_features=768, out_features=768, bias=True)
      )
      (feed_forward): FeedForward(
        (intermediate): Linear(in_features=768, out_features=3072, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (layerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (act_fn): GELU(approximate='none')
        (out): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (lm_head): LMHead(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (layerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (decoder): Linear(in_features=768, out_features

In [31]:
def generate(
    model: nn.Module,
    tokenize_text: torch.Tensor,
    max_new_tokens: Optional[int] = 128,
    temperature: Optional[float] = 1.0,
    do_sample: Optional[bool] = False,
    use_cache: Optional[bool] = False,
) -> torch.Tensor:
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (1,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    #     text = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
    idx = tokenize_text
    idx_next = idx
    index = 0
    take = -1
    #     for cur_pos in range(min_promp, total_len)
    for _ in range(max_new_tokens):
        if use_cache == False:
            with torch.no_grad():
                logits = model(input_ids=idx).logits
        else:
            with torch.no_grad():
                logits = model(
                    input_ids=idx_next, start_pos=index, use_cache=use_cache
                ).logits

        if take != 0:
            logits = logits[:, take, :] / temperature
            if use_cache == True:
                take = 0
        else:
            logits = logits[:, -1] / temperature
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)

        idx = torch.cat((idx, idx_next), dim=1)
        index = idx.size()[1] - 1  # model already have idx-1 kv-cache stored

    return idx

In [32]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

**This just for demonstrated purpose you need more data for fluent model**

In [33]:
text = tokenizer.encode(
    "this of the test", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text)
tokenizer.decode(out[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)

'this of the test.  The result was not a good deal. He was a good deal of a good deal of  the matter, and he was a good deal of it. He was a good deal of  course, and he was a good deal of a thing which he had been to  be a good deal of time. He had been a good deal of time, and he had  been a good deal of a good deal of time. He had been a good deal of  talk, and he had a good deal of time to get a chance to get a  chance. He had been a good deal of time, and he had'

In [34]:
text = tokenizer.encode(
    "long before that I was", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text)
tokenizer.decode(out[0])

'long before that I was  in the midst of a great many of the time, and I was so glad I had  to get a chance to get it out of my mind.I had to go to the  place, and I was going to be a good deal of my own.I had to  get a chance to get my mind to be a good deal of a time.  I was a good deal of a good deal, and I was so sorry I had been  in the way of the country.I had been a good deal of a good  time, but I had no time to get my mind to be so much more than'

**what is KV-Cache**

https://www.dipkumar.dev/posts/gpt-kvcache/

In [35]:
model._clean_cache()
model._setup_cache(config)
text = tokenizer.encode(
    "this of the test", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text, use_cache=True)
tokenizer.decode(out[0])

'this of the test.  The result was not a good deal. He was a good deal of a good deal of  the matter, and he was a good deal of it. He was a good deal of  course, and he was a good deal of a thing which he had been to  be a good deal of time. He had been a good deal of time, and he had  been a good deal of a good deal of time. He had been a good deal of  talk, and he had a good deal of time to get a chance to get a  chance. He had been a good deal of time, and he had'

In [36]:
model._clean_cache()
model._setup_cache(config)
text = text = tokenizer.encode(
    "long before that I was", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text, use_cache=True)
tokenizer.decode(out[0])

'long before that I was  in the midst of a great many of the time, and I was so glad I had  to get a chance to get it out of my mind.I had to go to the  place, and I was going to be a good deal of my own.I had to  get a chance to get my mind to be a good deal of a time.  I was a good deal of a good deal, and I was so sorry I had been  in the way of the country.I had been a good deal of a good  time, but I had no time to get my mind to be so much more than'

In [37]:
model._clean_cache()
model._setup_cache(config, cls=DynamicCache)
text = tokenizer.encode(
    "this of the test", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text, use_cache=True)
tokenizer.decode(out[0])

'this of the test.  The result was not a good deal. He was a good deal of a good deal of  the matter, and he was a good deal of it. He was a good deal of  course, and he was a good deal of a thing which he had been to  be a good deal of time. He had been a good deal of time, and he had  been a good deal of a good deal of time. He had been a good deal of  talk, and he had a good deal of time to get a chance to get a  chance. He had been a good deal of time, and he had'

In [38]:
model._clean_cache()
model._setup_cache(config, cls=DynamicCache)
text = text = tokenizer.encode(
    "long before that I was", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text, use_cache=True)
tokenizer.decode(out[0])

'long before that I was  in the midst of a great many of the time, and I was so glad I had  to get a chance to get it out of my mind.I had to go to the  place, and I was going to be a good deal of my own.I had to  get a chance to get my mind to be a good deal of a time.  I was a good deal of a good deal, and I was so sorry I had been  in the way of the country.I had been a good deal of a good  time, but I had no time to get my mind to be so much more than'

**Enable Sampling**

In [39]:
model._clean_cache()
model._setup_cache(config)
text = text = tokenizer.encode(
    "this of the test", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text, use_cache=True, do_sample=True)
tokenizer.decode(out[0])

'this of the test.  We can only tell us more about as a mile and an open calf and pass  through the lower left of the woods, and as he is in the most  frenzied blood.I honestly said anything.We were glad it was  emptied.We could not use how much night when those were coming  and lonesome anywhere in the house but we were dead, not the  stricken haggard.I thanked him to a dog, and was done at all.  Mars Tom, I thought, open to that box, and I wish we lent out a sigh  from the doors Secret dressed the old pilot'

In [40]:
model._clean_cache()
model._setup_cache(config)
text = text = tokenizer.encode(
    "long before that I was", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text, use_cache=True, do_sample=True)
tokenizer.decode(out[0])

"long before that I was watching him. When he said I was  tall--naulded:  No, I was not out of it. Let us kill us. I had to the rest of  the other.  He was up toiled and set at Quixby way into Peterse and returned. He aclysmed to come after the humanity, but he spoken here in the  people who smelled like their dust in it, and looked at him the others  entirely, as pleased, as he remembers the messenger of his oldhematically  grievousness. Won't you'll ask you for him? Are in the life of a  young"

In [41]:
model._clean_cache()
model._setup_cache(config)
text = text = tokenizer.encode(
    "Well, sir, you could", add_special_tokens=False, return_tensors="pt"
).to(device)
out = generate(model=model, tokenize_text=text, use_cache=True, do_sample=True)
tokenizer.decode(out[0])

"Well, sir, you could seem to  wait a-lock three hours for four Jake. And it lacksenment that he  hurts Beckys concerned only of foreign accord. trulers loved him,  and killed him for distinguished countless others in the river along,  and he shook his head and stood rigid, anyway.And yet he and  pretty soon he came up and didnt get it. Satan found it, and we got  out for lack of music.  Don't you know it before him.  Why, thats anything? Get out and I dont know it.Don favor, yes.  Well, suggested Huck, its one ben for the summer"