In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Cell 1: Import required libraries
import os
import yaml
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
from dotenv import load_dotenv
warnings.filterwarnings('ignore')

# load env variables
load_dotenv('../../../.env')

# Import your model and data loading components
from dataloader.dataset_wrapper import create_wrapper_from_dataframe

print("Libraries imported successfully!")

Libraries imported successfully!


In [None]:
# First install ClearML if not already installed: pip install clearml
import clearml
from clearml import Task, Logger

# Initialize ClearML Task
task = Task.init(project_name='CSMP_thesis_project', task_name='CSMP_traning_experiment_3', reuse_last_task_id=False)
logger = Logger.current_logger()

ClearML Task: created new task id=d961f7865a744c788142c5e566900c75
ClearML results page: https://app.clear.ml/projects/0fec81950d384f0294d2c713df3887db/experiments/d961f7865a744c788142c5e566900c75/output/log


ClearML Monitor: GPU monitoring failed getting GPU reading, switching off GPU monitoring


In [5]:
# Cell 2: Configuration and paths setup
CONFIG_PATH = "../../../configs/config.yaml"
TRAIN_CSV_PATH = "../../../data/traning_and_validation/train_deduplicated.csv"

# Device selection with automatic CUDA device selection based on free memory
if torch.backends.mps.is_available():
    DEVICE = 'mps'
    print("Using MPS (Metal Performance Shaders) for GPU acceleration")
elif torch.cuda.is_available():
    # Find CUDA device with most free memory
    max_free_memory = 0
    best_device = 0
    
    print("Available CUDA devices:")
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)
        torch.cuda.empty_cache()  # Clear cache to get accurate memory info
        
        total_memory = torch.cuda.get_device_properties(i).total_memory
        allocated_memory = torch.cuda.memory_allocated(i)
        free_memory = total_memory - allocated_memory
        
        print(f"  Device {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Total memory: {total_memory / 1024**3:.2f} GB")
        print(f"    Allocated memory: {allocated_memory / 1024**3:.2f} GB")
        print(f"    Free memory: {free_memory / 1024**3:.2f} GB")
        
        if free_memory > max_free_memory:
            max_free_memory = free_memory
            best_device = i
    
    DEVICE = f'cuda:{best_device}'
    print(f"\nSelected device {best_device} with {max_free_memory / 1024**3:.2f} GB free memory")
    print("Using CUDA for GPU acceleration")
else:
    DEVICE = 'cpu'
    print("Using CPU")

print(f"Using device: {DEVICE}")
print(f"Validation data path: {TRAIN_CSV_PATH}")

Available CUDA devices:


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [7]:
# Cell 3: Load configuration
print("Loading configuration...")
config = yaml.load(open(CONFIG_PATH, "r"), Loader=yaml.FullLoader)

print("Configuration loaded:")
print(f"- Batch size: {config.get('batch_size', 64)}")
print(f"- Model config keys: {list(config.get('model', {}).keys())}")
print(f"- Loss config: {config.get('loss', {})}")

Loading configuration...
Configuration loaded:
- Batch size: 256
- Model config keys: []
- Loss config: {'temperature': 0.1, 'use_cosine_similarity': True, 'alpha_weight': 0.75}


In [6]:
task.connect(config)

{'batch_size': 1024,
 'epochs': 100,
 'eval_every_n_epochs': 5,
 'log_every_n_steps': 10,
 'learning_rate': '1e-05',
 'weight_decay': 0.0001,
 'valid_size': 0.2,
 'fp16_precision': True,
 'truncation': True,
 'model_config': {'emb_dim': 512,
  'spec_embed_dim': 512,
  'embed_dim': 256,
  'feat_dim': 512,
  'num_layer': 5,
  'layers': 5,
  'drop_ratio': 0.2,
  'dropout': 0.1,
  'pool': 'mean'},
 'loss': {'temperature': 0.1,
  'use_cosine_similarity': True,
  'alpha_weight': 0.75}}

In [7]:
# Cell 4: Load and explore validation data
print("Loading train data...")
df_train = pd.read_csv(TRAIN_CSV_PATH)

print(f"Validation dataset shape: {df_train.shape}")
print(f"Columns: {list(df_train.columns)}")
print(f"Sample data:")
df_train.head()

Loading train data...
Validation dataset shape: (798444, 10)
Columns: ['peaks_json', 'ion_source', 'compound_source', 'instrument', 'adduct', 'precursor_mz', 'smiles', 'inchikey', 'ion_mode', 'molecular_formula']
Sample data:


Unnamed: 0,peaks_json,ion_source,compound_source,instrument,adduct,precursor_mz,smiles,inchikey,ion_mode,molecular_formula
0,"[[42.014248, 0.10199999999999998], [42.26601, ...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
1,"[[49.01717, 0.155], [49.020023, 0.253], [67.05...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
2,"[[49.017338, 0.242], [49.020237, 0.181], [67.0...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
3,"[[49.01701, 0.144], [49.019947, 0.244], [139.0...",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3
4,"[[49.017166, 0.155], [49.020008, 0.253], [139....",ESI,Crude,Orbitrap,[M+H]+,377.186,CC12CCC(C(=O)N(CNc3cc4c(cc3)c3ccccc3o4)C1=O)C2...,RNKMIWQDRWSWCD-UHFFFAOYSA-N,Positive,C23H24N2O3


In [8]:
df_train_sample = df_train.sample(n=10000,random_state=42).reset_index(drop=True)

In [9]:
# Cell 6: Prepare validation data loader
print("Preparing data loaders")

# Create data wrapper from DataFrame
wrapper = create_wrapper_from_dataframe(
    df=df_train_sample,
    batch_size=config.get('batch_size'),  
    num_workers=8,
    valid_size=config.get('valid_size'),  
    use_ddp=False,
    output_dir="../../../data/train_feature/",
    recompute=True
)

# Get the data loader
train_loader, val_loader = wrapper.get_data_loaders()

Preparing data loaders
Converting DataFrame to compatible files...
Processed 10000 valid spectra out of 10000 total entries.
Create data wrapper
calculating molecular graphs


  6%|▌         | 453/8000 [00:00<00:12, 624.92it/s]

SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure


 11%|█         | 845/8000 [00:01<00:11, 650.06it/s]

SMILES [Cl-].O=C(O)C=1C=CC=CC1C=2C=3C=CC(=CC3OC4=CC(C=CC42)=[N+](CC)CC)N(CC)CC calculation failure


 17%|█▋        | 1363/8000 [00:02<00:10, 641.48it/s]

SMILES [Na+].O=C(CCCCCCCCCCC)CC(O)S(=O)(=O)[O-] calculation failure


 21%|██        | 1684/8000 [00:02<00:10, 627.93it/s]

SMILES [Cl-].O=C1C(=COC2=C1C=C(C(O)=C2C[NH+](C)C)CC)C=3C=CC=4OCCOC4C3 calculation failure


 28%|██▊       | 2272/8000 [00:03<00:08, 646.47it/s]

SMILES [Cl-].O=C1C=2C=C(C(O)=C(C2OC(=C1C=3C=CC=4OCCOC4C3)C)C[NH+](C)C)CCC calculation failure


 37%|███▋      | 2985/8000 [00:04<00:08, 626.46it/s]

SMILES [Cl-].O=C1C2=CC=C(O)C(=C2OC(=C1C=3C=CC=4OCCCOC4C3)C)C[NH+](C)C calculation failure


 40%|███▉      | 3178/8000 [00:05<00:07, 630.96it/s]

SMILES [Na+].O=P([O-])(O)OCC1OC(N2C=NC=3C(=NC=NC32)N)C(O)C1O calculation failure


 45%|████▌     | 3624/8000 [00:05<00:06, 632.03it/s]

SMILES [Cl-].O=C1C2=CC=C(O)C(=C2OC(=C1C=3C=CC=4OCCCOC4C3)C)C[NH+](C)C calculation failure


 52%|█████▏    | 4138/8000 [00:06<00:06, 628.56it/s]

SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure
SMILES [Cl-].OC=1C=C(O)C=2C=C(OC3OC(CO)C(O)C(O)C3O)C(=[O+]C2C1)C=4C=CC(O)=C(O)C4 calculation failure


 57%|█████▋    | 4590/8000 [00:07<00:05, 627.54it/s]

SMILES [Na+].O=C(CCCCCCCCCCC)CC(O)S(=O)(=O)[O-] calculation failure


 62%|██████▏   | 4974/8000 [00:07<00:04, 626.19it/s]

SMILES CCC1=C(C2=NC1=CC3=C(C4=C([N-]3)C(=C5[C@H]([C@@H](C(=N5)C=C6C(=C(C(=C2)[N-]6)C=C)C)C)CCC(=O)OC/C=C(\C)/CCC[C@H](C)CCC[C@H](C)CCCC(C)C)[C@H](C4=O)C(=O)OC)C)C=O.[Mg+2] calculation failure


 65%|██████▍   | 5163/8000 [00:08<00:04, 619.85it/s]

SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure


 70%|███████   | 5603/8000 [00:08<00:03, 615.41it/s]

SMILES C1C(N(C2=C(N1)N=C(NC2=O)N)C=O)CNC3=CC=C(C=C3)C(=O)N[C@@H](CCC(=O)[O-])C(=O)[O-].[Ca+2] calculation failure


 76%|███████▋  | 6112/8000 [00:09<00:03, 616.31it/s]

SMILES [Na+].O=C([O-])C(CC)C1OC(C(=CC=CC2C=CC3CCCC3C2C(=O)C4=CC=CN4)CC)C(C)CC1 calculation failure
SMILES [K+].[K+].O=C([O-])C1OC(OC2C(OC(C(=O)[O-])C(O)C2O)OC3CCC4(C)C5C(=O)C=C6C7CC(C(=O)O)(C)CCC7(C)CCC6(C)C5(C)CCC4C3(C)C)C(O)C(O)C1O calculation failure


 89%|████████▉ | 7137/8000 [00:11<00:01, 627.23it/s]

SMILES [K+].O=C([O-])C12CCC(C(=C)C)C2C3CCC4C5(C)CCC(=O)C(C)(C)C5CCC4(C)C3(C)CC1 calculation failure


100%|██████████| 8000/8000 [00:12<00:00, 628.72it/s]


Calculated 6875 molecular graph-mass spectrometry pairs
calculating molecular graphs


  3%|▎         | 67/2000 [00:00<00:02, 665.67it/s]

SMILES [Na+].O=C(CCCCCCCCCCC)CC(O)S(=O)(=O)[O-] calculation failure


 16%|█▋        | 329/2000 [00:00<00:02, 632.05it/s]

SMILES [Br-].O=C(OC1CC2C3OC3C(C1)[N+]2(C)CCCC)C(C=4C=CC=CC4)CO calculation failure


 33%|███▎      | 658/2000 [00:01<00:02, 641.11it/s]

SMILES [I-].O=C(OCC1=CC[N+]2(C)CCC(O)C12)C(O)(C(O)C)C(C)C calculation failure
SMILES [K+].O=S(=O)([O-])ON=C(SC1OC(CO)C(O)C(O)C1O)CC=C.O calculation failure


 43%|████▎     | 859/2000 [00:01<00:01, 648.82it/s]

SMILES [Cl-].OC=1C=C(O)C=2C=C(OC3OC(CO)C(O)C(O)C3O)C(=[O+]C2C1)C=4C=CC(O)=C(O)C4 calculation failure


 69%|██████▉   | 1383/2000 [00:02<00:00, 619.04it/s]

SMILES [Cl-].O=C1C2=CC=C(O)C(=C2OC(=C1C=3C=CC=4OCCCOC4C3)C)C[NH+](C)C calculation failure
SMILES [K+].O=S(=O)([O-])ON=C(SC1OC(CO)C(O)C(O)C1O)CC=C.O calculation failure


100%|██████████| 2000/2000 [00:03<00:00, 636.89it/s]

Calculated 1697 molecular graph-mass spectrometry pairs





In [8]:
from model import ModelCLR

# Initialize model architecture
model = ModelCLR(**config["model_config"]).to(DEVICE)

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [11]:
model

ModelCLR(
  (Smiles_model): SmilesModel(
    (x_embedding1): Embedding(119, 512)
    (x_embedding2): Embedding(4, 512)
    (x_embedding3): Embedding(8, 512)
    (x_embedding4): Embedding(6, 512)
    (x_embedding5): Embedding(5, 512)
    (gnns): ModuleList(
      (0-4): 5 x GINEConv()
    )
    (batch_norms): ModuleList(
      (0-4): 5 x BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (feat_lin): Linear(in_features=512, out_features=512, bias=True)
    (out_lin): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (MS_model): MSModel(
    (mz_embedder): FourierEmbedder()
    (input_compress): Linear(in_features=513, out_features=512, bias=True)
    (peak_attn_layers): ModuleList(
      (0-4): 5 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in

In [12]:
# Parameter Count
print("PARAMETER ANALYSIS:")
print("-" * 40)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Trainable Parameters: {trainable_params:,}")


PARAMETER ANALYSIS:
----------------------------------------
Trainable Parameters: 22,821,120


In [13]:
# Memory Usage (approximate)
print("MEMORY ANALYSIS:")
print("-" * 40)
param_size = sum(p.numel() * p.element_size() for p in model.parameters())
buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
model_size_mb = (param_size + buffer_size) / 1024 / 1024

print(f"Model Size: {model_size_mb:.2f} MB")
print(f"Parameter Memory: {param_size / 1024 / 1024:.2f} MB")
print(f"Buffer Memory: {buffer_size / 1024 / 1024:.2f} MB")


MEMORY ANALYSIS:
----------------------------------------
Model Size: 87.08 MB
Parameter Memory: 87.06 MB
Buffer Memory: 0.02 MB


In [14]:
from loss.nt_xent import NTXentLoss

# Initialize loss function
temperature = config.get('loss', {}).get('temperature', 0.1)
batch_size = config.get('batch_size', 512)
use_cosine_similarity = config.get('loss', {}).get('use_cosine_similarity', True)
alpha_weight = config.get('loss', {}).get('alpha_weight', 1.0)

criterion = NTXentLoss(
    device=DEVICE, 
    batch_size=batch_size, 
    temperature=temperature, 
    use_cosine_similarity=use_cosine_similarity, 
    alpha_weight=alpha_weight
)

In [15]:
# Cell 15: Training Setup and Optimizer
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

print("Setting up training components...")
OUTPUT_DIR = "../../../models/models_experiments/candidate_v1"

# Initialize optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=float(config.get('learning_rate', 1e-4)),
    weight_decay=float(config.get('weight_decay', 1e-4))
)

# Training configuration
epochs = config.get('epochs', 100)
eval_every_n_epochs = config.get('eval_every_n_epochs', 5)
log_every_n_steps = config.get('log_every_n_steps', 2)

# Add ReduceLROnPlateau scheduler for representation learning
scheduler = ReduceLROnPlateau(
    optimizer, 
    mode='min',           # minimize validation loss
    factor=0.5,           # reduce LR by half
    patience=3,             
    verbose=True,
    min_lr=1e-6,
    threshold=1e-4
)


# Create checkpoint directory
checkpoint_dir = os.path.join(OUTPUT_DIR, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

print(f"Training for {epochs} epochs")
print(f"Optimizer: AdamW with lr={config.get('learning_rate', 5e-6)}")
print(f"Checkpoint directory: {checkpoint_dir}")

Setting up training components...
Training for 100 epochs
Optimizer: AdamW with lr=1e-05
Checkpoint directory: ../../../models/models_experiments/candidate_v1/checkpoints


In [16]:
# Cell 16: Training and Evaluation Functions
def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    total_loss = 0
    num_batches = 0
    batch_losses = []
    
    progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch}")
    
    for batch_idx, (graphs, mzs, intensities, num_peaks) in enumerate(progress_bar):
        
        # Move data to device
        graphs = graphs.to(device)
        mzs = mzs.to(device)
        intensities = intensities.to(device)
        num_peaks = num_peaks.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        mol_features, spec_features = model(graphs, mzs, intensities, num_peaks)
        loss = criterion(mol_features, spec_features)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track loss
        batch_loss = loss.item()
        total_loss += batch_loss
        batch_losses.append(batch_loss)
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{batch_loss:.4f}',
            'Avg Loss': f'{total_loss/num_batches:.4f}'
        })
        
        # Log every n steps
        if batch_idx % log_every_n_steps == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {batch_loss:.4f}")
    
    avg_loss = total_loss / num_batches
    return avg_loss, batch_losses

In [17]:
def evaluate_model(model, val_loader, criterion, device, epoch):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    molecular_features_list = []
    spectral_features_list = []
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Evaluating Epoch {epoch}")
        
        for batch_idx, (graphs, mzs, intensities, num_peaks) in enumerate(progress_bar):
            # Move data to device
            graphs = graphs.to(device)
            mzs = mzs.to(device)
            intensities = intensities.to(device)
            num_peaks = num_peaks.to(device)
            
            # Forward pass
            mol_features, spec_features = model(graphs, mzs, intensities, num_peaks)
            loss = criterion(mol_features, spec_features)
            
            total_loss += loss.item()
            num_batches += 1
            
            # Normalize features before storing 
            mol_features_norm = torch.nn.functional.normalize(mol_features, p=2, dim=1)
            spec_features_norm = torch.nn.functional.normalize(spec_features, p=2, dim=1)
            
            # Store normalized features for retrieval metrics
            molecular_features_list.append(mol_features_norm.cpu().numpy())
            spectral_features_list.append(spec_features_norm.cpu().numpy())
            
            progress_bar.set_postfix({'Val Loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / num_batches
    
    # Compute retrieval metrics
    all_mol_features = np.vstack(molecular_features_list)
    all_spec_features = np.vstack(spectral_features_list)
        
    # Compute cosine similarities
    cosine_similarities = np.sum(all_mol_features * all_spec_features, axis=1)
    mean_similarity = np.mean(cosine_similarities)
        
    return avg_loss, mean_similarity

In [18]:
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, checkpoint_dir):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'config': config
    }
    
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")
    
    # Save best model
    best_checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
    if not os.path.exists(best_checkpoint_path):
        torch.save(checkpoint, best_checkpoint_path)
        print(f"Best model saved: {best_checkpoint_path}")
    else:
        best_checkpoint = torch.load(best_checkpoint_path)
        if val_loss < best_checkpoint['val_loss']:
            torch.save(checkpoint, best_checkpoint_path)
            print(f"New best model saved: {best_checkpoint_path}")

In [19]:
import time

# Training history for plotting
train_history = {
    'epochs': [],
    'train_losses': [],
    'val_losses': [],
    'val_similarities': [],
    'learning_rates': [],  
}

best_val_loss = float('inf')
start_time = time.time()

print("Starting training...")
print("=" * 60)

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    
    # Training phase
    train_loss, batch_losses = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE, epoch
    )
    
    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    
    
    # Log training loss every epoch
    logger.report_scalar("Loss", "Train", iteration=epoch, value=train_loss)
    logger.report_scalar("Learning Rate", "LR", iteration=epoch, value=current_lr)
    
    
    # Validation phase (every n epochs)
    if epoch % eval_every_n_epochs == 0:
        val_loss, val_similarity = evaluate_model(
            model, val_loader, criterion, DEVICE, epoch
        )
        
        # Step scheduler with validation loss
        scheduler.step(val_loss)
        
        # Update current_lr after scheduler step
        current_lr = optimizer.param_groups[0]['lr']
        
        # Log results
        epoch_time = time.time() - epoch_start_time
        total_time = time.time() - start_time
        
        # Log validation metrics to ClearML
        logger.report_scalar("Loss", "Validation", iteration=epoch, value=val_loss)
        logger.report_scalar("Similarity", "Cosine Similarity", iteration=epoch, value=val_similarity)
        
        print(f"\nEpoch {epoch}/{epochs}")
        print(f"Train Loss: {train_loss:.6f}")
        print(f"Val Loss: {val_loss:.6f}")
        print(f"Val mean similarity: {val_similarity:.4f}")
        print(f"Learning Rate: {current_lr:.2e}")
        print(f"Epoch Time: {epoch_time:.2f}s, Total Time: {total_time:.2f}s")
        print("-" * 60)
        
        # Store history
        train_history['epochs'].append(int(epoch))
        train_history['train_losses'].append(float(train_loss))
        train_history['val_losses'].append(float(val_loss))
        train_history['val_similarities'].append(float(val_similarity))
        train_history['learning_rates'].append(float(current_lr))
        
        # Save checkpoint
        # save_checkpoint(
        #     model, optimizer, epoch, 
        #     train_loss, val_loss, checkpoint_dir
        # )
        
        if epoch % (eval_every_n_epochs * 2) == 0 and len(train_history['epochs']) > 1:
            import matplotlib.pyplot as plt
            
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
            
            # Loss plot
            epochs_list = train_history['epochs']
            ax1.plot(epochs_list, train_history['train_losses'], 'b-', label='Train Loss', linewidth=2)
            ax1.plot(epochs_list, train_history['val_losses'], 'r-', label='Val Loss', linewidth=2)
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training and Validation Loss')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # Similarity plot
            ax2.plot(epochs_list, train_history['val_similarities'], 'g-', linewidth=2)
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Cosine Similarity')
            ax2.set_title('Validation Cosine Similarity')
            ax2.grid(True, alpha=0.3)
            
            # Learning rate plot
            ax3.plot(epochs_list, train_history['learning_rates'], 'purple', linewidth=2)
            ax3.set_xlabel('Epoch')
            ax3.set_ylabel('Learning Rate')
            ax3.set_title('Learning Rate Schedule')
            ax3.set_yscale('log') 
            ax3.grid(True, alpha=0.3)
            
            plt.tight_layout()
            
            logger.report_matplotlib_figure("Training Progress", "Loss, Similarity and LR", iteration=epoch, figure=plt)
            plt.close()
    
    else:
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.6f}, LR: {current_lr:.2e}, Time: {epoch_time:.2f}s")

# Final summary logging
print("\nTraining completed!")
total_training_time = (time.time() - start_time)/3600
print(f"Total training time: {total_training_time:.2f} hours")

Starting training...


Training Epoch 1:   0%|          | 0/6 [00:00<?, ?it/s]

Training Epoch 1:   0%|          | 0/6 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.34 GiB. GPU 1 has a total capacity of 79.25 GiB of which 1.79 GiB is free. Process 3080882 has 3.27 GiB memory in use. Process 3128536 has 2.51 GiB memory in use. Process 3206825 has 414.00 MiB memory in use. Process 3207471 has 414.00 MiB memory in use. Including non-PyTorch memory, this process has 70.85 GiB memory in use. Of the allocated memory 66.58 GiB is allocated by PyTorch, and 3.78 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Log final metrics
logger.report_single_value("Total Training Time (hours)", total_training_time)
logger.report_single_value("Best Validation Loss", min(train_history['val_losses']) if train_history['val_losses'] else float('inf'))
logger.report_single_value("Best Cosine Similarity", max(train_history['val_similarities']) if train_history['val_similarities'] else 0.0)
logger.report_single_value("Total Epochs", epochs)

In [None]:
# Cell 0: Emergency GPU cleanup and memory monitoring
import torch
import gc
import os
import psutil

def emergency_gpu_cleanup():
    """Emergency cleanup to free all GPU memory"""
    try:
        # Clear all cached tensors
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            
        # Force garbage collection
        gc.collect()
        
        # Try to reset CUDA context (nuclear option)
        if torch.cuda.is_available():
            try:
                torch.cuda.reset_peak_memory_stats()
                print("CUDA memory stats reset")
            except:
                print("Could not reset CUDA memory stats")
                
    except Exception as e:
        print(f"Error during cleanup: {e}")

def print_gpu_memory_usage():
    """Print comprehensive GPU memory usage"""
    print("=" * 60)
    print("GPU MEMORY ANALYSIS")
    print("=" * 60)
    
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        print(f"CUDA available: {device_count} device(s)")
        
        for i in range(device_count):
            try:
                torch.cuda.set_device(i)
                device_name = torch.cuda.get_device_name(i)
                
                # Memory in bytes
                allocated = torch.cuda.memory_allocated(i)
                reserved = torch.cuda.memory_reserved(i)
                max_allocated = torch.cuda.max_memory_allocated(i)
                max_reserved = torch.cuda.max_memory_reserved(i)
                
                # Get total GPU memory
                total_memory = torch.cuda.get_device_properties(i).total_memory
                
                print(f"\nGPU {i}: {device_name}")
                print("-" * 40)
                print(f"Total Memory:     {total_memory / 1024**3:.2f} GB")
                print(f"Currently Allocated: {allocated / 1024**3:.2f} GB ({allocated/total_memory*100:.1f}%)")
                print(f"Currently Reserved:  {reserved / 1024**3:.2f} GB ({reserved/total_memory*100:.1f}%)")
                print(f"Peak Allocated:   {max_allocated / 1024**3:.2f} GB ({max_allocated/total_memory*100:.1f}%)")
                print(f"Peak Reserved:    {max_reserved / 1024**3:.2f} GB ({max_reserved/total_memory*100:.1f}%)")
                print(f"Free Memory:      {(total_memory - reserved) / 1024**3:.2f} GB")
                
                # Memory fragmentation analysis
                if reserved > 0:
                    fragmentation = (reserved - allocated) / reserved * 100
                    print(f"Memory Fragmentation: {fragmentation:.1f}%")
                
            except Exception as e:
                print(f"Error reading GPU {i} memory: {e}")
                
    elif torch.backends.mps.is_available():
        print("MPS (Metal Performance Shaders) available")
        print("Note: MPS memory monitoring not directly available")
        
        # Get system memory as proxy
        mem = psutil.virtual_memory()
        print(f"System Memory: {mem.total / 1024**3:.2f} GB")
        print(f"Available Memory: {mem.available / 1024**3:.2f} GB")
        
    else:
        print("No GPU acceleration available - using CPU")
        
        # Show CPU memory usage
        mem = psutil.virtual_memory()
        print(f"System Memory: {mem.total / 1024**3:.2f} GB")
        print(f"Available Memory: {mem.available / 1024**3:.2f} GB")
        print(f"Used Memory: {mem.used / 1024**3:.2f} GB ({mem.percent:.1f}%)")

def print_model_memory_usage(model):
    """Calculate and print model memory usage"""
    print("\n" + "=" * 60)
    print("MODEL MEMORY ANALYSIS")
    print("=" * 60)
    
    # Parameter memory
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    
    # Count parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    print(f"Model Parameters:")
    print(f"  Trainable: {trainable_params:,}")
    print(f"  Total: {total_params:,}")
    print(f"  Non-trainable: {total_params - trainable_params:,}")
    
    print(f"\nModel Memory:")
    print(f"  Parameters: {param_size / 1024**2:.2f} MB")
    print(f"  Buffers: {buffer_size / 1024**2:.2f} MB")
    print(f"  Total Model: {(param_size + buffer_size) / 1024**2:.2f} MB")
    
    # Estimate training memory (rough approximation)
    # Forward pass ≈ model size, backward pass ≈ 2x model size, optimizer ≈ 2x model size
    estimated_training_memory = (param_size + buffer_size) * 5  # Conservative estimate
    print(f"  Estimated Training Memory: {estimated_training_memory / 1024**2:.2f} MB")

def monitor_batch_memory(device, batch_size):
    """Monitor memory usage during a training step"""
    if torch.cuda.is_available() and device == 'cuda':
        print(f"\nBatch Memory Analysis (Batch Size: {batch_size}):")
        print("-" * 40)
        
        # Memory before
        mem_before = torch.cuda.memory_allocated() / 1024**2
        
        # Simulate memory usage estimation
        print(f"Memory before batch: {mem_before:.2f} MB")
        
        return mem_before
    return 0

# Run emergency cleanup first
emergency_gpu_cleanup()

# Print initial memory status
print_gpu_memory_usage()

Error during cleanup: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

GPU MEMORY ANALYSIS
CUDA available: 2 device(s)

GPU 0: NVIDIA A100 80GB PCIe
----------------------------------------
Total Memory:     79.25 GB
Currently Allocated: 0.00 GB (0.0%)
Currently Reserved:  0.00 GB (0.0%)
Peak Allocated:   0.00 GB (0.0%)
Peak Reserved:    0.00 GB (0.0%)
Free Memory:      79.25 GB

GPU 1: NVIDIA A100 80GB PCIe
----------------------------------------
Total Memory:     79.25 GB
Currently Allocated: 70.24 GB (88.6%)
Currently Reserved:  72.71 GB (91.7%)
Peak Allocated:   70.24 GB (88.6%)
Peak Reserved:    72.71 GB (91.7%)
Free Memory:      6.55 GB
Memory Fragmentation: 3.4%


In [None]:
# Close the task
task.close()