# STEP 1: PREPARE THE DATASET

* get the dataset from MS Marco 
* extract queries and documents 
* generate triplets of queries, relevant(positive) documents, and irrelevant (negative) documents
* tokenise your generated data

In [1]:
from datasets import load_dataset

df_hn = load_dataset("cocoritzy/week_2_triplet_dataset_hard_negatives")
df_sn = load_dataset("cocoritzy/week_2_triplet_dataset_soft_negatives")
# dataset = load_dataset("cocoritzy/week_2_triplet_dataset_hard_negatives", split="train[:10%]") # 10% of the datab
df_hn = df_hn["train"].to_pandas()
df_sn = df_sn["train"].to_pandas()

In [2]:
df_hn.head()

Unnamed: 0,query_id,query,positive_passage,negative_passage,negative_index_in_group
0,19699,what is rba,Results-Based Accountability® (also known as R...,vs. NetIQ Identity Manager. Risk-based authent...,8
1,19700,was ronald reagan a democrat,"From Wikipedia, the free encyclopedia. A Reaga...","1984 Re-Election. In November 1984, Ronald Rea...",7
2,19701,how long do you need for sydney and surroundin...,Sydney is the capital city of the Australian s...,"The Sydney central business district, Sydney h...",3
3,19702,price to install tile in shower,1 Install ceramic tile floor to match shower-A...,The national average for a new shower installa...,8
4,19703,why conversion observed in body,Conversion disorder is a type of somatoform di...,"Conclusions: In adult body CT, dose to an orga...",1


In [3]:
# Get max length of all queries and passages because you neeed embedding dimensions to be the same

def get_max_length(df, column_name):
    lengths = []
    for text in df[column_name]:
        lengths.append(len(text))

    print(f"max length in {column_name}:", max(lengths))
    return 


In [4]:
# Get max length in hn dataset
hn_negatives_length = get_max_length(df_hn, "negative_passage")
hn_positives_length = get_max_length(df_hn, "positive_passage")
hn_query_length = get_max_length(df_hn, "query")

max length in negative_passage: 1039
max length in positive_passage: 1167
max length in query: 144


In [5]:
# Get max length in sn dataset
sn_negatives_length = get_max_length(df_sn, "negative_passage")
sn_positives_length = get_max_length(df_sn, "positive_passage")
sn_query_length = get_max_length(df_sn, "query")


max length in negative_passage: 1128
max length in positive_passage: 1167
max length in query: 144


This is the dataset you'll be using for training

In [6]:
# Construct triplet datasets
triplet_hn = df_hn[["query", "positive_passage", "negative_passage"]]
triplet_sn = df_sn[["query", "positive_passage", "negative_passage"]]
triplet_hn.head()

Unnamed: 0,query,positive_passage,negative_passage
0,what is rba,Results-Based Accountability® (also known as R...,vs. NetIQ Identity Manager. Risk-based authent...
1,was ronald reagan a democrat,"From Wikipedia, the free encyclopedia. A Reaga...","1984 Re-Election. In November 1984, Ronald Rea..."
2,how long do you need for sydney and surroundin...,Sydney is the capital city of the Australian s...,"The Sydney central business district, Sydney h..."
3,price to install tile in shower,1 Install ceramic tile floor to match shower-A...,The national average for a new shower installa...
4,why conversion observed in body,Conversion disorder is a type of somatoform di...,"Conclusions: In adult body CT, dose to an orga..."


In [7]:
# Combine all columns into a single string
marco_text = ' '.join(triplet_hn["query"] + " " + triplet_hn["positive_passage"] + " " + triplet_hn["negative_passage"])


In [8]:
# Tokenize the data 
import re
from collections import Counter

def tokenize_text(text):
    
    
    # remove punctuation, number, and non-alphabetic characters
    remove_punctuation = re.sub(r'[^\w\s]', '', text)
    remove_numbers = re.sub(r'\d+', '', remove_punctuation)
    

    lower_case_words = remove_numbers.lower()
    # Split by whitespace and filter out empty strings
    words = [word for word in lower_case_words.split() if word]

    top_k = 30000
    word_counts = Counter(words)
    vocab = dict(word_counts.most_common(top_k))
    word_to_id = {word: i for i, word in enumerate(vocab.keys())}
    id_to_word = {i: word for i, word in enumerate(vocab.keys())}

    # Debugging: Print the unique vocabulary
    
    print("Vocabulary size:", len(vocab))

    return vocab, word_to_id, id_to_word



In [9]:
vocab, word_to_id, id_to_word = tokenize_text(marco_text)
print(word_to_id)

Vocabulary size: 30000


# Step 2: Load in a pretrained word2vec
* Coline's word2vec model embeddings to embed the query and document text

In [10]:
# First make sure the GPU is being used 

import torch 
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from model import CBOW

if torch.cuda.is_available():
   print(f"GPU: {torch.cuda.get_device_name(0)}")
   print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
   # Enable cuDNN auto-tuner
   torch.backends.cudnn.benchmark = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Print that the GPU is being used
print(f"Using device: {device}")

GPU: NVIDIA RTX A4000
GPU Memory: 16.88 GB
Using device: cuda


In [11]:
# model with text8 only
model_path = hf_hub_download(repo_id="cocoritzy/cbow-upvotes_model", filename="cbow_model.pt")


In [12]:
# Retrieve checkpoint 
checkpoint = torch.load(model_path, map_location=device) #A checkpoint is a file that saves the state of your model
token_to_index = checkpoint["token_to_index"]
embedding_dim= checkpoint["embedding_dim"]
vocab_size = len(token_to_index)  # fill in actual size

print(f"Vocabulary size: {vocab_size}")
print(f"Embedding dimension: {embedding_dim}")



Vocabulary size: 30000
Embedding dimension: 100


In [13]:
# Initialize the model
model = CBOW(vocab_size, embedding_dim).to(device)

# Load the model parameters from the checkpoint
model.load_state_dict(checkpoint["model_state_dict"])
model.eval() # contains the trained weights of the model 



CBOW(
  (embeddings): Embedding(30000, 100)
  (linear): Linear(in_features=100, out_features=30000, bias=True)
)

In [14]:
pretrained_weights = model.embeddings.weight.data.clone() # Get the actual matrix of the pretrained vectors (the weights)

# Quick sanity check
print(type(model.embeddings))  # nn.Embedding -> shows you the whole embedding layer
print(model.embeddings.weight.shape)  # torch.Size([vocab_size, embedding_dim])



<class 'torch.nn.modules.sparse.Embedding'>
torch.Size([30000, 100])


# STEP 4: CREATE NEW EMBEDDINGS FOR THIS DATASET
Now we should align the pretrained embeddings with my new word_to_id. We will do this by: 
* create an empty embedding
* then filling the vectors from CBOW
* if the word is not found in CBOW it gives a random vector

In [15]:
import torch

new_vocab_size = len(word_to_id)  # Your 30k vocab size
embedding_dim = pretrained_weights.shape[1]

# Fill with zeros
embedding_matrix = torch.zeros((new_vocab_size, embedding_dim))


In [16]:
for word, new_idx in word_to_id.items():
    if word in token_to_index:
        old_idx = token_to_index[word]
        embedding_matrix[new_idx] = pretrained_weights[old_idx]
    else:
        # Word not found in CBOW — random vector
        embedding_matrix[new_idx] = torch.randn(embedding_dim)


sanity checker ⬇️

In [17]:
# Check the embedding matrix
print(embedding_matrix.shape)
print(embedding_matrix[0])

# Look at a random word in the embedding matrix
print(id_to_word[100])
print(embedding_matrix[100])

torch.Size([30000, 100])
tensor([-6.4979e-01, -5.7941e-01,  6.6106e-01,  1.3047e+00, -1.3084e+00,
         6.4363e-01,  1.4673e+00, -2.7732e-01,  9.4580e-01, -4.2063e-01,
        -1.8621e+00, -3.1817e-01, -5.4162e-01, -9.9318e-01, -1.0937e+00,
         1.3680e+00, -7.8084e-01,  2.3505e-02,  2.1359e-01, -3.9024e-01,
        -9.9998e-01, -8.7959e-01, -6.7738e-01, -5.4328e-01,  3.5241e-01,
         1.0581e+00,  1.0696e-01,  2.9949e-01, -8.2086e-01,  1.4881e-01,
         3.9067e-01,  1.0017e+00,  2.2398e-01,  3.5957e-01, -4.5835e-01,
         2.6158e-02, -1.4169e-01, -8.9599e-02, -1.6806e-01,  1.2347e+00,
        -1.7992e-01, -1.3034e-01,  8.3350e-01,  5.1791e-03,  1.2716e+00,
         7.7249e-02, -7.4727e-01, -2.3041e-01, -5.2555e-01,  1.0789e+00,
         1.9860e-01,  1.4728e+00,  7.6100e-01,  1.5356e+00,  6.5659e-01,
         1.0465e+00, -2.9056e-01,  2.9421e+00,  5.7205e-01,  1.7678e-01,
        -8.0052e-01,  5.4538e-01,  2.6560e-01, -6.1731e+00,  8.3839e-01,
        -9.4757e-02,  3.80

In [18]:
# Basic similarity check
def cosine_similarity(vec1, vec2):
    return torch.dot(vec1, vec2) / (torch.norm(vec1) * torch.norm(vec2))

word1, word2 = 'takeoff', 'airplane'  # Replace with actual words
if word1 in word_to_id and word2 in word_to_id:
    idx1, idx2 = word_to_id[word1], word_to_id[word2]
    similarity = cosine_similarity(embedding_matrix[idx1], embedding_matrix[idx2])
    print(f"Cosine similarity between {word1} and {word2}: {similarity}")

Cosine similarity between takeoff and airplane: 0.18580347299575806


Looks sensible enough to proceed with creating the embedding layer

In [19]:
embedding_layer = nn.Embedding(30000, 100)
embedding_layer.weight = nn.Parameter(embedding_matrix)
embedding_layer.weight.requires_grad = False

# Test the embedding layer
test_word = 'takeoff'
test_idx = word_to_id[test_word]
print(embedding_layer(torch.tensor([test_idx])).squeeze().tolist())

[2.5662574768066406, -5.550731658935547, -0.002386652398854494, -0.35136479139328003, -0.7953556180000305, -0.5258115530014038, 0.34085527062416077, 0.9197907447814941, -0.07610265165567398, 2.126612424850464, -0.8872002959251404, 1.048814058303833, -1.6404482126235962, -2.732374668121338, -3.480424642562866, 0.03068462386727333, -1.1853442192077637, 3.805248737335205, 5.398148536682129, 0.9104251861572266, 1.8317471742630005, -2.385714054107666, 0.7987986207008362, 1.7342525720596313, -2.6690051555633545, -0.5912424325942993, -2.4031965732574463, -0.7736563682556152, 0.5487457513809204, -0.7048475742340088, -1.7247153520584106, -2.624924898147583, 2.9168269634246826, -1.997927188873291, -2.707930088043213, 0.9594621658325195, -5.314846515655518, -1.594692349433899, 1.420822262763977, -1.721580147743225, 2.1901707649230957, 1.4288684129714966, 0.49167197942733765, 0.3397097885608673, -1.3827555179595947, 1.4688661098480225, 1.6840145587921143, 1.961722493171692, 2.1646041870117188, -2.

# STEP 5: USE EMBEDDINGS IN THE TWO TOWERS AND SET UP ALL THE CLASSES AND FUNCTIONS
* remember you have two datasets: triplet_hn (hard negatives) and triple_sn (soft negatives)
* Use the embedding layer to embed the queries and documents
* for this, you are using __average pooling__ (not a RNN)


Define the model architecture. Both towers should return a pooled embedding for the query and document

In [20]:
# First define the model architecture aka the Two Towers with average pooling
class QueryTower(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = embedding_layer

    def forward(self, query_tokens):
        # Embed the query tokens
        embedded = self.embedding(query_tokens)
        # Average the embeddings (for average pooling)
        avg_query_embedding = torch.mean(embedded, dim=1)

        return avg_query_embedding
    
class DocumentTower(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = embedding_layer
        
    def forward(self, doc_tokens):
        # Embed the document tokens
        embedded = self.embedding(doc_tokens)
        # Average the embeddings (for average pooling)
        avg_doc_embedding = torch.mean(embedded, dim=1)

        return avg_doc_embedding



Create a Dataset class to make sure the dataset is loaded in properly

In [22]:
from torch.utils.data import Dataset, DataLoader

# Create a dataset class for the triplet
class TripletDataset(Dataset):
    def __init__(self, all_queries, all_positive_passages, all_negative_passages, word_to_id):
        """
        Args:
            queries (list of str): List of query texts.
            positive_passages (list of str): List of positive passage texts.
            negative_passages (list of str): List of negative passage texts.
            word_to_id (dict): Dictionary mapping words to their indices.
        """
        self.queries = all_queries
        self.positive_passages = all_positive_passages
        self.negative_passages = all_negative_passages
        self.word_to_id = word_to_id

    def __len__(self):
        # return the total number of samples
        return len(self.queries)
    
    def tokenize_text(self, text):
        # Remove punctuation, numbers, and convert to lowercase
        text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
        text = re.sub(r'\d+', '', text)      # Remove numbers
        text = text.lower()                  # Convert to lowercase
        tokens = [word for word in text.split() if word]  # Split and remove empty tokens
        return tokens
    
    def __getitem__(self, idx):
        # Get the query, positive passage, and negative passage for the given index
        query = self.queries[idx]
        positive_passage = self.positive_passages[idx]
        negative_passage = self.negative_passages[idx]

        # Tokenize the query and passages
        query_tokens = self.tokenize_text(query)
        positive_tokens = self.tokenize_text(positive_passage)
        negative_tokens = self.tokenize_text(negative_passage)

        # Conver to word ids
        query_indices = torch.tensor([self.word_to_id[word] for word in query_tokens if word in self.word_to_id])
        positive_indices = torch.tensor([self.word_to_id[word] for word in positive_tokens if word in self.word_to_id])
        negative_indices = torch.tensor([self.word_to_id[word] for word in negative_tokens if word in self.word_to_id])

        return query_indices, positive_indices, negative_indices
    

In [23]:

def pad_to_length(tensor, length, padding_value=0):
    # Calculate the padding size
    pad_size = length - tensor.size(0)
    if pad_size > 0:
        # Pad the tensor to the specified length
        return F.pad(tensor, (0, pad_size), value=padding_value)
    else:
        # If the tensor is already the desired length or longer, truncate it
        return tensor[:length]

def collate_fn(batch):
    # Separate the batch into queries, positive passages, and negative passages
    queries, positives, negatives = zip(*batch)
    
    # Define the fixed lengths (found from earlier in the notebook)
    max_query_length = 144
    max_positive_length = 1167
    max_negative_length = 1039
    
    # Pad sequences to the fixed lengths
    queries_padded = torch.stack([pad_to_length(q, max_query_length) for q in queries])
    positives_padded = torch.stack([pad_to_length(p, max_positive_length) for p in positives])
    negatives_padded = torch.stack([pad_to_length(n, max_negative_length) for n in negatives])
    
    return queries_padded, positives_padded, negatives_padded

Train the architecture on the data using 
* a distance function : compare the query and document encodings, where high similarity indicates a relevant document
* a triplet loss function : train the neural network to make the distance between queries and relevant documents small, and the distance between queries and irrelevant documents large.

In [28]:
# Create a distance function to compare the similarity between query and document embeddings
def compute_similarity(query_embedding, doc_embedding, k=5):

    # Compute the cosine similarity between the query and document embeddings
    similarity = F.cosine_similarity(query_embedding, doc_embedding)
    
    return similarity

In [29]:
# Define the Triplet Loss Function 

class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.loss_fn = nn.TripletMarginLoss(margin=self.margin)

    def forward(self, query_embedding, positive_embedding, negative_embedding):
        return self.loss_fn(query_embedding, positive_embedding, negative_embedding)



CHEEKY TESTING ZONE ⬇️

In [None]:
## cheeky test before we move on 

triplet_dataset = TripletDataset(triplet_hn["query"], triplet_hn["positive_passage"], triplet_hn["negative_passage"], word_to_id)

# Create DataLoader with the collate function
dataset_loader = DataLoader(triplet_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

query_tower = QueryTower(30000, 100)
doc_tower = DocumentTower(30000, 100)

# Run the model
for query_indices, positive_indices, negative_indices in dataset_loader:
    query_embedding = query_tower(query_indices)
    positive_embedding = doc_tower(positive_indices)
    negative_embedding = doc_tower(negative_indices)

    print("Query Embedding:", query_embedding)
    print("Positive Embedding:", positive_embedding)
    print("Negative Embedding:", negative_embedding)

The initial cheeky test showed that my tensors within queries and documents are different sizes. This is a problem because the torch.stack operation expects tensors in the batch to have the same size. So we'll fix it by padding the tensors based on the max lengths of the query and documents found in the earlier cells. 

# STEP 6: SETUP MODEL TRAINING

In [25]:
import datetime
import wandb
wandb.login()

# Initialize settings
torch.manual_seed(42)

timestamp = datetime.datetime.now().strftime('%Y_%m_%d__%H_%M_%S')

[34m[1mwandb[0m: Currently logged in as: [33mevelyntants[0m ([33mevelyntants-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [27]:
# Initialize Weights & Biases
wandb.init(project="two_tower_model",
           name=f"{timestamp}",
           config={
               # Model parameters
               "embedding_dim": 100,
               "vocab_size": 30000,
                
                # Training parameters
                "batch_size": 128,
                "learning_rate": 0.003,
                "num_epochs": 5,
                "train_split": 0.7,
                
                # Optimizer parameters
                "weight_decay": 1e-5,
                
                # DataLoader parameters
                "num_workers": 4,

                # Triplet Loss margin
                "triplet_margin": 1.0
            }
)