In [None]:
import torch
import pandas as pd
import numpy as np
import datasets
from load_models_and_data import load_vocabulary, load_embeddings, text_to_embeddings, calc_cosine_sim, calculate_similarities
from tqdm import tqdm
tqdm.pandas()
from TwoTowerNN import QryTower, DocTower, TripletEmbeddingDataset, run_hyperparameter_tuning
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader,  SubsetRandomSampler
from sklearn.model_selection import KFold, train_test_split
import os
import wandb
from dotenv import load_dotenv

API key loaded successfully


[34m[1mwandb[0m: Currently logged in as: [33mnnamdi-odozi[0m ([33mnnamdi-odozi-ave-actuaries[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from datasets import load_dataset

# Loading datasets from Hugging Face
ds_soft_neg = load_dataset("cocoritzy/week_2_triplet_dataset_soft_negatives")
ds_hard_neg = load_dataset("cocoritzy/week_2_triplet_dataset_hard_negatives")


In [None]:
# Paths to your files
embeddings_path = "./downloaded_model/embeddings.pt"
vocab_path = "./downloaded_model/tkn_ids_to_words.csv"

# Load embeddings and vocabulary
print("Loading embeddings and vocabulary...")
embeddings = load_embeddings(embeddings_path)
word_to_idx = load_vocabulary(vocab_path)

print(f"Loaded embeddings with shape: {embeddings.shape}")
print(f"Loaded vocabulary with {len(word_to_idx)} tokens")

# Example usage (uncomment when ready to test)
sample_text = "This is a test sentence"
embeddings_result = text_to_embeddings(sample_text, word_to_idx, embeddings)
print(f"Embedded text shape: {embeddings_result.shape}")

# Testing - Set numpy print options
np.set_printoptions(precision=4, suppress=True, threshold=10)  # threshold limits number of elements shown
numpy_array = embeddings_result.detach().numpy()
print("Embedding array with custom formatting:")
print(numpy_array)


In [None]:
ds_soft_neg

In [None]:
df_soft_neg  = pd.DataFrame(ds_soft_neg['train'])
df_hard_neg  = pd.DataFrame(ds_hard_neg['train'])

In [None]:
embedded_query = text_to_embeddings(df_soft_neg['query'][0], word_to_idx, embeddings)
embedded_positive = text_to_embeddings(df_soft_neg['positive_passage'][0], word_to_idx, embeddings)
embedded_negative = text_to_embeddings(df_soft_neg['negative_passage'][0], word_to_idx, embeddings)

embedded_query.shape

In [None]:
a = embedded_query.mean(dim=0)
b = embedded_positive.mean(dim=0)
c = embedded_negative.mean(dim=0)
a.shape


In [None]:
import torch.nn.functional as F

cosine_similarity = F.cosine_similarity(a, c, dim=0)
print(f"Cosine similarity between query and positive passage: {cosine_similarity.item()}")

In [None]:

# # Process the dataframe using apply just for first five rows
# print("Calculating similarities... This may take a while depending on dataframe size.")
# similarities = df_soft_neg[0:5].progress_apply(
#     lambda row: calculate_similarities(row, word_to_idx, embeddings), 
#     axis=1
# )

# # Join the similarities to the dataframe
# df_soft_neg_ext = pd.concat([df_soft_neg[0:5], similarities], axis=1)

# # Show a sample of the results
# #print(df_soft_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].head())
#print(df_soft_neg_ext.head())
#print(df_soft_neg_ext.columns)

In [None]:

# Process the dataframe using apply
print("Calculating similarities... This may take a while depending on dataframe size.")
similarities = df_soft_neg.progress_apply(
    lambda row: calculate_similarities(row, word_to_idx, embeddings), 
    axis=1
)

# Join the similarities to the dataframe
df_soft_neg_ext = pd.concat([df_soft_neg, similarities], axis=1)
print(df_soft_neg_ext.head())
# Show a sample of the results
#print(df_soft_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].head())

#print(df_soft_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].mean())

# Calculate how often the positive passage is ranked higher than negative
#higher_count = (df_soft_neg_ext['query_pos_sim'] > df_soft_neg_ext['query_neg_sim']).sum()
#total = len(df_soft_neg_ext)
#print(f"\nPositive passage ranked higher than negative: {higher_count} out of {total} ({higher_count/total:.2%})")



In [None]:
# Process the dataframe using apply
print("Calculating similarities... This may take a while depending on dataframe size.")
similarities = df_hard_neg.progress_apply(
    lambda row: calculate_similarities(row, word_to_idx, embeddings), 
    axis=1
)

# Join the similarities to the dataframe
df_hard_neg_ext = pd.concat([df_hard_neg, similarities], axis=1)
print(df_hard_neg_ext.head())
# Show a sample of the results
#print(df_hard_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].head())

#print(df_hard_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].mean())

# Calculate how often the positive passage is ranked higher than negative
#higher_count = (df_hard_neg_ext['query_pos_sim'] > df_hard_neg_ext['query_neg_sim']).sum()
#total = len(df_hard_neg_ext)
#print(f"\nPositive passage ranked higher than negative: {higher_count} out of {total} ({higher_count/total:.2%})")



In [None]:
df_all_neg_ext = pd.concat([df_soft_neg_ext, df_hard_neg_ext])
df_all_neg_ext.head()

In [None]:
# Save DataFrames to pickle format
df_soft_neg_ext.to_pickle("data/df_soft_neg_ext.pkl")
df_hard_neg_ext.to_pickle("data/df_hard_neg_ext.pkl")
df_all_neg_ext.to_pickle("data/df_all_neg_ext.pkl")

In [4]:
# Function to load a DataFrame from pickle if the file exists
def load_df_if_exists(file_path):
    if os.path.exists(file_path):
        return pd.read_pickle(file_path)
    else:
        print(f"File not found: {file_path}")
        return None

# Load DataFrames
df_soft_neg_ext = load_df_if_exists("data/df_soft_neg_ext.pkl")
df_hard_neg_ext = load_df_if_exists("data/df_hard_neg_ext.pkl")
df_all_neg_ext = load_df_if_exists("data/df_all_neg_ext.pkl")


In [8]:
df_soft_neg_ext.head()

Unnamed: 0,query_id,query,positive_passage,negative_passage,negative_from_query_id,avg_query_embedding,avg_pos_embedding,avg_neg_embedding
0,19699,what is rba,Results-Based Accountability® (also known as R...,I finally found some real salary data for phys...,86595,"[0.6579812, 0.24213153, 0.057250064, -0.825741...","[0.39086032, 0.3319433, 0.1275278, -0.80645, 0...","[0.569893, 0.18935415, 0.1920344, -0.7171183, ..."
1,19700,was ronald reagan a democrat,"From Wikipedia, the free encyclopedia. A Reaga...",The Pacific Ocean lies to the east while the S...,66360,"[-0.6998242, -0.24631366, -0.20571017, 0.24202...","[0.27046937, 0.2619914, 0.049588773, -0.618945...","[0.17404862, 0.21760696, -0.10469024, -0.23737..."
2,19701,how long do you need for sydney and surroundin...,Sydney is the capital city of the Australian s...,"Probiotics are found in foods such as yogurt, ...",88507,"[0.16817716, 0.29739928, -0.36492547, 0.064426...","[0.39110944, 0.23566554, 0.063871, -0.36585316...","[0.61134595, 0.36615297, 0.28972, -0.6924668, ..."
3,19702,price to install tile in shower,1 Install ceramic tile floor to match shower-A...,Iodine is critical to thyroid health and funct...,87550,"[-0.06541735, 0.2755244, 0.19580394, 0.4023429...","[0.69151133, 0.5770993, 0.22074986, -0.7754023...","[0.3590759, -0.036869664, 0.17250647, -0.53339..."
4,19703,why conversion observed in body,Conversion disorder is a type of somatoform di...,The answer to the question how much does it co...,61479,"[-0.13369766, -0.30740747, 0.5450557, 0.391294...","[0.42539537, 0.13814452, 0.37000972, -0.632320...","[0.5729694, 0.314426, 0.13929352, -0.9086552, ..."


### Twin Tower Network

In [None]:
# Create tower instances
qryTower = QryTower()
docTower = DocTower()


# Define hyperparameters
batch_size = 128
num_epochs = 1 # adjust num of epochs here
dataset_size = len(df_soft_neg_ext)  # or len(df_hard_neg_ext) depending on the dataset you want to use
steps_per_epoch = dataset_size // batch_size
total_steps = steps_per_epoch * num_epochs
learning_rate = 1e-3
embedding_dim = 128 
margin = 0.2 

In [None]:
# Create the dataset
dataset = TripletEmbeddingDataset(df_soft_neg_ext)

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,  # Adjust based on your machine's capabilities
    pin_memory=True  # Set to True if using GPU
)

In [None]:
qry = torch.randn(batch_size, embedding_dim)  # Query embeddings
pos = torch.randn(batch_size, embedding_dim)  # Positive doc embeddings
neg = torch.randn(batch_size, embedding_dim)  # Negative doc embeddings

#qry = df1['q']


# Set up the AdamW optimizer
optimizer = torch.optim.AdamW([
    {'params': qryTower.parameters()},
    {'params': docTower.parameters()}
], lr=learning_rate)

# Add learning rate scheduler (ReduceLROnPlateau)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',       # Reduce LR when monitored value stops decreasing
    factor=0.5,       # Multiply LR by this factor when reducing
    patience=2,       # Number of epochs with no improvement after which LR will be reduced
    verbose=True      # Print message when LR is reduced
)



In [None]:
# Training loop (simplified example)
for epoch in range(num_epochs):
    qryTower.train()
    docTower.train()
    
    
    total_loss = 0
    for batch in dataloader:
        # Get embeddings from batch
        query_emb = batch['query']
        pos_emb = batch['positive']
        neg_emb = batch['negative']
        
        # Forward pass through towers
        query_encoded = qryTower(query_emb)
        pos_encoded = docTower(pos_emb)
        neg_encoded = docTower(neg_emb)
        
        # Calculate similarities
        pos_sim = torch.nn.functional.cosine_similarity(query_encoded, pos_encoded)
        neg_sim = torch.nn.functional.cosine_similarity(query_encoded, neg_encoded)
        
        # Triplet loss
        margin = margin
        loss = torch.clamp(margin - pos_sim + neg_sim, min=0).mean()
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * len(query_emb)
    
    # Calculate average loss
    avg_loss = total_loss / len(dataset)
    
    # Update scheduler
    scheduler.step(avg_loss)
    
    # Print epoch results
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, "
          f"LR: {optimizer.param_groups[0]['lr']:.6f}")

In [13]:
# Run the hyperparameter tuning with your dataframe
best_params, final_qry_tower, final_doc_tower = run_hyperparameter_tuning(
    df_all_neg_ext,
    output_dims=[128],
    batch_sizes=[256, 512],
    n_folds=2,
    epochs=20
)

# Print the best parameters found
print(f"Best output dimension: {best_params['output_dim']}")
print(f"Best batch size: {best_params['batch_size']}")
print(f"Best validation loss: {best_params['avg_cv_loss']:.4f}")





--------------------------------------------------
Training with output_dim=128, batch_size=256
--------------------------------------------------

Fold 1/2


Epoch 1/20 (Train): 100%|██████████| 250/250 [00:05<00:00, 48.08it/s]
Epoch 1/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 31.58it/s]


Epoch 1/20, Train Loss: 0.0618, Val Loss: 0.0561, LR: 0.001000
New best model saved with validation loss: 0.0561


Epoch 2/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.18it/s]
Epoch 2/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 31.36it/s]


Epoch 2/20, Train Loss: 0.0514, Val Loss: 0.0531, LR: 0.001000
New best model saved with validation loss: 0.0531


Epoch 3/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 57.05it/s]
Epoch 3/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.40it/s]


Epoch 3/20, Train Loss: 0.0480, Val Loss: 0.0521, LR: 0.001000
New best model saved with validation loss: 0.0521


Epoch 4/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 55.21it/s]
Epoch 4/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.57it/s]


Epoch 4/20, Train Loss: 0.0459, Val Loss: 0.0521, LR: 0.001000
New best model saved with validation loss: 0.0521


Epoch 5/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 55.89it/s]
Epoch 5/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 31.75it/s]


Epoch 5/20, Train Loss: 0.0435, Val Loss: 0.0517, LR: 0.001000
New best model saved with validation loss: 0.0517


Epoch 6/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 58.19it/s]
Epoch 6/20 (Val): 100%|██████████| 250/250 [00:10<00:00, 24.76it/s]


Epoch 6/20, Train Loss: 0.0415, Val Loss: 0.0516, LR: 0.001000
New best model saved with validation loss: 0.0516


Epoch 7/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 58.67it/s]
Epoch 7/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 32.70it/s]


Epoch 7/20, Train Loss: 0.0397, Val Loss: 0.0516, LR: 0.001000


Epoch 8/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.50it/s]
Epoch 8/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 32.88it/s]


Epoch 8/20, Train Loss: 0.0381, Val Loss: 0.0517, LR: 0.001000


Epoch 9/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 55.78it/s]
Epoch 9/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.61it/s]


Epoch 9/20, Train Loss: 0.0362, Val Loss: 0.0518, LR: 0.000500


Epoch 10/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.96it/s]
Epoch 10/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.64it/s]


Epoch 10/20, Train Loss: 0.0320, Val Loss: 0.0518, LR: 0.000500


Epoch 11/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 60.06it/s]
Epoch 11/20 (Val): 100%|██████████| 250/250 [00:08<00:00, 29.57it/s]


Epoch 11/20, Train Loss: 0.0301, Val Loss: 0.0522, LR: 0.000500


Epoch 12/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.51it/s]
Epoch 12/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 32.49it/s]


Epoch 12/20, Train Loss: 0.0289, Val Loss: 0.0526, LR: 0.000250


Epoch 13/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 60.52it/s]
Epoch 13/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 32.87it/s]


Epoch 13/20, Train Loss: 0.0260, Val Loss: 0.0528, LR: 0.000250


Epoch 14/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 57.28it/s]
Epoch 14/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.69it/s]


Epoch 14/20, Train Loss: 0.0250, Val Loss: 0.0530, LR: 0.000250


Epoch 15/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 59.28it/s]
Epoch 15/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.50it/s]


Epoch 15/20, Train Loss: 0.0242, Val Loss: 0.0533, LR: 0.000125


Epoch 16/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.61it/s]
Epoch 16/20 (Val): 100%|██████████| 250/250 [00:08<00:00, 28.50it/s]


Epoch 16/20, Train Loss: 0.0226, Val Loss: 0.0535, LR: 0.000125


Epoch 17/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 55.44it/s]
Epoch 17/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 31.76it/s]


Epoch 17/20, Train Loss: 0.0221, Val Loss: 0.0537, LR: 0.000125


Epoch 18/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 54.87it/s]
Epoch 18/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.99it/s]


Epoch 18/20, Train Loss: 0.0217, Val Loss: 0.0538, LR: 0.000063


Epoch 19/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.08it/s]
Epoch 19/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.46it/s]


Epoch 19/20, Train Loss: 0.0208, Val Loss: 0.0539, LR: 0.000063


Epoch 20/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 57.74it/s]
Epoch 20/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 32.28it/s]


Epoch 20/20, Train Loss: 0.0205, Val Loss: 0.0540, LR: 0.000063

Fold 2/2


Epoch 1/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 58.63it/s]
Epoch 1/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.32it/s]


Epoch 1/20, Train Loss: 0.0628, Val Loss: 0.0534, LR: 0.001000
New best model saved with validation loss: 0.0534


Epoch 2/20 (Train): 100%|██████████| 250/250 [00:05<00:00, 47.53it/s]
Epoch 2/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.08it/s]


Epoch 2/20, Train Loss: 0.0516, Val Loss: 0.0521, LR: 0.001000
New best model saved with validation loss: 0.0521


Epoch 3/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 55.00it/s]
Epoch 3/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.49it/s]


Epoch 3/20, Train Loss: 0.0488, Val Loss: 0.0518, LR: 0.001000
New best model saved with validation loss: 0.0518


Epoch 4/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 58.37it/s]
Epoch 4/20 (Val): 100%|██████████| 250/250 [00:08<00:00, 31.10it/s]


Epoch 4/20, Train Loss: 0.0462, Val Loss: 0.0535, LR: 0.001000


Epoch 5/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.46it/s]
Epoch 5/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 32.83it/s]


Epoch 5/20, Train Loss: 0.0439, Val Loss: 0.0511, LR: 0.001000
New best model saved with validation loss: 0.0511


Epoch 6/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 53.49it/s]
Epoch 6/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.46it/s]


Epoch 6/20, Train Loss: 0.0420, Val Loss: 0.0513, LR: 0.001000


Epoch 7/20 (Train): 100%|██████████| 250/250 [00:05<00:00, 46.21it/s]
Epoch 7/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 34.13it/s]


Epoch 7/20, Train Loss: 0.0399, Val Loss: 0.0517, LR: 0.001000


Epoch 8/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 60.43it/s]
Epoch 8/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.95it/s]


Epoch 8/20, Train Loss: 0.0380, Val Loss: 0.0517, LR: 0.000500


Epoch 9/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 59.22it/s]
Epoch 9/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.75it/s]


Epoch 9/20, Train Loss: 0.0335, Val Loss: 0.0514, LR: 0.000500


Epoch 10/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 57.70it/s]
Epoch 10/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.40it/s]


Epoch 10/20, Train Loss: 0.0319, Val Loss: 0.0522, LR: 0.000500


Epoch 11/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 61.29it/s]
Epoch 11/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 34.06it/s]


Epoch 11/20, Train Loss: 0.0303, Val Loss: 0.0524, LR: 0.000250


Epoch 12/20 (Train): 100%|██████████| 250/250 [00:05<00:00, 46.51it/s]
Epoch 12/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.82it/s]


Epoch 12/20, Train Loss: 0.0275, Val Loss: 0.0525, LR: 0.000250


Epoch 13/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 60.17it/s]
Epoch 13/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.37it/s]


Epoch 13/20, Train Loss: 0.0264, Val Loss: 0.0528, LR: 0.000250


Epoch 14/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 57.25it/s]
Epoch 14/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 34.23it/s]


Epoch 14/20, Train Loss: 0.0255, Val Loss: 0.0531, LR: 0.000125


Epoch 15/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 60.09it/s]
Epoch 15/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 34.57it/s]


Epoch 15/20, Train Loss: 0.0239, Val Loss: 0.0532, LR: 0.000125


Epoch 16/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.48it/s]
Epoch 16/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 34.16it/s]


Epoch 16/20, Train Loss: 0.0233, Val Loss: 0.0533, LR: 0.000125


Epoch 17/20 (Train): 100%|██████████| 250/250 [00:05<00:00, 47.89it/s]
Epoch 17/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.34it/s]


Epoch 17/20, Train Loss: 0.0229, Val Loss: 0.0535, LR: 0.000063


Epoch 18/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.21it/s]
Epoch 18/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 32.39it/s]


Epoch 18/20, Train Loss: 0.0220, Val Loss: 0.0536, LR: 0.000063


Epoch 19/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.33it/s]
Epoch 19/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.54it/s]


Epoch 19/20, Train Loss: 0.0217, Val Loss: 0.0538, LR: 0.000063


Epoch 20/20 (Train): 100%|██████████| 250/250 [00:04<00:00, 56.48it/s]
Epoch 20/20 (Val): 100%|██████████| 250/250 [00:07<00:00, 33.52it/s]


Epoch 20/20, Train Loss: 0.0215, Val Loss: 0.0538, LR: 0.000031

Average CV loss for output_dim=128, batch_size=256: 0.0514


--------------------------------------------------
Training with output_dim=128, batch_size=512
--------------------------------------------------

Fold 1/2


Epoch 1/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.05it/s]
Epoch 1/20 (Val): 100%|██████████| 125/125 [00:07<00:00, 17.78it/s]


Epoch 1/20, Train Loss: 0.0670, Val Loss: 0.0557, LR: 0.001000
New best model saved with validation loss: 0.0557


Epoch 2/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 33.75it/s]
Epoch 2/20 (Val): 100%|██████████| 125/125 [00:09<00:00, 12.61it/s]


Epoch 2/20, Train Loss: 0.0523, Val Loss: 0.0534, LR: 0.001000
New best model saved with validation loss: 0.0534


Epoch 3/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.36it/s]
Epoch 3/20 (Val): 100%|██████████| 125/125 [00:07<00:00, 17.73it/s]


Epoch 3/20, Train Loss: 0.0485, Val Loss: 0.0527, LR: 0.001000
New best model saved with validation loss: 0.0527


Epoch 4/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.59it/s]
Epoch 4/20 (Val): 100%|██████████| 125/125 [00:07<00:00, 17.77it/s]


Epoch 4/20, Train Loss: 0.0463, Val Loss: 0.0521, LR: 0.001000
New best model saved with validation loss: 0.0521


Epoch 5/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 34.47it/s]
Epoch 5/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.12it/s]


Epoch 5/20, Train Loss: 0.0438, Val Loss: 0.0529, LR: 0.001000


Epoch 6/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 37.10it/s]
Epoch 6/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.29it/s]


Epoch 6/20, Train Loss: 0.0416, Val Loss: 0.0515, LR: 0.001000
New best model saved with validation loss: 0.0515


Epoch 7/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 37.20it/s]
Epoch 7/20 (Val): 100%|██████████| 125/125 [00:07<00:00, 15.75it/s]


Epoch 7/20, Train Loss: 0.0398, Val Loss: 0.0521, LR: 0.001000


Epoch 8/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.16it/s]
Epoch 8/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.35it/s]


Epoch 8/20, Train Loss: 0.0375, Val Loss: 0.0525, LR: 0.001000


Epoch 9/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.92it/s]
Epoch 9/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.14it/s]


Epoch 9/20, Train Loss: 0.0357, Val Loss: 0.0523, LR: 0.000500


Epoch 10/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.82it/s]
Epoch 10/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.18it/s]


Epoch 10/20, Train Loss: 0.0316, Val Loss: 0.0519, LR: 0.000500


Epoch 11/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.58it/s]
Epoch 11/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.19it/s]


Epoch 11/20, Train Loss: 0.0299, Val Loss: 0.0525, LR: 0.000500


Epoch 12/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 33.92it/s]
Epoch 12/20 (Val): 100%|██████████| 125/125 [00:07<00:00, 15.80it/s]


Epoch 12/20, Train Loss: 0.0287, Val Loss: 0.0526, LR: 0.000250


Epoch 13/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.84it/s]
Epoch 13/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.24it/s]


Epoch 13/20, Train Loss: 0.0263, Val Loss: 0.0526, LR: 0.000250


Epoch 14/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.47it/s]
Epoch 14/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.00it/s]


Epoch 14/20, Train Loss: 0.0253, Val Loss: 0.0528, LR: 0.000250


Epoch 15/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 34.48it/s]
Epoch 15/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.28it/s]


Epoch 15/20, Train Loss: 0.0246, Val Loss: 0.0530, LR: 0.000125


Epoch 16/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.16it/s]
Epoch 16/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.33it/s]


Epoch 16/20, Train Loss: 0.0232, Val Loss: 0.0530, LR: 0.000125


Epoch 17/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.04it/s]
Epoch 17/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.20it/s]


Epoch 17/20, Train Loss: 0.0228, Val Loss: 0.0530, LR: 0.000125


Epoch 18/20 (Train): 100%|██████████| 125/125 [00:05<00:00, 22.16it/s]
Epoch 18/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.16it/s]


Epoch 18/20, Train Loss: 0.0224, Val Loss: 0.0532, LR: 0.000063


Epoch 19/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 37.22it/s]
Epoch 19/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.15it/s]


Epoch 19/20, Train Loss: 0.0217, Val Loss: 0.0533, LR: 0.000063


Epoch 20/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.59it/s]
Epoch 20/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.30it/s]


Epoch 20/20, Train Loss: 0.0215, Val Loss: 0.0534, LR: 0.000063

Fold 2/2


Epoch 1/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 37.79it/s]
Epoch 1/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.13it/s]


Epoch 1/20, Train Loss: 0.0669, Val Loss: 0.0560, LR: 0.001000
New best model saved with validation loss: 0.0560


Epoch 2/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.77it/s]
Epoch 2/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.34it/s]


Epoch 2/20, Train Loss: 0.0524, Val Loss: 0.0538, LR: 0.001000
New best model saved with validation loss: 0.0538


Epoch 3/20 (Train): 100%|██████████| 125/125 [00:04<00:00, 29.09it/s]
Epoch 3/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.29it/s]


Epoch 3/20, Train Loss: 0.0490, Val Loss: 0.0510, LR: 0.001000
New best model saved with validation loss: 0.0510


Epoch 4/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.66it/s]
Epoch 4/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.14it/s]


Epoch 4/20, Train Loss: 0.0462, Val Loss: 0.0513, LR: 0.001000


Epoch 5/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.25it/s]
Epoch 5/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.05it/s]


Epoch 5/20, Train Loss: 0.0439, Val Loss: 0.0509, LR: 0.001000
New best model saved with validation loss: 0.0509


Epoch 6/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.51it/s]
Epoch 6/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.33it/s]


Epoch 6/20, Train Loss: 0.0418, Val Loss: 0.0507, LR: 0.001000
New best model saved with validation loss: 0.0507


Epoch 7/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.66it/s]
Epoch 7/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.41it/s]


Epoch 7/20, Train Loss: 0.0393, Val Loss: 0.0516, LR: 0.001000


Epoch 8/20 (Train): 100%|██████████| 125/125 [00:04<00:00, 28.93it/s]
Epoch 8/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.42it/s]


Epoch 8/20, Train Loss: 0.0378, Val Loss: 0.0515, LR: 0.001000


Epoch 9/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 38.07it/s]
Epoch 9/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.19it/s]


Epoch 9/20, Train Loss: 0.0357, Val Loss: 0.0516, LR: 0.000500


Epoch 10/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 37.66it/s]
Epoch 10/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.21it/s]


Epoch 10/20, Train Loss: 0.0320, Val Loss: 0.0515, LR: 0.000500


Epoch 11/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 38.19it/s]
Epoch 11/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.49it/s]


Epoch 11/20, Train Loss: 0.0305, Val Loss: 0.0515, LR: 0.000500


Epoch 12/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.07it/s]
Epoch 12/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.32it/s]


Epoch 12/20, Train Loss: 0.0293, Val Loss: 0.0518, LR: 0.000250


Epoch 13/20 (Train): 100%|██████████| 125/125 [00:04<00:00, 27.85it/s]
Epoch 13/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.35it/s]


Epoch 13/20, Train Loss: 0.0269, Val Loss: 0.0520, LR: 0.000250


Epoch 14/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.83it/s]
Epoch 14/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.32it/s]


Epoch 14/20, Train Loss: 0.0260, Val Loss: 0.0521, LR: 0.000250


Epoch 15/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 35.09it/s]
Epoch 15/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.46it/s]


Epoch 15/20, Train Loss: 0.0253, Val Loss: 0.0524, LR: 0.000125


Epoch 16/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.35it/s]
Epoch 16/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.14it/s]


Epoch 16/20, Train Loss: 0.0241, Val Loss: 0.0524, LR: 0.000125


Epoch 17/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 34.59it/s]
Epoch 17/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.41it/s]


Epoch 17/20, Train Loss: 0.0236, Val Loss: 0.0525, LR: 0.000125


Epoch 18/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.20it/s]
Epoch 18/20 (Val): 100%|██████████| 125/125 [00:08<00:00, 14.13it/s]


Epoch 18/20, Train Loss: 0.0233, Val Loss: 0.0526, LR: 0.000063


Epoch 19/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 34.84it/s]
Epoch 19/20 (Val): 100%|██████████| 125/125 [00:07<00:00, 17.72it/s]


Epoch 19/20, Train Loss: 0.0226, Val Loss: 0.0527, LR: 0.000063


Epoch 20/20 (Train): 100%|██████████| 125/125 [00:03<00:00, 36.96it/s]
Epoch 20/20 (Val): 100%|██████████| 125/125 [00:06<00:00, 18.34it/s]


Epoch 20/20, Train Loss: 0.0224, Val Loss: 0.0527, LR: 0.000063

Average CV loss for output_dim=128, batch_size=512: 0.0511


Best hyperparameters:
Output dimension: 128
Batch size: 512
Average CV loss: 0.0511


Training final model with best hyperparameters...


Epoch 1/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.23it/s]
Epoch 1/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 17.03it/s]


Epoch 1/20, Train Loss: 0.1191, Val Loss: 0.1076, LR: 0.001000
New best model saved with validation loss: 0.1076


Epoch 2/20 (Train): 100%|██████████| 250/250 [00:13<00:00, 18.83it/s]
Epoch 2/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.59it/s]


Epoch 2/20, Train Loss: 0.1005, Val Loss: 0.1024, LR: 0.001000
New best model saved with validation loss: 0.1024


Epoch 3/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.86it/s]
Epoch 3/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.68it/s]


Epoch 3/20, Train Loss: 0.0947, Val Loss: 0.1010, LR: 0.001000
New best model saved with validation loss: 0.1010


Epoch 4/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.18it/s]
Epoch 4/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.40it/s]


Epoch 4/20, Train Loss: 0.0906, Val Loss: 0.1025, LR: 0.001000


Epoch 5/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.27it/s]
Epoch 5/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.62it/s]


Epoch 5/20, Train Loss: 0.0878, Val Loss: 0.1005, LR: 0.001000
New best model saved with validation loss: 0.1005


Epoch 6/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.97it/s]
Epoch 6/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.58it/s]


Epoch 6/20, Train Loss: 0.0842, Val Loss: 0.0997, LR: 0.001000
New best model saved with validation loss: 0.0997


Epoch 7/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.70it/s]
Epoch 7/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.47it/s]


Epoch 7/20, Train Loss: 0.0814, Val Loss: 0.0988, LR: 0.001000
New best model saved with validation loss: 0.0988


Epoch 8/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.37it/s]
Epoch 8/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.53it/s]


Epoch 8/20, Train Loss: 0.0785, Val Loss: 0.0996, LR: 0.001000


Epoch 9/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.38it/s]
Epoch 9/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.60it/s]


Epoch 9/20, Train Loss: 0.0766, Val Loss: 0.0989, LR: 0.001000


Epoch 10/20 (Train): 100%|██████████| 250/250 [00:13<00:00, 18.85it/s]
Epoch 10/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.52it/s]


Epoch 10/20, Train Loss: 0.0735, Val Loss: 0.0994, LR: 0.000500


Epoch 11/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.83it/s]
Epoch 11/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.44it/s]


Epoch 11/20, Train Loss: 0.0671, Val Loss: 0.0987, LR: 0.000500
New best model saved with validation loss: 0.0987


Epoch 12/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.27it/s]
Epoch 12/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.42it/s]


Epoch 12/20, Train Loss: 0.0650, Val Loss: 0.0987, LR: 0.000500


Epoch 13/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.23it/s]
Epoch 13/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.56it/s]


Epoch 13/20, Train Loss: 0.0633, Val Loss: 0.0995, LR: 0.000500


Epoch 14/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.35it/s]
Epoch 14/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.45it/s]


Epoch 14/20, Train Loss: 0.0619, Val Loss: 0.0998, LR: 0.000250


Epoch 15/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.29it/s]
Epoch 15/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.64it/s]


Epoch 15/20, Train Loss: 0.0578, Val Loss: 0.0996, LR: 0.000250


Epoch 16/20 (Train): 100%|██████████| 250/250 [00:14<00:00, 17.66it/s]
Epoch 16/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.94it/s]


Epoch 16/20, Train Loss: 0.0566, Val Loss: 0.1000, LR: 0.000250


Epoch 17/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.23it/s]
Epoch 17/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.60it/s]


Epoch 17/20, Train Loss: 0.0557, Val Loss: 0.1002, LR: 0.000125


Epoch 18/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.36it/s]
Epoch 18/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.57it/s]


Epoch 18/20, Train Loss: 0.0535, Val Loss: 0.1002, LR: 0.000125


Epoch 19/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.34it/s]
Epoch 19/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.76it/s]


Epoch 19/20, Train Loss: 0.0528, Val Loss: 0.1004, LR: 0.000125


Epoch 20/20 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.45it/s]
Epoch 20/20 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.80it/s]

Epoch 20/20, Train Loss: 0.0524, Val Loss: 0.1003, LR: 0.000063





0,1
avg_cv_loss,█▁
batch_size,▁█
epoch,▁▁▃▄▅▇▇▇▁▂▄▅▅▆▇▂▂▂▃▅▅▆▆▇█▁▁▂▂▃▄▅▅▇▇█▁▁▃▅
learning_rate,█████▃▂▂█████▄▄▃▂▁▁▁██▄▃▁█████▂▁▁██▄▄▄▃▂
output_dim,▁▁
train_loss,▄▃▃▁▁▁▄▃▃▃▂▂▂▁▁▁▁▃▂▂▁▅▄▃▃▂▂▁▁▁▁▁█▇▇▅▅▄▄▄
val_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▇▇▇▇▇

0,1
avg_cv_loss,0.0511
batch_size,512.0
epoch,20.0
learning_rate,6e-05
output_dim,128.0
train_loss,0.05238
val_loss,0.10028


Best output dimension: 128
Best batch size: 512
Best validation loss: 0.0511


In [11]:
# Code to upload final model to wandb
import wandb
import os
import time
from dotenv import load_dotenv

# Load your API key from config.txt
def load_api_key_from_config(config_path="config.txt"):
    try:
        with open(config_path, "r") as f:
            first_line = f.readline().strip()
            if "=" in first_line:
                api_key = first_line.split("=")[1].strip()
            else:
                api_key = first_line
        return api_key
    except FileNotFoundError:
        print(f"Config file not found at {config_path}")
        return None

# Set up wandb - only set API key if wandb.run doesn't exist yet
if wandb.run is None:
    api_key = load_api_key_from_config()
    if api_key:
        os.environ["WANDB_API_KEY"] = api_key
        wandb.login()
        print("Successfully logged in to Weights & Biases")
    else:
        print("Failed to load API key")

# Check if there's an active run, only initialize if needed
if wandb.run is None:
    run = wandb.init(
        project="twin-tower-model",
        name="final-model-summary",  # Changed name to indicate this is a summary
        config={
            "output_dim": best_params["output_dim"],
            "batch_size": best_params["batch_size"],
            "architecture": "Twin Tower Network",
            "dataset": "MS MARCO"
        }
    )
else:
    run = wandb.run
    # Update the run with additional metadata if needed
    run.config.update({
        "architecture": "Twin Tower Network",
        "dataset": "MS MARCO"
    })

# Upload the model with a timestamp to avoid conflicts
timestamp = int(time.time())
artifact_name = f"twin-tower-final-model-{timestamp}"

model_artifact = wandb.Artifact(
    name=artifact_name, 
    type="model",
    description="Twin Tower model trained on full training data with optimal hyperparameters"
)

final_model_path = "checkpoints/final_model/final_model.pt"
model_artifact.add_file(final_model_path)
wandb.log_artifact(model_artifact)

print(f"Final model uploaded to Weights & Biases project: {run.project}")

# Only finish the run if we created it in this script
if run.name == "final-model-summary":
    wandb.finish()

Final model uploaded to Weights & Biases project: twin-tower-model
