# Configuration

In [None]:
!pip install -U datasets huggingface_hub fsspec

In [6]:
#config
vocab_size=30000
special_tokens=["CLS","SEP","UNK","PAD"]
n_segments=2
max_len=350
embedd_dim=768
n_layers=8
attn_heads=12
dropout=0.1
d_ff=1600



# Embeddings

In [7]:
import torch
from torch import nn
import math
class InputEmbedding(nn.Module):
  def __init__(self,vocab_size:int,d_model:int)->None:
    super().__init__()
    self.d_model=d_model
    self.vocab_size=vocab_size
    self.embedd=nn.Embedding(vocab_size,self.d_model)
  def forward(self,x):
  #(batch,seq_len)-->(batch,seq_len,d_model)
    return self.embedd(x)*math.sqrt(self.d_model)

class SegmentEmbedding(nn.Module):
  def __init__(self,n_segments:int,d_model:int)->None:
    super().__init__()
    self.segment_embedd=nn.Embedding(n_segments,d_model)
  def forward(self,x):
    return self.segment_embedd(x)

class PositionalEmbedding(nn.Module):
  def __init__(self,seq_len:int,d_model:int,dropout:float)->None:
    super().__init__()
    self.seq_len=seq_len
    self.d_model=d_model
    self.drop=nn.Dropout(dropout)
    pe=torch.zeros(seq_len,d_model)
    position=torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
    div_term=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
    pe[:,0::2]=torch.sin(position*div_term)
    pe[:,1::2]=torch.cos(position*div_term)

    pe=pe.unsqueeze(0)  #adding batch dim
    self.register_buffer("pe",pe)
  def forward(self,x):
    x=x+ self.pe[:,:x.shape[1],:].detach()
    return self.drop(x)



In [8]:
class full_embeddings(nn.Module):
  def __init__(self,src_emb:InputEmbedding,pe_emb:PositionalEmbedding,se_emb:SegmentEmbedding,sep_input_id):
    super().__init__()
    self.src_emb=src_emb
    self.pe_emb=pe_emb
    self.se_emb=se_emb
    self.sep_input_id=sep_input_id


  def forward(self,input_ids,segment_ids):

    x=self.pe_emb(self.src_emb(input_ids)+self.se_emb(segment_ids))
    return x


# Normalization block,residual block,feedforward block

In [9]:
class LayerNormalization(nn.Module):
  def __init__(self,d_model, eps:float=1e-6)->None:
    super().__init__()
    self.eps=eps
    self.alpha=nn.Parameter(torch.ones(d_model))
    self.bias=nn.Parameter(torch.zeros(d_model))
  def forward(self,x):
    mean=x.mean(dim=-1,keepdim=True)
    std=x.std(dim=-1,keepdim=True)

    return self.alpha*(x-mean)/(std+self.eps) +self.bias


class FeedForwardNetwork(nn.Module):
  def __init__(self,d_model:int,d_ff:int,dropout:float)->None:
    super().__init__()
    self.linear_1=nn.Linear(d_model,d_ff)
    self.linear_2=nn.Linear(d_ff,d_model)
    self.drop=nn.Dropout(dropout)
    # self.relu=nn.ReLU()
  def forward(self,x):
    x=torch.relu(self.linear_1(x))
    x=self.drop(x)
    return self.linear_2(x)


class ResidualConnection(nn.Module):
  def __init__(self,d_model,drop:float)->None:
    super().__init__()
    self.norm=LayerNormalization(d_model)
    self.drop=nn.Dropout(drop)
  def forward(self,x,sublayer):
    return x+self.drop(sublayer(self.norm(x)))



# Attention Block

In [10]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model:int,heads:int,dropout:float):
    super().__init__()
    self.heads=heads
    assert d_model % heads==0, "d_model not divisible by heads"
    self.d_k=d_model//heads
    self.heads=heads
    self.q=nn.Linear(d_model,d_model)
    self.k=nn.Linear(d_model,d_model)
    self.v=nn.Linear(d_model,d_model)
    self.w_o=nn.Linear(d_model,d_model)
    self.dropout=nn.Dropout(dropout)

  def attention(self,q,k,v,mask,dropout):
    d_k=q.shape[-1]
    attention_score=q@k.transpose(-2,-1)/math.sqrt(d_k)
    if mask is not None:
      mask = mask.unsqueeze(1).unsqueeze(2)
      # attention_score=attention_score.masked_fill(mask==0,-1e9)
      mask = mask.to(dtype=torch.bool, device=attention_score.device)
      min_val = torch.finfo(attention_score.dtype).min
      attention_score = attention_score.masked_fill(mask == 0, min_val)
    attention_score=torch.softmax(attention_score,dim=-1)
    if dropout is not None:
      attention_score=dropout(attention_score)
    return attention_score@v, attention_score

  def forward(self,q,k,v,mask):
    q=self.q(q)
    k=self.k(k)
    v=self.v(v)

    q=q.view(q.shape[0],q.shape[1],self.heads,self.d_k).transpose(1,2)
    k=k.view(k.shape[0],k.shape[1],self.heads,self.d_k).transpose(1,2)
    v=v.view(v.shape[0],v.shape[1],self.heads,self.d_k).transpose(1,2)

    x,attention_score=self.attention(q,k,v,mask,self.dropout)
    x=x.transpose(1,2).contiguous().view(x.shape[0],-1,self.heads*self.d_k)
    x=self.w_o(x)
    return x



# Encoder

In [11]:
class EncoderBlock(nn.Module):
  def __init__(self,d_model,d_ff,dropout,heads):
    super().__init__()
    self.feedfwd=FeedForwardNetwork(d_model,d_ff,dropout)

    self.residual=nn.ModuleList([ResidualConnection(d_model,dropout),ResidualConnection(d_model,dropout)])
    self.attention=MultiHeadAttention(d_model,heads,dropout)
  def forward(self, x,mask):
    x=self.residual[0](x,lambda x: self.attention(x,x,x,mask))
    x=self.residual[1](x,self.feedfwd)
    return x


class Encoder(nn.Module):
  def __init__(self,d_model,d_ff,dropout,heads,n_layers):
    super().__init__()
    self.encoders=nn.ModuleList([EncoderBlock(d_model,d_ff,dropout,heads) for _ in range(n_layers)])
  def forward(self,x,mask):
    for layer in self.encoders:
      x=layer(x,mask)
    return x



# Classifier

In [12]:
class Classifier(nn.Module):
  def __init__(self,d_model:int,d_ff:int,dropout:float,output_size:int=3)->None:
    super().__init__()
    self.classifier=nn.Sequential(nn.Linear(d_model,d_ff),
                                  nn.ReLU(),
                                  nn.Dropout(dropout),
                                  nn.Linear(d_ff,1024),
                                  nn.ReLU(),
                                  nn.Dropout(dropout),
                                  nn.Linear(1024,output_size),
                                  # nn.Sigmoid()
    )
  def forward(self,x):
    return self.classifier(x)



# Transformer

In [13]:
class Transformer(nn.Module):
  def __init__(self,encoder:Encoder,embeddings:full_embeddings)->None:
    super().__init__()
    self.encoder=encoder
    # self.classifier=classifier
    self.emb=embeddings
  def forward(self, x,segment_ids,mask):
    # B,S,E=x.shape
    x=self.emb(x,segment_ids)
    output=self.encoder(x,mask)
    cls=output[:,0,:]
    cls=cls.squeeze(1)
    # logits=self.classifier(cls)
    return cls
def build_transformer(vocab_size:int,n_segments:int,embedd_dim:int,max_len:int,n_layers:int,attn_heads:int,dropout:float,d_ff:int,c_d_ff:int,nli_pretrain:bool,sep_input_id:int=2)-> Transformer:
  input_emb=InputEmbedding(vocab_size,embedd_dim)
  seg_emb=SegmentEmbedding(n_segments,embedd_dim)
  pe_emb=PositionalEmbedding(max_len,embedd_dim,dropout)

  full_emb=full_embeddings(input_emb,pe_emb,seg_emb,sep_input_id)

  encoder=Encoder(embedd_dim,d_ff,dropout,attn_heads,n_layers)

  transformer=Transformer(encoder,full_emb)

  for p in transformer.parameters():
    if p.dim()>1:
      nn.init.xavier_uniform_(p)
  if nli_pretrain:
    classifier=Classifier(embedd_dim,d_ff,dropout)
    for p in classifier.parameters():
      if p.dim()>1:
        nn.init.xavier_uniform_(p)
    return transformer, classifier


  return transformer


# Tokenizer

In [14]:
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))

from tokenizers.trainers import WordPieceTrainer
trainer = WordPieceTrainer(vocab_size=vocab_size,special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]"])

from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()

# files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
# files=[f"/content/train-00000-of-00002.raw","/content/train-00001-of-00002.raw"]
# tokenizer.train(files, trainer)


In [15]:
# tokenizer.save("/content/tokenizer-wiki.json")

In [16]:
tokenizer = Tokenizer.from_file("/kaggle/input/my-tokenizer/tokenizer-wiki.json")

In [17]:
from tokenizers.processors import TemplateProcessing
tokenizer.post_processor = TemplateProcessing(
    single="[CLS] $A [SEP]",
    pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[
        ("[CLS]", tokenizer.token_to_id("[CLS]")),
        ("[SEP]", tokenizer.token_to_id("[SEP]")),

    ],
)

tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]",length=350)
tokenizer.enable_truncation(max_length=350)


# Dataset

In [18]:

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

class NLIDataset(Dataset):
    def __init__(self, split='test'):

      self.data = load_dataset('sentence-transformers/all-nli', 'pair-class', split=split)

    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        item = self.data[idx]
        label=item["label"]
        features=[[item['premise'], item['hypothesis']]]



        self.features=tokenizer.encode_batch(features)
        x = torch.tensor([encoding.ids for encoding in self.features], dtype=torch.long)
        segment_id = torch.tensor([encoding.type_ids for encoding in self.features], dtype=torch.long)
        mask = torch.tensor([encoding.attention_mask for encoding in self.features], dtype=torch.float)

        return {"input":x,"segment_id":segment_id,"mask":mask,"label":torch.tensor(label, dtype=torch.long)}





In [15]:
dataset = NLIDataset(split='test')
data_loader = DataLoader(dataset, batch_size=32, shuffle=True,pin_memory=True)

README.md: 0.00B [00:00, ?B/s]

pair-class/train-00000-of-00001.parquet:   0%|          | 0.00/69.5M [00:00<?, ?B/s]

pair-class/dev-00000-of-00001.parquet:   0%|          | 0.00/1.57M [00:00<?, ?B/s]

pair-class/test-00000-of-00001.parquet:   0%|          | 0.00/1.61M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/942069 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/19657 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/19656 [00:00<?, ? examples/s]

# Model setup

In [19]:

epochs=10
batch_size=32
lr=3e-4 #karpathy_constant
device= "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True 
from torch.amp import  autocast
from torch.amp import GradScaler



model,classifier=build_transformer(vocab_size,n_segments,embedd_dim,max_len,n_layers,attn_heads,dropout,d_ff,d_ff,True,2)
model=model.to(device)
classifier=classifier.to(device)
scaler=GradScaler()

model_optimizer=torch.optim.AdamW(model.parameters(),lr=lr)
cls_optimizer=torch.optim.AdamW(classifier.parameters(),lr=lr)
loss_fn=nn.CrossEntropyLoss()



# Training Loop

In [41]:
from tqdm import tqdm
import gc
model.train()
classifier.train()
best_loss=1e9
for epoch in range(epochs):
    train_loss = 0.0
    correct = 0
    total = 0

    loop = tqdm(data_loader, desc=f"Epoch {epoch+1}/{epochs}")
    for batch in loop:
        x = batch["input"].squeeze(1).to(device)
        segment_id = batch["segment_id"].squeeze(1).to(device)
        mask = batch["mask"].squeeze(1).to(device)
        labels = batch["label"].to(device).unsqueeze(1)  # shape: [B, 1]
    
        labels=labels.squeeze(1)
        with autocast(device):
            logits = model(x, segment_id, mask)  # shape: [B, 1]
            logits=classifier(logits)
            loss = loss_fn(logits, labels)

        model_optimizer.zero_grad()
        cls_optimizer.zero_grad()
        scaler.scale(loss).backward()
    
        scaler.step(model_optimizer)
        scaler.step(cls_optimizer)
        scaler.update()
        

        train_loss += loss.item()

        # Metrics
        preds = (torch.argmax(logits,dim=-1)).float()
        correct += torch.sum((preds == labels).float())
        total += labels.size(0)

        loop.set_postfix(Loss=loss.item(), Accuracy=(preds==labels).float().mean().item())
    gc.collect()
    torch.cuda.empty_cache()
    avg_loss = train_loss / len(data_loader)
    torch.save({"model":model.state_dict(),"classifier":classifier.state_dict()},f"/kaggle/working/checkpoints_phase1/phase_1_epoch{epoch+1}.pt")
    if(avg_loss<best_loss):
        torch.save({"model":model.state_dict(),"classifier":classifier.state_dict()},f"/kaggle/working/checkpoints_phase1/phase_1_best.pt")
        best_loss=avg_loss
    accuracy = correct / total
    # writer.add_scalar("train/loss",avg_loss,epoch+1)
    # writer.add_scalar("train/accuracy",accuracy,epoch+1)
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    # global_step += 1


Epoch 1/10: 100%|██████████| 615/615 [04:53<00:00,  2.09it/s, Accuracy=0.5, Loss=0.983]  


Epoch 1 - Loss: 1.4626, Accuracy: 0.3347


Epoch 2/10: 100%|██████████| 615/615 [04:52<00:00,  2.10it/s, Accuracy=0.25, Loss=1.13]  


Epoch 2 - Loss: 1.1322, Accuracy: 0.3543


Epoch 3/10: 100%|██████████| 615/615 [04:53<00:00,  2.10it/s, Accuracy=0.375, Loss=1.11] 


Epoch 3 - Loss: 1.0866, Accuracy: 0.4046


Epoch 4/10: 100%|██████████| 615/615 [04:53<00:00,  2.09it/s, Accuracy=0.75, Loss=0.884] 


Epoch 4 - Loss: 1.0616, Accuracy: 0.4393


Epoch 5/10: 100%|██████████| 615/615 [04:53<00:00,  2.10it/s, Accuracy=0.625, Loss=0.807]


Epoch 5 - Loss: 1.0203, Accuracy: 0.4894


Epoch 6/10: 100%|██████████| 615/615 [04:53<00:00,  2.10it/s, Accuracy=0.25, Loss=1.06]  


Epoch 6 - Loss: 0.9818, Accuracy: 0.5320


Epoch 7/10: 100%|██████████| 615/615 [04:52<00:00,  2.10it/s, Accuracy=0.625, Loss=1.16] 


Epoch 7 - Loss: 0.9266, Accuracy: 0.5677


Epoch 8/10: 100%|██████████| 615/615 [04:52<00:00,  2.10it/s, Accuracy=0.375, Loss=0.771]


Epoch 8 - Loss: 0.8845, Accuracy: 0.5957


Epoch 9/10: 100%|██████████| 615/615 [04:52<00:00,  2.11it/s, Accuracy=0.75, Loss=0.813] 


Epoch 9 - Loss: 0.8510, Accuracy: 0.6166


Epoch 10/10: 100%|██████████| 615/615 [04:52<00:00,  2.10it/s, Accuracy=0.75, Loss=0.651] 


Epoch 10 - Loss: 0.8158, Accuracy: 0.6396


# MNRL Loss

In [20]:

import torch
import torch.nn.functional as F

def mnr_loss(q_emb: torch.Tensor, p_emb: torch.Tensor, temperature: float = 0.05) -> torch.Tensor:
    # Normalize embeddings
    q = F.normalize(q_emb, p=2, dim=1)
    p = F.normalize(p_emb, p=2, dim=1)

    # Compute cosine similarity matrix: [B, B]
    logits = torch.matmul(q, p.T) / temperature  # scaled dot product

    # Targets are diagonal (i.e., i-th query matches i-th positive)
    labels = torch.arange(logits.size(0)).to(logits.device)

    # Cross entropy loss over in-batch negatives
    return F.cross_entropy(logits, labels)


In [21]:
import os
# os.makedirs("/kaggle/working/checkpoints_phase1",exist_ok=True)
os.makedirs("/kaggle/working/checkpoints_phase2",exist_ok=True)

# NQ Dataset

In [22]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
class NQDataset(Dataset):
  def __init__(self,split="train"):
    self.ds = load_dataset("sentence-transformers/natural-questions",split=split)
  def __len__(self):return len(self.ds)
  def __getitem__(self,idx):
    item=self.ds[idx]
    query=item["query"]
    answer=item["answer"]
    data=[query]
    data=tokenizer.encode_batch(data)
    tokens=torch.tensor([encoding.ids for encoding in data],dtype=torch.long)
    segment_id=torch.tensor([encoding.type_ids for encoding in data],dtype=torch.long)
    mask=torch.tensor([encoding.attention_mask for encoding in data],dtype=torch.long)

    passage=tokenizer.encode_batch([answer])
    passage_tokens=torch.tensor([encoding.ids for encoding in passage],dtype=torch.long)
    passage_segment_id=torch.tensor([encoding.type_ids for encoding in passage],dtype=torch.long)
    passage_mask=torch.tensor([encoding.attention_mask for encoding in passage],dtype=torch.long)

    return {"token":tokens,"segment_id":segment_id,"mask":mask,"passage_token":passage_tokens,"passage_segment_id":passage_segment_id,"passage_mask":passage_mask}





In [23]:
nq_dataset=NQDataset()
nq_data_loader=DataLoader(nq_dataset,batch_size=32,shuffle=True,pin_memory=True)

README.md: 0.00B [00:00, ?B/s]

pair/train-00000-of-00001.parquet:   0%|          | 0.00/44.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100231 [00:00<?, ? examples/s]

In [26]:
check_model=torch.load("/kaggle/input/my-tokenizer/phase_2_best (1).pt")
check_optimizer=torch.load("/kaggle/input/my-tokenizer/phase_2_optimizer.pt")
model.load_state_dict(check_model["model"])
model_optimizer.load_state_dict(check_optimizer["optimizer"])

# MNRL Training Loop

In [None]:
from tqdm import tqdm
import gc

accum_steps = 2  # Simulate batch size 64 (32 × 2)
model.train()
best_loss = 1e9
# classifier.train()

for epoch in range(5):
    train_loss = 0.0
    loop = tqdm(enumerate(nq_data_loader), total=len(nq_data_loader), desc=f"Epoch {epoch+1}/10")

    model_optimizer.zero_grad()  # Important: start clean

    for step, batch in loop:
        query_token = batch["token"].squeeze(1).to(device)
        query_segment_id = batch["segment_id"].squeeze(1).to(device)
        query_mask = batch["mask"].squeeze(1).to(device)

        passage_token = batch["passage_token"].squeeze(1).to(device)
        passage_segment_id = batch["passage_segment_id"].squeeze(1).to(device)
        passage_mask = batch["passage_mask"].squeeze(1).to(device)

        with autocast(device):
            q_emb = model(query_token, query_segment_id, query_mask)
            p_emb = model(passage_token, passage_segment_id, passage_mask)

            loss = mnr_loss(q_emb, p_emb) / accum_steps  # Normalize loss

        scaler.scale(loss).backward()

        if (step + 1) % accum_steps == 0 or (step + 1) == len(nq_data_loader):
            scaler.step(model_optimizer)
            scaler.update()
            model_optimizer.zero_grad()

        train_loss += loss.item() * accum_steps  # Reverse normalization for logging
        loop.set_postfix(Loss=loss.item() * accum_steps)

    gc.collect()
    torch.cuda.empty_cache()

    avg_loss = train_loss / len(nq_data_loader)
    torch.save({"model": model.state_dict()}, f"/kaggle/working/checkpoints_phase2/phase_2_last.pt")
    if avg_loss < best_loss:
        torch.save({"model": model.state_dict()}, f"/kaggle/working/checkpoints_phase2/phase_2_best.pt")
        best_loss = avg_loss

    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")


Epoch 1/10: 100%|██████████| 3133/3133 [50:00<00:00,  1.04it/s, Loss=0.628]


Epoch 1 - Loss: 1.3811


Epoch 2/10: 100%|██████████| 3133/3133 [49:55<00:00,  1.05it/s, Loss=1.18] 


Epoch 2 - Loss: 1.3520


Epoch 3/10:  54%|█████▍    | 1696/3133 [27:02<22:50,  1.05it/s, Loss=1.06] 

# Save the model