# Two Towers Model finetuning 

- Finetuning
- based on word2vec embeddings from genism



In [1]:
# Import Libraries
import os
import sys

import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import faiss
import numpy as np

import wandb
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))




In [2]:
from utils.load_data import load_word2vec


[nltk_data] Downloading package stopwords to /home/g_byte/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/g_byte/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:

from utils.preprocess_str import str_to_tokens, preprocess_list

In [4]:
from models.core import DocumentDataset, TwoTowerModel, loss_fn

In [5]:
from utils.checkpoint import save_checkpoint

In [6]:
RANDOM_SEED = 42
FREEZE_EMBEDDINGS = True
VERBOSE = True
HIDDEN_DIM = 128
NUM_LAYERS = 1
MARGIN = 0.5
LEARNING_RATE = 0.000001
NUM_EPOCHS = 3
MODEL_NAME = "mlx-w2-two-tower-search-finetuned-duh"
PROJECTION_DIM = 64

In [7]:
df_validation = pd.read_parquet('./data/validation.parquet')

In [8]:
# Load embeddings
vocab,embeddings, word_to_idx = load_word2vec()
embedding_layer = nn.Embedding.from_pretrained(embeddings, freeze=FREEZE_EMBEDDINGS)

EMBEDDING_DIM = embeddings.shape[1]
VOCAB_SIZE = len(vocab)

In [9]:
# Pull data from sql then parse it into a df based on the the format below
# df is in query| doc_relevant | url_relevant | doc_rel_tokens

In [10]:
# Load data to fine tune the model on user logs
#df_user_logs = pd.read_parquet('../data/user_logs.parquet')

df = pd.read_parquet('./df.parquet')
df.head(15)

Unnamed: 0,id,session_id,search_query,related_docs,selected_doc,created_at
0,1,55406478-18e5-458f-992b-4b4cb6eb726e,chicken,"{""5. Bake the fish. Place the baking dish in t...",4,2024-10-31 10:54:34.576127
1,2,55406478-18e5-458f-992b-4b4cb6eb726e,chicken,"{""5. Bake the fish. Place the baking dish in t...",2,2024-10-31 10:58:28.713607
2,3,55406478-18e5-458f-992b-4b4cb6eb726e,chicken,"{""5. Bake the fish. Place the baking dish in t...",4,2024-10-31 10:58:36.771682
3,4,55406478-18e5-458f-992b-4b4cb6eb726e,fish,"{""There are two main types of coffee plant. Th...",1,2024-10-31 10:59:00.611794
4,5,55406478-18e5-458f-992b-4b4cb6eb726e,flower,"{""Most are dioecious, meaning they have male a...",0,2024-10-31 10:59:41.916059
5,6,55406478-18e5-458f-992b-4b4cb6eb726e,flower,"{""Most are dioecious, meaning they have male a...",3,2024-10-31 10:59:49.191131
6,7,55406478-18e5-458f-992b-4b4cb6eb726e,flower,"{""Most are dioecious, meaning they have male a...",4,2024-10-31 10:59:53.885600
7,8,55406478-18e5-458f-992b-4b4cb6eb726e,basketball,"{""Jeffrey Alan Samardzija (/səˈmɑrdʒə/ ; born ...",2,2024-10-31 11:00:30.914590
8,9,55406478-18e5-458f-992b-4b4cb6eb726e,basketball,"{""Jeffrey Alan Samardzija (/səˈmɑrdʒə/ ; born ...",3,2024-10-31 11:00:37.085660
9,10,55406478-18e5-458f-992b-4b4cb6eb726e,tea,"{""One mineral not added back into white rice i...",1,2024-10-31 11:01:05.583204


In [11]:
df = df[['search_query', 'related_docs']].copy()

df.rename(columns={'search_query': 'query', 'related_docs': 'doc_relevant'}, inplace=True)
df['doc_irrelevant'] = ""
df.head()

Unnamed: 0,query,doc_relevant,doc_irrelevant
0,chicken,"{""5. Bake the fish. Place the baking dish in t...",
1,chicken,"{""5. Bake the fish. Place the baking dish in t...",
2,chicken,"{""5. Bake the fish. Place the baking dish in t...",
3,fish,"{""There are two main types of coffee plant. Th...",
4,flower,"{""Most are dioecious, meaning they have male a...",


In [12]:
def tokenize(df, word_to_idx):
    # Tokenize
    df.loc[:, 'doc_rel_tokens'] = df['doc_relevant'].apply(lambda x: str_to_tokens(x, word_to_idx))
    df.loc[:, 'doc_irr_tokens'] = df['doc_irrelevant'].apply(lambda x: str_to_tokens(x, word_to_idx))
    df.loc[:, 'query_tokens'] = df['query'].apply(lambda x: str_to_tokens(x, word_to_idx))
    return df

tokenize(df,word_to_idx)

Unnamed: 0,query,doc_relevant,doc_irrelevant,doc_rel_tokens,doc_irr_tokens,query_tokens
0,chicken,"{""5. Bake the fish. Place the baking dish in t...",,"[1133, 7436, 1130, 108, 7436, 2906, 10331, 182...","[1133, 1134]","[1133, 4501, 1134]"
1,chicken,"{""5. Bake the fish. Place the baking dish in t...",,"[1133, 7436, 1130, 108, 7436, 2906, 10331, 182...","[1133, 1134]","[1133, 4501, 1134]"
2,chicken,"{""5. Bake the fish. Place the baking dish in t...",,"[1133, 7436, 1130, 108, 7436, 2906, 10331, 182...","[1133, 1134]","[1133, 4501, 1134]"
3,fish,"{""There are two main types of coffee plant. Th...",,"[1133, 3, 197, 204, 3040, 706, 40061, 33872, 1...","[1133, 1134]","[1133, 1130, 1134]"
4,flower,"{""Most are dioecious, meaning they have male a...",,"[1133, 89, 527, 538, 2264, 59, 706, 1184, 313,...","[1133, 1134]","[1133, 2264, 1134]"
5,flower,"{""Most are dioecious, meaning they have male a...",,"[1133, 89, 527, 538, 2264, 59, 706, 1184, 313,...","[1133, 1134]","[1133, 2264, 1134]"
6,flower,"{""Most are dioecious, meaning they have male a...",,"[1133, 89, 527, 538, 2264, 59, 706, 1184, 313,...","[1133, 1134]","[1133, 2264, 1134]"
7,basketball,"{""Jeffrey Alan Samardzija (/səˈmɑrdʒə/ ; born ...",,"[1133, 6078, 2425, 265, 481, 2283, 6266, 19, 9...","[1133, 1134]","[1133, 2687, 1134]"
8,basketball,"{""Jeffrey Alan Samardzija (/səˈmɑrdʒə/ ; born ...",,"[1133, 6078, 2425, 265, 481, 2283, 6266, 19, 9...","[1133, 1134]","[1133, 2687, 1134]"
9,tea,"{""One mineral not added back into white rice i...",,"[1133, 0, 1854, 343, 235, 420, 2658, 6511, 0, ...","[1133, 1134]","[1133, 3507, 1134]"


In [13]:
df_full = df.copy()
df_full.head(20)

Unnamed: 0,query,doc_relevant,doc_irrelevant,doc_rel_tokens,doc_irr_tokens,query_tokens
0,chicken,"{""5. Bake the fish. Place the baking dish in t...",,"[1133, 7436, 1130, 108, 7436, 2906, 10331, 182...","[1133, 1134]","[1133, 4501, 1134]"
1,chicken,"{""5. Bake the fish. Place the baking dish in t...",,"[1133, 7436, 1130, 108, 7436, 2906, 10331, 182...","[1133, 1134]","[1133, 4501, 1134]"
2,chicken,"{""5. Bake the fish. Place the baking dish in t...",,"[1133, 7436, 1130, 108, 7436, 2906, 10331, 182...","[1133, 1134]","[1133, 4501, 1134]"
3,fish,"{""There are two main types of coffee plant. Th...",,"[1133, 3, 197, 204, 3040, 706, 40061, 33872, 1...","[1133, 1134]","[1133, 1130, 1134]"
4,flower,"{""Most are dioecious, meaning they have male a...",,"[1133, 89, 527, 538, 2264, 59, 706, 1184, 313,...","[1133, 1134]","[1133, 2264, 1134]"
5,flower,"{""Most are dioecious, meaning they have male a...",,"[1133, 89, 527, 538, 2264, 59, 706, 1184, 313,...","[1133, 1134]","[1133, 2264, 1134]"
6,flower,"{""Most are dioecious, meaning they have male a...",,"[1133, 89, 527, 538, 2264, 59, 706, 1184, 313,...","[1133, 1134]","[1133, 2264, 1134]"
7,basketball,"{""Jeffrey Alan Samardzija (/səˈmɑrdʒə/ ; born ...",,"[1133, 6078, 2425, 265, 481, 2283, 6266, 19, 9...","[1133, 1134]","[1133, 2687, 1134]"
8,basketball,"{""Jeffrey Alan Samardzija (/səˈmɑrdʒə/ ; born ...",,"[1133, 6078, 2425, 265, 481, 2283, 6266, 19, 9...","[1133, 1134]","[1133, 2687, 1134]"
9,tea,"{""One mineral not added back into white rice i...",,"[1133, 0, 1854, 343, 235, 420, 2658, 6511, 0, ...","[1133, 1134]","[1133, 3507, 1134]"


In [14]:
df = df.sample(n=10, random_state=RANDOM_SEED)

In [15]:
print(df_full.columns)


Index(['query', 'doc_relevant', 'doc_irrelevant', 'doc_rel_tokens',
       'doc_irr_tokens', 'query_tokens'],
      dtype='object')


In [16]:
dataset = DocumentDataset(df_full)

In [None]:
import importlib
import models.core

importlib.reload(models.core)
import utils.collate

importlib.reload(utils.collate)
from utils.collate import collate



dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate)

In [18]:

# Create model
model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)



In [19]:
projection_dim_sweep = [24, 48, 96, 192]
margin_sweep = [0.1, 0.4, 0.7, 1.0]
lr_sweep = [LEARNING_RATE * i for i in [0.0001, 0.001, 0.01, 0.1, 1]]

In [20]:
for projection_dim in projection_dim_sweep:
    run_name = f"avg_pooling_projection_dim_{projection_dim}_commit_3799989"
    model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, 
                          projection_dim=projection_dim, 
                          embedding_layer=embedding_layer, 
                          margin=MARGIN)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    wandb.init(project=MODEL_NAME, name=run_name)
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch + 1} of {NUM_EPOCHS}")
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx % 1000 == 0:
                print(f"Batch {batch_idx + 1} of {len(dataloader)}")
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch

            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)

            loss = loss_fn(similarity_rel, similarity_irr, MARGIN)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
    wandb.finish()


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgurleenvasir0[0m ([33mgurleenvasir0-me[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112635611142549, max=1.0…

Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▅▁

0,1
loss,0.53134


Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▄▁

0,1
loss,0.4844


Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▅▁

0,1
loss,0.49659


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115339355536788, max=1.0…

Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▅▁

0,1
loss,0.38821


In [21]:
for margin in margin_sweep:
    run_name = f"avg_pooling_margin_{margin}_commit_3799989"
    model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=margin)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    wandb.init(project=MODEL_NAME, name=run_name)
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch + 1} of {NUM_EPOCHS}")
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx % 1000 == 0:
                print(f"Batch {batch_idx + 1} of {len(dataloader)}")
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch

            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)

            loss = loss_fn(similarity_rel, similarity_irr, MARGIN)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
        save_checkpoint(model, epoch, MODEL_NAME)
        
    wandb.finish()


Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▅▁

0,1
loss,0.49616


Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▄▁

0,1
loss,0.58037


Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▄▁

0,1
loss,0.48992


Epoch 1 of 3
Batch 1 of 1
Epoch 2 of 3
Batch 1 of 1
Epoch 3 of 3
Batch 1 of 1


0,1
loss,█▅▁

0,1
loss,0.53388


In [22]:
for lr in lr_sweep:
    run_name = f"avg_pooling_learning_rate_{lr}_commit_3799989"
    model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    wandb.init(project=MODEL_NAME, name=run_name)
    for epoch in range(NUM_EPOCHS):
        print('Learning Rate:', lr)
        print(f"Epoch {epoch + 1} of {NUM_EPOCHS}")
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx % 5000 == 0:
                print(f"E{epoch + 1}: Batch {batch_idx + 1} of {len(dataloader)}")
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch

            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)

            loss = loss_fn(similarity_rel, similarity_irr, MARGIN)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
        save_checkpoint(model, epoch, run_name)

    wandb.finish()
    # torch.save(model, 'fine_tuned_model.pth')



Learning Rate: 1e-10
Epoch 1 of 3
E1: Batch 1 of 1
Learning Rate: 1e-10
Epoch 2 of 3
E2: Batch 1 of 1
Learning Rate: 1e-10
Epoch 3 of 3
E3: Batch 1 of 1


0,1
loss,▁▁▁

0,1
loss,0.47114


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113144077777784, max=1.0…

Learning Rate: 1e-09
Epoch 1 of 3
E1: Batch 1 of 1
Learning Rate: 1e-09
Epoch 2 of 3
E2: Batch 1 of 1
Learning Rate: 1e-09
Epoch 3 of 3
E3: Batch 1 of 1


0,1
loss,█▃▁

0,1
loss,0.4461


Learning Rate: 1e-08
Epoch 1 of 3
E1: Batch 1 of 1
Learning Rate: 1e-08
Epoch 2 of 3
E2: Batch 1 of 1
Learning Rate: 1e-08
Epoch 3 of 3
E3: Batch 1 of 1


0,1
loss,█▄▁

0,1
loss,0.47131


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112932699971781, max=1.0…

Learning Rate: 1e-07
Epoch 1 of 3
E1: Batch 1 of 1
Learning Rate: 1e-07
Epoch 2 of 3
E2: Batch 1 of 1
Learning Rate: 1e-07
Epoch 3 of 3
E3: Batch 1 of 1


0,1
loss,█▅▁

0,1
loss,0.50482


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112856422227601, max=1.0…

Learning Rate: 1e-06
Epoch 1 of 3
E1: Batch 1 of 1
Learning Rate: 1e-06
Epoch 2 of 3
E2: Batch 1 of 1
Learning Rate: 1e-06
Epoch 3 of 3
E3: Batch 1 of 1


0,1
loss,█▅▁

0,1
loss,0.43358


In [23]:
model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)

In [24]:
# This needs to be updated
# model.load_state_dict(torch.load(f'fine_tuned_model'))

In [25]:
def validate(model, dataloader, margin):
    model.eval()
    total_loss = 0
    total_batches = 0
    
    with torch.no_grad():
        for batch in dataloader:
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch
            
            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)
            
            loss = loss_fn(similarity_rel, similarity_irr, margin)
            total_loss += loss.item()
            total_batches += 1
    
    avg_loss = total_loss / total_batches
    return avg_loss

In [26]:
df_validation_sample = tokenize(df_validation.sample(n=1000, random_state=RANDOM_SEED), word_to_idx).reset_index(drop=True)

# Create validation dataset and dataloader
validation_dataset = DocumentDataset(df_validation_sample)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False, collate_fn=collate)


In [27]:
df_validation_sample = tokenize(df_validation.sample(n=1000, random_state=RANDOM_SEED), word_to_idx).reset_index(drop=True)
# Create validation dataset and dataloader
validation_dataset = DocumentDataset(df_validation_sample)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False, collate_fn=collate)


In [28]:
validate(model, validation_dataloader, MARGIN)

0.5055236024782062

In [29]:
rel_doc = df_validation_sample.loc[0, 'doc_relevant']
irr_doc = df_validation_sample.loc[0, 'doc_irrelevant']
query = df_validation_sample.loc[0, 'query']

rel_doc_tokens = torch.tensor(df_validation_sample.loc[0, 'doc_rel_tokens'])
irr_doc_tokens = torch.tensor(df_validation_sample.loc[0, 'doc_irr_tokens'])
query_tokens = torch.tensor(df_validation_sample.loc[0, 'query_tokens'])

print(rel_doc_tokens.shape)
print(irr_doc_tokens.shape)
print(query_tokens.shape)


torch.Size([39])
torch.Size([52])
torch.Size([4])


In [30]:
model.eval()
with torch.no_grad():
    similarity_rel = model(rel_doc_tokens.unsqueeze(0), query_tokens.unsqueeze(0))
    similarity_irr = model(irr_doc_tokens.unsqueeze(0), query_tokens.unsqueeze(0))

similarity_rel, similarity_irr

(tensor([-0.0838]), tensor([0.0621]))

In [31]:
query = "What are the effects of climate change?"
documents = [
    "Climate change is causing rising sea levels and more frequent extreme weather events.",
    "The Earth orbits around the Sun in an elliptical path.",
    "Global warming is leading to the melting of polar ice caps and glaciers.",
    "Photosynthesis is the process by which plants convert sunlight into energy.",
    "Increased greenhouse gas emissions are a major contributor to global climate change.",
    "The recipe for a classic Margherita pizza includes fresh mozzarella, tomatoes, and basil.",
    "The history of the Roman Empire is marked by significant military conquests and cultural achievements.",
    "Quantum mechanics explores the behavior of particles at the atomic and subatomic levels.",
    "The rules of chess involve strategic movement of pieces like the knight, bishop, and rook.",
    "The process of photosynthesis in plants involves converting carbon dioxide and water into glucose and oxygen using sunlight."
]


In [32]:
model.eval()
with torch.no_grad():
    # Tokenize and prepare the query
    query_tokens = torch.tensor([str_to_tokens(query, word_to_idx)])
    query_mask = (query_tokens != 0).float()

    # Tokenize and prepare the documents
    doc_tokens = [torch.tensor([str_to_tokens(doc, word_to_idx)]) for doc in documents]
    doc_masks = [(doc != 0).float() for doc in doc_tokens]

    # Calculate similarities
    similarities = []
    for doc, mask in zip(doc_tokens, doc_masks):
        similarity = model(doc, query_tokens, doc_mask=mask, query_mask=query_mask)
        similarities.append(similarity.item())

    # Sort documents by similarity
    ranked_docs = sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True)



In [33]:
df_ranked_docs = pd.DataFrame(ranked_docs, columns=['Document', 'Similarity'])
df_ranked_docs['Query'] = query
df_ranked_docs = df_ranked_docs[['Query', 'Document', 'Similarity']]
pd.set_option('display.max_colwidth', None)

styled_df = df_ranked_docs.style.set_table_styles(
    {
        'Query': [{'selector': '', 'props': [('width', '150px')]}],
        'Document': [{'selector': '', 'props': [('width', '600px')]}]
    }
)

styled_df


Unnamed: 0,Query,Document,Similarity
0,What are the effects of climate change?,Global warming is leading to the melting of polar ice caps and glaciers.,0.089823
1,What are the effects of climate change?,Climate change is causing rising sea levels and more frequent extreme weather events.,0.013515
2,What are the effects of climate change?,Increased greenhouse gas emissions are a major contributor to global climate change.,-0.019892
3,What are the effects of climate change?,"The recipe for a classic Margherita pizza includes fresh mozzarella, tomatoes, and basil.",-0.029182
4,What are the effects of climate change?,The process of photosynthesis in plants involves converting carbon dioxide and water into glucose and oxygen using sunlight.,-0.043987
5,What are the effects of climate change?,Photosynthesis is the process by which plants convert sunlight into energy.,-0.078965
6,What are the effects of climate change?,Quantum mechanics explores the behavior of particles at the atomic and subatomic levels.,-0.089775
7,What are the effects of climate change?,The Earth orbits around the Sun in an elliptical path.,-0.128343
8,What are the effects of climate change?,"The rules of chess involve strategic movement of pieces like the knight, bishop, and rook.",-0.170925
9,What are the effects of climate change?,The history of the Roman Empire is marked by significant military conquests and cultural achievements.,-0.239401


In [34]:
df = pd.read_parquet('./data/training-with-tokens.parquet')


In [None]:
df = df[['query', 'doc_relevant', 'url_relevant']]


Unnamed: 0,query,doc_relevant,url_relevant,doc_rel_tokens
0,what is rba,"Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.",https://en.wikipedia.org/wiki/Reserve_Bank_of_Australia,"[1133, 69, 64417, 4756, 1988, 1082, 600, 3313, 64417, 5589, 373, 10630, 3035, 186, 633, 556, 623, 8483, 1266, 3704, 646, 18, 1146, 463, 1120, 1353, 633, 901, 1311, 2875, 1223, 1003, 64417, 2655, 27, 2373, 3625, 17, 149, 1898, 593, 19213, 324, 1134]"
1,what is rba,"The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.",https://en.wikipedia.org/wiki/Reserve_Bank_of_Australia,"[1133, 1353, 646, 633, 64417, 468, 481, 633, 266, 646, 7616, 332, 171, 1353, 646, 220, 731, 266, 646, 194, 1953, 646, 3704, 646, 18, 1146, 463, 1120, 1353, 633, 901, 1311, 2875, 1223, 1003, 64417, 2655, 27, 2373, 3625, 17, 149, 1898, 593, 19213, 324, 1134]"


In [None]:
doc_dataset = models.core.DocDataset(df, word_to_idx)
doc_dataset = models.core.DocumentDataset(df)
doc_dataloader = DataLoader(doc_dataset, batch_size=32, shuffle=False, collate_fn=collate_docdataset)
    # Shuffle MUST be set to false to preserve the order of the documents
for tokens, mask, indices in doc_dataloader:
    print(tokens.shape)
    print(mask.shape)
    print(indices.shape)
    break


KeyError: 'doc_irrelevant'

In [None]:
model.eval()

doc_projections = []

with torch.no_grad():
    for batch_tokens, batch_mask, batch_indices in tqdm(doc_dataloader):

        doc_encodings = model.doc_encode(batch_tokens, batch_mask)
        batch_projections = model.doc_project(doc_encodings)

        doc_projections.append(batch_projections)
        doc_indices.append(batch_indices)


In [None]:
doc_projections = torch.cat(doc_projections, dim=0)


In [None]:

doc_projection_dim = doc_projections.shape[1] # same as PROJECTION_DIM
num_docs = doc_projections.shape[0] # same as len(df), len(doc_indices)
doc_embedding_matrix = nn.Embedding.from_pretrained(doc_projections, freeze=True)


In [None]:
torch.save(doc_embedding_matrix.weight.data, '../data/doc-embedding-matrix-64.pth')