In [1]:
import pandas as pd

**Data used**


[https://www.kaggle.com/datasets/ajax0564/quotes-dataset](http://)

In [2]:
df = pd.read_csv("../input/transformer-dataprep/All_Quotes.csv")

In [3]:
df.shape

(526110, 3)

In [4]:
df.head(3)

Unnamed: 0,quote,index,text_length
0,If you do not take an interest in the affairs ...,,
1,A wise man speaks because he has something to ...,,
2,"Someday, in the distant future, our grand-chil...",,


In [5]:
df["quote_length"] = df["quote"].apply(lambda x: len(x.split()))

# Get the maximum and minimum text lengths
max_text_length = df["quote_length"].max()
min_text_length = df["quote_length"].min()

# Print the results
print("Maximum text length:", max_text_length)
print("Minimum text length:", min_text_length)

Maximum text length: 5459
Minimum text length: 1


In [6]:
dff = df[df["quote_length"] >= 5].reset_index(drop=True)

In [7]:
dff[dff["quote_length"] == 5]

Unnamed: 0,quote,index,text_length,quote_length
46900,Self-talk reflects your innermost feelings.,1789.0,43.0,5
47247,"Sometimes, remembering hurts too much.",2145.0,38.0,5
47530,Because...because...she came here with me.,2438.0,42.0,5
48532,Cynics are simply thwarted romantics.,3471.0,37.0,5
49073,"Forget injuries, never forget kindnesses.",4032.0,41.0,5
...,...,...,...,...
517685,Every relationship has its complications.,491243.0,41.0,5
518209,Success demands singleness of purpose.,491783.0,38.0,5
519378,Men aren't necessities. They're luxuries.,492984.0,41.0,5
519663,Getting angry doesn't solve anything.,493277.0,37.0,5


**Reading Material to understand MLM/Electra**


https://web.stanford.edu/~jurafsky/slp3/11.pdf](http://)


https://research.google/blog/more-efficient-nlp-model-pre-training-with-electra/

In [8]:
import os
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModel,
    AutoTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
)
import gc
import numpy as np
from transformers import AutoConfig
from tqdm.notebook import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings

warnings.simplefilter("ignore")

In [9]:
model_ckpt = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [10]:
import torch


class MLMDataset(Dataset):
    def __init__(self, text):
        self.text = text
        self.max_len = 128
        self.tokenizer = tokenizer
        self.num_examples = len(self.text)

    def __len__(self):
        return self.num_examples

    def __getitem__(self, idx):
        text = str(self.text[idx])

        tokenized_text = self.tokenizer(
            text,
            add_special_tokens=True,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_attention_mask=True,
        )

        ids = tokenized_text["input_ids"]
        mask = tokenized_text["attention_mask"]

        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(mask, dtype=torch.long),
        }

In [11]:
import torch
from typing import Optional, Tuple


def masked_language_modeling(
    input_ids: torch.Tensor,
    tokenizer=tokenizer,
    fraction: float = 0.15,
    ignore_index: int = -100,
) -> Tuple[torch.Tensor]:
    label = input_ids.clone()

    special_tokens_mask = torch.tensor(
        [
            tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
            for val in label.tolist()
        ],
        dtype=torch.bool,
    )  # get all special token mask to ignore their selection in MLM
    probability_matrix = torch.full(
        label.shape, fraction
    )  # gen probability matrix to select 15% tokens for MLM
    probability_matrix.masked_fill_(
        special_tokens_mask, value=0.0
    )  # zero out the probability of special tokens so that they do not get selected
    # https://pytorch.org/docs/stable/generated/torch.bernoulli.html
    # Draws binary random numbers (0 or 1) from a Bernoulli distribution.
    masked_indices = torch.bernoulli(probability_matrix).bool()

    label[~masked_indices] = (
        ignore_index  # We only compute loss on masked tokens cross entropy ignore_index
    )

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(label.shape, 0.8)).bool() & masked_indices
    )  # rondomly select 80% tokens from 15% token
    input_ids[indices_replaced] = tokenizer.convert_tokens_to_ids(
        tokenizer.mask_token
    )  # replace it with [MASK]

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(label.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )  # from 15% ignore 80% select 10% from 20%
    random_words = torch.randint(
        len(tokenizer), label.shape, dtype=torch.long
    )  # get random token index from tokenizer
    input_ids[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged do nothing
    #     batch['input_ids_modified'] = input_ids
    #     batch['label'] = label
    #     batch['masked_indices'] = masked_indices
    return input_ids, label, masked_indices

In [12]:
def log(t, eps=1e-9) -> torch.Tensor:
    "get log of input with added epsilon"
    return torch.log(t + eps)


def noise(t) -> torch.Tensor:
    "generate noise for given tensor"
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def sample(t, temperature=1.0) -> torch.Tensor:
    "for sampling the tokens with added noise and tempreture"
    return ((t / temperature) + noise(t)).argmax(dim=-1)


def electra(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    tokenizer,
    masked_indices: torch.Tensor,
    temperature: Optional[int] = 3,
) -> Tuple[torch.Tensor]:
    sample_logits = logits[masked_indices]  # get token index from MLM step
    sampled = sample(
        sample_logits, temperature=temperature
    )  # temperature to control how many token to replace
    # scatter the sampled values back to the input
    discriminator_input = input_ids.clone()  # Original input
    discriminator_input[masked_indices] = sampled.detach()

    # generate discriminator labels, with replaced as True and original as False
    disc_labels = (input_ids != discriminator_input).float().detach()

    # get discriminator predictions of replaced / original
    non_padded_indices = torch.nonzero(
        input_ids != tokenizer.pad_token_id, as_tuple=True
    )  # needed to caluclate loss only on  non [PAD]  tokens  .to(
    #         logits.device
    #     )
    return discriminator_input, disc_labels, non_padded_indices

In [13]:
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [14]:
import torch
import torch.nn as nn
from einops import rearrange, reduce
from typing import Optional, Tuple


class AbsoluteEncoding(nn.Module):
    """Construct the Absolute embeddings from position"""

    def __init__(self, config) -> None:
        super().__init__()
        self.pos_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=getattr(config, "pad_token_id", None),
        )
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )
        self.max_size = config.max_position_embeddings

    def forward(self, size: int) -> torch.Tensor:
        if self.max_size < size:
            raise ValueError(
                f"The hidden size ({size }) is more than the config max_position_embeddings {self.max_size}"
            )
        return self.pos_embeddings(self.position_ids[:, :size])


class SinusoidalEncoding(nn.Module):
    """Construct the Sinusoidal embeddings from word, position and token_type embeddings."""

    def __init__(self, config) -> None:
        super().__init__()
        if config.hidden_size % 2 != 0:
            raise ValueError(
                f"Cannot use SinusoidalEncoding with "
                "odd hidden dim got dim {config.hidden_size}"
            )
        self.positional_encoding = torch.zeros(
            1, config.max_position_embeddings, config.hidden_size
        )
        self.position = torch.arange(0, config.max_position_embeddings).unsqueeze(1)
        self.div_term = torch.exp(
            (
                torch.arange(0, config.hidden_size, 2, dtype=torch.float)
                * -(torch.log(torch.tensor(10000.0)) / config.hidden_size)
            )
        )

        self.positional_encoding[:, :, 0::2] = torch.sin(
            self.position.float() * self.div_term
        )
        self.positional_encoding[:, :, 1::2] = torch.cos(
            self.position.float() * self.div_term
        )

    def forward(self, seq_len: int) -> torch.Tensor:

        return self.positional_encoding[:, :seq_len]


# copied and modified from transformer/models/gemma
class RotaryEmbedding(nn.Module):
    """Construct the positionl frequencies for RoPE embedding"""

    def __init__(self, config, base=10000, device=None):
        super().__init__()

        self.dim = int(config.hidden_size // config.num_attention_heads)
        self.max_position_embeddings = config.max_position_embeddings
        self.base = base
        self.register_buffer(
            "inv_freq",
            1.0
            / (
                self.base
                ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
            ),
            persistent=False,
        )
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )

    @torch.no_grad()
    def forward(self, seq_len: int = None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # size = x.size()[2]
        position_ids = torch.arange(seq_len).unsqueeze(0)
        # position_ids = self.position_ids[:, :size].float()

        inv_freq_expanded = (
            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        )
        position_ids_expanded = position_ids[:, None, :].float()

        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
            1, 2
        )
        return freqs


# Copied from transformers
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def _rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers
def apply_rotary_pos_emb(q, k, freqs, unsqueeze_dim=1) -> Tuple[torch.Tensor]:
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        freqs: precalculated frqs for sin cos
        only_q: bool = False for encoder decoder
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = emb.cos().to(dtype=q.dtype)
    sin = emb.sin().to(dtype=q.dtype)
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q * cos) + (rotate_half(q) * sin)

    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# To do :  Alibi

In [15]:
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from typing import Optional, Tuple, Union


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def repeat_kv_einops(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = repeat(
        hidden_states,
        "batch num_key_value_heads slen head_dim -> batch num_key_value_heads n_rep slen head_dim",
        n_rep=n_rep,
    )  # hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    # return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
    return rearrange(
        hidden_states,
        "batch num_key_value_heads n_rep slen head_dim -> batch (num_key_value_heads n_rep) slen head_dim",
    )


class AttentionSelfOutput(nn.Module):
    def __init__(
        self, config, bias: Optional[bool] = True, out_features: Optional[int] = None
    ):
        super().__init__()
        self.dense = nn.Linear(
            config.hidden_size,
            config.hidden_size if out_features is None else out_features,
            bias=bias,
        )
        self.layernorm = nn.LayerNorm(
            config.hidden_size, eps=getattr(config, "layer_norm_eps", 1e-6)
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: torch.FloatTensor of shape (batch, seq_len, embed_dim)`
            input_tensor: torch.FloatTensor of shape (batch, seq_len, embed_dim)`

        return:
               hidden_states: torch.FloatTensor of shape (batch, seq_len, embed_dim)

        """
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states + input_tensor)
        return hidden_states


class EncoderAttention(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.attention_bias = getattr(config, "attention_bias", True)
        self.layer_idx = layer_idx
        # self.qkv = nn.Linear(config.hidden_size,3*config.hidden_size)
        self.query = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.key = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.value = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.out = AttentionSelfOutput(config=config, bias=self.attention_bias)
        self.num_attention_heads = config.num_attention_heads

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)`
            attention_mask: torch.Tensor of shape (batch,1, seq_len, seqlen)`
            freqs: Positional freqs in case of RoPE embedding
        return:
               hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)

        """
        q = self.query(hidden_state)
        k = self.key(hidden_state)
        v = self.value(hidden_state)
        # transform it into batch_size x no_of_heads x seqlen x head_dim for Multihead Attention
        q = rearrange(q, "b l (h d) -> b h l d", h=self.num_attention_heads)
        k = rearrange(k, "b l (h d) -> b h l d", h=self.num_attention_heads)
        v = rearrange(v, "b l (h d) -> b h l d", h=self.num_attention_heads)
        if freqs is not None:
            q, k = apply_rotary_pos_emb(q, k, freqs)  # apply RoPE if freqs is available

        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, attn_mask=attention_mask, is_causal=False
        )
        # transform it back into batch_size x seqlen x hidden_dim
        out = rearrange(out, "b h l d -> b l (h d)")
        return self.out(out, hidden_state)


class EncoderAttentionGqa(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if not self.flash and self.layer_idx == 0:  # avoid to print m times
            print("WARNING: Flash Attention requires PyTorch >= 2.0")
        self.layer_idx = layer_idx
        self.head_dim = int(config.hidden_size // config.num_attention_heads)
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = getattr(config, "num_key_value_heads", 4)
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
        if (
            self.num_attention_heads % self.num_key_value_heads != 0
            or self.num_attention_heads < self.num_key_value_heads
        ):
            raise ValueError(
                f"num_key_value_heads {self.num_key_value_heads }  should be less than equal num_attention_heads {config.num_attention_heads} and  multiple of num_attention_heads {config.num_attention_heads} "
            )
        self.attention_bias = getattr(config, "attention_bias", True)
        self.out = AttentionSelfOutput(config=config, bias=self.attention_bias)
        self.query = nn.Linear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.key = nn.Linear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=self.attention_bias,
        )
        self.value = nn.Linear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=self.attention_bias,
        )

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)`
            attention_mask: torch.Tensor of shape (batch,1, seq_len, seqlen)`
            freqs: Positional freqs in case of RoPE embedding
        return:
               hidden_states: torch.Tensor of shape (batch, seq_len, embed_dim)

        """
        q = self.query(hidden_state)
        k = self.key(hidden_state)
        v = self.value(hidden_state)
        # transform it into batch_size x no_of_heads x seqlen x head_dim for Multihead Attention
        q = rearrange(q, "b l (h d) -> b h l d", d=self.head_dim)
        k = rearrange(k, "b l (h d) -> b h l d", d=self.head_dim)
        v = rearrange(v, "b l (h d) -> b h l d", d=self.head_dim)

        if freqs is not None:
            q, k = apply_rotary_pos_emb(q, k, freqs)  # apply RoPE if freqs is available

        k = repeat_kv(
            k, n_rep=self.num_key_value_groups
        )  # in case of GQA repeat k,v to make it same as q
        v = repeat_kv(v, n_rep=self.num_key_value_groups)
        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, attn_mask=attention_mask, is_causal=False
        )
        # transform it back into batch_size x seqlen x hidden_dim
        out = rearrange(out, "b h l d -> b l (h d)")

        return self.out(out, hidden_state)

In [16]:
_ACT_ = {
    "gelu": nn.GELU(),
    "leaky_relu": nn.LeakyReLU(),
    "relu6": nn.ReLU6(),
    "sigmoid": nn.Sigmoid(),
    "silu": nn.SiLU(),
    "swish": nn.SiLU(),
    "tanh": nn.Tanh(),
}


class FeedForward(nn.Module):
    def __init__(self, config, multiplier: Union[int, float] = 4) -> None:
        super().__init__()
        self.intermediate = nn.Linear(
            config.hidden_size, int(multiplier) * config.hidden_size
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        if _ACT_.get(getattr(config, "hidden_act", None), None):
            self.act_fn = _ACT_[config.hidden_act]
        else:
            self.act_fn = nn.GELU()
        self.out = nn.Linear(int(multiplier) * config.hidden_size, config.hidden_size)

    def forward(
        self, hidden_state: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        output = self.intermediate(hidden_state)
        output = self.act_fn(output)
        output = self.out(output)
        output = self.dropout(output)
        output = self.layernorm(output + input_tensor)
        return output

In [17]:
import torch
import torch.nn as nn
from typing import Optional, Tuple
from dataclasses import dataclass

_position_embeddings = {
    "absolute": AbsoluteEncoding,
    "sinusoidal": SinusoidalEncoding,
}  #'relative':RelativePositionalEncoding


@dataclass
class EncoderOutput(object):
    logits: torch.Tensor


@dataclass
class MLMOutput(object):
    hidden_state: torch.Tensor
    logits: torch.Tensor


class EncoderLayer(nn.Module):
    "encoder layer for encoder model"

    def __init__(self, config, layer_idx: int, attention_type: str = None) -> None:
        super().__init__()
        self.attention = (
            EncoderAttentionGqa(config, layer_idx=layer_idx)
            if attention_type == "gqa"
            else EncoderAttention(config, layer_idx=layer_idx)
        )
        if attention_type == "gqa" and layer_idx == 0:  # avoid to print m times
            print("Encoder Using GQA Attention")
        self.feed_forward = FeedForward(config)
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            hidden_state: torch.Tensor of shape (batch, seq_len,embd_dim)
            attention_mask: torch.Tensor of shape (batch,1,seqlen,seqlen)
            freqs: positionl information to use in RoPE
        return:
               hidden_state: torch.Tensor of shape (batch, seq_len, embed_dim) of last layer

        """
        out = self.attention(
            hidden_state=hidden_state, attention_mask=attention_mask, freqs=freqs
        )
        out = self.feed_forward(out, hidden_state)
        return out


class LMHead(nn.Module):
    """Head for masked language modelling"""

    def __init__(self, config) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(
            config.hidden_size, eps=getattr(config, "layer_norm_eps", 1e-6)
        )

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_state)
        x = nn.GELU()(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x


class EncoderModel(nn.Module):

    def __init__(
        self,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> None:
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            padding_idx=getattr(config, "pad_token_id", None),
        )
        if _position_embeddings.get(pos_embedding_type, None) is not None:
            self.position_embeddings = _position_embeddings.get(pos_embedding_type)(
                config
            )
        else:
            self.position_embeddings = None
        if pos_embedding_type == "rope":
            self.emb_freq = RotaryEmbedding(config)(config.max_position_embeddings)
            print(
                "Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable"
            )
        self.all_layer = nn.ModuleList(
            [
                EncoderLayer(config, layer_idx, attention_type)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(
                module.weight, mean=0.0, std=0.02 / torch.sqrt(2 * len(self.all_layer))
            )

    def forward(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            input_ids: torch.LongTensor of shape (batch, seq_len) for encoder`
            attention_mask: torch.Tensor of shape (batch,seqlen) for encoder
        return:
               logits: torch.Tensor of shape (batch, seq_len, embed_dim) of last layer

        """
        bsz, seqlen = input_ids.shape
        hidden_state = self.word_embeddings(input_ids)
        freqs = None
        if self.position_embeddings is not None:
            pos_info = self.position_embeddings(seqlen)[:, :seqlen, :].to(
                input_ids.device
            )
            hidden_state = hidden_state + pos_info
        else:
            freqs = self.emb_freq[:, :seqlen].to(input_ids.device)

        if attention_mask is None:
            encoder_batch_size, encoder_sequence_length = input_ids.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            attention_mask = torch.ones(encoder_hidden_shape, device=input_ids.device)

        attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).type_as(hidden_state)
        attention_mask = (1.0 - attention_mask) * torch.finfo(
            hidden_state.dtype
        ).min  # invert it to to add directly to attention score

        for layer in self.all_layer:
            hidden_state = layer(hidden_state, attention_mask, freqs)
        return EncoderOutput(hidden_state)

    @classmethod
    def from_config(
        cls,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> nn.Module:
        return cls(config, pos_embedding_type, attention_type)


class EncoderForMaskedLM(nn.Module):

    def __init__(
        self,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> None:
        super().__init__()
        self.encoder = EncoderModel(
            config, pos_embedding_type=pos_embedding_type, attention_type=attention_type
        )
        self.lm_head = LMHead(config=config)

    def forward(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            input_ids: torch.LongTensor of shape (batch, seq_len) for encoder`
            attention_mask: torch.Tensor of shape (batch,seqlen) for encoder
        return:
               hidden_state: torch.Tensor of shape (batch, seq_len, embed_dim) of last layer
               logits: torch.Tensor of shape (batch,seqlen, vocab_size)

        """
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.lm_head(out.logits)
        return MLMOutput(hidden_state=out.logits, logits=logits)

    @classmethod
    def from_config(
        cls,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> nn.Module:
        return cls(config, pos_embedding_type, attention_type)

In [18]:
config = AutoConfig.from_pretrained(model_ckpt)

In [19]:
config.num_hidden_layers = 4
generator = EncoderForMaskedLM(config, pos_embedding_type="rope")

Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable


In [20]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.discriminator = EncoderModel(config, pos_embedding_type="rope")
        self.discriminator_head = nn.Linear(config.hidden_size, 1)

    def forward(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        out = self.discriminator(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.discriminator_head(out.logits)
        return logits


config.num_hidden_layers = 6
# according to paper discriminator should be larger than generator

In [21]:
class ElectraModel(nn.Module):
    def __init__(self, generator, discriminator):
        super().__init__()
        self.discriminator_model = discriminator
        self.generator_model = generator

    def get_generator_output(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        return self.generator_model(input_ids, attention_mask)

    def get_discriminator_output(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        return self.discriminator_model(input_ids, attention_mask)

In [22]:
generator

EncoderForMaskedLM(
  (encoder): EncoderModel(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (all_layer): ModuleList(
      (0-3): 4 x EncoderLayer(
        (attention): EncoderAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (out): AttentionSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (feed_forward): FeedForward(
          (intermediate): Linear(in_features=768, out_features=3072, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (act_fn): GELU(approximate='none')
          (out): Linear(

In [23]:
discriminator = Discriminator(config)

Encoder Ignoring sinusoidal or absolute position embeddings because rope,is enable


In [24]:
discriminator

Discriminator(
  (discriminator): EncoderModel(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (all_layer): ModuleList(
      (0-5): 6 x EncoderLayer(
        (attention): EncoderAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (out): AttentionSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (feed_forward): FeedForward(
          (intermediate): Linear(in_features=768, out_features=3072, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (act_fn): GELU(approximate='none')
          (out): Linear

In [25]:
electra_model = ElectraModel(generator, discriminator)

In [35]:
class ElectraLoss:
    def __init__(self, config):
        self.config = config
        self.generator_loss = nn.CrossEntropyLoss()

    #         self.discriminator_loss = nn.BCELoss()

    def __call__(
        self,
        generator_logits,
        generator_label,
        disc_logits,
        disc_labels,
        non_padded_indices,
    ):
        generator_loss = self.generator_loss(
            generator_logits.view(-1, self.config.vocab_size), generator_label.view(-1)
        )
        disc_logits = disc_logits.reshape_as(disc_labels)
        discriminator_loss = nn.functional.binary_cross_entropy_with_logits(
            disc_logits[non_padded_indices], disc_labels[non_padded_indices]
        )
        #         discriminator_loss =  self.discriminator_loss(disc_logits[non_padded_indices],disc_labels[non_padded_indices])
        return generator_loss + discriminator_loss, generator_loss, discriminator_loss

**Electra  Embedding Disentanglement**

In [42]:
m = AutoModel.from_pretrained(model_ckpt)

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Electra](https://openreview.net/pdf?id=r1xMH1BtvB)

**Copy Embeddings for faster convergence**

**share the embeddings (both the token and positional embeddings in case of absolute embedding) of the generator and discriminator**

In [44]:
electra_model.generator_model.encoder.word_embeddings.weight = (
    m.embeddings.word_embeddings.weight
)
electra_model.discriminator_model.discriminator.word_embeddings.weight = (
    m.embeddings.word_embeddings.weight
)

In [45]:
dff = dff.sample(frac=0.08, random_state=42)
dff.shape

(42078, 4)

In [46]:
train_dataset = MLMDataset(text=dff["quote"].values)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

In [47]:
from accelerate import Accelerator
import sys


def single_gpu():
    electra_loss = ElectraLoss(config)
    accumulation_steps = 1  # batch size 128 large dont need gradient accumulation
    lr = 1e-3
    EPOCHS = 8
    Config = {"num_epoch": EPOCHS, "learning_rate": lr}

    accelerator = Accelerator(
        log_with="tensorboard",
        project_dir=".",
    )
    accelerator.init_trackers("Electra_project", config=Config)

    electra_model.cuda()

    no_decay = ["bias", "layernorm.weight", "layernorm.bias"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in electra_model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.01,
        },
        {
            "params": [
                p
                for n, p in electra_model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)

    #     optimizer = AdamW(electra_model.parameters(), lr = lr)
    num_train_optimization_steps = int(EPOCHS * len(train_loader) / accumulation_steps)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0.05 * num_train_optimization_steps,
        num_training_steps=num_train_optimization_steps,
    )  # PyTorch scheduler

    epoch_check = len(train_loader)
    total_step = epoch_check * EPOCHS
    train_bar = tqdm(total=total_step, dynamic_ncols=True)
    t_step = 1
    k = 0
    electra_model.train()
    for epoch in range(EPOCHS):
        avg_loss = 0.0
        loss_list = []
        running_loss = 0.0
        for step, data in enumerate(train_loader):
            train_bar.update(1)
            attention_mask = data["attention_mask"].cuda()
            optimizer.zero_grad()
            generator_input_ids, generator_label, masked_indices = (
                masked_language_modeling(input_ids=data["input_ids"])
            )
            out = electra_model.get_generator_output(
                generator_input_ids.cuda(), attention_mask
            )
            discriminator_input, disc_labels, non_padded_indices = electra(
                out.logits, data["input_ids"].cuda(), tokenizer, masked_indices
            )
            disc_logits = electra_model.get_discriminator_output(
                discriminator_input, attention_mask
            )
            loss, gen_loss, disc_loss = electra_loss(
                out.logits,
                generator_label.cuda(),
                disc_logits,
                disc_labels,
                non_padded_indices,
            )
            loss.backward()
            optimizer.step()
            scheduler.step()
            accelerator.log({"total_loss_step": loss}, step=t_step)
            accelerator.log({"generator_loss_step": gen_loss}, step=t_step)
            accelerator.log({"discriminator_loss_step": disc_loss}, step=t_step)

            train_bar.set_description(
                f'epoch: {epoch+1} step: {t_step} loss: {"%.4f" % loss}'
            )
            t_step += 1

            loss_list.append(loss.detach().cpu().item())
        scheduler.step()
        avg_loss = np.round(np.mean(loss_list), 4)
        accelerator.log({"training_loss_epoch": avg_loss}, step=epoch + 1)
        print(f'Epoch: {epoch+1} loss: {"%.4f" % avg_loss }')
    accelerator.end_training()
    PATH = f"electra_pretraining_{epoch}.pth"
    torch.save(electra_model.state_dict(), PATH)

In [48]:
single_gpu()

2024-05-13 17:15:19.259115: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-13 17:15:19.259214: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-13 17:15:19.381896: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


  0%|          | 0/5264 [00:00<?, ?it/s]

Epoch: 1 loss: 6.3315
Epoch: 2 loss: 5.2709
Epoch: 3 loss: 5.0046
Epoch: 4 loss: 4.8197
Epoch: 5 loss: 4.6595
Epoch: 6 loss: 4.5186
Epoch: 7 loss: 4.4116
Epoch: 8 loss: 4.2999


**Total Loss**

![image.png](attachment:image.png)

**Generator Loss**

![image.png](attachment:image.png)

**Discriminator Loss**

![image.png](attachment:image.png)

**To Get a better model you need to train it on large data and increase generator model size and discriminator model size while keeping  discriminator model size  greater than generator model size like bert-small bert-base**

**masked token prediction**

In [49]:
check = "stress shows on <mask> face."
inputs = tokenizer(check, truncation=True, padding=True, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
input_ids = inputs["input_ids"].cuda()
attention_mask = inputs["attention_mask"].cuda()
with torch.no_grad():
    logits = electra_model.get_generator_output(input_ids, attention_mask).logits
mask_token_logits = logits[0, mask_token_index, :]
top_3 = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()

for token in top_3:
    print(check.replace(tokenizer.mask_token, tokenizer.decode([token])))

stress shows on  your face.
stress shows on  the face.
stress shows on  my face.


In [50]:
check = "You are <mask> weak."
inputs = tokenizer(check, truncation=True, padding=True, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
input_ids = inputs["input_ids"].cuda()
attention_mask = inputs["attention_mask"].cuda()
with torch.no_grad():
    logits = electra_model.get_generator_output(input_ids, attention_mask).logits
mask_token_logits = logits[0, mask_token_index, :]
top_3 = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()

for token in top_3:
    print(check.replace(tokenizer.mask_token, tokenizer.decode([token])))

You are  not weak.
You are  a weak.
You are  always weak.


In [51]:
check = "Get up and face your <mask> fears every day"
inputs = tokenizer(check, truncation=True, padding=True, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
input_ids = inputs["input_ids"].cuda()
attention_mask = inputs["attention_mask"].cuda()
with torch.no_grad():
    logits = electra_model.get_generator_output(input_ids, attention_mask).logits
mask_token_logits = logits[0, mask_token_index, :]
top_3 = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()

for token in top_3:
    print(check.replace(tokenizer.mask_token, tokenizer.decode([token])))

Get up and face your  own fears every day
Get up and face your  best fears every day
Get up and face your  life fears every day


In [52]:
check = "There is something <mask> and something extremely profound, in owning a home."
inputs = tokenizer(check, truncation=True, padding=True, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
input_ids = inputs["input_ids"].cuda()
attention_mask = inputs["attention_mask"].cuda()
with torch.no_grad():
    logits = electra_model.get_generator_output(input_ids, attention_mask).logits
mask_token_logits = logits[0, mask_token_index, :]
top_3 = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()

for token in top_3:
    print(check.replace(tokenizer.mask_token, tokenizer.decode([token])))

There is something  wrong and something extremely profound, in owning a home.
There is something  good and something extremely profound, in owning a home.
There is something  better and something extremely profound, in owning a home.


In [53]:
check = "what is your <mask> my friend."
inputs = tokenizer(check, truncation=True, padding=True, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
input_ids = inputs["input_ids"].cuda()
attention_mask = inputs["attention_mask"].cuda()
with torch.no_grad():
    logits = electra_model.get_generator_output(input_ids, attention_mask).logits
mask_token_logits = logits[0, mask_token_index, :]
top_3 = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()

for token in top_3:
    print(check.replace(tokenizer.mask_token, tokenizer.decode([token])))

what is your  own my friend.
what is your  life my friend.
what is your  only my friend.
