In [None]:
import torch
import pandas as pd
import numpy as np
import datasets
from load_models_and_data import load_vocabulary, load_embeddings, text_to_embeddings, calc_cosine_sim, calculate_similarities
from tqdm import tqdm
tqdm.pandas()
from TwoTowerNN import QryTower, DocTower, TripletEmbeddingDataset, run_hyperparameter_tuning
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader,  SubsetRandomSampler
from sklearn.model_selection import KFold, train_test_split
import os
import wandb
from dotenv import load_dotenv

API key loaded successfully


[34m[1mwandb[0m: Currently logged in as: [33mnnamdi-odozi[0m ([33mnnamdi-odozi-ave-actuaries[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from datasets import load_dataset

# Loading datasets from Hugging Face
ds_soft_neg = load_dataset("cocoritzy/week_2_triplet_dataset_soft_negatives")
ds_hard_neg = load_dataset("cocoritzy/week_2_triplet_dataset_hard_negatives")


In [None]:
# Paths to your files
embeddings_path = "./downloaded_model/embeddings.pt"
vocab_path = "./downloaded_model/tkn_ids_to_words.csv"

# Load embeddings and vocabulary
print("Loading embeddings and vocabulary...")
embeddings = load_embeddings(embeddings_path)
word_to_idx = load_vocabulary(vocab_path)

print(f"Loaded embeddings with shape: {embeddings.shape}")
print(f"Loaded vocabulary with {len(word_to_idx)} tokens")

# Example usage (uncomment when ready to test)
sample_text = "This is a test sentence"
embeddings_result = text_to_embeddings(sample_text, word_to_idx, embeddings)
print(f"Embedded text shape: {embeddings_result.shape}")

# Testing - Set numpy print options
np.set_printoptions(precision=4, suppress=True, threshold=10)  # threshold limits number of elements shown
numpy_array = embeddings_result.detach().numpy()
print("Embedding array with custom formatting:")
print(numpy_array)


In [None]:
ds_soft_neg

In [None]:
df_soft_neg  = pd.DataFrame(ds_soft_neg['train'])
df_hard_neg  = pd.DataFrame(ds_hard_neg['train'])

In [None]:
embedded_query = text_to_embeddings(df_soft_neg['query'][0], word_to_idx, embeddings)
embedded_positive = text_to_embeddings(df_soft_neg['positive_passage'][0], word_to_idx, embeddings)
embedded_negative = text_to_embeddings(df_soft_neg['negative_passage'][0], word_to_idx, embeddings)

embedded_query.shape

In [None]:
a = embedded_query.mean(dim=0)
b = embedded_positive.mean(dim=0)
c = embedded_negative.mean(dim=0)
a.shape


In [None]:
import torch.nn.functional as F

cosine_similarity = F.cosine_similarity(a, c, dim=0)
print(f"Cosine similarity between query and positive passage: {cosine_similarity.item()}")

In [None]:

# # Process the dataframe using apply just for first five rows
# print("Calculating similarities... This may take a while depending on dataframe size.")
# similarities = df_soft_neg[0:5].progress_apply(
#     lambda row: calculate_similarities(row, word_to_idx, embeddings), 
#     axis=1
# )

# # Join the similarities to the dataframe
# df_soft_neg_ext = pd.concat([df_soft_neg[0:5], similarities], axis=1)

# # Show a sample of the results
# #print(df_soft_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].head())
#print(df_soft_neg_ext.head())
#print(df_soft_neg_ext.columns)

In [None]:

# Process the dataframe using apply
print("Calculating similarities... This may take a while depending on dataframe size.")
similarities = df_soft_neg.progress_apply(
    lambda row: calculate_similarities(row, word_to_idx, embeddings), 
    axis=1
)

# Join the similarities to the dataframe
df_soft_neg_ext = pd.concat([df_soft_neg, similarities], axis=1)
print(df_soft_neg_ext.head())
# Show a sample of the results
#print(df_soft_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].head())

#print(df_soft_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].mean())

# Calculate how often the positive passage is ranked higher than negative
#higher_count = (df_soft_neg_ext['query_pos_sim'] > df_soft_neg_ext['query_neg_sim']).sum()
#total = len(df_soft_neg_ext)
#print(f"\nPositive passage ranked higher than negative: {higher_count} out of {total} ({higher_count/total:.2%})")



In [None]:
# Process the dataframe using apply
print("Calculating similarities... This may take a while depending on dataframe size.")
similarities = df_hard_neg.progress_apply(
    lambda row: calculate_similarities(row, word_to_idx, embeddings), 
    axis=1
)

# Join the similarities to the dataframe
df_hard_neg_ext = pd.concat([df_hard_neg, similarities], axis=1)
print(df_hard_neg_ext.head())
# Show a sample of the results
#print(df_hard_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].head())

#print(df_hard_neg_ext[['query_pos_sim', 'query_neg_sim', 'pos_neg_sim']].mean())

# Calculate how often the positive passage is ranked higher than negative
#higher_count = (df_hard_neg_ext['query_pos_sim'] > df_hard_neg_ext['query_neg_sim']).sum()
#total = len(df_hard_neg_ext)
#print(f"\nPositive passage ranked higher than negative: {higher_count} out of {total} ({higher_count/total:.2%})")



In [None]:
df_all_neg_ext = pd.concat([df_soft_neg_ext, df_hard_neg_ext])
df_all_neg_ext.head()

In [None]:
# Save DataFrames to pickle format
df_soft_neg_ext.to_pickle("data/df_soft_neg_ext.pkl")
df_hard_neg_ext.to_pickle("data/df_hard_neg_ext.pkl")
df_all_neg_ext.to_pickle("data/df_all_neg_ext.pkl")

In [4]:
# Function to load a DataFrame from pickle if the file exists
def load_df_if_exists(file_path):
    if os.path.exists(file_path):
        return pd.read_pickle(file_path)
    else:
        print(f"File not found: {file_path}")
        return None

# Load DataFrames
df_soft_neg_ext = load_df_if_exists("data/df_soft_neg_ext.pkl")
df_hard_neg_ext = load_df_if_exists("data/df_hard_neg_ext.pkl")
df_all_neg_ext = load_df_if_exists("data/df_all_neg_ext.pkl")


In [8]:
df_soft_neg_ext.head()

Unnamed: 0,query_id,query,positive_passage,negative_passage,negative_from_query_id,avg_query_embedding,avg_pos_embedding,avg_neg_embedding
0,19699,what is rba,Results-Based Accountability® (also known as R...,I finally found some real salary data for phys...,86595,"[0.6579812, 0.24213153, 0.057250064, -0.825741...","[0.39086032, 0.3319433, 0.1275278, -0.80645, 0...","[0.569893, 0.18935415, 0.1920344, -0.7171183, ..."
1,19700,was ronald reagan a democrat,"From Wikipedia, the free encyclopedia. A Reaga...",The Pacific Ocean lies to the east while the S...,66360,"[-0.6998242, -0.24631366, -0.20571017, 0.24202...","[0.27046937, 0.2619914, 0.049588773, -0.618945...","[0.17404862, 0.21760696, -0.10469024, -0.23737..."
2,19701,how long do you need for sydney and surroundin...,Sydney is the capital city of the Australian s...,"Probiotics are found in foods such as yogurt, ...",88507,"[0.16817716, 0.29739928, -0.36492547, 0.064426...","[0.39110944, 0.23566554, 0.063871, -0.36585316...","[0.61134595, 0.36615297, 0.28972, -0.6924668, ..."
3,19702,price to install tile in shower,1 Install ceramic tile floor to match shower-A...,Iodine is critical to thyroid health and funct...,87550,"[-0.06541735, 0.2755244, 0.19580394, 0.4023429...","[0.69151133, 0.5770993, 0.22074986, -0.7754023...","[0.3590759, -0.036869664, 0.17250647, -0.53339..."
4,19703,why conversion observed in body,Conversion disorder is a type of somatoform di...,The answer to the question how much does it co...,61479,"[-0.13369766, -0.30740747, 0.5450557, 0.391294...","[0.42539537, 0.13814452, 0.37000972, -0.632320...","[0.5729694, 0.314426, 0.13929352, -0.9086552, ..."


### Twin Tower Network

In [None]:
# Create tower instances
qryTower = QryTower()
docTower = DocTower()


# Define hyperparameters
batch_size = 128
num_epochs = 1 # adjust num of epochs here
dataset_size = len(df_soft_neg_ext)  # or len(df_hard_neg_ext) depending on the dataset you want to use
steps_per_epoch = dataset_size // batch_size
total_steps = steps_per_epoch * num_epochs
learning_rate = 1e-3
embedding_dim = 128 
margin = 0.2 

In [None]:
# Create the dataset
dataset = TripletEmbeddingDataset(df_soft_neg_ext)

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,  # Adjust based on your machine's capabilities
    pin_memory=True  # Set to True if using GPU
)

In [None]:
qry = torch.randn(batch_size, embedding_dim)  # Query embeddings
pos = torch.randn(batch_size, embedding_dim)  # Positive doc embeddings
neg = torch.randn(batch_size, embedding_dim)  # Negative doc embeddings

#qry = df1['q']


# Set up the AdamW optimizer
optimizer = torch.optim.AdamW([
    {'params': qryTower.parameters()},
    {'params': docTower.parameters()}
], lr=learning_rate)

# Add learning rate scheduler (ReduceLROnPlateau)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',       # Reduce LR when monitored value stops decreasing
    factor=0.5,       # Multiply LR by this factor when reducing
    patience=2,       # Number of epochs with no improvement after which LR will be reduced
    verbose=True      # Print message when LR is reduced
)



In [None]:
# Training loop (simplified example)
for epoch in range(num_epochs):
    qryTower.train()
    docTower.train()
    
    
    total_loss = 0
    for batch in dataloader:
        # Get embeddings from batch
        query_emb = batch['query']
        pos_emb = batch['positive']
        neg_emb = batch['negative']
        
        # Forward pass through towers
        query_encoded = qryTower(query_emb)
        pos_encoded = docTower(pos_emb)
        neg_encoded = docTower(neg_emb)
        
        # Calculate similarities
        pos_sim = torch.nn.functional.cosine_similarity(query_encoded, pos_encoded)
        neg_sim = torch.nn.functional.cosine_similarity(query_encoded, neg_encoded)
        
        # Triplet loss
        margin = margin
        loss = torch.clamp(margin - pos_sim + neg_sim, min=0).mean()
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * len(query_emb)
    
    # Calculate average loss
    avg_loss = total_loss / len(dataset)
    
    # Update scheduler
    scheduler.step(avg_loss)
    
    # Print epoch results
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, "
          f"LR: {optimizer.param_groups[0]['lr']:.6f}")

In [None]:
# Run the hyperparameter tuning with your dataframe
best_params, final_qry_tower, final_doc_tower = run_hyperparameter_tuning(
    df_all_neg_ext,
    output_dims=[128],
    batch_sizes=[256, 512],
    epochs=10
)

# Print the best parameters found
print(f"Best output dimension: {best_params['output_dim']}")
print(f"Best batch size: {best_params['batch_size']}")
print(f"Best validation loss: {best_params['avg_cv_loss']:.4f}")





--------------------------------------------------
Training with output_dim=128, batch_size=256
--------------------------------------------------

Fold 1/5


Epoch 1/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 64.79it/s]
Epoch 1/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.12it/s]


Epoch 1/15, Train Loss: 0.0961, Val Loss: 0.0215, LR: 0.001000
New best model saved with validation loss: 0.0215


Epoch 2/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 58.04it/s]
Epoch 2/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.23it/s]


Epoch 2/15, Train Loss: 0.0811, Val Loss: 0.0209, LR: 0.001000
New best model saved with validation loss: 0.0209


Epoch 3/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 53.62it/s]
Epoch 3/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.29it/s]


Epoch 3/15, Train Loss: 0.0764, Val Loss: 0.0206, LR: 0.001000
New best model saved with validation loss: 0.0206


Epoch 4/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.12it/s]
Epoch 4/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.27it/s]


Epoch 4/15, Train Loss: 0.0731, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 5/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 58.92it/s]
Epoch 5/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.09it/s]


Epoch 5/15, Train Loss: 0.0698, Val Loss: 0.0202, LR: 0.001000


Epoch 6/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.13it/s]
Epoch 6/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.98it/s]


Epoch 6/15, Train Loss: 0.0669, Val Loss: 0.0202, LR: 0.001000


Epoch 7/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.42it/s]
Epoch 7/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 28.74it/s]


Epoch 7/15, Train Loss: 0.0641, Val Loss: 0.0201, LR: 0.000500


Epoch 8/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 52.27it/s]
Epoch 8/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.40it/s]


Epoch 8/15, Train Loss: 0.0579, Val Loss: 0.0200, LR: 0.000500


Epoch 9/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.18it/s]
Epoch 9/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 28.79it/s]


Epoch 9/15, Train Loss: 0.0555, Val Loss: 0.0199, LR: 0.000500
New best model saved with validation loss: 0.0199


Epoch 10/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 57.21it/s]
Epoch 10/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.94it/s]


Epoch 10/15, Train Loss: 0.0535, Val Loss: 0.0202, LR: 0.000500


Epoch 11/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 59.66it/s]
Epoch 11/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.19it/s]


Epoch 11/15, Train Loss: 0.0517, Val Loss: 0.0202, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 59.83it/s]
Epoch 12/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.44it/s]


Epoch 12/15, Train Loss: 0.0500, Val Loss: 0.0203, LR: 0.000250


Epoch 13/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 50.75it/s]
Epoch 13/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.12it/s]


Epoch 13/15, Train Loss: 0.0459, Val Loss: 0.0203, LR: 0.000250


Epoch 14/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 59.88it/s]
Epoch 14/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.41it/s]


Epoch 14/15, Train Loss: 0.0446, Val Loss: 0.0203, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 63.23it/s]
Epoch 15/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.63it/s]


Epoch 15/15, Train Loss: 0.0435, Val Loss: 0.0204, LR: 0.000125

Fold 2/5


Epoch 1/15 (Train): 100%|██████████| 399/399 [00:05<00:00, 66.57it/s]
Epoch 1/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.55it/s]


Epoch 1/15, Train Loss: 0.0940, Val Loss: 0.0218, LR: 0.001000
New best model saved with validation loss: 0.0218


Epoch 2/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 64.66it/s]
Epoch 2/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.14it/s]


Epoch 2/15, Train Loss: 0.0806, Val Loss: 0.0210, LR: 0.001000
New best model saved with validation loss: 0.0210


Epoch 3/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 56.24it/s]
Epoch 3/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.48it/s]


Epoch 3/15, Train Loss: 0.0764, Val Loss: 0.0204, LR: 0.001000
New best model saved with validation loss: 0.0204


Epoch 4/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 64.39it/s]
Epoch 4/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 32.26it/s]


Epoch 4/15, Train Loss: 0.0732, Val Loss: 0.0204, LR: 0.001000


Epoch 5/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.44it/s]
Epoch 5/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.38it/s]


Epoch 5/15, Train Loss: 0.0701, Val Loss: 0.0203, LR: 0.001000
New best model saved with validation loss: 0.0203


Epoch 6/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.75it/s]
Epoch 6/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.78it/s]


Epoch 6/15, Train Loss: 0.0673, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 7/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.84it/s]
Epoch 7/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.73it/s]


Epoch 7/15, Train Loss: 0.0645, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 8/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 52.04it/s]
Epoch 8/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.96it/s]


Epoch 8/15, Train Loss: 0.0623, Val Loss: 0.0201, LR: 0.001000


Epoch 9/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.48it/s]
Epoch 9/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.61it/s]


Epoch 9/15, Train Loss: 0.0598, Val Loss: 0.0201, LR: 0.001000


Epoch 10/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.90it/s]
Epoch 10/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.07it/s]


Epoch 10/15, Train Loss: 0.0576, Val Loss: 0.0203, LR: 0.000500


Epoch 11/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.18it/s]
Epoch 11/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.97it/s]


Epoch 11/15, Train Loss: 0.0513, Val Loss: 0.0201, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.95it/s]
Epoch 12/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 28.90it/s]


Epoch 12/15, Train Loss: 0.0491, Val Loss: 0.0202, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 65.82it/s]
Epoch 13/15 (Val): 100%|██████████| 100/100 [00:04<00:00, 23.70it/s]


Epoch 13/15, Train Loss: 0.0475, Val Loss: 0.0204, LR: 0.000250


Epoch 14/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 58.57it/s]
Epoch 14/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.12it/s]


Epoch 14/15, Train Loss: 0.0435, Val Loss: 0.0203, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.08it/s]
Epoch 15/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 28.25it/s]


Epoch 15/15, Train Loss: 0.0422, Val Loss: 0.0204, LR: 0.000250

Fold 3/5


Epoch 1/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 56.48it/s]
Epoch 1/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.53it/s]


Epoch 1/15, Train Loss: 0.0954, Val Loss: 0.0212, LR: 0.001000
New best model saved with validation loss: 0.0212


Epoch 2/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.88it/s]
Epoch 2/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.37it/s]


Epoch 2/15, Train Loss: 0.0803, Val Loss: 0.0205, LR: 0.001000
New best model saved with validation loss: 0.0205


Epoch 3/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.09it/s]
Epoch 3/15 (Val): 100%|██████████| 100/100 [00:04<00:00, 23.58it/s]


Epoch 3/15, Train Loss: 0.0760, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 4/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.23it/s]
Epoch 4/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 28.90it/s]


Epoch 4/15, Train Loss: 0.0728, Val Loss: 0.0201, LR: 0.001000


Epoch 5/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.20it/s]
Epoch 5/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.84it/s]


Epoch 5/15, Train Loss: 0.0696, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 6/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.83it/s]
Epoch 6/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.05it/s]


Epoch 6/15, Train Loss: 0.0672, Val Loss: 0.0201, LR: 0.001000


Epoch 7/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.95it/s]
Epoch 7/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.41it/s]


Epoch 7/15, Train Loss: 0.0643, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 8/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 57.88it/s]
Epoch 8/15 (Val): 100%|██████████| 100/100 [00:04<00:00, 23.13it/s]


Epoch 8/15, Train Loss: 0.0619, Val Loss: 0.0200, LR: 0.001000


Epoch 9/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.63it/s]
Epoch 9/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.79it/s]


Epoch 9/15, Train Loss: 0.0594, Val Loss: 0.0200, LR: 0.001000


Epoch 10/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.47it/s]
Epoch 10/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.68it/s]


Epoch 10/15, Train Loss: 0.0571, Val Loss: 0.0202, LR: 0.000500


Epoch 11/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 59.86it/s]
Epoch 11/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.49it/s]


Epoch 11/15, Train Loss: 0.0511, Val Loss: 0.0200, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.31it/s]
Epoch 12/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.47it/s]


Epoch 12/15, Train Loss: 0.0488, Val Loss: 0.0202, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.63it/s]
Epoch 13/15 (Val): 100%|██████████| 100/100 [00:04<00:00, 24.14it/s]


Epoch 13/15, Train Loss: 0.0471, Val Loss: 0.0202, LR: 0.000250


Epoch 14/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 65.52it/s]
Epoch 14/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.90it/s]


Epoch 14/15, Train Loss: 0.0433, Val Loss: 0.0202, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.55it/s]
Epoch 15/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.23it/s]


Epoch 15/15, Train Loss: 0.0420, Val Loss: 0.0203, LR: 0.000250

Fold 4/5


Epoch 1/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.88it/s]
Epoch 1/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.60it/s]


Epoch 1/15, Train Loss: 0.0964, Val Loss: 0.0213, LR: 0.001000
New best model saved with validation loss: 0.0213


Epoch 2/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.94it/s]
Epoch 2/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.57it/s]


Epoch 2/15, Train Loss: 0.0803, Val Loss: 0.0211, LR: 0.001000
New best model saved with validation loss: 0.0211


Epoch 3/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.28it/s]
Epoch 3/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.34it/s]


Epoch 3/15, Train Loss: 0.0761, Val Loss: 0.0204, LR: 0.001000
New best model saved with validation loss: 0.0204


Epoch 4/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 54.08it/s]
Epoch 4/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.60it/s]


Epoch 4/15, Train Loss: 0.0726, Val Loss: 0.0201, LR: 0.001000
New best model saved with validation loss: 0.0201


Epoch 5/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.03it/s]
Epoch 5/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 32.99it/s]


Epoch 5/15, Train Loss: 0.0696, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 6/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 57.38it/s]
Epoch 6/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.82it/s]


Epoch 6/15, Train Loss: 0.0668, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 7/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 57.71it/s]
Epoch 7/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.53it/s]


Epoch 7/15, Train Loss: 0.0643, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 8/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.74it/s]
Epoch 8/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.83it/s]


Epoch 8/15, Train Loss: 0.0618, Val Loss: 0.0204, LR: 0.001000


Epoch 9/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 54.10it/s]
Epoch 9/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.36it/s]


Epoch 9/15, Train Loss: 0.0594, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 10/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 57.86it/s]
Epoch 10/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.53it/s]


Epoch 10/15, Train Loss: 0.0569, Val Loss: 0.0201, LR: 0.001000


Epoch 11/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.47it/s]
Epoch 11/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 32.32it/s]


Epoch 11/15, Train Loss: 0.0550, Val Loss: 0.0200, LR: 0.001000


Epoch 12/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 59.65it/s]
Epoch 12/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.77it/s]


Epoch 12/15, Train Loss: 0.0529, Val Loss: 0.0203, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.60it/s]
Epoch 13/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 29.58it/s]


Epoch 13/15, Train Loss: 0.0467, Val Loss: 0.0201, LR: 0.000500


Epoch 14/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 51.86it/s]
Epoch 14/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.37it/s]


Epoch 14/15, Train Loss: 0.0445, Val Loss: 0.0203, LR: 0.000500


Epoch 15/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.96it/s]
Epoch 15/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.98it/s]


Epoch 15/15, Train Loss: 0.0429, Val Loss: 0.0202, LR: 0.000250

Fold 5/5


Epoch 1/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.14it/s]
Epoch 1/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.93it/s]


Epoch 1/15, Train Loss: 0.0962, Val Loss: 0.0210, LR: 0.001000
New best model saved with validation loss: 0.0210


Epoch 2/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 65.07it/s]
Epoch 2/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.09it/s]


Epoch 2/15, Train Loss: 0.0816, Val Loss: 0.0205, LR: 0.001000
New best model saved with validation loss: 0.0205


Epoch 3/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.71it/s]
Epoch 3/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.93it/s]


Epoch 3/15, Train Loss: 0.0770, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 4/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 55.18it/s]
Epoch 4/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 32.00it/s]


Epoch 4/15, Train Loss: 0.0734, Val Loss: 0.0202, LR: 0.001000


Epoch 5/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 64.38it/s]
Epoch 5/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.47it/s]


Epoch 5/15, Train Loss: 0.0703, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 6/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 57.56it/s]
Epoch 6/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.50it/s]


Epoch 6/15, Train Loss: 0.0675, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 7/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.93it/s]
Epoch 7/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.41it/s]


Epoch 7/15, Train Loss: 0.0645, Val Loss: 0.0199, LR: 0.001000


Epoch 8/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.19it/s]
Epoch 8/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 30.95it/s]


Epoch 8/15, Train Loss: 0.0619, Val Loss: 0.0200, LR: 0.001000


Epoch 9/15 (Train): 100%|██████████| 399/399 [00:07<00:00, 54.23it/s]
Epoch 9/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.22it/s]


Epoch 9/15, Train Loss: 0.0593, Val Loss: 0.0201, LR: 0.000500


Epoch 10/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 64.40it/s]
Epoch 10/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.37it/s]


Epoch 10/15, Train Loss: 0.0530, Val Loss: 0.0198, LR: 0.000500
New best model saved with validation loss: 0.0198


Epoch 11/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.74it/s]
Epoch 11/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.30it/s]


Epoch 11/15, Train Loss: 0.0506, Val Loss: 0.0198, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 61.90it/s]
Epoch 12/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 32.31it/s]


Epoch 12/15, Train Loss: 0.0487, Val Loss: 0.0201, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 62.69it/s]
Epoch 13/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 32.13it/s]


Epoch 13/15, Train Loss: 0.0472, Val Loss: 0.0202, LR: 0.000250


Epoch 14/15 (Train): 100%|██████████| 399/399 [00:06<00:00, 60.65it/s]
Epoch 14/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 32.01it/s]


Epoch 14/15, Train Loss: 0.0432, Val Loss: 0.0201, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 399/399 [00:08<00:00, 46.07it/s]
Epoch 15/15 (Val): 100%|██████████| 100/100 [00:03<00:00, 31.47it/s]


Epoch 15/15, Train Loss: 0.0419, Val Loss: 0.0202, LR: 0.000250

Average CV loss for output_dim=128, batch_size=256: 0.0199


--------------------------------------------------
Training with output_dim=128, batch_size=512
--------------------------------------------------

Fold 1/5


Epoch 1/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.05it/s]
Epoch 1/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 15.91it/s]


Epoch 1/15, Train Loss: 0.0989, Val Loss: 0.0213, LR: 0.001000
New best model saved with validation loss: 0.0213


Epoch 2/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 36.15it/s]
Epoch 2/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.80it/s]


Epoch 2/15, Train Loss: 0.0810, Val Loss: 0.0205, LR: 0.001000
New best model saved with validation loss: 0.0205


Epoch 3/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 41.44it/s]
Epoch 3/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.52it/s]


Epoch 3/15, Train Loss: 0.0766, Val Loss: 0.0202, LR: 0.001000
New best model saved with validation loss: 0.0202


Epoch 4/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.92it/s]
Epoch 4/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.16it/s]


Epoch 4/15, Train Loss: 0.0726, Val Loss: 0.0201, LR: 0.001000
New best model saved with validation loss: 0.0201


Epoch 5/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 34.02it/s]
Epoch 5/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.07it/s]


Epoch 5/15, Train Loss: 0.0699, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 6/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 37.84it/s]
Epoch 6/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.18it/s]


Epoch 6/15, Train Loss: 0.0672, Val Loss: 0.0198, LR: 0.001000
New best model saved with validation loss: 0.0198


Epoch 7/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.04it/s]
Epoch 7/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.19it/s]


Epoch 7/15, Train Loss: 0.0644, Val Loss: 0.0199, LR: 0.001000


Epoch 8/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 43.74it/s]
Epoch 8/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.67it/s]


Epoch 8/15, Train Loss: 0.0619, Val Loss: 0.0198, LR: 0.001000
New best model saved with validation loss: 0.0198


Epoch 9/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.88it/s]
Epoch 9/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.80it/s]


Epoch 9/15, Train Loss: 0.0598, Val Loss: 0.0200, LR: 0.001000


Epoch 10/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 34.05it/s]
Epoch 10/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.28it/s]


Epoch 10/15, Train Loss: 0.0571, Val Loss: 0.0202, LR: 0.001000


Epoch 11/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.34it/s]
Epoch 11/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.88it/s]


Epoch 11/15, Train Loss: 0.0551, Val Loss: 0.0201, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 41.37it/s]
Epoch 12/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.97it/s]


Epoch 12/15, Train Loss: 0.0498, Val Loss: 0.0198, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.84it/s]
Epoch 13/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.15it/s]


Epoch 13/15, Train Loss: 0.0478, Val Loss: 0.0200, LR: 0.000500


Epoch 14/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.34it/s]
Epoch 14/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.03it/s]


Epoch 14/15, Train Loss: 0.0463, Val Loss: 0.0201, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 200/200 [00:06<00:00, 32.20it/s]
Epoch 15/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.41it/s]


Epoch 15/15, Train Loss: 0.0431, Val Loss: 0.0201, LR: 0.000250

Fold 2/5


Epoch 1/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 37.07it/s]
Epoch 1/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.06it/s]


Epoch 1/15, Train Loss: 0.0986, Val Loss: 0.0215, LR: 0.001000
New best model saved with validation loss: 0.0215


Epoch 2/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.37it/s]
Epoch 2/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.36it/s]


Epoch 2/15, Train Loss: 0.0813, Val Loss: 0.0206, LR: 0.001000
New best model saved with validation loss: 0.0206


Epoch 3/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 36.75it/s]
Epoch 3/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.26it/s]


Epoch 3/15, Train Loss: 0.0759, Val Loss: 0.0203, LR: 0.001000
New best model saved with validation loss: 0.0203


Epoch 4/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.78it/s]
Epoch 4/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.79it/s]


Epoch 4/15, Train Loss: 0.0724, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 5/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 33.42it/s]
Epoch 5/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.08it/s]


Epoch 5/15, Train Loss: 0.0696, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 6/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.78it/s]
Epoch 6/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.99it/s]


Epoch 6/15, Train Loss: 0.0663, Val Loss: 0.0199, LR: 0.001000


Epoch 7/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.31it/s]
Epoch 7/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.26it/s]


Epoch 7/15, Train Loss: 0.0639, Val Loss: 0.0198, LR: 0.001000
New best model saved with validation loss: 0.0198


Epoch 8/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.99it/s]
Epoch 8/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.96it/s]


Epoch 8/15, Train Loss: 0.0614, Val Loss: 0.0199, LR: 0.001000


Epoch 9/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.34it/s]
Epoch 9/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.30it/s]


Epoch 9/15, Train Loss: 0.0592, Val Loss: 0.0199, LR: 0.001000


Epoch 10/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.76it/s]
Epoch 10/15 (Val): 100%|██████████| 50/50 [00:05<00:00,  8.60it/s]


Epoch 10/15, Train Loss: 0.0571, Val Loss: 0.0200, LR: 0.000500


Epoch 11/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 37.43it/s]
Epoch 11/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.85it/s]


Epoch 11/15, Train Loss: 0.0515, Val Loss: 0.0199, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.59it/s]
Epoch 12/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.20it/s]


Epoch 12/15, Train Loss: 0.0495, Val Loss: 0.0200, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.48it/s]
Epoch 13/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.32it/s]


Epoch 13/15, Train Loss: 0.0479, Val Loss: 0.0201, LR: 0.000250


Epoch 14/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 41.32it/s]
Epoch 14/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.58it/s]


Epoch 14/15, Train Loss: 0.0446, Val Loss: 0.0201, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.74it/s]
Epoch 15/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 12.65it/s]


Epoch 15/15, Train Loss: 0.0435, Val Loss: 0.0202, LR: 0.000250

Fold 3/5


Epoch 1/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.54it/s]
Epoch 1/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.71it/s]


Epoch 1/15, Train Loss: 0.1000, Val Loss: 0.0219, LR: 0.001000
New best model saved with validation loss: 0.0219


Epoch 2/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.13it/s]
Epoch 2/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.64it/s]


Epoch 2/15, Train Loss: 0.0821, Val Loss: 0.0209, LR: 0.001000
New best model saved with validation loss: 0.0209


Epoch 3/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.26it/s]
Epoch 3/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.43it/s]


Epoch 3/15, Train Loss: 0.0772, Val Loss: 0.0205, LR: 0.001000
New best model saved with validation loss: 0.0205


Epoch 4/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.27it/s]
Epoch 4/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.80it/s]


Epoch 4/15, Train Loss: 0.0738, Val Loss: 0.0201, LR: 0.001000
New best model saved with validation loss: 0.0201


Epoch 5/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.14it/s]
Epoch 5/15 (Val): 100%|██████████| 50/50 [00:04<00:00, 12.26it/s]


Epoch 5/15, Train Loss: 0.0708, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 6/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.33it/s]
Epoch 6/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.14it/s]


Epoch 6/15, Train Loss: 0.0677, Val Loss: 0.0200, LR: 0.001000


Epoch 7/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.89it/s]
Epoch 7/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.90it/s]


Epoch 7/15, Train Loss: 0.0652, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 8/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.39it/s]
Epoch 8/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.90it/s]


Epoch 8/15, Train Loss: 0.0626, Val Loss: 0.0200, LR: 0.001000


Epoch 9/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.84it/s]
Epoch 9/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.94it/s]


Epoch 9/15, Train Loss: 0.0604, Val Loss: 0.0200, LR: 0.001000


Epoch 10/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.91it/s]
Epoch 10/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.98it/s]


Epoch 10/15, Train Loss: 0.0580, Val Loss: 0.0200, LR: 0.000500


Epoch 11/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 33.92it/s]
Epoch 11/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.75it/s]


Epoch 11/15, Train Loss: 0.0523, Val Loss: 0.0201, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 41.57it/s]
Epoch 12/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.13it/s]


Epoch 12/15, Train Loss: 0.0501, Val Loss: 0.0199, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.21it/s]
Epoch 13/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.44it/s]


Epoch 13/15, Train Loss: 0.0486, Val Loss: 0.0201, LR: 0.000250


Epoch 14/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 41.06it/s]
Epoch 14/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 15.99it/s]


Epoch 14/15, Train Loss: 0.0451, Val Loss: 0.0201, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.44it/s]
Epoch 15/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.93it/s]


Epoch 15/15, Train Loss: 0.0439, Val Loss: 0.0202, LR: 0.000250

Fold 4/5


Epoch 1/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 35.76it/s]
Epoch 1/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.55it/s]


Epoch 1/15, Train Loss: 0.0973, Val Loss: 0.0215, LR: 0.001000
New best model saved with validation loss: 0.0215


Epoch 2/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 44.83it/s]
Epoch 2/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.95it/s]


Epoch 2/15, Train Loss: 0.0814, Val Loss: 0.0206, LR: 0.001000
New best model saved with validation loss: 0.0206


Epoch 3/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 44.47it/s]
Epoch 3/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.94it/s]


Epoch 3/15, Train Loss: 0.0768, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 4/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 43.04it/s]
Epoch 4/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.24it/s]


Epoch 4/15, Train Loss: 0.0730, Val Loss: 0.0200, LR: 0.001000
New best model saved with validation loss: 0.0200


Epoch 5/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 41.90it/s]
Epoch 5/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.46it/s]


Epoch 5/15, Train Loss: 0.0704, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 6/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 35.73it/s]
Epoch 6/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.69it/s]


Epoch 6/15, Train Loss: 0.0677, Val Loss: 0.0199, LR: 0.001000
New best model saved with validation loss: 0.0199


Epoch 7/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.49it/s]
Epoch 7/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.86it/s]


Epoch 7/15, Train Loss: 0.0650, Val Loss: 0.0198, LR: 0.001000
New best model saved with validation loss: 0.0198


Epoch 8/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 43.82it/s]
Epoch 8/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.93it/s]


Epoch 8/15, Train Loss: 0.0628, Val Loss: 0.0197, LR: 0.001000
New best model saved with validation loss: 0.0197


Epoch 9/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.85it/s]
Epoch 9/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.35it/s]


Epoch 9/15, Train Loss: 0.0603, Val Loss: 0.0199, LR: 0.001000


Epoch 10/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.91it/s]
Epoch 10/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.63it/s]


Epoch 10/15, Train Loss: 0.0578, Val Loss: 0.0197, LR: 0.001000


Epoch 11/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 34.99it/s]
Epoch 11/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.97it/s]


Epoch 11/15, Train Loss: 0.0561, Val Loss: 0.0200, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 43.84it/s]
Epoch 12/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.79it/s]


Epoch 12/15, Train Loss: 0.0502, Val Loss: 0.0196, LR: 0.000500
New best model saved with validation loss: 0.0196


Epoch 13/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.55it/s]
Epoch 13/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.19it/s]


Epoch 13/15, Train Loss: 0.0481, Val Loss: 0.0197, LR: 0.000500


Epoch 14/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 44.94it/s]
Epoch 14/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.02it/s]


Epoch 14/15, Train Loss: 0.0466, Val Loss: 0.0198, LR: 0.000500


Epoch 15/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 44.69it/s]
Epoch 15/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.95it/s]


Epoch 15/15, Train Loss: 0.0453, Val Loss: 0.0199, LR: 0.000250

Fold 5/5


Epoch 1/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 35.87it/s]
Epoch 1/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]


Epoch 1/15, Train Loss: 0.0999, Val Loss: 0.0211, LR: 0.001000
New best model saved with validation loss: 0.0211


Epoch 2/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.98it/s]
Epoch 2/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.99it/s]


Epoch 2/15, Train Loss: 0.0812, Val Loss: 0.0205, LR: 0.001000
New best model saved with validation loss: 0.0205


Epoch 3/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 42.59it/s]
Epoch 3/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.44it/s]


Epoch 3/15, Train Loss: 0.0762, Val Loss: 0.0198, LR: 0.001000
New best model saved with validation loss: 0.0198


Epoch 4/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 44.63it/s]
Epoch 4/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 17.17it/s]


Epoch 4/15, Train Loss: 0.0728, Val Loss: 0.0197, LR: 0.001000
New best model saved with validation loss: 0.0197


Epoch 5/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 45.12it/s]
Epoch 5/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 18.21it/s]


Epoch 5/15, Train Loss: 0.0695, Val Loss: 0.0196, LR: 0.001000
New best model saved with validation loss: 0.0196


Epoch 6/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 43.43it/s]
Epoch 6/15 (Val): 100%|██████████| 50/50 [00:04<00:00, 10.28it/s]


Epoch 6/15, Train Loss: 0.0669, Val Loss: 0.0196, LR: 0.001000


Epoch 7/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 40.58it/s]
Epoch 7/15 (Val): 100%|██████████| 50/50 [00:02<00:00, 16.67it/s]


Epoch 7/15, Train Loss: 0.0639, Val Loss: 0.0197, LR: 0.001000


Epoch 8/15 (Train): 100%|██████████| 200/200 [00:04<00:00, 41.13it/s]
Epoch 8/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.35it/s]


Epoch 8/15, Train Loss: 0.0614, Val Loss: 0.0195, LR: 0.001000
New best model saved with validation loss: 0.0195


Epoch 9/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.25it/s]
Epoch 9/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 15.85it/s]


Epoch 9/15, Train Loss: 0.0590, Val Loss: 0.0198, LR: 0.001000


Epoch 10/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 39.21it/s]
Epoch 10/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.44it/s]


Epoch 10/15, Train Loss: 0.0569, Val Loss: 0.0196, LR: 0.001000


Epoch 11/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.69it/s]
Epoch 11/15 (Val): 100%|██████████| 50/50 [00:05<00:00,  8.63it/s]


Epoch 11/15, Train Loss: 0.0550, Val Loss: 0.0199, LR: 0.000500


Epoch 12/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.79it/s]
Epoch 12/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.36it/s]


Epoch 12/15, Train Loss: 0.0494, Val Loss: 0.0197, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 38.19it/s]
Epoch 13/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.04it/s]


Epoch 13/15, Train Loss: 0.0474, Val Loss: 0.0198, LR: 0.000500


Epoch 14/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 37.62it/s]
Epoch 14/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.12it/s]


Epoch 14/15, Train Loss: 0.0461, Val Loss: 0.0198, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 200/200 [00:05<00:00, 37.36it/s]
Epoch 15/15 (Val): 100%|██████████| 50/50 [00:03<00:00, 16.24it/s]


Epoch 15/15, Train Loss: 0.0425, Val Loss: 0.0200, LR: 0.000250

Average CV loss for output_dim=128, batch_size=512: 0.0197


Best hyperparameters:
Output dimension: 128
Batch size: 512
Average CV loss: 0.0197


Training final model with best hyperparameters...


Epoch 1/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.26it/s]
Epoch 1/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.47it/s]


Epoch 1/15, Train Loss: 0.1181, Val Loss: 0.1063, LR: 0.001000
New best model saved with validation loss: 0.1063


Epoch 2/15 (Train): 100%|██████████| 250/250 [00:13<00:00, 19.14it/s]
Epoch 2/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 17.08it/s]


Epoch 2/15, Train Loss: 0.1003, Val Loss: 0.1024, LR: 0.001000
New best model saved with validation loss: 0.1024


Epoch 3/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.17it/s]
Epoch 3/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.43it/s]


Epoch 3/15, Train Loss: 0.0950, Val Loss: 0.1022, LR: 0.001000
New best model saved with validation loss: 0.1022


Epoch 4/15 (Train): 100%|██████████| 250/250 [00:13<00:00, 18.66it/s]
Epoch 4/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.11it/s]


Epoch 4/15, Train Loss: 0.0911, Val Loss: 0.1005, LR: 0.001000
New best model saved with validation loss: 0.1005


Epoch 5/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.27it/s]
Epoch 5/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.51it/s]


Epoch 5/15, Train Loss: 0.0874, Val Loss: 0.0997, LR: 0.001000
New best model saved with validation loss: 0.0997


Epoch 6/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.22it/s]
Epoch 6/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.17it/s]


Epoch 6/15, Train Loss: 0.0843, Val Loss: 0.0995, LR: 0.001000
New best model saved with validation loss: 0.0995


Epoch 7/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.82it/s]
Epoch 7/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.20it/s]


Epoch 7/15, Train Loss: 0.0815, Val Loss: 0.1005, LR: 0.001000


Epoch 8/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.21it/s]
Epoch 8/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.20it/s]


Epoch 8/15, Train Loss: 0.0788, Val Loss: 0.0999, LR: 0.001000


Epoch 9/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.98it/s]
Epoch 9/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.23it/s]


Epoch 9/15, Train Loss: 0.0765, Val Loss: 0.0996, LR: 0.000500


Epoch 10/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.85it/s]
Epoch 10/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.17it/s]


Epoch 10/15, Train Loss: 0.0696, Val Loss: 0.0992, LR: 0.000500
New best model saved with validation loss: 0.0992


Epoch 11/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.02it/s]
Epoch 11/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.19it/s]


Epoch 11/15, Train Loss: 0.0669, Val Loss: 0.0989, LR: 0.000500
New best model saved with validation loss: 0.0989


Epoch 12/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.06it/s]
Epoch 12/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.41it/s]


Epoch 12/15, Train Loss: 0.0652, Val Loss: 0.0992, LR: 0.000500


Epoch 13/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.97it/s]
Epoch 13/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.33it/s]


Epoch 13/15, Train Loss: 0.0635, Val Loss: 0.0995, LR: 0.000500


Epoch 14/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 20.17it/s]
Epoch 14/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.30it/s]


Epoch 14/15, Train Loss: 0.0620, Val Loss: 0.0998, LR: 0.000250


Epoch 15/15 (Train): 100%|██████████| 250/250 [00:12<00:00, 19.90it/s]
Epoch 15/15 (Val): 100%|██████████| 63/63 [00:03<00:00, 18.31it/s]

Epoch 15/15, Train Loss: 0.0578, Val Loss: 0.0999, LR: 0.000250





0,1
avg_cv_loss,█▁
batch_size,▁█
epoch,▁▁▄▇█▃▅▇▇█▅▅▆█▁▄▅▅▇▂▅▅▆▁▁▇▁▂▄▇▂▃▄▅▂▅▇▇▅▇
learning_rate,█▃███▃▁███████▃█▃▃▃▁▃▁███▃▁▁███▃▃██▃▁█▃▁
output_dim,▁▁
train_loss,▅▅▃▂▂▆▅▅▄▃▇▆▄▂▂▁▄▄▃▇▄▃▃█▅▅▄▁█▅▂█▃▂▂▅▄▄▇▅
val_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██

0,1
avg_cv_loss,0.01975
batch_size,512.0
epoch,15.0
learning_rate,0.00025
output_dim,128.0
train_loss,0.05783
val_loss,0.09991


Best output dimension: 128
Best batch size: 512
Best validation loss: 0.0197


In [11]:
# Code to upload final model to wandb
import wandb
import os
import time
from dotenv import load_dotenv

# Load your API key from config.txt
def load_api_key_from_config(config_path="config.txt"):
    try:
        with open(config_path, "r") as f:
            first_line = f.readline().strip()
            if "=" in first_line:
                api_key = first_line.split("=")[1].strip()
            else:
                api_key = first_line
        return api_key
    except FileNotFoundError:
        print(f"Config file not found at {config_path}")
        return None

# Set up wandb - only set API key if wandb.run doesn't exist yet
if wandb.run is None:
    api_key = load_api_key_from_config()
    if api_key:
        os.environ["WANDB_API_KEY"] = api_key
        wandb.login()
        print("Successfully logged in to Weights & Biases")
    else:
        print("Failed to load API key")

# Check if there's an active run, only initialize if needed
if wandb.run is None:
    run = wandb.init(
        project="twin-tower-model",
        name="final-model-summary",  # Changed name to indicate this is a summary
        config={
            "output_dim": best_params["output_dim"],
            "batch_size": best_params["batch_size"],
            "architecture": "Twin Tower Network",
            "dataset": "MS MARCO"
        }
    )
else:
    run = wandb.run
    # Update the run with additional metadata if needed
    run.config.update({
        "architecture": "Twin Tower Network",
        "dataset": "MS MARCO"
    })

# Upload the model with a timestamp to avoid conflicts
timestamp = int(time.time())
artifact_name = f"twin-tower-final-model-{timestamp}"

model_artifact = wandb.Artifact(
    name=artifact_name, 
    type="model",
    description="Twin Tower model trained on full training data with optimal hyperparameters"
)

final_model_path = "checkpoints/final_model/final_model.pt"
model_artifact.add_file(final_model_path)
wandb.log_artifact(model_artifact)

print(f"Final model uploaded to Weights & Biases project: {run.project}")

# Only finish the run if we created it in this script
if run.name == "final-model-summary":
    wandb.finish()

Final model uploaded to Weights & Biases project: twin-tower-model
