In [1]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F
from tokenizer import Tokenizer

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from enums import PAD_ID, MAX_LEN, CLS_ID
from sent_transformer import SentenceTransformer

In [2]:
from datasets import load_dataset, concatenate_datasets
from collections import defaultdict
import random
from tqdm import tqdm

# Load and merge STSB train + validation datasets
sts = load_dataset("sentence-transformers/stsb")
sts_trainval = concatenate_datasets([sts["train"], sts["validation"]])


# Create associated sentences map
pairs = defaultdict(list)
for row in sts_trainval:
    s1, s2, score = row["sentence1"], row["sentence2"], row["score"]
    pairs[s1].append((s2, score))
    pairs[s2].append((s1, score))

# Create a sentence pool to randomly sample negatives.
sent_pool = list(set(sts_trainval["sentence1"] + sts_trainval["sentence2"]))

# Set thresholds for positive, negative, and num triplets for each anchor 
POS_T, NEG_T, K = 0.6, 0.5, 3

# Generate triplets for STSB dataset by randomly sampling extra negatives from sent_pool.
all_triplets = set()
for anchor, lst in tqdm(pairs.items(), desc="STSB trainval triplets"):
    pos = [s for s, sc in lst if sc >= POS_T]
    neg = [s for s, sc in lst if sc <= NEG_T]
    if not pos:
        continue
    for p in pos:
        for _ in range(K):
            if neg:
                n = random.choice(neg)
            else:
                n = random.choice(sent_pool)
                while n in (anchor, p):
                    n = random.choice(sent_pool)
            all_triplets.add((anchor, p, n))


# Load QQP_triplets dataset
qqp = load_dataset("embedding-data/QQP_triplets")
for split in qqp:
    for row in tqdm(qqp[split], desc=f"QQP {split}"):
        query = row["set"]["query"]
        positives = row["set"]["pos"]
        negatives = row["set"]["neg"]
        for p in positives:
            for n in negatives:
                all_triplets.add((query, p, n))

# Shuffle
all_triplets = list(all_triplets)
random.seed(42)
random.shuffle(all_triplets)

# 70/20/10 Train/Val/Test split
n = len(all_triplets)
n_train = int(0.7 * n)
n_val   = int(0.2 * n)

train_data = all_triplets[:n_train]
val_data   = all_triplets[n_train : n_train + n_val]
test_data  = all_triplets[n_train + n_val : ]


print(f"Total triplets: {n}")
print(f" Train: {len(train_data)}")
print(f" Val:   {len(val_data)}")
print(f" Test:  {len(test_data)}")


STSB trainval triplets: 100%|█████████| 13227/13227 [00:00<00:00, 727609.73it/s]
QQP train: 100%|█████████████████████| 101762/101762 [00:03<00:00, 29732.52it/s]


Total triplets: 2808290
 Train: 1965802
 Val:   561658
 Test:  280830


In [3]:
class TripletDataset(Dataset):
    def __init__(self, triplets, tokenizer, max_len=MAX_LEN):
        self.triplets = triplets
        self.tok = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a,p,n = self.triplets[idx]
        return [self.tok.tokenize(s).tolist() for s in (a,p,n)]


def collate(batch):
    a, p, n = zip(*batch)

    def pad(seq):
        L = max(len(s) for s in seq)
        batch_ids = [s + [PAD_ID] * (L - len(s)) for s in seq]
        attn_mask = [[1] * len(s) + [0] * (L - len(s)) for s in seq]
        return (
            torch.tensor(batch_ids, dtype=torch.long),
            torch.tensor(attn_mask, dtype=torch.long)
        )

    a_ids, a_mask = pad(a)
    p_ids, p_mask = pad(p)
    n_ids, n_mask = pad(n)

    return (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask)

In [4]:
tokenizer = Tokenizer("bpe_merged.json")
train_dataset = TripletDataset(train_data, tokenizer)
val_dataset = TripletDataset(val_data, tokenizer)
test_dataset   = TripletDataset(test_data, tokenizer)


train_loader = DataLoader(
    TripletDataset(train_data, tokenizer),
    batch_size=32, shuffle=True,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)
val_loader = DataLoader(
    TripletDataset(val_data, tokenizer),
    batch_size=32, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)
test_loader = DataLoader(
    TripletDataset(test_data, tokenizer),
    batch_size=32, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)

In [7]:
# !pip install transformers

In [15]:
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F

# 1) Choose a smaller BERT
teacher_name     = "google/bert_uncased_L-2_H-128_A-2"   # hidden_size=256
teacher_tokenizer= AutoTokenizer.from_pretrained(teacher_name)
teacher_model    = AutoModel.from_pretrained(teacher_name)
teacher_model.eval()

# # 2) In your distillation loop
# raw_texts = [ "This is an example.", "Another sentence." ]
# teach_inputs = teacher_tokenizer(
#     raw_texts, padding=True, truncation=True, return_tensors="pt"
# )
# with torch.no_grad():
#     out = teacher_model(**teach_inputs).last_hidden_state  # (B, L, 256)
#     teacher_cls = out[:,0]                                 # (B, 256)

# # 3) Your student still runs on BPE + custom encoder → student_cls: (B, D)
# #    Then distillation: e.g. MSE:
# loss = F.mse_loss(student_cls, teacher_cls)


config.json:   0%|          | 0.00/382 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-1): 2 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


In [38]:
raw_texts = ["Hi, How are you", "How are you, Hi", "Hi", "Wonder ful saar"]       # e.g. list[str] you collate


In [43]:
t_out.last_hidden_state.mean(dim=1).tolist()

[[-0.8823256492614746,
  0.4203495383262634,
  -0.7122400999069214,
  -1.6795321702957153,
  -0.1700838953256607,
  0.018848402425646782,
  0.11936243623495102,
  0.6452649235725403,
  -1.188864827156067,
  -0.3175344169139862,
  -0.022278079763054848,
  0.27063825726509094,
  -0.1428239643573761,
  0.6041097640991211,
  1.7108962535858154,
  -1.1337039470672607,
  -0.7393425107002258,
  0.47491124272346497,
  -1.1764651536941528,
  1.3366515636444092,
  -0.47980427742004395,
  0.5489088892936707,
  1.4174871444702148,
  0.19421899318695068,
  1.1056174039840698,
  -0.3610375225543976,
  0.5508883595466614,
  0.5334464311599731,
  0.4596266746520996,
  -0.2989594638347626,
  -1.8627790212631226,
  -2.7980494499206543,
  -0.5149377584457397,
  0.5198287963867188,
  0.5482035875320435,
  -0.22754310071468353,
  0.4483339488506317,
  0.3482033908367157,
  -1.9715547561645508,
  -0.82817143201828,
  1.2674686908721924,
  -1.0244967937469482,
  0.508489191532135,
  -1.007757306098938,
  -0.