In [1]:
from datasets import load_dataset

# TODO also check setings 'pair-class', 'pair-score','triplet', 'pair'
dataset = load_dataset("sentence-transformers/all-nli",'triplet')
print(dataset)


  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 557850
    })
    dev: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 6584
    })
    test: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 6609
    })
})


In [2]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer

class BertSentenceEmbedder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = outputs.last_hidden_state  # (batch, seq_len, hidden)
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())


        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        token_counts = input_mask_expanded.sum(1)
        # avoid division by zero, if empty sentence
        token_counts = torch.clamp(token_counts, min=1e-9)
        return sum_embeddings / token_counts # compute mean



In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from torch.nn import TripletMarginLoss
loss_fn = TripletMarginLoss(margin=1.0, p=2)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def collate_fn(batch):
    batch_size = len(batch)
    anchors = [b["anchor"] for b in batch]
    positives = [b["positive"] for b in batch]
    negatives = [b["negative"] for b in batch]

    # Combine into one list
    all_sentences = anchors + positives + negatives
    # Tokenize in one pass
    enc_all = tokenizer(all_sentences, padding=True, truncation=True, return_tensors="pt")
    # Then split back again
    enc_anchor = {k: v[:batch_size] for k, v in enc_all.items()}
    enc_pos = {k: v[batch_size:2*batch_size] for k, v in enc_all.items()}
    enc_neg = {k: v[2*batch_size:] for k, v in enc_all.items()}


    # each is dict of of 'input_ids', 'token_type_ids', 'attention_mask' which are tensors
    return enc_anchor, enc_pos, enc_neg

batch_size=32

train_loader = DataLoader(dataset["train"], batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

epochs = 10

model = BertSentenceEmbedder().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

model.train()
for epoch in range(epochs):
    for enc_anchor, enc_pos, enc_neg in train_loader:
        # Extract input_ids and attention_mask for each
        input_ids_a, attn_a = enc_anchor["input_ids"], enc_anchor["attention_mask"]
        input_ids_p, attn_p = enc_pos["input_ids"], enc_pos["attention_mask"]
        input_ids_n, attn_n = enc_neg["input_ids"], enc_neg["attention_mask"]

        # Concatenate input_ids and attention_mask for single forward pass
        all_input_ids = torch.cat([input_ids_a, input_ids_p, input_ids_n], dim=0).to(device)
        all_attention_mask = torch.cat([attn_a, attn_p, attn_n], dim=0).to(device)

        # Forward pass
        all_embeddings = model(all_input_ids, all_attention_mask)
        # Split embeddings back into anchor, positive, negative
        emb_a = all_embeddings[:batch_size]
        emb_p = all_embeddings[batch_size:2*batch_size]
        emb_n = all_embeddings[2*batch_size:]

        # Compute triplet loss
        loss = loss_fn(emb_a, emb_p, emb_n)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}: Loss = {loss.item():.4f}")


  return forward_call(*args, **kwargs)


tensor(0.4623, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4518, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.9787, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.9346, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5480, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5303, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.8784, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.6794, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.9181, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.8520, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4811, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5418, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5589, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.6934, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.7532, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.6858, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.7359, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5390, device='cuda:0',

OutOfMemoryError: CUDA out of memory. Tried to allocate 44.00 MiB. GPU 0 has a total capacty of 6.00 GiB of which 0 bytes is free. Of the allocated memory 10.33 GiB is allocated by PyTorch, and 1.38 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
torch.save(model.state_dict(), "bert_triplet_state_dict.pt")