In [2]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F
from tokenizer import Tokenizer

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from enums import PAD_ID, MAX_LEN, CLS_ID
from sent_transformer import SentenceTransformer

# Senetence Transformer Overview
## Datasets:
* STS-B (Semantic Textual Similarity Benchmark)
  
> Sentences are paired with similarity scores. We use high-score pairs as positives (score ≥ 0.6) and low-score pairs as negatives (score ≤ 0.5).
* QQP (Quora Question Pairs)

> Provided triplet dataset with queries, positive matches, and negative examples.
    

#### The two datasets are merged to create a diverse training pool.
#### We build triplets (anchor, positive, negative) and split into train (70%), validation (20%), and test (10%) sets.

> #### Total triplets: ~323,000.


## Senetence Transformer Architecture

### The sentence transformer has a minimalist architecture.

> Embedding Layer: Converts token IDs to dense vectors.

> Sinusoidal Positional Encoding: Adds positional information without training.

> Transformer Encoder: Stacks num_layers layers of basic self-attention blocks.

> CLS Pooling: Extracts the representation of the [CLS] token (first token). Can be used for training Classification Tasks Downstream.

> Projection Head: Represents the learned Transformer Encoder using a MLP layer.


## Pretraining Architecture

### We pretrain the sentence transformer using triplet loss.

> Inputs: Each batch contains (anchor, positive, negative) sentence triplets.

> Forward pass: Encode anchor, positive, and negative independently and Get their embeddings.

> Minimize triplet loss to pull positive pairs closer and push negative pairs farther apart in embedding space.

## Training enhancements:

> Mixed Precision (AMP) for faster and memory-efficient training.

> Gradient Clipping to avoid exploding gradients.

> Cosine Annealing Scheduler to gradually reduce the learning rate.

> Early Stopping based on validation loss.

## Inference
> Compute similarity of anchor-positive sentences.
> Compute similarity of anchor-negative sentences.


### Important Notes:

* Sentence Transformer is defined in `sent_transformer.py` because we will be calling it for Multi-Task Expansion
* Lightweight pretraining dataset and architecture, so maximal accuracy cannot be acheived on downstream tasks.
* Positional encodings are fixed, not learned.
* All attention operations are standard PyTorch Transformer blocks.
* Further reading:
  
  * Attention Is All You Need: https://arxiv.org/abs/1706.03762
  * Sentence-BERT: https://arxiv.org/abs/1908.10084

In [3]:
from datasets import load_dataset, concatenate_datasets
from collections import defaultdict
import random
from tqdm import tqdm

# Load and merge STSB train + validation datasets
sts = load_dataset("sentence-transformers/stsb")
sts_trainval = concatenate_datasets([sts["train"], sts["validation"]])


# Create associated sentences map
pairs = defaultdict(list)
for row in sts_trainval:
    s1, s2, score = row["sentence1"], row["sentence2"], row["score"]
    pairs[s1].append((s2, score))
    pairs[s2].append((s1, score))

# Create a sentence pool to randomly sample negatives.
sent_pool = list(set(sts_trainval["sentence1"] + sts_trainval["sentence2"]))

# Set thresholds for positive, negative, and num triplets for each anchor 
POS_T, NEG_T, K = 0.6, 0.5, 3

# Generate triplets for STSB dataset by randomly sampling extra negatives from sent_pool.
all_triplets = set()
for anchor, lst in tqdm(pairs.items(), desc="STSB trainval triplets"):
    pos = [s for s, sc in lst if sc >= POS_T]
    neg = [s for s, sc in lst if sc <= NEG_T]
    if not pos:
        continue
    for p in pos:
        for _ in range(K):
            if neg:
                n = random.choice(neg)
            else:
                n = random.choice(sent_pool)
                while n in (anchor, p):
                    n = random.choice(sent_pool)
            all_triplets.add((anchor, p, n))


# Load QQP_triplets dataset
qqp = load_dataset("embedding-data/QQP_triplets")
for split in qqp:
    for row in tqdm(qqp[split], desc=f"QQP {split}"):
        query = row["set"]["query"]
        positives = row["set"]["pos"]
        negatives = row["set"]["neg"]
        for p in positives:
            for n in negatives:
                all_triplets.add((query, p, n))

# Shuffle
all_triplets = list(all_triplets)
random.seed(42)
random.shuffle(all_triplets)

# 70/20/10 Train/Val/Test split
n = len(all_triplets)
n_train = int(0.7 * n)
n_val   = int(0.2 * n)

train_data = all_triplets[:n_train]
val_data   = all_triplets[n_train : n_train + n_val]
test_data  = all_triplets[n_train + n_val : ]


print(f"Total triplets: {n}")
print(f" Train: {len(train_data)}")
print(f" Val:   {len(val_data)}")
print(f" Test:  {len(test_data)}")


STS trainval triplets: 100%|██████████| 13227/13227 [00:00<00:00, 839851.33it/s]
QQP train: 100%|█████████████████████| 101762/101762 [00:02<00:00, 42287.27it/s]


Total triplets: 323435
 Train: 226404
 Val:   64687
 Test:  32344


In [4]:
class TripletDataset(Dataset):
    def __init__(self, triplets, tokenizer, max_len=MAX_LEN):
        self.triplets = triplets
        self.tok = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a,p,n = self.triplets[idx]
        return [self.tok.tokenize(s).tolist() for s in (a,p,n)]


def collate(batch):
    a, p, n = zip(*batch)

    def pad(seq):
        L = max(len(s) for s in seq)
        batch_ids = [s + [PAD_ID] * (L - len(s)) for s in seq]
        attn_mask = [[1] * len(s) + [0] * (L - len(s)) for s in seq]
        return (
            torch.tensor(batch_ids, dtype=torch.long),
            torch.tensor(attn_mask, dtype=torch.long)
        )

    a_ids, a_mask = pad(a)
    p_ids, p_mask = pad(p)
    n_ids, n_mask = pad(n)

    return (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask)

In [5]:
tokenizer = Tokenizer("bpe_merged.json")
train_dataset = TripletDataset(train_data, tokenizer)
val_dataset = TripletDataset(val_data, tokenizer)
test_dataset   = TripletDataset(test_data, tokenizer)


train_loader = DataLoader(
    TripletDataset(train_data, tokenizer),
    batch_size=32, shuffle=True,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)
val_loader = DataLoader(
    TripletDataset(val_data, tokenizer),
    batch_size=32, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)
test_loader = DataLoader(
    TripletDataset(test_data, tokenizer),
    batch_size=32, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)

In [6]:
from enums import D_MODEL, NHEAD, NUM_LAYERS, PROJ_DIM, DROPOUT, MARGIN


class PretrainSentenceTransformer(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model=D_MODEL,
                 nhead=NHEAD,
                 num_layers=NUM_LAYERS,
                 proj_dim=PROJ_DIM,
                 margin=MARGIN,
                 dropout=DROPOUT):
        super().__init__()
        self.encoder = SentenceTransformer(
            vocab_size, d_model, nhead, num_layers, proj_dim, dropout
        )
        # cosine-distance triplet loss
        cos_sim = nn.CosineSimilarity(dim=1)
        self.loss_fn = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1 - cos_sim(x, y),
            margin=margin
        )

    def forward(self, a, a_mask, p, p_mask, n, n_mask):
        e_a = self.encoder(a, attention_mask=a_mask)
        e_p = self.encoder(p, attention_mask=p_mask)
        e_n = self.encoder(n, attention_mask=n_mask)
        return e_a, e_p, e_n

    def training_step(self, batch):
        (a, a_mask), (p, p_mask), (n, n_mask) = batch
        e_a, e_p, e_n = self(a, a_mask, p, p_mask, n, n_mask)
        return self.loss_fn(e_a, e_p, e_n)

In [7]:
import os


device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_size = max(tokenizer.rev_merged.keys()) + 3

model     = PretrainSentenceTransformer(vocab_size).to(device)
opt       = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs    = 50
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

# AMP setup
use_amp = torch.cuda.is_available()
scaler  = torch.cuda.amp.GradScaler() if use_amp else None

# Early stopping & checkpointing params
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

best_val_loss = float('inf')
patience      = 3
no_improve    = 0

# Training
for epoch in range(1, epochs+1):
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        # unpack tuples
        (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask) = batch
    
        # move each tensor explicitly to device
        a_ids, a_mask = a_ids.to(device), a_mask.to(device)
        p_ids, p_mask = p_ids.to(device), p_mask.to(device)
        n_ids, n_mask = n_ids.to(device), n_mask.to(device)
    
        if use_amp:
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask)))
            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt)
            scaler.update()
        else:
            loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask)))
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
    
        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask) = batch
            a_ids, a_mask = a_ids.to(device), a_mask.to(device)
            p_ids, p_mask = p_ids.to(device), p_mask.to(device)
            n_ids, n_mask = n_ids.to(device), n_mask.to(device)
        
            with torch.no_grad():
                if use_amp:
                    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                        loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask)))
                else:
                    loss = model.training_step(((a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask)))
                val_loss += loss.item()

    val_loss /= len(val_loader)

    # Learning Rate Schedule
    scheduler.step()
    lr = scheduler.get_last_lr()[0]

    print(f"Epoch {epoch}/{epochs} — "
          f"train_loss={train_loss:.5f}  val_loss={val_loss:.5f}  lr={lr:.1e}")

    # Checkpoint after each epoch.
    # Notice we are only checkpointing the model encoder as it is is the sentence transformer.
    enc_state = model.encoder.state_dict()
    torch.save({
        'epoch': epoch,
        'encoder_state_dict': enc_state,
        'optimizer_state_dict': opt.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, os.path.join(checkpoint_dir, f"encoder_epoch_{epoch}.pt"))

    torch.save(enc_state, os.path.join(checkpoint_dir, f"epoch_{epoch}.pt"))

    # Early stopping to avoid overfitting
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0
        torch.save(model.encoder.state_dict(), "best_encoder.pt")
        print("  ↳ New best Sentence Transformer saved.")
    else:
        no_improve += 1
        print(f"  ↳ No improvement for {no_improve} epoch(s).")
        if no_improve >= patience:
            print(f"Stopping early after {patience} epochs without improvement.")
            break

Epoch 1 Train: 100%|████████████████████████| 7076/7076 [38:52<00:00,  3.03it/s]
  output = torch._nested_tensor_from_mask(
Epoch 1 Val: 100%|██████████████████████████| 2022/2022 [04:24<00:00,  7.63it/s]


Epoch 1/50 — train_loss=0.05900  val_loss=0.04115  lr=1.0e-03
  ↳ New best Sentence Transformer saved.


Epoch 2 Train: 100%|████████████████████████| 7076/7076 [39:03<00:00,  3.02it/s]
Epoch 2 Val: 100%|██████████████████████████| 2022/2022 [04:22<00:00,  7.72it/s]


Epoch 2/50 — train_loss=0.04117  val_loss=0.03261  lr=1.0e-03
  ↳ New best Sentence Transformer saved.


Epoch 3 Train: 100%|████████████████████████| 7076/7076 [43:46<00:00,  2.69it/s]
Epoch 3 Val: 100%|██████████████████████████| 2022/2022 [05:48<00:00,  5.80it/s]


Epoch 3/50 — train_loss=0.03570  val_loss=0.03046  lr=9.9e-04
  ↳ New best Sentence Transformer saved.


Epoch 4 Train: 100%|████████████████████████| 7076/7076 [40:09<00:00,  2.94it/s]
Epoch 4 Val: 100%|██████████████████████████| 2022/2022 [04:26<00:00,  7.58it/s]


Epoch 4/50 — train_loss=0.03216  val_loss=0.02902  lr=9.8e-04
  ↳ New best Sentence Transformer saved.


Epoch 5 Train: 100%|████████████████████████| 7076/7076 [38:32<00:00,  3.06it/s]
Epoch 5 Val: 100%|██████████████████████████| 2022/2022 [04:22<00:00,  7.71it/s]


Epoch 5/50 — train_loss=0.02954  val_loss=0.02693  lr=9.8e-04
  ↳ New best Sentence Transformer saved.


Epoch 6 Train: 100%|████████████████████████| 7076/7076 [38:23<00:00,  3.07it/s]
Epoch 6 Val: 100%|██████████████████████████| 2022/2022 [04:20<00:00,  7.75it/s]


Epoch 6/50 — train_loss=0.02720  val_loss=0.02508  lr=9.6e-04
  ↳ New best Sentence Transformer saved.


Epoch 7 Train: 100%|████████████████████████| 7076/7076 [38:14<00:00,  3.08it/s]
Epoch 7 Val: 100%|██████████████████████████| 2022/2022 [04:20<00:00,  7.76it/s]


Epoch 7/50 — train_loss=0.02587  val_loss=0.02257  lr=9.5e-04
  ↳ New best Sentence Transformer saved.


Epoch 8 Train: 100%|████████████████████████| 7076/7076 [38:28<00:00,  3.07it/s]
Epoch 8 Val: 100%|██████████████████████████| 2022/2022 [04:31<00:00,  7.45it/s]


Epoch 8/50 — train_loss=0.02393  val_loss=0.02203  lr=9.4e-04
  ↳ New best Sentence Transformer saved.


Epoch 9 Train: 100%|████████████████████████| 7076/7076 [38:12<00:00,  3.09it/s]
Epoch 9 Val: 100%|██████████████████████████| 2022/2022 [04:19<00:00,  7.79it/s]


Epoch 9/50 — train_loss=0.02259  val_loss=0.02184  lr=9.2e-04
  ↳ New best Sentence Transformer saved.


Epoch 10 Train: 100%|███████████████████████| 7076/7076 [38:00<00:00,  3.10it/s]
Epoch 10 Val: 100%|█████████████████████████| 2022/2022 [04:30<00:00,  7.48it/s]


Epoch 10/50 — train_loss=0.02133  val_loss=0.01982  lr=9.0e-04
  ↳ New best Sentence Transformer saved.


Epoch 11 Train: 100%|███████████████████████| 7076/7076 [38:03<00:00,  3.10it/s]
Epoch 11 Val: 100%|█████████████████████████| 2022/2022 [04:23<00:00,  7.67it/s]


Epoch 11/50 — train_loss=0.02016  val_loss=0.01929  lr=8.9e-04
  ↳ New best Sentence Transformer saved.


Epoch 12 Train: 100%|███████████████████████| 7076/7076 [37:47<00:00,  3.12it/s]
Epoch 12 Val: 100%|█████████████████████████| 2022/2022 [04:20<00:00,  7.76it/s]


Epoch 12/50 — train_loss=0.01902  val_loss=0.01957  lr=8.6e-04
  ↳ No improvement for 1 epoch(s).


Epoch 13 Train: 100%|███████████████████████| 7076/7076 [38:00<00:00,  3.10it/s]
Epoch 13 Val: 100%|█████████████████████████| 2022/2022 [04:28<00:00,  7.52it/s]


Epoch 13/50 — train_loss=0.01801  val_loss=0.01860  lr=8.4e-04
  ↳ New best Sentence Transformer saved.


Epoch 14 Train: 100%|███████████████████████| 7076/7076 [38:23<00:00,  3.07it/s]
Epoch 14 Val: 100%|█████████████████████████| 2022/2022 [04:26<00:00,  7.60it/s]


Epoch 14/50 — train_loss=0.01711  val_loss=0.01724  lr=8.2e-04
  ↳ New best Sentence Transformer saved.


Epoch 15 Train: 100%|███████████████████████| 7076/7076 [38:29<00:00,  3.06it/s]
Epoch 15 Val: 100%|█████████████████████████| 2022/2022 [04:18<00:00,  7.81it/s]


Epoch 15/50 — train_loss=0.01635  val_loss=0.01648  lr=7.9e-04
  ↳ New best Sentence Transformer saved.


Epoch 16 Train: 100%|███████████████████████| 7076/7076 [38:56<00:00,  3.03it/s]
Epoch 16 Val: 100%|█████████████████████████| 2022/2022 [04:35<00:00,  7.33it/s]


Epoch 16/50 — train_loss=0.01542  val_loss=0.01631  lr=7.7e-04
  ↳ New best Sentence Transformer saved.


Epoch 17 Train: 100%|███████████████████████| 7076/7076 [38:51<00:00,  3.03it/s]
Epoch 17 Val: 100%|█████████████████████████| 2022/2022 [04:27<00:00,  7.57it/s]


Epoch 17/50 — train_loss=0.01477  val_loss=0.01574  lr=7.4e-04
  ↳ New best Sentence Transformer saved.


Epoch 18 Train: 100%|███████████████████████| 7076/7076 [38:22<00:00,  3.07it/s]
Epoch 18 Val: 100%|█████████████████████████| 2022/2022 [04:27<00:00,  7.55it/s]


Epoch 18/50 — train_loss=0.01409  val_loss=0.01615  lr=7.1e-04
  ↳ No improvement for 1 epoch(s).


Epoch 19 Train: 100%|███████████████████████| 7076/7076 [45:43<00:00,  2.58it/s]
Epoch 19 Val: 100%|█████████████████████████| 2022/2022 [04:40<00:00,  7.20it/s]


Epoch 19/50 — train_loss=0.01350  val_loss=0.01524  lr=6.8e-04
  ↳ New best Sentence Transformer saved.


Epoch 20 Train: 100%|███████████████████████| 7076/7076 [48:35<00:00,  2.43it/s]
Epoch 20 Val: 100%|█████████████████████████| 2022/2022 [04:37<00:00,  7.29it/s]


Epoch 20/50 — train_loss=0.01269  val_loss=0.01463  lr=6.5e-04
  ↳ New best Sentence Transformer saved.


Epoch 21 Train: 100%|███████████████████████| 7076/7076 [43:25<00:00,  2.72it/s]
Epoch 21 Val: 100%|█████████████████████████| 2022/2022 [06:13<00:00,  5.42it/s]


Epoch 21/50 — train_loss=0.01203  val_loss=0.01432  lr=6.2e-04
  ↳ New best Sentence Transformer saved.


Epoch 22 Train: 100%|███████████████████████| 7076/7076 [48:49<00:00,  2.42it/s]
Epoch 22 Val: 100%|█████████████████████████| 2022/2022 [06:30<00:00,  5.17it/s]


Epoch 22/50 — train_loss=0.01149  val_loss=0.01405  lr=5.9e-04
  ↳ New best Sentence Transformer saved.


Epoch 23 Train: 100%|███████████████████████| 7076/7076 [50:07<00:00,  2.35it/s]
Epoch 23 Val: 100%|█████████████████████████| 2022/2022 [06:26<00:00,  5.23it/s]


Epoch 23/50 — train_loss=0.01085  val_loss=0.01387  lr=5.6e-04
  ↳ New best Sentence Transformer saved.


Epoch 24 Train: 100%|███████████████████████| 7076/7076 [51:59<00:00,  2.27it/s]
Epoch 24 Val: 100%|█████████████████████████| 2022/2022 [06:31<00:00,  5.17it/s]


Epoch 24/50 — train_loss=0.01022  val_loss=0.01327  lr=5.3e-04
  ↳ New best Sentence Transformer saved.


Epoch 25 Train: 100%|███████████████████████| 7076/7076 [52:28<00:00,  2.25it/s]
Epoch 25 Val: 100%|█████████████████████████| 2022/2022 [06:08<00:00,  5.49it/s]


Epoch 25/50 — train_loss=0.00982  val_loss=0.01328  lr=5.0e-04
  ↳ No improvement for 1 epoch(s).


Epoch 26 Train: 100%|███████████████████████| 7076/7076 [48:11<00:00,  2.45it/s]
Epoch 26 Val: 100%|█████████████████████████| 2022/2022 [06:31<00:00,  5.16it/s]


Epoch 26/50 — train_loss=0.00913  val_loss=0.01254  lr=4.7e-04
  ↳ New best Sentence Transformer saved.


Epoch 27 Train: 100%|███████████████████████| 7076/7076 [44:14<00:00,  2.67it/s]
Epoch 27 Val: 100%|█████████████████████████| 2022/2022 [04:38<00:00,  7.26it/s]


Epoch 27/50 — train_loss=0.00868  val_loss=0.01251  lr=4.4e-04
  ↳ New best Sentence Transformer saved.


Epoch 28 Train: 100%|███████████████████████| 7076/7076 [46:43<00:00,  2.52it/s]
Epoch 28 Val: 100%|█████████████████████████| 2022/2022 [07:53<00:00,  4.27it/s]


Epoch 28/50 — train_loss=0.00824  val_loss=0.01233  lr=4.1e-04
  ↳ New best Sentence Transformer saved.


Epoch 29 Train: 100%|███████████████████████| 7076/7076 [58:18<00:00,  2.02it/s]
Epoch 29 Val: 100%|█████████████████████████| 2022/2022 [07:08<00:00,  4.72it/s]


Epoch 29/50 — train_loss=0.00784  val_loss=0.01198  lr=3.8e-04
  ↳ New best Sentence Transformer saved.


Epoch 30 Train: 100%|███████████████████████| 7076/7076 [58:48<00:00,  2.01it/s]
Epoch 30 Val: 100%|█████████████████████████| 2022/2022 [06:04<00:00,  5.55it/s]


Epoch 30/50 — train_loss=0.00734  val_loss=0.01137  lr=3.5e-04
  ↳ New best Sentence Transformer saved.


Epoch 31 Train: 100%|███████████████████████| 7076/7076 [45:47<00:00,  2.58it/s]
Epoch 31 Val: 100%|█████████████████████████| 2022/2022 [04:39<00:00,  7.24it/s]


Epoch 31/50 — train_loss=0.00688  val_loss=0.01147  lr=3.2e-04
  ↳ No improvement for 1 epoch(s).


Epoch 32 Train: 100%|███████████████████████| 7076/7076 [39:44<00:00,  2.97it/s]
Epoch 32 Val: 100%|█████████████████████████| 2022/2022 [04:38<00:00,  7.27it/s]


Epoch 32/50 — train_loss=0.00655  val_loss=0.01112  lr=2.9e-04
  ↳ New best Sentence Transformer saved.


Epoch 33 Train: 100%|███████████████████████| 7076/7076 [42:07<00:00,  2.80it/s]
Epoch 33 Val: 100%|█████████████████████████| 2022/2022 [05:29<00:00,  6.14it/s]


Epoch 33/50 — train_loss=0.00622  val_loss=0.01088  lr=2.6e-04
  ↳ New best Sentence Transformer saved.


Epoch 34 Train: 100%|███████████████████████| 7076/7076 [45:25<00:00,  2.60it/s]
Epoch 34 Val: 100%|█████████████████████████| 2022/2022 [05:31<00:00,  6.11it/s]


Epoch 34/50 — train_loss=0.00592  val_loss=0.01075  lr=2.3e-04
  ↳ New best Sentence Transformer saved.


Epoch 35 Train: 100%|███████████████████████| 7076/7076 [45:19<00:00,  2.60it/s]
Epoch 35 Val: 100%|█████████████████████████| 2022/2022 [05:43<00:00,  5.89it/s]


Epoch 35/50 — train_loss=0.00551  val_loss=0.01059  lr=2.1e-04
  ↳ New best Sentence Transformer saved.


Epoch 36 Train: 100%|███████████████████████| 7076/7076 [43:00<00:00,  2.74it/s]
Epoch 36 Val: 100%|█████████████████████████| 2022/2022 [04:37<00:00,  7.30it/s]


Epoch 36/50 — train_loss=0.00519  val_loss=0.01039  lr=1.8e-04
  ↳ New best Sentence Transformer saved.


Epoch 37 Train: 100%|███████████████████████| 7076/7076 [39:48<00:00,  2.96it/s]
Epoch 37 Val: 100%|█████████████████████████| 2022/2022 [04:39<00:00,  7.23it/s]


Epoch 37/50 — train_loss=0.00504  val_loss=0.01024  lr=1.6e-04
  ↳ New best Sentence Transformer saved.


Epoch 38 Train: 100%|███████████████████████| 7076/7076 [39:45<00:00,  2.97it/s]
Epoch 38 Val: 100%|█████████████████████████| 2022/2022 [04:20<00:00,  7.75it/s]


Epoch 38/50 — train_loss=0.00474  val_loss=0.01012  lr=1.4e-04
  ↳ New best Sentence Transformer saved.


Epoch 39 Train: 100%|███████████████████████| 7076/7076 [39:40<00:00,  2.97it/s]
Epoch 39 Val: 100%|█████████████████████████| 2022/2022 [04:37<00:00,  7.28it/s]


Epoch 39/50 — train_loss=0.00455  val_loss=0.00999  lr=1.1e-04
  ↳ New best Sentence Transformer saved.


Epoch 40 Train: 100%|███████████████████████| 7076/7076 [39:40<00:00,  2.97it/s]
Epoch 40 Val: 100%|█████████████████████████| 2022/2022 [04:37<00:00,  7.28it/s]


Epoch 40/50 — train_loss=0.00433  val_loss=0.00985  lr=9.5e-05
  ↳ New best Sentence Transformer saved.


Epoch 41 Train: 100%|███████████████████████| 7076/7076 [39:15<00:00,  3.00it/s]
Epoch 41 Val: 100%|█████████████████████████| 2022/2022 [04:34<00:00,  7.36it/s]


Epoch 41/50 — train_loss=0.00412  val_loss=0.00975  lr=7.8e-05
  ↳ New best Sentence Transformer saved.


Epoch 42 Train: 100%|███████████████████████| 7076/7076 [39:24<00:00,  2.99it/s]
Epoch 42 Val: 100%|█████████████████████████| 2022/2022 [04:36<00:00,  7.32it/s]


Epoch 42/50 — train_loss=0.00391  val_loss=0.00969  lr=6.2e-05
  ↳ New best Sentence Transformer saved.


Epoch 43 Train: 100%|███████████████████████| 7076/7076 [39:23<00:00,  2.99it/s]
Epoch 43 Val: 100%|█████████████████████████| 2022/2022 [04:34<00:00,  7.36it/s]


Epoch 43/50 — train_loss=0.00381  val_loss=0.00968  lr=4.8e-05
  ↳ New best Sentence Transformer saved.


Epoch 44 Train: 100%|███████████████████████| 7076/7076 [39:07<00:00,  3.01it/s]
Epoch 44 Val: 100%|█████████████████████████| 2022/2022 [04:36<00:00,  7.32it/s]


Epoch 44/50 — train_loss=0.00373  val_loss=0.00949  lr=3.5e-05
  ↳ New best Sentence Transformer saved.


Epoch 45 Train: 100%|███████████████████████| 7076/7076 [39:16<00:00,  3.00it/s]
Epoch 45 Val: 100%|█████████████████████████| 2022/2022 [04:34<00:00,  7.38it/s]


Epoch 45/50 — train_loss=0.00358  val_loss=0.00946  lr=2.4e-05
  ↳ New best Sentence Transformer saved.


Epoch 46 Train: 100%|███████████████████████| 7076/7076 [39:20<00:00,  3.00it/s]
Epoch 46 Val: 100%|█████████████████████████| 2022/2022 [04:30<00:00,  7.47it/s]


Epoch 46/50 — train_loss=0.00347  val_loss=0.00945  lr=1.6e-05
  ↳ New best Sentence Transformer saved.


Epoch 47 Train: 100%|███████████████████████| 7076/7076 [39:01<00:00,  3.02it/s]
Epoch 47 Val: 100%|█████████████████████████| 2022/2022 [04:37<00:00,  7.30it/s]


Epoch 47/50 — train_loss=0.00339  val_loss=0.00943  lr=8.9e-06
  ↳ New best Sentence Transformer saved.


Epoch 48 Train: 100%|███████████████████████| 7076/7076 [39:19<00:00,  3.00it/s]
Epoch 48 Val: 100%|█████████████████████████| 2022/2022 [04:33<00:00,  7.41it/s]


Epoch 48/50 — train_loss=0.00339  val_loss=0.00940  lr=3.9e-06
  ↳ New best Sentence Transformer saved.


Epoch 49 Train: 100%|███████████████████████| 7076/7076 [38:57<00:00,  3.03it/s]
Epoch 49 Val: 100%|█████████████████████████| 2022/2022 [04:37<00:00,  7.30it/s]


Epoch 49/50 — train_loss=0.00345  val_loss=0.00940  lr=9.9e-07
  ↳ No improvement for 1 epoch(s).


Epoch 50 Train: 100%|███████████████████████| 7076/7076 [39:33<00:00,  2.98it/s]
Epoch 50 Val: 100%|█████████████████████████| 2022/2022 [04:38<00:00,  7.25it/s]


Epoch 50/50 — train_loss=0.00331  val_loss=0.00940  lr=0.0e+00
  ↳ New best Sentence Transformer saved.


In [9]:
# Sanity check on test data.
# Notice the positive sentences have a greater similarity than the negative sentances when compared to the anchor sentences.
for idx in np.random.randint(1,100,10):
    (a, a_mask), (p, p_mask), (n, n_mask) = collate([test_dataset[idx]])

    a_ids, a_mask = a.to(device), a_mask.to(device)
    p_ids, p_mask = p.to(device), p_mask.to(device)
    n_ids, n_mask = n.to(device), n_mask.to(device)

    with torch.no_grad():
        e_a, e_p, e_n = model(a, a_mask, p, p_mask, n, n_mask)
        cos = F.cosine_similarity(e_a, e_p).item()
        cos_neg = F.cosine_similarity(e_a, e_n).item()
    
    print(f"cos(+): {cos:.3f}  cos(-): {cos_neg:.3f}")


cos(+): 0.713  cos(-): -0.246
cos(+): 0.730  cos(-): 0.385
cos(+): 0.837  cos(-): 0.271
cos(+): 0.684  cos(-): 0.211
cos(+): 0.613  cos(-): -0.029
cos(+): 0.886  cos(-): 0.088
cos(+): 0.951  cos(-): 0.111
cos(+): 0.935  cos(-): 0.150
cos(+): 0.806  cos(-): -0.005
cos(+): 0.947  cos(-): 0.580
