In [None]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F
from tokenizer import Tokenizer

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from enums import PAD_ID, MAX_LEN, CLS_ID
from sent_transformer import SentenceTransformer

# Senetence Transformer Overview
## Datasets:
* STS-B (Semantic Textual Similarity Benchmark)
  
> Sentences are paired with similarity scores. We use high-score pairs as positives (score ≥ 0.6) and low-score pairs as negatives (score ≤ 0.5).
* QQP (Quora Question Pairs)

> Provided triplet dataset with queries, positive matches, and negative examples.
    

#### The two datasets are merged to create a diverse training pool.
#### We build triplets (anchor, positive, negative) and split into train (70%), validation (20%), and test (10%) sets.

> #### Total triplets: ~323,000.


## Senetence Transformer Architecture

### The sentence transformer has a minimalist architecture.

> Embedding Layer: Converts token IDs to dense vectors.

> Sinusoidal Positional Encoding: Adds positional information without training.

> Transformer Encoder: Stacks num_layers layers of basic self-attention blocks.

> CLS Pooling: Extracts the representation of the [CLS] token (first token). Can be used for training Classification Tasks Downstream.

> Projection Head: Represents the learned Transformer Encoder using a MLP layer.


## Pretraining Architecture

### We pretrain the sentence transformer using triplet loss.

> Inputs: Each batch contains (anchor, positive, negative) sentence triplets.

> Forward pass: Encode anchor, positive, and negative independently and Get their embeddings.

> Minimize triplet loss to pull positive pairs closer and push negative pairs farther apart in embedding space.

## Training enhancements:

> Mixed Precision (AMP) for faster and memory-efficient training.

> Gradient Clipping to avoid exploding gradients.

> Cosine Annealing Scheduler to gradually reduce the learning rate.

> Early Stopping based on validation loss.

## Inference
> Compute similarity of anchor-positive sentences.
> Compute similarity of anchor-negative sentences.


### Important Notes:

* Sentence Transformer is defined in `sent_transformer.py` because we will be calling it for Multi-Task Expansion
* Lightweight pretraining dataset and architecture, so maximal accuracy cannot be acheived on downstream tasks.
* Positional encodings are fixed, not learned.
* All attention operations are standard PyTorch Transformer blocks.
* Further reading:
  
  * Attention Is All You Need: https://arxiv.org/abs/1706.03762
  * Sentence-BERT: https://arxiv.org/abs/1908.10084

In [None]:
from datasets import load_dataset, concatenate_datasets
from collections import defaultdict
import random
from tqdm import tqdm

# Load and merge STSB train + validation datasets
sts = load_dataset("sentence-transformers/stsb")
sts_trainval = concatenate_datasets([sts["train"], sts["validation"]])


# Create associated sentences map
pairs = defaultdict(list)
for row in sts_trainval:
    s1, s2, score = row["sentence1"], row["sentence2"], row["score"]
    pairs[s1].append((s2, score))
    pairs[s2].append((s1, score))

# Create a sentence pool to randomly sample negatives.
sent_pool = list(set(sts_trainval["sentence1"] + sts_trainval["sentence2"]))

# Set thresholds for positive, negative, and num triplets for each anchor 
POS_T, NEG_T, K = 0.6, 0.5, 3

# Generate triplets for STSB dataset by randomly sampling extra negatives from sent_pool.
all_triplets = set()
for anchor, lst in tqdm(pairs.items(), desc="STSB trainval triplets"):
    pos = [s for s, sc in lst if sc >= POS_T]
    neg = [s for s, sc in lst if sc <= NEG_T]
    if not pos:
        continue
    for p in pos:
        for _ in range(K):
            if neg:
                n = random.choice(neg)
            else:
                n = random.choice(sent_pool)
                while n in (anchor, p):
                    n = random.choice(sent_pool)
            all_triplets.add((anchor, p, n))


# Load QQP_triplets dataset
qqp = load_dataset("embedding-data/QQP_triplets")
for split in qqp:
    for row in tqdm(qqp[split], desc=f"QQP {split}"):
        query = row["set"]["query"]
        positives = row["set"]["pos"]
        negatives = row["set"]["neg"]
        for p in positives:
            for n in negatives:
                all_triplets.add((query, p, n))

In [None]:
# Shuffle
all_triplets = list(all_triplets)
random.seed(42)
random.shuffle(all_triplets)

In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
from tqdm import tqdm

teacher_name = "google/bert_uncased_L-2_H-128_A-2"
device       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_name)
teacher_model     = AutoModel.from_pretrained(teacher_name).eval().to(device)

all_texts = set()
for a, p, n in all_triplets:
    all_texts.add(a); all_texts.add(p); all_texts.add(n)
all_texts = list(all_texts)

teacher_embeds = {}
batch_size = 512
for i in tqdm(range(0, len(all_texts), batch_size), desc="Embedding teacher"):
    batch_txt = all_texts[i:i+batch_size]
    enc = teacher_tokenizer(
        batch_txt,
        padding=True,
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt"
    ).to(device)
    with torch.no_grad():
        out = teacher_model(**enc)
        embs = out.pooler_output.cpu()   # (B, hidden_dim)
    for text, emb in zip(batch_txt, embs):
        teacher_embeds[text] = emb

torch.save(teacher_embeds, "teacher_embeddings.pt")

In [None]:
# 70/20/10 Train/Val/Test split
n = len(all_triplets)
n_train = int(0.7 * n)
n_val   = int(0.2 * n)

train_data = all_triplets[:n_train]
val_data   = all_triplets[n_train : n_train + n_val]
test_data  = all_triplets[n_train + n_val : ]


print(f"Total triplets: {n}")
print(f" Train: {len(train_data)}")
print(f" Val:   {len(val_data)}")
print(f" Test:  {len(test_data)}")

In [None]:
from torch.nn.utils.rnn import pad_sequence


class TripletDataset(Dataset):
    def __init__(self, triplets, tokenizer, max_len=MAX_LEN):
        self.triplets = triplets
        self.tok = tokenizer
        self.max_len = max_len
  

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a,p,n = self.triplets[idx]
        return [(self.tok.tokenize(s), teacher_embeds[s]) for s in (a,p,n)]


def collate(batch):
    anchors, positives, negatives = zip(*batch)
    a_seqs,  t_as = zip(*anchors)
    p_seqs,  t_ps = zip(*positives)
    n_seqs,  t_ns = zip(*negatives)

    def pad_and_mask(seq_tensors):
        truncated = [s[:MAX_LEN] for s in seq_tensors]
        ids = pad_sequence(truncated, batch_first=True, padding_value=PAD_ID)
        if ids.size(1) > MAX_LEN:
            ids = ids[:, :MAX_LEN]
        mask = ids.ne(PAD_ID).long()
        return ids, mask

    a_ids, a_mask = pad_and_mask(a_seqs)
    p_ids, p_mask = pad_and_mask(p_seqs)
    n_ids, n_mask = pad_and_mask(n_seqs)

    t_a = torch.stack(t_as, dim=0)
    t_p = torch.stack(t_ps, dim=0)
    t_n = torch.stack(t_ns, dim=0)

    return (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), t_a, t_p, t_n



In [None]:
tokenizer = Tokenizer("bpe_merged.json")
train_dataset = TripletDataset(train_data, tokenizer)
val_dataset = TripletDataset(val_data, tokenizer)
test_dataset   = TripletDataset(test_data, tokenizer)


train_loader = DataLoader(
    TripletDataset(train_data, tokenizer),
    batch_size=16, shuffle=True,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)
val_loader = DataLoader(
    TripletDataset(val_data, tokenizer),
    batch_size=16, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)
test_loader = DataLoader(
    TripletDataset(test_data, tokenizer),
    batch_size=16, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)

In [None]:
from enums import D_MODEL, NHEAD, NUM_LAYERS, PROJ_DIM, DROPOUT, MARGIN


class PretrainSentenceTransformer(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model=D_MODEL,
                 nhead=NHEAD,
                 num_layers=NUM_LAYERS,
                 proj_dim=PROJ_DIM,
                 margin=MARGIN,
                 dropout=DROPOUT):
        super().__init__()
        self.encoder = SentenceTransformer(
            vocab_size, d_model, nhead, num_layers, proj_dim, dropout
        )
        # cosine-distance triplet loss
        self.cos = nn.CosineSimilarity(dim=1)
        self.triplet = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1 - self.cos(x, y),
            margin=margin
        )
    
    def forward(self, a, a_mask, p, p_mask, n, n_mask):
        cls_a, e_a = self.encoder(a, attention_mask=a_mask, return_all=True)
        cls_p, e_p = self.encoder(p, attention_mask=p_mask, return_all=True)
        cls_n, e_n = self.encoder(n, attention_mask=n_mask, return_all=True)
        return cls_a, e_a, cls_p, e_p, cls_n, e_n

    def training_step(self, batch):
        (a, a_mask), (p, p_mask), (n, n_mask), t_a, t_p, t_n = batch
        cls_a, e_a, cls_p, e_p, cls_n, e_n = self(a, a_mask, p, p_mask, n, n_mask)

        loss_trip  = self.triplet(cls_a, cls_p, cls_n)

        loss_da = (1 - self.cos(cls_a, t_a)).mean()
        loss_dp = (1 - self.cos(cls_p, t_p)).mean()
        loss_dn = (1 - self.cos(cls_n, t_n)).mean()

        return loss_trip + loss_da + loss_dp + loss_dn


In [None]:
import os


device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_size = max(tokenizer.rev_merged.keys()) + 3

model     = PretrainSentenceTransformer(vocab_size).to(device)
opt       = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs    = 50
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

# AMP setup
use_amp = torch.cuda.is_available()
scaler  = torch.cuda.amp.GradScaler() if use_amp else None

# Early stopping & checkpointing params
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

best_val_loss = float('inf')
patience      = 3
no_improve    = 0

# Training
for epoch in range(1, epochs + 1):
    model.train()
    train_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), t_as, t_ps, t_ns = batch
        a_ids, a_mask = a_ids.to(device), a_mask.to(device)
        p_ids, p_mask = p_ids.to(device), p_mask.to(device)
        n_ids, n_mask = n_ids.to(device), n_mask.to(device)
        t_as, t_ps, t_ns = t_as.to(device), t_ps.to(device), t_ns.to(device)

        opt.zero_grad()

        if use_amp:
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), t_as, t_ps, t_ns))
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt)
            scaler.update()
        else:
            loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), t_as, t_ps, t_ns))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), t_as, t_ps, t_ns = batch
            a_ids, a_mask = a_ids.to(device), a_mask.to(device)
            p_ids, p_mask = p_ids.to(device), p_mask.to(device)
            n_ids, n_mask = n_ids.to(device), n_mask.to(device)
            t_as, t_ps, t_ns = t_as.to(device), t_ps.to(device), t_ns.to(device)

            if use_amp:
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), t_as, t_ps, t_ns))
            else:
                loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), t_as, t_ps, t_ns))

            val_loss += loss.item()

    val_loss /= len(val_loader)

    # Learning Rate Schedule
    scheduler.step()
    lr = scheduler.get_last_lr()[0]

    print(f"Epoch {epoch}/{epochs} — "
          f"train_loss={train_loss:.5f}  val_loss={val_loss:.5f}  lr={lr:.1e}")

    # Checkpoint after each epoch.
    # Notice we are only checkpointing the model encoder as it is is the sentence transformer.
    enc_state = model.encoder.state_dict()
    torch.save({
        'epoch': epoch,
        'encoder_state_dict': enc_state,
        'optimizer_state_dict': opt.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, os.path.join(checkpoint_dir, f"encoder_epoch_{epoch}.pt"))

    torch.save(enc_state, os.path.join(checkpoint_dir, f"epoch_{epoch}.pt"))

    # Early stopping to avoid overfitting
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0
        torch.save(model.encoder.state_dict(), "best_encoder.pt")
        print("  ↳ New best Sentence Transformer saved.")
    else:
        no_improve += 1
        print(f"  ↳ No improvement for {no_improve} epoch(s).")
        if no_improve >= patience:
            print(f"Stopping early after {patience} epochs without improvement.")
            break

In [None]:
# Sanity check on test data.
# Notice the positive sentences have a greater similarity than the negative sentances when compared to the anchor sentences.
for idx in np.random.randint(1,100,10):
    (a, a_mask), (p, p_mask), (n, n_mask) = collate([test_dataset[idx]])

    a_ids, a_mask = a.to(device), a_mask.to(device)
    p_ids, p_mask = p.to(device), p_mask.to(device)
    n_ids, n_mask = n.to(device), n_mask.to(device)

    with torch.no_grad():
        e_a, e_p, e_n = model(a, a_mask, p, p_mask, n, n_mask)
        cos = F.cosine_similarity(e_a, e_p).item()
        cos_neg = F.cosine_similarity(e_a, e_n).item()
    
    print(f"cos(+): {cos:.3f}  cos(-): {cos_neg:.3f}")
