In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Cell 1: Import required libraries
import os
import yaml
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
from dotenv import load_dotenv
warnings.filterwarnings('ignore')

# load env variables
load_dotenv('../../../.env')

# Import your model and data loading components
from dataloader.dataset_wrapper import create_wrapper_from_dataframe

print("Libraries imported successfully!")

Libraries imported successfully!


In [3]:
# First install ClearML if not already installed: pip install clearml
import clearml
from clearml import Task, Logger

# Initialize ClearML Task
task = Task.init(project_name='CSMP_thesis_project', task_name='CSMP_traning_experiment_5', reuse_last_task_id=False)
logger = Logger.current_logger()

ClearML Task: created new task id=2dba1462897e474e98fad19f45e54c4d
ClearML results page: https://app.clear.ml/projects/0fec81950d384f0294d2c713df3887db/experiments/2dba1462897e474e98fad19f45e54c4d/output/log


ClearML Monitor: GPU monitoring failed getting GPU reading, switching off GPU monitoring
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start


In [4]:
# Cell 2: Configuration and paths setup
CONFIG_PATH = "../../../configs/config.yaml"
TRAIN_CSV_PATH = "../../../data/traning_and_validation/train_deduplicated.csv"

In [5]:
# Device selection with automatic CUDA device selection based on free memory
if torch.backends.mps.is_available():
    DEVICE = 'mps'
    print("Using MPS (Metal Performance Shaders) for GPU acceleration")
elif torch.cuda.is_available():
    # Find CUDA device with most free memory
    max_free_memory = 0
    best_device = 0
    
    print("Available CUDA devices:")
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)
        torch.cuda.empty_cache()  # Clear cache to get accurate memory info
        
        total_memory = torch.cuda.get_device_properties(i).total_memory
        allocated_memory = torch.cuda.memory_allocated(i)
        free_memory = total_memory - allocated_memory
        
        print(f"  Device {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Total memory: {total_memory / 1024**3:.2f} GB")
        print(f"    Allocated memory: {allocated_memory / 1024**3:.2f} GB")
        print(f"    Free memory: {free_memory / 1024**3:.2f} GB")
        
        if free_memory > max_free_memory:
            max_free_memory = free_memory
            best_device = i
    
    DEVICE = f'cuda:{best_device}'
    print(f"\nSelected device {best_device} with {max_free_memory / 1024**3:.2f} GB free memory")
    print("Using CUDA for GPU acceleration")
else:
    DEVICE = 'cpu'
    print("Using CPU")

print(f"Using device: {DEVICE}")
print(f"Validation data path: {TRAIN_CSV_PATH}")

Available CUDA devices:
  Device 0: NVIDIA A100 80GB PCIe
    Total memory: 79.25 GB
    Allocated memory: 0.00 GB
    Free memory: 79.25 GB
  Device 1: NVIDIA A100 80GB PCIe
    Total memory: 79.25 GB
    Allocated memory: 0.00 GB
    Free memory: 79.25 GB

Selected device 0 with 79.25 GB free memory
Using CUDA for GPU acceleration
Using device: cuda:0
Validation data path: ../../../data/traning_and_validation/train_deduplicated.csv


In [6]:
DEVICE = f'cuda:1'

In [7]:
# Cell 3: Load configuration
print("Loading configuration...")
config = yaml.load(open(CONFIG_PATH, "r"), Loader=yaml.FullLoader)

print("Configuration loaded:")
print(f"- Batch size: {config.get('batch_size', 64)}")
print(f"- Model config keys: {list(config.get('model', {}).keys())}")
print(f"- Loss config: {config.get('loss', {})}")

Loading configuration...
Configuration loaded:
- Batch size: 256
- Model config keys: []
- Loss config: {'temperature': 0.3, 'use_cosine_similarity': True, 'alpha_weight': 0.75}


In [8]:
task.connect(config)

{'batch_size': 256,
 'epochs': 100,
 'eval_every_n_epochs': 5,
 'log_every_n_steps': 10,
 'learning_rate': '1e-05',
 'weight_decay': '1e-6',
 'valid_size': 0.2,
 'fp16_precision': True,
 'truncation': True,
 'sample_size': 10000,
 'model_config': {'emb_dim': 256,
  'spec_embed_dim': 256,
  'embed_dim': 128,
  'feat_dim': 512,
  'num_layer': 5,
  'layers': 5,
  'drop_ratio': 0.3,
  'dropout': 0.1,
  'pool': 'mean'},
 'loss': {'temperature': 0.3,
  'use_cosine_similarity': True,
  'alpha_weight': 0.75}}

In [9]:
# Cell 4: Load and explore validation data
print("Loading train data...")
df_train = pd.read_csv(TRAIN_CSV_PATH)

print(f"Validation dataset shape: {df_train.shape}")
print(f"Columns: {list(df_train.columns)}")
print(f"Sample data:")
df_train.head()

Loading train data...
Validation dataset shape: (798444, 10)
Columns: ['peaks_json', 'ion_source', 'compound_source', 'instrument', 'adduct', 'precursor_mz', 'smiles', 'inchikey', 'ion_mode', 'molecular_formula']
Sample data:


Unnamed: 0,peaks_json,ion_source,compound_source,instrument,adduct,precursor_mz,smiles,inchikey,ion_mode,molecular_formula
0,"[[42.014248, 0.10199999999999998], [42.26601, ...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
1,"[[49.01717, 0.155], [49.020023, 0.253], [67.05...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
2,"[[49.017338, 0.242], [49.020237, 0.181], [67.0...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
3,"[[49.01701, 0.144], [49.019947, 0.244], [139.0...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
4,"[[49.017166, 0.155], [49.020008, 0.253], [139....",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3


In [19]:
df_train_sample = df_train.sample(n=config.get('sample_size'),random_state=42).reset_index(drop=True)

In [20]:
# Cell 6: Prepare validation data loader
print("Preparing data loaders")

# Create data wrapper from DataFrame
wrapper = create_wrapper_from_dataframe(
    df=df_train_sample,
    batch_size=config.get('batch_size'),  
    num_workers=8,
    valid_size=config.get('valid_size'),  
    use_ddp=False,
    output_dir="../../../data/train_feature/",
    recompute=True
)

# Get the data loader
train_loader, val_loader = wrapper.get_data_loaders()

Preparing data loaders
Converting DataFrame to compatible files...
Processed 10000 valid spectra out of 10000 total entries.
Create data wrapper
Total spectra: 10000
Unique molecules: 7805
Training spectra: 7983 from 6244 molecules
Validation spectra: 2017 from 1561 molecules
✓ No molecule overlap between train and validation sets
calculating molecular graphs


  7%|▋         | 567/7983 [00:01<00:29, 253.46it/s]

SMILES [Cl-].O=C1C2=CC=C(O)C(=C2OC(=C1C=3C=CC=4OCCCOC4C3)C)C[NH+](C)C calculation failure


 14%|█▍        | 1141/7983 [00:02<00:11, 601.89it/s]

SMILES CCC1=C(C2=NC1=CC3=C(C4=C([N-]3)C(=C5[C@H]([C@@H](C(=N5)C=C6C(=C(C(=C2)[N-]6)C=C)C)C)CCC(=O)OC/C=C(\C)/CCC[C@H](C)CCC[C@H](C)CCCC(C)C)[C@H](C4=O)C(=O)OC)C)C=O.[Mg+2] calculation failure
SMILES [Na+].O=P([O-])(O)OCC1OC(N2C=NC=3C(=NC=NC32)N)C(O)C1O calculation failure


 21%|██        | 1662/7983 [00:03<00:09, 638.03it/s]

SMILES [K+].O=S(=O)([O-])ON=C(SC1OC(CO)C(O)C(O)C1O)CC=C.O calculation failure


 28%|██▊       | 2236/7983 [00:04<00:09, 627.45it/s]

SMILES [Cl-].O=C1C(=COC2=C1C=C(C(O)=C2C[NH+](C)C)CC)C=3C=CC=4OCCOC4C3 calculation failure


 43%|████▎     | 3446/7983 [00:06<00:07, 605.06it/s]

SMILES C1C(N(C2=C(N1)N=C(NC2=O)N)C=O)CNC3=CC=C(C=C3)C(=O)N[C@@H](CCC(=O)[O-])C(=O)[O-].[Ca+2] calculation failure
SMILES [Cl-].O=C1C2=CC=C(O)C(=C2OC(=C1C=3C=CC=4OCCCOC4C3)C)C[NH+](C)C calculation failure


 47%|████▋     | 3764/7983 [00:06<00:06, 625.10it/s]

SMILES [Cl-].OC=1C=C(O)C=2C=C(OC3OC(CO)C(O)C(O)C3O)C(=[O+]C2C1)C=4C=CC(O)=C(O)C4 calculation failure


 51%|█████     | 4081/7983 [00:07<00:06, 621.98it/s]

SMILES [K+].[K+].O=C([O-])C1OC(OC2C(OC(C(=O)[O-])C(O)C2O)OC3CCC4(C)C5C(=O)C=C6C7CC(C(=O)O)(C)CCC7(C)CCC6(C)C5(C)CCC4C3(C)C)C(O)C(O)C1O calculation failure
SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure


 63%|██████▎   | 5042/7983 [00:09<00:04, 615.32it/s]

SMILES [Na+].O=C([O-])C(CC)C1OC(C(=CC=CC2C=CC3CCCC3C2C(=O)C4=CC=CN4)CC)C(C)CC1 calculation failure


 66%|██████▌   | 5233/7983 [00:09<00:04, 624.61it/s]

SMILES [K+].O=S(=O)([O-])ON=C(SC1OC(CO)C(O)C(O)C1O)CC=C.O calculation failure
SMILES [Cl-].O=C1C2=CC=C(O)C(=C2OC(=C1C=3C=CC=4OCCCOC4C3)C)C[NH+](C)C calculation failure


 68%|██████▊   | 5424/7983 [00:09<00:04, 621.19it/s]

SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure


 71%|███████   | 5682/7983 [00:10<00:03, 619.18it/s]

SMILES [Cl-].O=C(O)C=1C=CC=CC1C=2C=3C=CC(=CC3OC4=CC(C=CC42)=[N+](CC)CC)N(CC)CC calculation failure


 74%|███████▎  | 5877/7983 [00:10<00:03, 629.09it/s]

SMILES [Na+].O=C(CCCCCCCCCCC)CC(O)S(=O)(=O)[O-] calculation failure


 78%|███████▊  | 6194/7983 [00:10<00:02, 621.73it/s]

SMILES [Na+].O=C(CCCCCCCCCCC)CC(O)S(=O)(=O)[O-] calculation failure


 81%|████████  | 6447/7983 [00:11<00:02, 622.14it/s]

SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure
SMILES [Na+].O=C(CCCCCCCCCCC)CC(O)S(=O)(=O)[O-] calculation failure


 86%|████████▋ | 6887/7983 [00:12<00:01, 607.59it/s]

SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure


 90%|█████████ | 7198/7983 [00:12<00:01, 613.89it/s]

SMILES [Cl-].OC=1C=C(O)C=2C=C(OC3OC(CO)C(O)C(O)C3O)C(=[O+]C2C1)C=4C=CC(O)=C(O)C4 calculation failure


100%|██████████| 7983/7983 [00:13<00:00, 578.66it/s]


Calculated 6850 molecular graph-mass spectrometry pairs
calculating molecular graphs


 26%|██▌       | 516/2017 [00:00<00:02, 642.42it/s]

SMILES [Cl-].O=C1C=2C=C(C(O)=C(C2OC(=C1C=3C=CC=4OCCOC4C3)C)C[NH+](C)C)CCC calculation failure


 45%|████▌     | 908/2017 [00:01<00:01, 638.74it/s]

SMILES [Br-].O=C(OC1CC2C3OC3C(C1)[N+]2(C)CCCC)C(C=4C=CC=CC4)CO calculation failure


 77%|███████▋  | 1562/2017 [00:02<00:00, 635.21it/s]

SMILES [K+].O=C([O-])C12CCC(C(=C)C)C2C3CCC4C5(C)CCC(=O)C(C)(C)C5CCC4(C)C3(C)CC1 calculation failure


100%|██████████| 2017/2017 [00:03<00:00, 633.25it/s]

Calculated 1722 molecular graph-mass spectrometry pairs





In [21]:
from model import ModelCLR

# Initialize model architecture
model = ModelCLR(**config["model_config"]).to(DEVICE)

In [22]:
model

ModelCLR(
  (Smiles_model): SmilesModel(
    (x_embedding1): Embedding(119, 256)
    (x_embedding2): Embedding(4, 256)
    (x_embedding3): Embedding(8, 256)
    (x_embedding4): Embedding(6, 256)
    (x_embedding5): Embedding(5, 256)
    (gnns): ModuleList(
      (0-4): 5 x GINEConv()
    )
    (batch_norms): ModuleList(
      (0-4): 5 x BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (feat_lin): Linear(in_features=256, out_features=512, bias=True)
    (out_lin): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (MS_model): MSModel(
    (mz_embedder): FourierEmbedder()
    (input_compress): Linear(in_features=257, out_features=256, bias=True)
    (peak_attn_layers): ModuleList(
      (0-4): 5 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in

In [23]:
# Parameter Count
print("PARAMETER ANALYSIS:")
print("-" * 40)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Trainable Parameters: {trainable_params:,}")


PARAMETER ANALYSIS:
----------------------------------------
Trainable Parameters: 6,102,784


In [24]:
# Memory Usage (approximate)
print("MEMORY ANALYSIS:")
print("-" * 40)
param_size = sum(p.numel() * p.element_size() for p in model.parameters())
buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
model_size_mb = (param_size + buffer_size) / 1024 / 1024

print(f"Model Size: {model_size_mb:.2f} MB")
print(f"Parameter Memory: {param_size / 1024 / 1024:.2f} MB")
print(f"Buffer Memory: {buffer_size / 1024 / 1024:.2f} MB")


MEMORY ANALYSIS:
----------------------------------------
Model Size: 23.29 MB
Parameter Memory: 23.28 MB
Buffer Memory: 0.01 MB


In [25]:
from loss.nt_xent import NTXentLoss

# Initialize loss function
temperature = config.get('loss', {}).get('temperature', 0.1)
batch_size = config.get('batch_size', 512)
use_cosine_similarity = config.get('loss', {}).get('use_cosine_similarity', True)
alpha_weight = config.get('loss', {}).get('alpha_weight', 1.0)

criterion = NTXentLoss(
    device=DEVICE, 
    batch_size=batch_size, 
    temperature=temperature, 
    use_cosine_similarity=use_cosine_similarity, 
    alpha_weight=alpha_weight
)

In [26]:
# Cell 15: Training Setup and Optimizer
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

print("Setting up training components...")
OUTPUT_DIR = "../../../models/models_experiments/candidate_v1"

# Initialize optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=float(config.get('learning_rate', 1e-05)),
    weight_decay=float(config.get('weight_decay', 1e-05))
)

# Training configuration
epochs = config.get('epochs', 100)
eval_every_n_epochs = config.get('eval_every_n_epochs', 5)
log_every_n_steps = config.get('log_every_n_steps', 2)

# Add ReduceLROnPlateau scheduler for representation learning
scheduler = ReduceLROnPlateau(
    optimizer, 
    mode='min',           # minimize validation loss
    factor=0.5,           # reduce LR by half
    patience=3,             
    verbose=True,
    min_lr=1e-6,
    threshold=1e-4
)

# Create checkpoint directory
checkpoint_dir = os.path.join(OUTPUT_DIR, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

print(f"Training for {epochs} epochs")
print(f"Optimizer: AdamW with lr={config.get('learning_rate', 5e-6)}")
print(f"Checkpoint directory: {checkpoint_dir}")

Setting up training components...
Training for 100 epochs
Optimizer: AdamW with lr=1e-05
Checkpoint directory: ../../../models/models_experiments/candidate_v1/checkpoints


In [27]:
# Cell 16: Training and Evaluation Functions
def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    total_loss = 0
    num_batches = 0
    batch_losses = []
    
    progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch}")
    
    for batch_idx, (graphs, mzs, intensities, num_peaks) in enumerate(progress_bar):
        
        # Move data to device
        graphs = graphs.to(device)
        mzs = mzs.to(device)
        intensities = intensities.to(device)
        num_peaks = num_peaks.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        mol_features, spec_features = model(graphs, mzs, intensities, num_peaks)
        loss = criterion(mol_features, spec_features)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track loss
        batch_loss = loss.item()
        total_loss += batch_loss
        batch_losses.append(batch_loss)
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{batch_loss:.4f}',
            'Avg Loss': f'{total_loss/num_batches:.4f}'
        })
        
        # Log every n steps
        if batch_idx % log_every_n_steps == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {batch_loss:.4f}")
    
    avg_loss = total_loss / num_batches
    return avg_loss, batch_losses

In [28]:
def evaluate_model(model, val_loader, criterion, device, epoch):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    molecular_features_list = []
    spectral_features_list = []
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Evaluating Epoch {epoch}")
        
        for batch_idx, (graphs, mzs, intensities, num_peaks) in enumerate(progress_bar):
            # Move data to device
            graphs = graphs.to(device)
            mzs = mzs.to(device)
            intensities = intensities.to(device)
            num_peaks = num_peaks.to(device)
            
            # Forward pass
            mol_features, spec_features = model(graphs, mzs, intensities, num_peaks)
            loss = criterion(mol_features, spec_features)
            
            total_loss += loss.item()
            num_batches += 1
            
            # Normalize features before storing 
            mol_features_norm = torch.nn.functional.normalize(mol_features, p=2, dim=1)
            spec_features_norm = torch.nn.functional.normalize(spec_features, p=2, dim=1)
            
            # Store normalized features for retrieval metrics
            molecular_features_list.append(mol_features_norm.cpu().numpy())
            spectral_features_list.append(spec_features_norm.cpu().numpy())
            
            progress_bar.set_postfix({'Val Loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / num_batches
    
    # Compute retrieval metrics
    all_mol_features = np.vstack(molecular_features_list)
    all_spec_features = np.vstack(spectral_features_list)
        
    # Compute cosine similarities
    cosine_similarities = np.sum(all_mol_features * all_spec_features, axis=1)
    mean_similarity = np.mean(cosine_similarities)
        
    return avg_loss, mean_similarity

In [29]:
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, checkpoint_dir):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'config': config
    }
    
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")
    
    # Save best model
    best_checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
    if not os.path.exists(best_checkpoint_path):
        torch.save(checkpoint, best_checkpoint_path)
        print(f"Best model saved: {best_checkpoint_path}")
    else:
        best_checkpoint = torch.load(best_checkpoint_path)
        if val_loss < best_checkpoint['val_loss']:
            torch.save(checkpoint, best_checkpoint_path)
            print(f"New best model saved: {best_checkpoint_path}")

In [None]:
import time

# Training history for plotting
train_history = {
    'epochs': [],
    'train_losses': [],
    'val_losses': [],
    'val_similarities': [],
    'learning_rates': [],  
}

best_val_loss = float('inf')
start_time = time.time()

print("Starting training...")
print("=" * 60)

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    
    # Training phase
    train_loss, batch_losses = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE, epoch
    )    

    # Get current learning rate after scheduler step
    current_lr = optimizer.param_groups[0]['lr']
    
    # Log training loss every epoch
    logger.report_scalar("Loss", "Train", iteration=epoch, value=train_loss)
    
    # Validation phase (every n epochs)
    if epoch % eval_every_n_epochs == 0:
        val_loss, val_similarity = evaluate_model(
            model, val_loader, criterion, DEVICE, epoch
        )
        
    
        # # Step scheduler every epoch
        # scheduler.step(val_loss)
        
        # Log results
        epoch_time = time.time() - epoch_start_time
        total_time = time.time() - start_time
        
        # Log validation metrics to ClearML
        logger.report_scalar("Loss", "Validation", iteration=epoch, value=val_loss)
        logger.report_scalar("Similarity", "Cosine Similarity", iteration=epoch, value=val_similarity)
        
        print(f"\nEpoch {epoch}/{epochs}")
        print(f"Train Loss: {train_loss:.6f}")
        print(f"Val Loss: {val_loss:.6f}")
        print(f"Val mean similarity: {val_similarity:.4f}")
        print(f"Learning Rate: {current_lr:.2e}")
        print(f"Epoch Time: {epoch_time:.2f}s, Total Time: {total_time:.2f}s")
        print("-" * 60)
        
        # Store history
        train_history['epochs'].append(int(epoch))
        train_history['train_losses'].append(float(train_loss))
        train_history['val_losses'].append(float(val_loss))
        train_history['val_similarities'].append(float(val_similarity))
        train_history['learning_rates'].append(float(current_lr))
        
        # Save checkpoint
        # save_checkpoint(
        #     model, optimizer, epoch, 
        #     train_loss, val_loss, checkpoint_dir
        # )
        
        if epoch % (eval_every_n_epochs * 2) == 0 and len(train_history['epochs']) > 1:
            import matplotlib.pyplot as plt
            
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
            
            # Loss plot
            epochs_list = train_history['epochs']
            ax1.plot(epochs_list, train_history['train_losses'], 'b-', label='Train Loss', linewidth=2)
            ax1.plot(epochs_list, train_history['val_losses'], 'r-', label='Val Loss', linewidth=2)
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training and Validation Loss')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # Similarity plot
            ax2.plot(epochs_list, train_history['val_similarities'], 'g-', linewidth=2)
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Cosine Similarity')
            ax2.set_title('Validation Cosine Similarity')
            ax2.grid(True, alpha=0.3)
            
            # Learning rate plot
            ax3.plot(epochs_list, train_history['learning_rates'], 'purple', linewidth=2)
            ax3.set_xlabel('Epoch')
            ax3.set_ylabel('Learning Rate')
            ax3.set_title('Learning Rate Schedule')
            ax3.set_yscale('log') 
            ax3.grid(True, alpha=0.3)
            
            plt.tight_layout()
            
            logger.report_matplotlib_figure("Training Progress", "Loss, Similarity and LR", iteration=epoch, figure=plt)
            plt.close()
    
    else:
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.6f}, LR: {current_lr:.2e}, Time: {epoch_time:.2f}s")

# Final summary logging
print("\nTraining completed!")
total_training_time = (time.time() - start_time)/3600
print(f"Total training time: {total_training_time:.2f} hours")

Starting training...


Training Epoch 1:   0%|          | 0/26 [00:00<?, ?it/s]

Training Epoch 1:   4%|▍         | 1/26 [00:01<00:25,  1.00s/it, Loss=5.5561, Avg Loss=5.5561]

Epoch 1, Batch 0, Loss: 5.5561


Training Epoch 1:  42%|████▏     | 11/26 [00:04<00:04,  3.36it/s, Loss=5.5449, Avg Loss=5.5500]

Epoch 1, Batch 10, Loss: 5.5449


Training Epoch 1:  81%|████████  | 21/26 [00:07<00:01,  3.22it/s, Loss=5.5469, Avg Loss=5.5489]

Epoch 1, Batch 20, Loss: 5.5469


Training Epoch 1: 100%|██████████| 26/26 [00:08<00:00,  2.92it/s, Loss=5.5432, Avg Loss=5.5484]


Epoch 1/100 - Train Loss: 5.548429, LR: 1.00e-05, Time: 8.91s


Training Epoch 2:   4%|▍         | 1/26 [00:01<00:26,  1.06s/it, Loss=5.5470, Avg Loss=5.5470]

Epoch 2, Batch 0, Loss: 5.5470


Training Epoch 2:  42%|████▏     | 11/26 [00:04<00:04,  3.24it/s, Loss=5.5462, Avg Loss=5.5460]

Epoch 2, Batch 10, Loss: 5.5462


Training Epoch 2:  81%|████████  | 21/26 [00:07<00:01,  3.20it/s, Loss=5.5487, Avg Loss=5.5458]

Epoch 2, Batch 20, Loss: 5.5487


Training Epoch 2: 100%|██████████| 26/26 [00:08<00:00,  2.98it/s, Loss=5.5469, Avg Loss=5.5460]


Epoch 2/100 - Train Loss: 5.546022, LR: 1.00e-05, Time: 8.72s


Training Epoch 3:   4%|▍         | 1/26 [00:00<00:22,  1.10it/s, Loss=5.5482, Avg Loss=5.5482]

Epoch 3, Batch 0, Loss: 5.5482


Training Epoch 3:  42%|████▏     | 11/26 [00:03<00:04,  3.15it/s, Loss=5.5475, Avg Loss=5.5456]

Epoch 3, Batch 10, Loss: 5.5475


Training Epoch 3:  81%|████████  | 21/26 [00:06<00:01,  3.28it/s, Loss=5.5472, Avg Loss=5.5456]

Epoch 3, Batch 20, Loss: 5.5472


Training Epoch 3: 100%|██████████| 26/26 [00:08<00:00,  3.01it/s, Loss=5.5451, Avg Loss=5.5453]


Epoch 3/100 - Train Loss: 5.545349, LR: 1.00e-05, Time: 8.65s


Training Epoch 4:   4%|▍         | 1/26 [00:00<00:22,  1.10it/s, Loss=5.5450, Avg Loss=5.5450]

Epoch 4, Batch 0, Loss: 5.5450


Training Epoch 4:  42%|████▏     | 11/26 [00:03<00:04,  3.42it/s, Loss=5.5448, Avg Loss=5.5449]

Epoch 4, Batch 10, Loss: 5.5448


Training Epoch 4:  81%|████████  | 21/26 [00:07<00:01,  3.09it/s, Loss=5.5436, Avg Loss=5.5448]

Epoch 4, Batch 20, Loss: 5.5436


Training Epoch 4: 100%|██████████| 26/26 [00:08<00:00,  3.00it/s, Loss=5.5466, Avg Loss=5.5448]


Epoch 4/100 - Train Loss: 5.544811, LR: 1.00e-05, Time: 8.67s


Training Epoch 5:   4%|▍         | 1/26 [00:00<00:20,  1.20it/s, Loss=5.5470, Avg Loss=5.5470]

Epoch 5, Batch 0, Loss: 5.5470


Training Epoch 5:  42%|████▏     | 11/26 [00:03<00:04,  3.19it/s, Loss=5.5450, Avg Loss=5.5453]

Epoch 5, Batch 10, Loss: 5.5450


Training Epoch 5:  81%|████████  | 21/26 [00:06<00:01,  3.16it/s, Loss=5.5445, Avg Loss=5.5447]

Epoch 5, Batch 20, Loss: 5.5445


Training Epoch 5: 100%|██████████| 26/26 [00:08<00:00,  3.00it/s, Loss=5.5460, Avg Loss=5.5449]
Evaluating Epoch 5: 100%|██████████| 7/7 [00:01<00:00,  4.50it/s, Val Loss=5.2212]



Epoch 5/100
Train Loss: 5.544893
Val Loss: 5.500609
Val mean similarity: -0.0277
Learning Rate: 1.00e-05
Epoch Time: 10.22s, Total Time: 45.18s
------------------------------------------------------------


Training Epoch 6:   4%|▍         | 1/26 [00:00<00:22,  1.13it/s, Loss=5.5394, Avg Loss=5.5394]

Epoch 6, Batch 0, Loss: 5.5394


Training Epoch 6:  42%|████▏     | 11/26 [00:03<00:04,  3.05it/s, Loss=5.5489, Avg Loss=5.5452]

Epoch 6, Batch 10, Loss: 5.5489


Training Epoch 6:  81%|████████  | 21/26 [00:07<00:01,  3.17it/s, Loss=5.5444, Avg Loss=5.5450]

Epoch 6, Batch 20, Loss: 5.5444


Training Epoch 6: 100%|██████████| 26/26 [00:08<00:00,  2.96it/s, Loss=5.5423, Avg Loss=5.5449]


Epoch 6/100 - Train Loss: 5.544889, LR: 1.00e-05, Time: 8.78s


Training Epoch 7:   4%|▍         | 1/26 [00:00<00:22,  1.12it/s, Loss=5.5422, Avg Loss=5.5422]

Epoch 7, Batch 0, Loss: 5.5422


Training Epoch 7:  42%|████▏     | 11/26 [00:04<00:04,  3.03it/s, Loss=5.5460, Avg Loss=5.5449]

Epoch 7, Batch 10, Loss: 5.5460


Training Epoch 7:  81%|████████  | 21/26 [00:07<00:01,  3.10it/s, Loss=5.5490, Avg Loss=5.5453]

Epoch 7, Batch 20, Loss: 5.5490


Training Epoch 7: 100%|██████████| 26/26 [00:08<00:00,  2.92it/s, Loss=5.5466, Avg Loss=5.5452]


Epoch 7/100 - Train Loss: 5.545202, LR: 1.00e-05, Time: 8.90s


Training Epoch 8:   4%|▍         | 1/26 [00:00<00:19,  1.26it/s, Loss=5.5456, Avg Loss=5.5456]

Epoch 8, Batch 0, Loss: 5.5456


Training Epoch 8:  42%|████▏     | 11/26 [00:03<00:04,  3.32it/s, Loss=5.5432, Avg Loss=5.5443]

Epoch 8, Batch 10, Loss: 5.5432


Training Epoch 8:  81%|████████  | 21/26 [00:07<00:01,  3.20it/s, Loss=5.5445, Avg Loss=5.5448]

Epoch 8, Batch 20, Loss: 5.5445


Training Epoch 8: 100%|██████████| 26/26 [00:08<00:00,  3.03it/s, Loss=5.5457, Avg Loss=5.5450]


Epoch 8/100 - Train Loss: 5.545001, LR: 1.00e-05, Time: 8.60s


Training Epoch 9:   4%|▍         | 1/26 [00:00<00:20,  1.24it/s, Loss=5.5470, Avg Loss=5.5470]

Epoch 9, Batch 0, Loss: 5.5470


Training Epoch 9:  42%|████▏     | 11/26 [00:03<00:04,  3.28it/s, Loss=5.5430, Avg Loss=5.5458]

Epoch 9, Batch 10, Loss: 5.5430


Training Epoch 9:  81%|████████  | 21/26 [00:06<00:01,  3.45it/s, Loss=5.5405, Avg Loss=5.5448]

Epoch 9, Batch 20, Loss: 5.5405


Training Epoch 9: 100%|██████████| 26/26 [00:08<00:00,  3.00it/s, Loss=5.5448, Avg Loss=5.5446]


Epoch 9/100 - Train Loss: 5.544621, LR: 1.00e-05, Time: 8.68s


Training Epoch 10:   4%|▍         | 1/26 [00:00<00:21,  1.16it/s, Loss=5.5416, Avg Loss=5.5416]

Epoch 10, Batch 0, Loss: 5.5416


Training Epoch 10:  42%|████▏     | 11/26 [00:03<00:04,  3.20it/s, Loss=5.5458, Avg Loss=5.5444]

Epoch 10, Batch 10, Loss: 5.5458


Training Epoch 10:  81%|████████  | 21/26 [00:07<00:01,  3.33it/s, Loss=5.5384, Avg Loss=5.5436]

Epoch 10, Batch 20, Loss: 5.5384


Training Epoch 10: 100%|██████████| 26/26 [00:08<00:00,  2.98it/s, Loss=5.5426, Avg Loss=5.5437]
Evaluating Epoch 10: 100%|██████████| 7/7 [00:01<00:00,  5.17it/s, Val Loss=5.2213]



Epoch 10/100
Train Loss: 5.543704
Val Loss: 5.501530
Val mean similarity: -0.0163
Learning Rate: 1.00e-05
Epoch Time: 10.10s, Total Time: 90.24s
------------------------------------------------------------


Training Epoch 11:   4%|▍         | 1/26 [00:00<00:20,  1.22it/s, Loss=5.5410, Avg Loss=5.5410]

Epoch 11, Batch 0, Loss: 5.5410


Training Epoch 11:  42%|████▏     | 11/26 [00:03<00:04,  3.31it/s, Loss=5.5404, Avg Loss=5.5434]

Epoch 11, Batch 10, Loss: 5.5404


Training Epoch 11:  81%|████████  | 21/26 [00:07<00:01,  3.23it/s, Loss=5.5450, Avg Loss=5.5443]

Epoch 11, Batch 20, Loss: 5.5450


Training Epoch 11: 100%|██████████| 26/26 [00:08<00:00,  2.97it/s, Loss=5.5460, Avg Loss=5.5448]


Epoch 11/100 - Train Loss: 5.544779, LR: 1.00e-05, Time: 8.76s


Training Epoch 12:   4%|▍         | 1/26 [00:00<00:18,  1.35it/s, Loss=5.5426, Avg Loss=5.5426]

Epoch 12, Batch 0, Loss: 5.5426


Training Epoch 12:  42%|████▏     | 11/26 [00:03<00:04,  3.68it/s, Loss=5.5455, Avg Loss=5.5457]

Epoch 12, Batch 10, Loss: 5.5455


Training Epoch 12:  81%|████████  | 21/26 [00:06<00:01,  3.24it/s, Loss=5.5438, Avg Loss=5.5441]

Epoch 12, Batch 20, Loss: 5.5438


Training Epoch 12: 100%|██████████| 26/26 [00:08<00:00,  3.04it/s, Loss=5.5428, Avg Loss=5.5441]


Epoch 12/100 - Train Loss: 5.544077, LR: 1.00e-05, Time: 8.55s


Training Epoch 13:   4%|▍         | 1/26 [00:01<00:27,  1.10s/it, Loss=5.5496, Avg Loss=5.5496]

Epoch 13, Batch 0, Loss: 5.5496


Training Epoch 13:  42%|████▏     | 11/26 [00:04<00:04,  3.11it/s, Loss=5.5434, Avg Loss=5.5438]

Epoch 13, Batch 10, Loss: 5.5434


Training Epoch 13:  81%|████████  | 21/26 [00:08<00:01,  2.77it/s, Loss=5.5451, Avg Loss=5.5439]

Epoch 13, Batch 20, Loss: 5.5451


Training Epoch 13: 100%|██████████| 26/26 [00:09<00:00,  2.69it/s, Loss=5.5404, Avg Loss=5.5439]


Epoch 13/100 - Train Loss: 5.543907, LR: 1.00e-05, Time: 9.67s


Training Epoch 14:   4%|▍         | 1/26 [00:01<00:27,  1.09s/it, Loss=5.5402, Avg Loss=5.5402]

Epoch 14, Batch 0, Loss: 5.5402


Training Epoch 14:  42%|████▏     | 11/26 [00:05<00:06,  2.18it/s, Loss=5.5413, Avg Loss=5.5417]

Epoch 14, Batch 10, Loss: 5.5413


Training Epoch 14:  81%|████████  | 21/26 [00:07<00:01,  3.50it/s, Loss=5.5418, Avg Loss=5.5427]

Epoch 14, Batch 20, Loss: 5.5418


Training Epoch 14: 100%|██████████| 26/26 [00:09<00:00,  2.70it/s, Loss=5.5457, Avg Loss=5.5429]


Epoch 14/100 - Train Loss: 5.542893, LR: 1.00e-05, Time: 9.65s


Training Epoch 15:   4%|▍         | 1/26 [00:00<00:24,  1.03it/s, Loss=5.5448, Avg Loss=5.5448]

Epoch 15, Batch 0, Loss: 5.5448


Training Epoch 15:  42%|████▏     | 11/26 [00:04<00:04,  3.14it/s, Loss=5.5394, Avg Loss=5.5437]

Epoch 15, Batch 10, Loss: 5.5394


Training Epoch 15:  81%|████████  | 21/26 [00:07<00:01,  3.43it/s, Loss=5.5363, Avg Loss=5.5434]

Epoch 15, Batch 20, Loss: 5.5363


Training Epoch 15: 100%|██████████| 26/26 [00:08<00:00,  3.00it/s, Loss=5.5456, Avg Loss=5.5432]
Evaluating Epoch 15: 100%|██████████| 7/7 [00:01<00:00,  4.88it/s, Val Loss=5.2214]



Epoch 15/100
Train Loss: 5.543209
Val Loss: 5.501758
Val mean similarity: -0.0106
Learning Rate: 1.00e-05
Epoch Time: 10.11s, Total Time: 137.87s
------------------------------------------------------------


Training Epoch 16:   4%|▍         | 1/26 [00:00<00:22,  1.11it/s, Loss=5.5460, Avg Loss=5.5460]

Epoch 16, Batch 0, Loss: 5.5460


Training Epoch 16:  42%|████▏     | 11/26 [00:03<00:04,  3.15it/s, Loss=5.5433, Avg Loss=5.5438]

Epoch 16, Batch 10, Loss: 5.5433


Training Epoch 16:  81%|████████  | 21/26 [00:07<00:01,  3.46it/s, Loss=5.5325, Avg Loss=5.5425]

Epoch 16, Batch 20, Loss: 5.5325


Training Epoch 16: 100%|██████████| 26/26 [00:08<00:00,  3.01it/s, Loss=5.5436, Avg Loss=5.5424]


Epoch 16/100 - Train Loss: 5.542415, LR: 1.00e-05, Time: 8.65s


Training Epoch 17:   4%|▍         | 1/26 [00:00<00:18,  1.37it/s, Loss=5.5438, Avg Loss=5.5438]

Epoch 17, Batch 0, Loss: 5.5438


Training Epoch 17:  42%|████▏     | 11/26 [00:03<00:04,  3.15it/s, Loss=5.5498, Avg Loss=5.5420]

Epoch 17, Batch 10, Loss: 5.5498


Training Epoch 17:  81%|████████  | 21/26 [00:06<00:01,  3.00it/s, Loss=5.5482, Avg Loss=5.5428]

Epoch 17, Batch 20, Loss: 5.5482


Training Epoch 17: 100%|██████████| 26/26 [00:08<00:00,  3.04it/s, Loss=5.5398, Avg Loss=5.5427]


Epoch 17/100 - Train Loss: 5.542673, LR: 1.00e-05, Time: 8.56s


Training Epoch 18:   4%|▍         | 1/26 [00:00<00:22,  1.11it/s, Loss=5.5472, Avg Loss=5.5472]

Epoch 18, Batch 0, Loss: 5.5472


Training Epoch 18:  42%|████▏     | 11/26 [00:04<00:05,  2.88it/s, Loss=5.5301, Avg Loss=5.5409]

Epoch 18, Batch 10, Loss: 5.5301


Training Epoch 18:  81%|████████  | 21/26 [00:07<00:01,  2.92it/s, Loss=5.5469, Avg Loss=5.5415]

Epoch 18, Batch 20, Loss: 5.5469


Training Epoch 18: 100%|██████████| 26/26 [00:09<00:00,  2.73it/s, Loss=5.5414, Avg Loss=5.5421]


Epoch 18/100 - Train Loss: 5.542150, LR: 1.00e-05, Time: 9.54s


Training Epoch 19:   4%|▍         | 1/26 [00:00<00:21,  1.14it/s, Loss=5.5346, Avg Loss=5.5346]

Epoch 19, Batch 0, Loss: 5.5346


Training Epoch 19:  42%|████▏     | 11/26 [00:03<00:04,  3.35it/s, Loss=5.5403, Avg Loss=5.5402]

Epoch 19, Batch 10, Loss: 5.5403


Training Epoch 19:  81%|████████  | 21/26 [00:07<00:01,  3.05it/s, Loss=5.5468, Avg Loss=5.5421]

Epoch 19, Batch 20, Loss: 5.5468


Training Epoch 19: 100%|██████████| 26/26 [00:08<00:00,  2.96it/s, Loss=5.5393, Avg Loss=5.5423]


Epoch 19/100 - Train Loss: 5.542293, LR: 1.00e-05, Time: 8.79s


Training Epoch 20:   4%|▍         | 1/26 [00:01<00:27,  1.08s/it, Loss=5.5479, Avg Loss=5.5479]

Epoch 20, Batch 0, Loss: 5.5479


Training Epoch 20:  42%|████▏     | 11/26 [00:04<00:04,  3.53it/s, Loss=5.5375, Avg Loss=5.5405]

Epoch 20, Batch 10, Loss: 5.5375


Training Epoch 20:  81%|████████  | 21/26 [00:07<00:01,  3.46it/s, Loss=5.5368, Avg Loss=5.5401]

Epoch 20, Batch 20, Loss: 5.5368


Training Epoch 20: 100%|██████████| 26/26 [00:08<00:00,  2.95it/s, Loss=5.5422, Avg Loss=5.5403]
Evaluating Epoch 20: 100%|██████████| 7/7 [00:01<00:00,  4.77it/s, Val Loss=5.2296]



Epoch 20/100
Train Loss: 5.540296
Val Loss: 5.505402
Val mean similarity: 0.0031
Learning Rate: 1.00e-05
Epoch Time: 10.30s, Total Time: 183.72s
------------------------------------------------------------


Training Epoch 21:   4%|▍         | 1/26 [00:00<00:23,  1.08it/s, Loss=5.5304, Avg Loss=5.5304]

Epoch 21, Batch 0, Loss: 5.5304


Training Epoch 21:  42%|████▏     | 11/26 [00:04<00:04,  3.17it/s, Loss=5.5399, Avg Loss=5.5394]

Epoch 21, Batch 10, Loss: 5.5399


Training Epoch 21:  81%|████████  | 21/26 [00:06<00:01,  3.10it/s, Loss=5.5304, Avg Loss=5.5408]

Epoch 21, Batch 20, Loss: 5.5304


Training Epoch 21: 100%|██████████| 26/26 [00:08<00:00,  2.89it/s, Loss=5.5427, Avg Loss=5.5416]


Epoch 21/100 - Train Loss: 5.541597, LR: 1.00e-05, Time: 8.99s


Training Epoch 22:   4%|▍         | 1/26 [00:01<00:43,  1.72s/it, Loss=5.5344, Avg Loss=5.5344]

Epoch 22, Batch 0, Loss: 5.5344


Training Epoch 22:  42%|████▏     | 11/26 [00:04<00:04,  3.07it/s, Loss=5.5448, Avg Loss=5.5410]

Epoch 22, Batch 10, Loss: 5.5448


Training Epoch 22:  81%|████████  | 21/26 [00:07<00:01,  3.43it/s, Loss=5.5436, Avg Loss=5.5415]

Epoch 22, Batch 20, Loss: 5.5436


Training Epoch 22: 100%|██████████| 26/26 [00:09<00:00,  2.76it/s, Loss=5.5396, Avg Loss=5.5416]


Epoch 22/100 - Train Loss: 5.541568, LR: 1.00e-05, Time: 9.44s


Training Epoch 23:   4%|▍         | 1/26 [00:01<00:26,  1.05s/it, Loss=5.5431, Avg Loss=5.5431]

Epoch 23, Batch 0, Loss: 5.5431


Training Epoch 23:  42%|████▏     | 11/26 [00:03<00:04,  3.20it/s, Loss=5.5388, Avg Loss=5.5387]

Epoch 23, Batch 10, Loss: 5.5388


Training Epoch 23:  81%|████████  | 21/26 [00:07<00:01,  2.98it/s, Loss=5.5373, Avg Loss=5.5399]

Epoch 23, Batch 20, Loss: 5.5373


Training Epoch 23: 100%|██████████| 26/26 [00:08<00:00,  2.92it/s, Loss=5.5340, Avg Loss=5.5403]


Epoch 23/100 - Train Loss: 5.540270, LR: 1.00e-05, Time: 8.91s


Training Epoch 24:   4%|▍         | 1/26 [00:00<00:22,  1.13it/s, Loss=5.5415, Avg Loss=5.5415]

Epoch 24, Batch 0, Loss: 5.5415


Training Epoch 24:  42%|████▏     | 11/26 [00:03<00:04,  3.57it/s, Loss=5.5411, Avg Loss=5.5392]

Epoch 24, Batch 10, Loss: 5.5411


Training Epoch 24:  81%|████████  | 21/26 [00:06<00:01,  3.20it/s, Loss=5.5484, Avg Loss=5.5414]

Epoch 24, Batch 20, Loss: 5.5484


Training Epoch 24: 100%|██████████| 26/26 [00:08<00:00,  3.07it/s, Loss=5.5349, Avg Loss=5.5411]


Epoch 24/100 - Train Loss: 5.541052, LR: 1.00e-05, Time: 8.47s


Training Epoch 25:   4%|▍         | 1/26 [00:00<00:22,  1.11it/s, Loss=5.5499, Avg Loss=5.5499]

Epoch 25, Batch 0, Loss: 5.5499


Training Epoch 25:  42%|████▏     | 11/26 [00:03<00:04,  3.17it/s, Loss=5.5451, Avg Loss=5.5414]

Epoch 25, Batch 10, Loss: 5.5451


Training Epoch 25:  81%|████████  | 21/26 [00:06<00:01,  2.79it/s, Loss=5.5462, Avg Loss=5.5389]

Epoch 25, Batch 20, Loss: 5.5462


Training Epoch 25: 100%|██████████| 26/26 [00:09<00:00,  2.70it/s, Loss=5.5377, Avg Loss=5.5395]
Evaluating Epoch 25: 100%|██████████| 7/7 [00:01<00:00,  4.71it/s, Val Loss=5.2280]



Epoch 25/100
Train Loss: 5.539466
Val Loss: 5.505472
Val mean similarity: 0.0109
Learning Rate: 1.00e-05
Epoch Time: 11.13s, Total Time: 231.53s
------------------------------------------------------------


Training Epoch 26:   4%|▍         | 1/26 [00:01<00:25,  1.00s/it, Loss=5.5380, Avg Loss=5.5380]

Epoch 26, Batch 0, Loss: 5.5380


Training Epoch 26:  42%|████▏     | 11/26 [00:03<00:04,  3.38it/s, Loss=5.5353, Avg Loss=5.5364]

Epoch 26, Batch 10, Loss: 5.5353


Training Epoch 26:  81%|████████  | 21/26 [00:07<00:01,  3.04it/s, Loss=5.5347, Avg Loss=5.5379]

Epoch 26, Batch 20, Loss: 5.5347


Training Epoch 26: 100%|██████████| 26/26 [00:08<00:00,  2.97it/s, Loss=5.5330, Avg Loss=5.5382]


Epoch 26/100 - Train Loss: 5.538245, LR: 5.00e-06, Time: 8.76s


Training Epoch 27:   4%|▍         | 1/26 [00:00<00:21,  1.18it/s, Loss=5.5384, Avg Loss=5.5384]

Epoch 27, Batch 0, Loss: 5.5384


Training Epoch 27:  42%|████▏     | 11/26 [00:03<00:04,  3.63it/s, Loss=5.5415, Avg Loss=5.5365]

Epoch 27, Batch 10, Loss: 5.5415


Training Epoch 27:  81%|████████  | 21/26 [00:06<00:01,  3.30it/s, Loss=5.5406, Avg Loss=5.5370]

Epoch 27, Batch 20, Loss: 5.5406


Training Epoch 27: 100%|██████████| 26/26 [00:08<00:00,  2.99it/s, Loss=5.5358, Avg Loss=5.5370]


Epoch 27/100 - Train Loss: 5.537008, LR: 5.00e-06, Time: 8.69s


Training Epoch 28:   4%|▍         | 1/26 [00:01<00:27,  1.12s/it, Loss=5.5446, Avg Loss=5.5446]

Epoch 28, Batch 0, Loss: 5.5446


Training Epoch 28:  42%|████▏     | 11/26 [00:04<00:04,  3.37it/s, Loss=5.5414, Avg Loss=5.5379]

Epoch 28, Batch 10, Loss: 5.5414


Training Epoch 28:  81%|████████  | 21/26 [00:07<00:01,  3.06it/s, Loss=5.5331, Avg Loss=5.5392]

Epoch 28, Batch 20, Loss: 5.5331


Training Epoch 28: 100%|██████████| 26/26 [00:08<00:00,  2.92it/s, Loss=5.5403, Avg Loss=5.5385]


Epoch 28/100 - Train Loss: 5.538472, LR: 5.00e-06, Time: 8.91s


Training Epoch 29:   4%|▍         | 1/26 [00:00<00:22,  1.13it/s, Loss=5.5400, Avg Loss=5.5400]

Epoch 29, Batch 0, Loss: 5.5400


Training Epoch 29:  42%|████▏     | 11/26 [00:03<00:04,  3.39it/s, Loss=5.5296, Avg Loss=5.5325]

Epoch 29, Batch 10, Loss: 5.5296


Training Epoch 29:  81%|████████  | 21/26 [00:07<00:02,  2.29it/s, Loss=5.5239, Avg Loss=5.5356]

Epoch 29, Batch 20, Loss: 5.5239


Training Epoch 29: 100%|██████████| 26/26 [00:09<00:00,  2.68it/s, Loss=5.5467, Avg Loss=5.5359]


Epoch 29/100 - Train Loss: 5.535864, LR: 5.00e-06, Time: 9.71s


Training Epoch 30:   4%|▍         | 1/26 [00:01<00:25,  1.02s/it, Loss=5.5441, Avg Loss=5.5441]

Epoch 30, Batch 0, Loss: 5.5441


Training Epoch 30:  42%|████▏     | 11/26 [00:04<00:04,  3.30it/s, Loss=5.5376, Avg Loss=5.5361]

Epoch 30, Batch 10, Loss: 5.5376


Training Epoch 30:  81%|████████  | 21/26 [00:07<00:01,  2.97it/s, Loss=5.5334, Avg Loss=5.5366]

Epoch 30, Batch 20, Loss: 5.5334


Training Epoch 30: 100%|██████████| 26/26 [00:09<00:00,  2.88it/s, Loss=5.5357, Avg Loss=5.5365]
Evaluating Epoch 30: 100%|██████████| 7/7 [00:01<00:00,  4.69it/s, Val Loss=5.2354]



Epoch 30/100
Train Loss: 5.536543
Val Loss: 5.507898
Val mean similarity: 0.0071
Learning Rate: 5.00e-06
Epoch Time: 10.52s, Total Time: 278.13s
------------------------------------------------------------


Training Epoch 31:   4%|▍         | 1/26 [00:00<00:22,  1.13it/s, Loss=5.5225, Avg Loss=5.5225]

Epoch 31, Batch 0, Loss: 5.5225


Training Epoch 31:  42%|████▏     | 11/26 [00:04<00:04,  3.23it/s, Loss=5.5418, Avg Loss=5.5378]

Epoch 31, Batch 10, Loss: 5.5418


Training Epoch 31:  81%|████████  | 21/26 [00:07<00:01,  3.00it/s, Loss=5.5388, Avg Loss=5.5368]

Epoch 31, Batch 20, Loss: 5.5388


Training Epoch 31: 100%|██████████| 26/26 [00:08<00:00,  2.92it/s, Loss=5.5416, Avg Loss=5.5370]


Epoch 31/100 - Train Loss: 5.537034, LR: 5.00e-06, Time: 8.90s


Training Epoch 32:   4%|▍         | 1/26 [00:00<00:24,  1.03it/s, Loss=5.5268, Avg Loss=5.5268]

Epoch 32, Batch 0, Loss: 5.5268


Training Epoch 32:  42%|████▏     | 11/26 [00:03<00:04,  3.52it/s, Loss=5.5432, Avg Loss=5.5377]

Epoch 32, Batch 10, Loss: 5.5432


Training Epoch 32:  81%|████████  | 21/26 [00:07<00:01,  3.27it/s, Loss=5.5422, Avg Loss=5.5365]

Epoch 32, Batch 20, Loss: 5.5422


Training Epoch 32: 100%|██████████| 26/26 [00:08<00:00,  2.96it/s, Loss=5.5410, Avg Loss=5.5359]


Epoch 32/100 - Train Loss: 5.535860, LR: 5.00e-06, Time: 8.80s


Training Epoch 33:   4%|▍         | 1/26 [00:01<00:25,  1.00s/it, Loss=5.5393, Avg Loss=5.5393]

Epoch 33, Batch 0, Loss: 5.5393


Training Epoch 33:  42%|████▏     | 11/26 [00:04<00:05,  2.85it/s, Loss=5.5393, Avg Loss=5.5389]

Epoch 33, Batch 10, Loss: 5.5393


Training Epoch 33:  81%|████████  | 21/26 [00:07<00:01,  3.14it/s, Loss=5.5406, Avg Loss=5.5381]

Epoch 33, Batch 20, Loss: 5.5406


Training Epoch 33: 100%|██████████| 26/26 [00:09<00:00,  2.69it/s, Loss=5.5300, Avg Loss=5.5377]


Epoch 33/100 - Train Loss: 5.537709, LR: 5.00e-06, Time: 9.67s


Training Epoch 34:   4%|▍         | 1/26 [00:00<00:22,  1.10it/s, Loss=5.5359, Avg Loss=5.5359]

Epoch 34, Batch 0, Loss: 5.5359


Training Epoch 34:  42%|████▏     | 11/26 [00:04<00:05,  2.89it/s, Loss=5.5283, Avg Loss=5.5334]

Epoch 34, Batch 10, Loss: 5.5283


Training Epoch 34:  81%|████████  | 21/26 [00:07<00:01,  3.41it/s, Loss=5.5557, Avg Loss=5.5366]

Epoch 34, Batch 20, Loss: 5.5557


Training Epoch 34: 100%|██████████| 26/26 [00:08<00:00,  2.95it/s, Loss=5.5400, Avg Loss=5.5373]


Epoch 34/100 - Train Loss: 5.537255, LR: 5.00e-06, Time: 8.81s


Training Epoch 35:   4%|▍         | 1/26 [00:00<00:19,  1.28it/s, Loss=5.5295, Avg Loss=5.5295]

Epoch 35, Batch 0, Loss: 5.5295


Training Epoch 35:  42%|████▏     | 11/26 [00:03<00:04,  3.43it/s, Loss=5.5328, Avg Loss=5.5333]

Epoch 35, Batch 10, Loss: 5.5328


Training Epoch 35:  81%|████████  | 21/26 [00:06<00:01,  3.47it/s, Loss=5.5402, Avg Loss=5.5358]

Epoch 35, Batch 20, Loss: 5.5402


Training Epoch 35: 100%|██████████| 26/26 [00:08<00:00,  3.01it/s, Loss=5.5238, Avg Loss=5.5358]
Evaluating Epoch 35: 100%|██████████| 7/7 [00:01<00:00,  4.94it/s, Val Loss=5.2281]



Epoch 35/100
Train Loss: 5.535844
Val Loss: 5.506888
Val mean similarity: -0.0019
Learning Rate: 5.00e-06
Epoch Time: 10.07s, Total Time: 325.23s
------------------------------------------------------------


Training Epoch 36:   4%|▍         | 1/26 [00:00<00:19,  1.31it/s, Loss=5.5228, Avg Loss=5.5228]

Epoch 36, Batch 0, Loss: 5.5228


Training Epoch 36:  42%|████▏     | 11/26 [00:03<00:05,  2.94it/s, Loss=5.5417, Avg Loss=5.5350]

Epoch 36, Batch 10, Loss: 5.5417


Training Epoch 36:  81%|████████  | 21/26 [00:06<00:01,  3.29it/s, Loss=5.5210, Avg Loss=5.5352]

Epoch 36, Batch 20, Loss: 5.5210


Training Epoch 36: 100%|██████████| 26/26 [00:09<00:00,  2.88it/s, Loss=5.5409, Avg Loss=5.5353]


Epoch 36/100 - Train Loss: 5.535338, LR: 5.00e-06, Time: 9.06s


Training Epoch 37:   4%|▍         | 1/26 [00:01<00:47,  1.88s/it, Loss=5.5331, Avg Loss=5.5331]

Epoch 37, Batch 0, Loss: 5.5331


Training Epoch 37:  42%|████▏     | 11/26 [00:04<00:05,  2.91it/s, Loss=5.5512, Avg Loss=5.5339]

Epoch 37, Batch 10, Loss: 5.5512


Training Epoch 37:  81%|████████  | 21/26 [00:07<00:01,  3.22it/s, Loss=5.5247, Avg Loss=5.5363]

Epoch 37, Batch 20, Loss: 5.5247


Training Epoch 37: 100%|██████████| 26/26 [00:09<00:00,  2.70it/s, Loss=5.5234, Avg Loss=5.5359]


Epoch 37/100 - Train Loss: 5.535950, LR: 5.00e-06, Time: 9.64s


Training Epoch 38:   4%|▍         | 1/26 [00:00<00:23,  1.08it/s, Loss=5.5311, Avg Loss=5.5311]

Epoch 38, Batch 0, Loss: 5.5311


Training Epoch 38:  42%|████▏     | 11/26 [00:03<00:04,  3.07it/s, Loss=5.5375, Avg Loss=5.5371]

Epoch 38, Batch 10, Loss: 5.5375


Training Epoch 38:  81%|████████  | 21/26 [00:06<00:01,  3.01it/s, Loss=5.5276, Avg Loss=5.5372]

Epoch 38, Batch 20, Loss: 5.5276


Training Epoch 38: 100%|██████████| 26/26 [00:08<00:00,  3.01it/s, Loss=5.5179, Avg Loss=5.5368]


Epoch 38/100 - Train Loss: 5.536771, LR: 5.00e-06, Time: 8.64s


Training Epoch 39:   4%|▍         | 1/26 [00:00<00:20,  1.24it/s, Loss=5.5262, Avg Loss=5.5262]

Epoch 39, Batch 0, Loss: 5.5262


Training Epoch 39:  42%|████▏     | 11/26 [00:03<00:04,  3.36it/s, Loss=5.5419, Avg Loss=5.5331]

Epoch 39, Batch 10, Loss: 5.5419


Training Epoch 39:  81%|████████  | 21/26 [00:06<00:01,  3.59it/s, Loss=5.5478, Avg Loss=5.5342]

Epoch 39, Batch 20, Loss: 5.5478


Training Epoch 39: 100%|██████████| 26/26 [00:08<00:00,  3.05it/s, Loss=5.5384, Avg Loss=5.5341]


Epoch 39/100 - Train Loss: 5.534134, LR: 5.00e-06, Time: 8.52s


Training Epoch 40:   4%|▍         | 1/26 [00:00<00:21,  1.19it/s, Loss=5.5173, Avg Loss=5.5173]

Epoch 40, Batch 0, Loss: 5.5173


Training Epoch 40:  42%|████▏     | 11/26 [00:03<00:04,  3.34it/s, Loss=5.5362, Avg Loss=5.5332]

Epoch 40, Batch 10, Loss: 5.5362


Training Epoch 40:  81%|████████  | 21/26 [00:07<00:01,  2.55it/s, Loss=5.5304, Avg Loss=5.5339]

Epoch 40, Batch 20, Loss: 5.5304


Training Epoch 40: 100%|██████████| 26/26 [00:09<00:00,  2.66it/s, Loss=5.5367, Avg Loss=5.5339]
Evaluating Epoch 40: 100%|██████████| 7/7 [00:01<00:00,  4.73it/s, Val Loss=5.2279]



Epoch 40/100
Train Loss: 5.533920
Val Loss: 5.509496
Val mean similarity: 0.0002
Learning Rate: 5.00e-06
Epoch Time: 11.26s, Total Time: 372.40s
------------------------------------------------------------


Training Epoch 41:   4%|▍         | 1/26 [00:00<00:20,  1.21it/s, Loss=5.5445, Avg Loss=5.5445]

Epoch 41, Batch 0, Loss: 5.5445


Training Epoch 41:  42%|████▏     | 11/26 [00:04<00:04,  3.42it/s, Loss=5.5199, Avg Loss=5.5293]

Epoch 41, Batch 10, Loss: 5.5199


Training Epoch 41:  81%|████████  | 21/26 [00:07<00:01,  3.50it/s, Loss=5.5357, Avg Loss=5.5335]

Epoch 41, Batch 20, Loss: 5.5357


Training Epoch 41: 100%|██████████| 26/26 [00:08<00:00,  3.02it/s, Loss=5.5467, Avg Loss=5.5348]


Epoch 41/100 - Train Loss: 5.534832, LR: 5.00e-06, Time: 8.62s


Training Epoch 42:   4%|▍         | 1/26 [00:01<00:27,  1.08s/it, Loss=5.5325, Avg Loss=5.5325]

Epoch 42, Batch 0, Loss: 5.5325


Training Epoch 42:  42%|████▏     | 11/26 [00:04<00:04,  3.53it/s, Loss=5.5371, Avg Loss=5.5348]

Epoch 42, Batch 10, Loss: 5.5371


Training Epoch 42:  81%|████████  | 21/26 [00:07<00:01,  3.64it/s, Loss=5.5218, Avg Loss=5.5349]

Epoch 42, Batch 20, Loss: 5.5218


Training Epoch 42: 100%|██████████| 26/26 [00:08<00:00,  2.96it/s, Loss=5.5420, Avg Loss=5.5346]


Epoch 42/100 - Train Loss: 5.534605, LR: 5.00e-06, Time: 8.78s


Training Epoch 43:   4%|▍         | 1/26 [00:00<00:20,  1.20it/s, Loss=5.5368, Avg Loss=5.5368]

Epoch 43, Batch 0, Loss: 5.5368


Training Epoch 43:  42%|████▏     | 11/26 [00:04<00:05,  2.95it/s, Loss=5.5282, Avg Loss=5.5347]

Epoch 43, Batch 10, Loss: 5.5282


Training Epoch 43:  81%|████████  | 21/26 [00:06<00:01,  3.45it/s, Loss=5.5350, Avg Loss=5.5342]

Epoch 43, Batch 20, Loss: 5.5350


Training Epoch 43: 100%|██████████| 26/26 [00:08<00:00,  3.03it/s, Loss=5.5430, Avg Loss=5.5338]


Epoch 43/100 - Train Loss: 5.533814, LR: 5.00e-06, Time: 8.58s


Training Epoch 44:   4%|▍         | 1/26 [00:00<00:22,  1.12it/s, Loss=5.5290, Avg Loss=5.5290]

Epoch 44, Batch 0, Loss: 5.5290


Training Epoch 44:  42%|████▏     | 11/26 [00:04<00:05,  2.66it/s, Loss=5.5300, Avg Loss=5.5312]

Epoch 44, Batch 10, Loss: 5.5300


Training Epoch 44:  81%|████████  | 21/26 [00:08<00:01,  3.33it/s, Loss=5.5462, Avg Loss=5.5326]

Epoch 44, Batch 20, Loss: 5.5462


Training Epoch 44: 100%|██████████| 26/26 [00:09<00:00,  2.65it/s, Loss=5.5332, Avg Loss=5.5335]


Epoch 44/100 - Train Loss: 5.533471, LR: 5.00e-06, Time: 9.82s


Training Epoch 45:   4%|▍         | 1/26 [00:00<00:20,  1.24it/s, Loss=5.5330, Avg Loss=5.5330]

Epoch 45, Batch 0, Loss: 5.5330


Training Epoch 45:  42%|████▏     | 11/26 [00:04<00:04,  3.06it/s, Loss=5.5342, Avg Loss=5.5349]

Epoch 45, Batch 10, Loss: 5.5342


Training Epoch 45:  81%|████████  | 21/26 [00:06<00:01,  3.13it/s, Loss=5.5324, Avg Loss=5.5359]

Epoch 45, Batch 20, Loss: 5.5324


Training Epoch 45: 100%|██████████| 26/26 [00:08<00:00,  3.07it/s, Loss=5.5280, Avg Loss=5.5348]
Evaluating Epoch 45: 100%|██████████| 7/7 [00:01<00:00,  5.05it/s, Val Loss=5.2278]



Epoch 45/100
Train Loss: 5.534780
Val Loss: 5.508954
Val mean similarity: -0.0040
Learning Rate: 5.00e-06
Epoch Time: 9.87s, Total Time: 418.94s
------------------------------------------------------------


Training Epoch 46:   4%|▍         | 1/26 [00:00<00:19,  1.26it/s, Loss=5.5287, Avg Loss=5.5287]

Epoch 46, Batch 0, Loss: 5.5287


Training Epoch 46:  42%|████▏     | 11/26 [00:04<00:04,  3.12it/s, Loss=5.5324, Avg Loss=5.5323]

Epoch 46, Batch 10, Loss: 5.5324


Training Epoch 46:  81%|████████  | 21/26 [00:06<00:01,  3.42it/s, Loss=5.5397, Avg Loss=5.5338]

Epoch 46, Batch 20, Loss: 5.5397


Training Epoch 46: 100%|██████████| 26/26 [00:08<00:00,  3.04it/s, Loss=5.5346, Avg Loss=5.5326]


Epoch 46/100 - Train Loss: 5.532583, LR: 2.50e-06, Time: 8.56s


Training Epoch 47:   4%|▍         | 1/26 [00:00<00:20,  1.22it/s, Loss=5.5328, Avg Loss=5.5328]

Epoch 47, Batch 0, Loss: 5.5328


Training Epoch 47:  42%|████▏     | 11/26 [00:03<00:04,  3.66it/s, Loss=5.5143, Avg Loss=5.5302]

Epoch 47, Batch 10, Loss: 5.5143


Training Epoch 47:  81%|████████  | 21/26 [00:06<00:01,  3.02it/s, Loss=5.5266, Avg Loss=5.5317]

Epoch 47, Batch 20, Loss: 5.5266


Training Epoch 47: 100%|██████████| 26/26 [00:08<00:00,  3.00it/s, Loss=5.5299, Avg Loss=5.5317]


Epoch 47/100 - Train Loss: 5.531744, LR: 2.50e-06, Time: 8.67s


Training Epoch 48:   4%|▍         | 1/26 [00:00<00:21,  1.15it/s, Loss=5.5407, Avg Loss=5.5407]

Epoch 48, Batch 0, Loss: 5.5407


Training Epoch 48:  42%|████▏     | 11/26 [00:07<00:07,  2.07it/s, Loss=5.5295, Avg Loss=5.5363]

Epoch 48, Batch 10, Loss: 5.5295


Training Epoch 48:  81%|████████  | 21/26 [00:10<00:01,  3.55it/s, Loss=5.5331, Avg Loss=5.5347]

Epoch 48, Batch 20, Loss: 5.5331


Training Epoch 48: 100%|██████████| 26/26 [00:12<00:00,  2.08it/s, Loss=5.5205, Avg Loss=5.5338]


Epoch 48/100 - Train Loss: 5.533798, LR: 2.50e-06, Time: 12.52s


Training Epoch 49:   4%|▍         | 1/26 [00:00<00:22,  1.11it/s, Loss=5.5294, Avg Loss=5.5294]

Epoch 49, Batch 0, Loss: 5.5294


Training Epoch 49:  42%|████▏     | 11/26 [00:03<00:04,  3.38it/s, Loss=5.5187, Avg Loss=5.5332]

Epoch 49, Batch 10, Loss: 5.5187


Training Epoch 49:  81%|████████  | 21/26 [00:06<00:01,  3.52it/s, Loss=5.5306, Avg Loss=5.5326]

Epoch 49, Batch 20, Loss: 5.5306


Training Epoch 49: 100%|██████████| 26/26 [00:08<00:00,  3.03it/s, Loss=5.5346, Avg Loss=5.5328]


Epoch 49/100 - Train Loss: 5.532756, LR: 2.50e-06, Time: 8.59s


Training Epoch 50:   4%|▍         | 1/26 [00:01<00:25,  1.04s/it, Loss=5.5234, Avg Loss=5.5234]

Epoch 50, Batch 0, Loss: 5.5234


Training Epoch 50:  42%|████▏     | 11/26 [00:04<00:05,  2.96it/s, Loss=5.5268, Avg Loss=5.5304]

Epoch 50, Batch 10, Loss: 5.5268


Training Epoch 50:  81%|████████  | 21/26 [00:07<00:01,  3.47it/s, Loss=5.5389, Avg Loss=5.5310]

Epoch 50, Batch 20, Loss: 5.5389


Training Epoch 50: 100%|██████████| 26/26 [00:08<00:00,  3.02it/s, Loss=5.5337, Avg Loss=5.5306]
Evaluating Epoch 50: 100%|██████████| 7/7 [00:01<00:00,  4.68it/s, Val Loss=5.2250]



Epoch 50/100
Train Loss: 5.530555
Val Loss: 5.510104
Val mean similarity: 0.0118
Learning Rate: 2.50e-06
Epoch Time: 10.12s, Total Time: 467.40s
------------------------------------------------------------


Training Epoch 51:   4%|▍         | 1/26 [00:00<00:24,  1.04it/s, Loss=5.5326, Avg Loss=5.5326]

Epoch 51, Batch 0, Loss: 5.5326


Training Epoch 51:  42%|████▏     | 11/26 [00:03<00:04,  3.39it/s, Loss=5.5304, Avg Loss=5.5329]

Epoch 51, Batch 10, Loss: 5.5304


Training Epoch 51:  81%|████████  | 21/26 [00:07<00:01,  2.88it/s, Loss=5.5269, Avg Loss=5.5319]

Epoch 51, Batch 20, Loss: 5.5269


Training Epoch 51: 100%|██████████| 26/26 [00:10<00:00,  2.58it/s, Loss=5.5222, Avg Loss=5.5319]


Epoch 51/100 - Train Loss: 5.531884, LR: 2.50e-06, Time: 10.09s


Training Epoch 52:   4%|▍         | 1/26 [00:00<00:24,  1.03it/s, Loss=5.5192, Avg Loss=5.5192]

Epoch 52, Batch 0, Loss: 5.5192


Training Epoch 52:  42%|████▏     | 11/26 [00:04<00:04,  3.25it/s, Loss=5.5340, Avg Loss=5.5342]

Epoch 52, Batch 10, Loss: 5.5340


Training Epoch 52:  81%|████████  | 21/26 [00:06<00:01,  3.72it/s, Loss=5.5210, Avg Loss=5.5314]

Epoch 52, Batch 20, Loss: 5.5210


Training Epoch 52: 100%|██████████| 26/26 [00:08<00:00,  3.04it/s, Loss=5.5313, Avg Loss=5.5321]


Epoch 52/100 - Train Loss: 5.532111, LR: 2.50e-06, Time: 8.54s


Training Epoch 53:   4%|▍         | 1/26 [00:00<00:21,  1.17it/s, Loss=5.5246, Avg Loss=5.5246]

Epoch 53, Batch 0, Loss: 5.5246


Training Epoch 53:  42%|████▏     | 11/26 [00:03<00:04,  3.23it/s, Loss=5.5280, Avg Loss=5.5285]

Epoch 53, Batch 10, Loss: 5.5280


Training Epoch 53:  81%|████████  | 21/26 [00:06<00:01,  3.44it/s, Loss=5.5424, Avg Loss=5.5297]

Epoch 53, Batch 20, Loss: 5.5424


Training Epoch 53: 100%|██████████| 26/26 [00:08<00:00,  3.10it/s, Loss=5.5500, Avg Loss=5.5302]


Epoch 53/100 - Train Loss: 5.530188, LR: 2.50e-06, Time: 8.40s


Training Epoch 54:   4%|▍         | 1/26 [00:00<00:21,  1.14it/s, Loss=5.5277, Avg Loss=5.5277]

Epoch 54, Batch 0, Loss: 5.5277


Training Epoch 54:  42%|████▏     | 11/26 [00:04<00:05,  2.86it/s, Loss=5.5308, Avg Loss=5.5296]

Epoch 54, Batch 10, Loss: 5.5308


Training Epoch 54:  81%|████████  | 21/26 [00:07<00:01,  3.65it/s, Loss=5.5157, Avg Loss=5.5293]

Epoch 54, Batch 20, Loss: 5.5157


Training Epoch 54: 100%|██████████| 26/26 [00:08<00:00,  2.95it/s, Loss=5.5220, Avg Loss=5.5302]


Epoch 54/100 - Train Loss: 5.530225, LR: 2.50e-06, Time: 8.81s


Training Epoch 55:   4%|▍         | 1/26 [00:00<00:22,  1.13it/s, Loss=5.5199, Avg Loss=5.5199]

Epoch 55, Batch 0, Loss: 5.5199


Training Epoch 55:  42%|████▏     | 11/26 [00:04<00:04,  3.02it/s, Loss=5.5286, Avg Loss=5.5330]

Epoch 55, Batch 10, Loss: 5.5286


Training Epoch 55:  81%|████████  | 21/26 [00:07<00:01,  3.47it/s, Loss=5.5298, Avg Loss=5.5333]

Epoch 55, Batch 20, Loss: 5.5298


Training Epoch 55: 100%|██████████| 26/26 [00:09<00:00,  2.66it/s, Loss=5.5325, Avg Loss=5.5322]
Evaluating Epoch 55: 100%|██████████| 7/7 [00:01<00:00,  4.68it/s, Val Loss=5.2248]



Epoch 55/100
Train Loss: 5.532163
Val Loss: 5.511070
Val mean similarity: 0.0050
Learning Rate: 2.50e-06
Epoch Time: 11.27s, Total Time: 515.39s
------------------------------------------------------------


Training Epoch 56:   4%|▍         | 1/26 [00:00<00:20,  1.21it/s, Loss=5.5460, Avg Loss=5.5460]

Epoch 56, Batch 0, Loss: 5.5460


Training Epoch 56:  42%|████▏     | 11/26 [00:03<00:04,  3.07it/s, Loss=5.5467, Avg Loss=5.5302]

Epoch 56, Batch 10, Loss: 5.5467


Training Epoch 56:  81%|████████  | 21/26 [00:06<00:01,  3.38it/s, Loss=5.5247, Avg Loss=5.5319]

Epoch 56, Batch 20, Loss: 5.5247


Training Epoch 56: 100%|██████████| 26/26 [00:08<00:00,  3.08it/s, Loss=5.5411, Avg Loss=5.5321]


Epoch 56/100 - Train Loss: 5.532070, LR: 2.50e-06, Time: 8.44s


Training Epoch 57:   4%|▍         | 1/26 [00:00<00:21,  1.19it/s, Loss=5.5215, Avg Loss=5.5215]

Epoch 57, Batch 0, Loss: 5.5215


Training Epoch 57:  42%|████▏     | 11/26 [00:03<00:04,  3.30it/s, Loss=5.5079, Avg Loss=5.5273]

Epoch 57, Batch 10, Loss: 5.5079


Training Epoch 57:  81%|████████  | 21/26 [00:06<00:01,  3.07it/s, Loss=5.5146, Avg Loss=5.5299]

Epoch 57, Batch 20, Loss: 5.5146


Training Epoch 57: 100%|██████████| 26/26 [00:08<00:00,  3.03it/s, Loss=5.5406, Avg Loss=5.5317]


Epoch 57/100 - Train Loss: 5.531710, LR: 2.50e-06, Time: 8.59s


Training Epoch 58:   4%|▍         | 1/26 [00:01<00:26,  1.08s/it, Loss=5.5280, Avg Loss=5.5280]

Epoch 58, Batch 0, Loss: 5.5280


Training Epoch 58:  42%|████▏     | 11/26 [00:04<00:04,  3.57it/s, Loss=5.5457, Avg Loss=5.5308]

Epoch 58, Batch 10, Loss: 5.5457


Training Epoch 58:  81%|████████  | 21/26 [00:07<00:01,  3.41it/s, Loss=5.5259, Avg Loss=5.5291]

Epoch 58, Batch 20, Loss: 5.5259


Training Epoch 58: 100%|██████████| 26/26 [00:08<00:00,  2.91it/s, Loss=5.5106, Avg Loss=5.5286]


Epoch 58/100 - Train Loss: 5.528578, LR: 2.50e-06, Time: 8.94s


Training Epoch 59:   4%|▍         | 1/26 [00:00<00:20,  1.21it/s, Loss=5.5179, Avg Loss=5.5179]

Epoch 59, Batch 0, Loss: 5.5179


Training Epoch 59:  42%|████▏     | 11/26 [00:04<00:05,  2.82it/s, Loss=5.5257, Avg Loss=5.5286]

Epoch 59, Batch 10, Loss: 5.5257


Training Epoch 59:  81%|████████  | 21/26 [00:08<00:01,  3.25it/s, Loss=5.5506, Avg Loss=5.5327]

Epoch 59, Batch 20, Loss: 5.5506


Training Epoch 59: 100%|██████████| 26/26 [00:09<00:00,  2.68it/s, Loss=5.5228, Avg Loss=5.5322]


Epoch 59/100 - Train Loss: 5.532243, LR: 2.50e-06, Time: 9.72s


Training Epoch 60:   4%|▍         | 1/26 [00:00<00:21,  1.14it/s, Loss=5.5432, Avg Loss=5.5432]

Epoch 60, Batch 0, Loss: 5.5432


Training Epoch 60:  42%|████▏     | 11/26 [00:03<00:04,  3.35it/s, Loss=5.5560, Avg Loss=5.5302]

Epoch 60, Batch 10, Loss: 5.5560


Training Epoch 60:  81%|████████  | 21/26 [00:06<00:01,  3.33it/s, Loss=5.5188, Avg Loss=5.5308]

Epoch 60, Batch 20, Loss: 5.5188


Training Epoch 60: 100%|██████████| 26/26 [00:08<00:00,  3.00it/s, Loss=5.5371, Avg Loss=5.5308]
Evaluating Epoch 60: 100%|██████████| 7/7 [00:01<00:00,  5.05it/s, Val Loss=5.2236]



Epoch 60/100
Train Loss: 5.530762
Val Loss: 5.512265
Val mean similarity: 0.0007
Learning Rate: 2.50e-06
Epoch Time: 10.05s, Total Time: 561.14s
------------------------------------------------------------


Training Epoch 61:   4%|▍         | 1/26 [00:00<00:20,  1.19it/s, Loss=5.5354, Avg Loss=5.5354]

Epoch 61, Batch 0, Loss: 5.5354


Training Epoch 61:  42%|████▏     | 11/26 [00:04<00:04,  3.21it/s, Loss=5.5306, Avg Loss=5.5291]

Epoch 61, Batch 10, Loss: 5.5306


Training Epoch 61:  81%|████████  | 21/26 [00:07<00:01,  3.27it/s, Loss=5.5386, Avg Loss=5.5322]

Epoch 61, Batch 20, Loss: 5.5386


Training Epoch 61: 100%|██████████| 26/26 [00:08<00:00,  3.00it/s, Loss=5.5183, Avg Loss=5.5311]


Epoch 61/100 - Train Loss: 5.531056, LR: 2.50e-06, Time: 8.68s


Training Epoch 62:   4%|▍         | 1/26 [00:01<00:26,  1.08s/it, Loss=5.5277, Avg Loss=5.5277]

Epoch 62, Batch 0, Loss: 5.5277


Training Epoch 62:  42%|████▏     | 11/26 [00:03<00:04,  3.42it/s, Loss=5.5345, Avg Loss=5.5282]

Epoch 62, Batch 10, Loss: 5.5345


Training Epoch 62:  81%|████████  | 21/26 [00:06<00:01,  3.26it/s, Loss=5.5173, Avg Loss=5.5282]

Epoch 62, Batch 20, Loss: 5.5173


Training Epoch 62: 100%|██████████| 26/26 [00:08<00:00,  2.96it/s, Loss=5.5382, Avg Loss=5.5295]


Epoch 62/100 - Train Loss: 5.529541, LR: 2.50e-06, Time: 8.92s


Training Epoch 63:   4%|▍         | 1/26 [00:01<00:25,  1.01s/it, Loss=5.5234, Avg Loss=5.5234]

Epoch 63, Batch 0, Loss: 5.5234


Training Epoch 63:  42%|████▏     | 11/26 [00:05<00:05,  2.75it/s, Loss=5.5253, Avg Loss=5.5273]

Epoch 63, Batch 10, Loss: 5.5253


Training Epoch 63:  81%|████████  | 21/26 [00:08<00:01,  3.59it/s, Loss=5.5242, Avg Loss=5.5289]

Epoch 63, Batch 20, Loss: 5.5242


Training Epoch 63:  85%|████████▍ | 22/26 [00:08<00:01,  2.51it/s, Loss=5.5305, Avg Loss=5.5290]


KeyboardInterrupt: 

In [24]:
# Log final metrics
logger.report_single_value("Total Training Time (hours)", total_training_time)
logger.report_single_value("Best Validation Loss", min(train_history['val_losses']) if train_history['val_losses'] else float('inf'))
logger.report_single_value("Best Cosine Similarity", max(train_history['val_similarities']) if train_history['val_similarities'] else 0.0)
logger.report_single_value("Total Epochs", epochs)

In [31]:
# Cell 0: Emergency GPU cleanup and memory monitoring
import torch
import gc
import os
import psutil

def emergency_gpu_cleanup():
    """Emergency cleanup to free all GPU memory"""
    try:
        # Clear all cached tensors
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            
        # Force garbage collection
        gc.collect()
        
        # Try to reset CUDA context (nuclear option)
        if torch.cuda.is_available():
            try:
                torch.cuda.reset_peak_memory_stats()
                print("CUDA memory stats reset")
            except:
                print("Could not reset CUDA memory stats")
                
    except Exception as e:
        print(f"Error during cleanup: {e}")

def print_gpu_memory_usage():
    """Print comprehensive GPU memory usage"""
    print("=" * 60)
    print("GPU MEMORY ANALYSIS")
    print("=" * 60)
    
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        print(f"CUDA available: {device_count} device(s)")
        
        for i in range(device_count):
            try:
                torch.cuda.set_device(i)
                device_name = torch.cuda.get_device_name(i)
                
                # Memory in bytes
                allocated = torch.cuda.memory_allocated(i)
                reserved = torch.cuda.memory_reserved(i)
                max_allocated = torch.cuda.max_memory_allocated(i)
                max_reserved = torch.cuda.max_memory_reserved(i)
                
                # Get total GPU memory
                total_memory = torch.cuda.get_device_properties(i).total_memory
                
                print(f"\nGPU {i}: {device_name}")
                print("-" * 40)
                print(f"Total Memory:     {total_memory / 1024**3:.2f} GB")
                print(f"Currently Allocated: {allocated / 1024**3:.2f} GB ({allocated/total_memory*100:.1f}%)")
                print(f"Currently Reserved:  {reserved / 1024**3:.2f} GB ({reserved/total_memory*100:.1f}%)")
                print(f"Peak Allocated:   {max_allocated / 1024**3:.2f} GB ({max_allocated/total_memory*100:.1f}%)")
                print(f"Peak Reserved:    {max_reserved / 1024**3:.2f} GB ({max_reserved/total_memory*100:.1f}%)")
                print(f"Free Memory:      {(total_memory - reserved) / 1024**3:.2f} GB")
                
                # Memory fragmentation analysis
                if reserved > 0:
                    fragmentation = (reserved - allocated) / reserved * 100
                    print(f"Memory Fragmentation: {fragmentation:.1f}%")
                
            except Exception as e:
                print(f"Error reading GPU {i} memory: {e}")
                
    elif torch.backends.mps.is_available():
        print("MPS (Metal Performance Shaders) available")
        print("Note: MPS memory monitoring not directly available")
        
        # Get system memory as proxy
        mem = psutil.virtual_memory()
        print(f"System Memory: {mem.total / 1024**3:.2f} GB")
        print(f"Available Memory: {mem.available / 1024**3:.2f} GB")
        
    else:
        print("No GPU acceleration available - using CPU")
        
        # Show CPU memory usage
        mem = psutil.virtual_memory()
        print(f"System Memory: {mem.total / 1024**3:.2f} GB")
        print(f"Available Memory: {mem.available / 1024**3:.2f} GB")
        print(f"Used Memory: {mem.used / 1024**3:.2f} GB ({mem.percent:.1f}%)")

def print_model_memory_usage(model):
    """Calculate and print model memory usage"""
    print("\n" + "=" * 60)
    print("MODEL MEMORY ANALYSIS")
    print("=" * 60)
    
    # Parameter memory
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    
    # Count parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    print(f"Model Parameters:")
    print(f"  Trainable: {trainable_params:,}")
    print(f"  Total: {total_params:,}")
    print(f"  Non-trainable: {total_params - trainable_params:,}")
    
    print(f"\nModel Memory:")
    print(f"  Parameters: {param_size / 1024**2:.2f} MB")
    print(f"  Buffers: {buffer_size / 1024**2:.2f} MB")
    print(f"  Total Model: {(param_size + buffer_size) / 1024**2:.2f} MB")
    
    # Estimate training memory (rough approximation)
    # Forward pass ≈ model size, backward pass ≈ 2x model size, optimizer ≈ 2x model size
    estimated_training_memory = (param_size + buffer_size) * 5  # Conservative estimate
    print(f"  Estimated Training Memory: {estimated_training_memory / 1024**2:.2f} MB")

def monitor_batch_memory(device, batch_size):
    """Monitor memory usage during a training step"""
    if torch.cuda.is_available() and device == 'cuda':
        print(f"\nBatch Memory Analysis (Batch Size: {batch_size}):")
        print("-" * 40)
        
        # Memory before
        mem_before = torch.cuda.memory_allocated() / 1024**2
        
        # Simulate memory usage estimation
        print(f"Memory before batch: {mem_before:.2f} MB")
        
        return mem_before
    return 0

# Run emergency cleanup first
emergency_gpu_cleanup()

# Print initial memory status
print_gpu_memory_usage()

CUDA memory stats reset
GPU MEMORY ANALYSIS
CUDA available: 2 device(s)

GPU 0: NVIDIA A100 80GB PCIe
----------------------------------------
Total Memory:     79.25 GB
Currently Allocated: 0.00 GB (0.0%)
Currently Reserved:  0.00 GB (0.0%)
Peak Allocated:   0.00 GB (0.0%)
Peak Reserved:    0.00 GB (0.0%)
Free Memory:      79.25 GB

GPU 1: NVIDIA A100 80GB PCIe
----------------------------------------
Total Memory:     79.25 GB
Currently Allocated: 0.10 GB (0.1%)
Currently Reserved:  0.26 GB (0.3%)
Peak Allocated:   0.10 GB (0.1%)
Peak Reserved:    0.26 GB (0.3%)
Free Memory:      79.00 GB
Memory Fragmentation: 60.0%


In [32]:
# Close the task
task.close()