# Main

# Train Fraud

## Conv VAE

### Fraud

In [2]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vae_model import ConvVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/conv_vae/fraud_conv_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vae_config = config_parser["Conv_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVae(conv_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

Initializing WandB run: conv-vae-fraud_20250313_173404 (Project: fraud-classification, Entity: alexkstern)


Loaded configuration from configs/conv_vae/fraud_conv_vae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 62727
Models will be saved to: saved_models/conv_vae/fraud_conv_vae/20250313_173415


                                                                           

Epoch 1/150: Train Loss = 704.4488, Val Loss = 1022.8286
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.8286


                                                                           

Epoch 2/150: Train Loss = 679.0434, Val Loss = 1022.5550
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5550


                                                                           

Epoch 3/150: Train Loss = 678.9731, Val Loss = 1022.5265
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5265


                                                                           

Epoch 4/150: Train Loss = 678.9738, Val Loss = 1022.5521


                                                                           

Epoch 5/150: Train Loss = 678.9614, Val Loss = 1022.5346


                                                                           

Epoch 6/150: Train Loss = 678.9674, Val Loss = 1022.5279


                                                                           

Epoch 7/150: Train Loss = 678.9592, Val Loss = 1022.5274


                                                                           

Epoch 8/150: Train Loss = 678.9613, Val Loss = 1022.5326


                                                                           

Epoch 9/150: Train Loss = 678.9659, Val Loss = 1022.5347


                                                                           

Epoch 10/150: Train Loss = 678.9616, Val Loss = 1022.5378


                                                                           

Epoch 11/150: Train Loss = 678.9632, Val Loss = 1022.5346


                                                                           

Epoch 12/150: Train Loss = 678.9609, Val Loss = 1022.5292


                                                                           

Epoch 13/150: Train Loss = 678.9595, Val Loss = 1022.5272


                                                                           

Epoch 14/150: Train Loss = 678.9634, Val Loss = 1022.5356


                                                                           

Epoch 15/150: Train Loss = 678.9615, Val Loss = 1022.5340


                                                                           

Epoch 16/150: Train Loss = 678.9606, Val Loss = 1022.5270


                                                                           

Epoch 17/150: Train Loss = 678.9582, Val Loss = 1022.5342


                                                                           

Epoch 18/150: Train Loss = 678.9600, Val Loss = 1022.5373


                                                                           

Epoch 19/150: Train Loss = 678.9630, Val Loss = 1022.5355


                                                                           

Epoch 20/150: Train Loss = 678.9585, Val Loss = 1022.5329


                                                                           

Epoch 21/150: Train Loss = 678.9605, Val Loss = 1022.5376


                                                                           

Epoch 22/150: Train Loss = 678.9619, Val Loss = 1022.5338


                                                                           

Epoch 23/150: Train Loss = 678.9579, Val Loss = 1022.5380


                                                                           

Epoch 24/150: Train Loss = 678.9638, Val Loss = 1022.5348


                                                                           

Epoch 25/150: Train Loss = 678.9619, Val Loss = 1022.5290


                                                                           

Epoch 26/150: Train Loss = 678.9622, Val Loss = 1022.5253
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5253


                                                                           

Epoch 27/150: Train Loss = 678.9598, Val Loss = 1022.5322


                                                                           

Epoch 28/150: Train Loss = 678.9614, Val Loss = 1022.5257


                                                                           

Epoch 29/150: Train Loss = 678.9601, Val Loss = 1022.5303


                                                                           

Epoch 30/150: Train Loss = 678.9610, Val Loss = 1022.5331


                                                                           

Epoch 31/150: Train Loss = 678.9617, Val Loss = 1022.5251
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5251


                                                                           

Epoch 32/150: Train Loss = 678.9612, Val Loss = 1022.5330


                                                                           

Epoch 33/150: Train Loss = 678.9610, Val Loss = 1022.5312


                                                                           

Epoch 34/150: Train Loss = 678.9616, Val Loss = 1022.5319


                                                                           

Epoch 35/150: Train Loss = 678.9608, Val Loss = 1022.5320


                                                                           

Epoch 36/150: Train Loss = 678.9605, Val Loss = 1022.5306


                                                                           

Epoch 37/150: Train Loss = 678.9602, Val Loss = 1022.5325


                                                                           

Epoch 38/150: Train Loss = 678.9607, Val Loss = 1022.5325


                                                                           

Epoch 39/150: Train Loss = 678.9605, Val Loss = 1022.5303


                                                                           

Epoch 40/150: Train Loss = 678.9611, Val Loss = 1022.5281


                                                                           

Epoch 41/150: Train Loss = 678.9596, Val Loss = 1022.5290


                                                                           

Epoch 42/150: Train Loss = 678.9612, Val Loss = 1022.5331


                                                                           

Epoch 43/150: Train Loss = 678.9608, Val Loss = 1022.5286


                                                                           

Epoch 44/150: Train Loss = 678.9612, Val Loss = 1022.5280


                                                                           

Epoch 45/150: Train Loss = 678.9613, Val Loss = 1022.5333


                                                                           

Epoch 46/150: Train Loss = 678.9593, Val Loss = 1022.5317


                                                                           

Epoch 47/150: Train Loss = 678.9613, Val Loss = 1022.5313


                                                                           

Epoch 48/150: Train Loss = 678.9618, Val Loss = 1022.5286


                                                                           

Epoch 49/150: Train Loss = 678.9620, Val Loss = 1022.5313


                                                                           

Epoch 50/150: Train Loss = 678.9610, Val Loss = 1022.5315


                                                                           

Epoch 51/150: Train Loss = 678.9607, Val Loss = 1022.5331


                                                                           

Epoch 52/150: Train Loss = 678.9609, Val Loss = 1022.5324


                                                                           

Epoch 53/150: Train Loss = 678.9611, Val Loss = 1022.5312


                                                                           

Epoch 54/150: Train Loss = 678.9603, Val Loss = 1022.5318


                                                                           

Epoch 55/150: Train Loss = 678.9613, Val Loss = 1022.5318


                                                                           

Epoch 56/150: Train Loss = 678.9609, Val Loss = 1022.5315


                                                                           

Epoch 57/150: Train Loss = 678.9604, Val Loss = 1022.5309


                                                                           

Epoch 58/150: Train Loss = 678.9604, Val Loss = 1022.5290


                                                                           

Epoch 59/150: Train Loss = 678.9603, Val Loss = 1022.5314


                                                                           

Epoch 60/150: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 61/150: Train Loss = 678.9609, Val Loss = 1022.5329


                                                                           

Epoch 62/150: Train Loss = 678.9610, Val Loss = 1022.5316


                                                                           

Epoch 63/150: Train Loss = 678.9608, Val Loss = 1022.5316


                                                                           

Epoch 64/150: Train Loss = 678.9608, Val Loss = 1022.5321


                                                                           

Epoch 65/150: Train Loss = 678.9609, Val Loss = 1022.5323


                                                                           

Epoch 66/150: Train Loss = 678.9607, Val Loss = 1022.5312


                                                                           

Epoch 67/150: Train Loss = 678.9605, Val Loss = 1022.5317


                                                                           

Epoch 68/150: Train Loss = 678.9603, Val Loss = 1022.5306


                                                                           

Epoch 69/150: Train Loss = 678.9601, Val Loss = 1022.5305


                                                                           

Epoch 70/150: Train Loss = 678.9615, Val Loss = 1022.5304


                                                                           

Epoch 71/150: Train Loss = 678.9589, Val Loss = 1022.5275


                                                                           

Epoch 72/150: Train Loss = 671.6301, Val Loss = 1006.9674
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1006.9674


                                                                           

Epoch 73/150: Train Loss = 665.9213, Val Loss = 1006.9708


                                                                           

Epoch 74/150: Train Loss = 665.9213, Val Loss = 1006.9710


                                                                           

Epoch 75/150: Train Loss = 665.9195, Val Loss = 1006.9802


                                                                           

Epoch 76/150: Train Loss = 665.9198, Val Loss = 1006.9707


                                                                           

Epoch 77/150: Train Loss = 665.9196, Val Loss = 1007.1178


                                                                           

Epoch 78/150: Train Loss = 665.9198, Val Loss = 1006.9691


                                                                           

Epoch 79/150: Train Loss = 665.9195, Val Loss = 1006.9674


                                                                           

Epoch 80/150: Train Loss = 665.9188, Val Loss = 1006.9704


                                                                           

Epoch 81/150: Train Loss = 665.9202, Val Loss = 1006.9702


                                                                           

Epoch 82/150: Train Loss = 665.9192, Val Loss = 1006.9699


                                                                           

Epoch 83/150: Train Loss = 665.9195, Val Loss = 1006.9697


                                                                           

Epoch 84/150: Train Loss = 665.9199, Val Loss = 1006.9696


                                                                           

Epoch 85/150: Train Loss = 665.9195, Val Loss = 1006.9710


                                                                           

Epoch 86/150: Train Loss = 665.9356, Val Loss = 1006.9724


                                                                           

Epoch 87/150: Train Loss = 665.9226, Val Loss = 1006.9698


                                                                           

Epoch 88/150: Train Loss = 665.9193, Val Loss = 1006.9711


                                                                           

Epoch 89/150: Train Loss = 665.9196, Val Loss = 1006.9689


                                                                           

Epoch 90/150: Train Loss = 665.9197, Val Loss = 1006.9696


                                                                           

Epoch 91/150: Train Loss = 665.9192, Val Loss = 1006.9713


                                                                           

Epoch 92/150: Train Loss = 665.9197, Val Loss = 1006.9709


                                                                           

Epoch 93/150: Train Loss = 665.9194, Val Loss = 1006.9685


                                                                           

Epoch 94/150: Train Loss = 665.9196, Val Loss = 1006.9702


                                                                           

Epoch 95/150: Train Loss = 665.9197, Val Loss = 1006.9694


                                                                           

Epoch 96/150: Train Loss = 665.9200, Val Loss = 1006.9695


                                                                           

Epoch 97/150: Train Loss = 665.9195, Val Loss = 1006.9704


                                                                           

Epoch 98/150: Train Loss = 665.9193, Val Loss = 1006.9701


                                                                           

Epoch 99/150: Train Loss = 665.9199, Val Loss = 1006.9702


                                                                           

Epoch 100/150: Train Loss = 665.9192, Val Loss = 1006.9706


                                                                           

Epoch 101/150: Train Loss = 665.9195, Val Loss = 1006.9701


                                                                           

Epoch 102/150: Train Loss = 665.9194, Val Loss = 1006.9707


                                                                           

Epoch 103/150: Train Loss = 665.9196, Val Loss = 1006.9698


                                                                           

Epoch 104/150: Train Loss = 665.9193, Val Loss = 1006.9697


                                                                           

Epoch 105/150: Train Loss = 665.9192, Val Loss = 1006.9699


                                                                           

Epoch 106/150: Train Loss = 665.9192, Val Loss = 1006.9694


                                                                           

Epoch 107/150: Train Loss = 665.9197, Val Loss = 1006.9702


                                                                           

Epoch 108/150: Train Loss = 665.9195, Val Loss = 1006.9695


                                                                           

Epoch 109/150: Train Loss = 665.9195, Val Loss = 1006.9697


                                                                           

Epoch 110/150: Train Loss = 665.9197, Val Loss = 1006.9704


                                                                           

Epoch 111/150: Train Loss = 665.9190, Val Loss = 1006.9696


                                                                           

Epoch 112/150: Train Loss = 665.9196, Val Loss = 1006.9696


                                                                           

Epoch 113/150: Train Loss = 665.9193, Val Loss = 1006.9692


                                                                           

Epoch 114/150: Train Loss = 665.9196, Val Loss = 1006.9699


                                                                           

Epoch 115/150: Train Loss = 665.9197, Val Loss = 1006.9702


                                                                           

Epoch 116/150: Train Loss = 665.9192, Val Loss = 1006.9700


                                                                           

Epoch 117/150: Train Loss = 665.9196, Val Loss = 1006.9697


                                                                           

Epoch 118/150: Train Loss = 665.9195, Val Loss = 1006.9698


                                                                           

Epoch 119/150: Train Loss = 665.9194, Val Loss = 1006.9695


                                                                           

Epoch 120/150: Train Loss = 665.9193, Val Loss = 1006.9700


                                                                           

Epoch 121/150: Train Loss = 665.9194, Val Loss = 1006.9697


                                                                           

Epoch 122/150: Train Loss = 665.9189, Val Loss = 1006.9681


                                                                           

Epoch 123/150: Train Loss = 665.9222, Val Loss = 1006.9705


                                                                           

Epoch 124/150: Train Loss = 665.9194, Val Loss = 1006.9702


                                                                           

Epoch 125/150: Train Loss = 665.9191, Val Loss = 1006.9697


                                                                           

Epoch 126/150: Train Loss = 665.9199, Val Loss = 1006.9695


                                                                           

Epoch 127/150: Train Loss = 665.9197, Val Loss = 1006.9706


                                                                           

Epoch 128/150: Train Loss = 665.9198, Val Loss = 1006.9698


                                                                           

Epoch 129/150: Train Loss = 665.9190, Val Loss = 1006.9698


                                                                           

Epoch 130/150: Train Loss = 665.9191, Val Loss = 1006.9701


                                                                           

Epoch 131/150: Train Loss = 665.9199, Val Loss = 1006.9701


                                                                           

Epoch 132/150: Train Loss = 665.9194, Val Loss = 1006.9697


                                                                           

Epoch 133/150: Train Loss = 665.9199, Val Loss = 1006.9691


                                                                           

Epoch 134/150: Train Loss = 665.9196, Val Loss = 1006.9699


                                                                           

Epoch 135/150: Train Loss = 665.9195, Val Loss = 1006.9694


                                                                           

Epoch 136/150: Train Loss = 665.9193, Val Loss = 1006.9699


                                                                           

Epoch 137/150: Train Loss = 665.9192, Val Loss = 1006.9693


                                                                           

Epoch 138/150: Train Loss = 665.9194, Val Loss = 1006.9704


                                                                           

Epoch 139/150: Train Loss = 665.9193, Val Loss = 1006.9701


                                                                           

Epoch 140/150: Train Loss = 665.9196, Val Loss = 1006.9702


                                                                           

Epoch 141/150: Train Loss = 665.9194, Val Loss = 1006.9698


                                                                           

Epoch 142/150: Train Loss = 665.9193, Val Loss = 1006.9702


                                                                           

Epoch 143/150: Train Loss = 665.9194, Val Loss = 1006.9699


                                                                           

Epoch 144/150: Train Loss = 665.9193, Val Loss = 1006.9700


                                                                           

Epoch 145/150: Train Loss = 665.9191, Val Loss = 1006.9692


                                                                           

Epoch 146/150: Train Loss = 665.9194, Val Loss = 1006.9695


                                                                           

Epoch 147/150: Train Loss = 665.9191, Val Loss = 1006.9704


                                                                           

Epoch 148/150: Train Loss = 665.9193, Val Loss = 1006.9692


                                                                           

Epoch 149/150: Train Loss = 665.9197, Val Loss = 1006.9706


                                                                           

Epoch 150/150: Train Loss = 665.9197, Val Loss = 1006.9702
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/final_model.pt
Training complete. Best validation loss: 1006.9674


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/recon_loss,▄▄▃▄▂█▂▁▃▃▄▄▃▃▃▃▃▇▆▁▁▂▄▅▄▄▂▁▄▄▄▄▃▅▃▅▂▆▆▁
train/total_loss,█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/total_loss,█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,150.0
train/recon_loss,4001.76562
train/total_loss,665.91969
val/recon_loss,10876.7793
val/total_loss,1006.97016


Training session complete.


### Normal

In [1]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vae_model import ConvVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/conv_vae/normal_conv_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vae_config = config_parser["Conv_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVae(conv_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Initializing WandB run: conv-vae-normal_20250313_175704 (Project: fraud-classification, Entity: alexkstern)


[34m[1mwandb[0m: Currently logged in as: [33malexkstern[0m ([33malexksternteam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loaded configuration from configs/conv_vae/normal_conv_vae.config
Filtered dataset to class 0: 12000 samples
Normalization statistics (calculated from class 0): {'Time': {'mean': 94364.65358333333, 'std': 47365.815157589255}, 'Amount': {'mean': 87.60478666666666, 'std': 240.59403081682598}}
Filtered dataset to class 0: 12000 samples
Filtered dataset to class 0: 1500 samples
Filtered dataset to class 0: 1500 samples
Total number of trainable parameters: 62727
Models will be saved to: saved_models/conv_vae/normal_conv_vae/20250313_175713


                                                                               

Epoch 1/30: Train Loss = 31.6793, Val Loss = 32.5413
Model saved to saved_models/conv_vae/normal_conv_vae/20250313_175713/best_model.pt
New best validation loss: 32.5413


                                                                               

Epoch 2/30: Train Loss = 31.5506, Val Loss = 32.5407
Model saved to saved_models/conv_vae/normal_conv_vae/20250313_175713/best_model.pt
New best validation loss: 32.5407


                                                                               

Epoch 3/30: Train Loss = 31.5506, Val Loss = 32.5412


                                                                               

Epoch 4/30: Train Loss = 31.5506, Val Loss = 32.5413


                                                                               

Epoch 5/30: Train Loss = 31.5506, Val Loss = 32.5413


                                                                               

Epoch 6/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 7/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 8/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 9/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 10/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 11/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 12/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 13/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 14/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 15/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 16/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 17/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 18/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 19/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 20/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 21/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 22/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 23/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 24/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 25/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 26/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 27/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 28/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 29/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 30/30: Train Loss = 31.5505, Val Loss = 32.5413
Model saved to saved_models/conv_vae/normal_conv_vae/20250313_175713/final_model.pt
Training complete. Best validation loss: 32.5407


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train/recon_loss,▃█▂▂▂▂▂▃▃▃▃▂▄▁▂▁▂▂▂▁▁▁▂▁▂▃▂▂▂▂
train/total_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,█▁▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
val/total_loss,█▁▇███████████████████████████

0,1
epoch,30.0
train/recon_loss,170.05019
train/total_loss,31.55054
val/recon_loss,199.38646
val/total_loss,32.54126


Training session complete.


## Transformer VAE

### Fraud

In [1]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vae_model import TransformerVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/transformer_vae/fraud_transformer_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vae_config = config_parser["Transformer_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVae(transformer_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Initializing WandB run: transformer-vae-fraud_20250313_190718 (Project: fraud-classification, Entity: alexkstern)


[34m[1mwandb[0m: Currently logged in as: [33malexkstern[0m ([33malexksternteam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loaded configuration from configs/transformer_vae/fraud_transformer_vae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 147879
Models will be saved to: saved_models/transformer_vae/fraud_transformer_vae/20250313_190737


                                                                           

Epoch 1/100: Train Loss = 682.0477, Val Loss = 1022.8600
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.8600


                                                                           

Epoch 2/100: Train Loss = 679.1918, Val Loss = 1022.7154
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.7154


                                                                           

Epoch 3/100: Train Loss = 679.1025, Val Loss = 1022.6512
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.6512


                                                                           

Epoch 4/100: Train Loss = 679.0561, Val Loss = 1022.6143
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.6143


                                                                           

Epoch 5/100: Train Loss = 679.0293, Val Loss = 1022.5921
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5921


                                                                           

Epoch 6/100: Train Loss = 679.0118, Val Loss = 1022.5776
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5776


                                                                           

Epoch 7/100: Train Loss = 679.0006, Val Loss = 1022.5677
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5677


                                                                           

Epoch 8/100: Train Loss = 678.9925, Val Loss = 1022.5607
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5607


                                                                           

Epoch 9/100: Train Loss = 678.9867, Val Loss = 1022.5558
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5558


                                                                           

Epoch 10/100: Train Loss = 678.9825, Val Loss = 1022.5517
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5517


                                                                           

Epoch 11/100: Train Loss = 678.9791, Val Loss = 1022.5487
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5487


                                                                           

Epoch 12/100: Train Loss = 678.9764, Val Loss = 1022.5463
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5463


                                                                           

Epoch 13/100: Train Loss = 678.9744, Val Loss = 1022.5444
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5444


                                                                           

Epoch 14/100: Train Loss = 678.9726, Val Loss = 1022.5428
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5428


                                                                           

Epoch 15/100: Train Loss = 678.9712, Val Loss = 1022.5415
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5415


                                                                           

Epoch 16/100: Train Loss = 678.9701, Val Loss = 1022.5404
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5404


                                                                           

Epoch 17/100: Train Loss = 678.9690, Val Loss = 1022.5395
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5395


                                                                           

Epoch 18/100: Train Loss = 678.9682, Val Loss = 1022.5387
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5387


                                                                           

Epoch 19/100: Train Loss = 678.9675, Val Loss = 1022.5380
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5380


                                                                           

Epoch 20/100: Train Loss = 678.9669, Val Loss = 1022.5374
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5374


                                                                           

Epoch 21/100: Train Loss = 678.9663, Val Loss = 1022.5368
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5368


                                                                           

Epoch 22/100: Train Loss = 678.9658, Val Loss = 1022.5365
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5365


                                                                           

Epoch 23/100: Train Loss = 678.9654, Val Loss = 1022.5360
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5360


                                                                           

Epoch 24/100: Train Loss = 678.9650, Val Loss = 1022.5357
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5357


                                                                           

Epoch 25/100: Train Loss = 678.9647, Val Loss = 1022.5354
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5354


                                                                           

Epoch 26/100: Train Loss = 678.9644, Val Loss = 1022.5351
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5351


                                                                           

Epoch 27/100: Train Loss = 678.9641, Val Loss = 1022.5349
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5349


                                                                           

Epoch 28/100: Train Loss = 678.9639, Val Loss = 1022.5346
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5346


                                                                           

Epoch 29/100: Train Loss = 678.9637, Val Loss = 1022.5344
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5344


                                                                           

Epoch 30/100: Train Loss = 678.9634, Val Loss = 1022.5342
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5342


                                                                           

Epoch 31/100: Train Loss = 678.9633, Val Loss = 1022.5340
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5340


                                                                           

Epoch 32/100: Train Loss = 678.9631, Val Loss = 1022.5339
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5339


                                                                           

Epoch 33/100: Train Loss = 678.9630, Val Loss = 1022.5337
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5337


                                                                           

Epoch 34/100: Train Loss = 678.9628, Val Loss = 1022.5336
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5336


                                                                           

Epoch 35/100: Train Loss = 678.9627, Val Loss = 1022.5335
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5335


                                                                           

Epoch 36/100: Train Loss = 678.9626, Val Loss = 1022.5334
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5334


                                                                           

Epoch 37/100: Train Loss = 678.9625, Val Loss = 1022.5333
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5333


                                                                           

Epoch 38/100: Train Loss = 678.9624, Val Loss = 1022.5333
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5333


                                                                           

Epoch 39/100: Train Loss = 678.9623, Val Loss = 1022.5331
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5331


                                                                           

Epoch 40/100: Train Loss = 678.9622, Val Loss = 1022.5330
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5330


                                                                           

Epoch 41/100: Train Loss = 678.9621, Val Loss = 1022.5329
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5329


                                                                           

Epoch 42/100: Train Loss = 678.9620, Val Loss = 1022.5330


                                                                           

Epoch 43/100: Train Loss = 678.9620, Val Loss = 1022.5328
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5328


                                                                           

Epoch 44/100: Train Loss = 678.9619, Val Loss = 1022.5328
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5328


                                                                           

Epoch 45/100: Train Loss = 678.9619, Val Loss = 1022.5328


                                                                           

Epoch 46/100: Train Loss = 678.9618, Val Loss = 1022.5327
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5327


                                                                           

Epoch 47/100: Train Loss = 678.9617, Val Loss = 1022.5326
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5326


                                                                           

Epoch 48/100: Train Loss = 678.9617, Val Loss = 1022.5326
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5326


                                                                           

Epoch 49/100: Train Loss = 678.9616, Val Loss = 1022.5326
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5326


                                                                           

Epoch 50/100: Train Loss = 678.9616, Val Loss = 1022.5325
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5325


                                                                           

Epoch 51/100: Train Loss = 678.9615, Val Loss = 1022.5325


                                                                           

Epoch 52/100: Train Loss = 678.9615, Val Loss = 1022.5324
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5324


                                                                           

Epoch 53/100: Train Loss = 678.9615, Val Loss = 1022.5324


                                                                           

Epoch 54/100: Train Loss = 678.9614, Val Loss = 1022.5324


                                                                           

Epoch 55/100: Train Loss = 678.9614, Val Loss = 1022.5324
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5324


                                                                           

Epoch 56/100: Train Loss = 678.9614, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 57/100: Train Loss = 678.9613, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 58/100: Train Loss = 678.9613, Val Loss = 1022.5323


                                                                           

Epoch 59/100: Train Loss = 678.9613, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 60/100: Train Loss = 678.9613, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 61/100: Train Loss = 678.9612, Val Loss = 1022.5321
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5321


                                                                           

Epoch 62/100: Train Loss = 678.9612, Val Loss = 1022.5321


                                                                           

Epoch 63/100: Train Loss = 678.9612, Val Loss = 1022.5321
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5321


                                                                           

Epoch 64/100: Train Loss = 678.9612, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 65/100: Train Loss = 678.9612, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 66/100: Train Loss = 678.9611, Val Loss = 1022.5321


                                                                           

Epoch 67/100: Train Loss = 678.9611, Val Loss = 1022.5320


                                                                           

Epoch 68/100: Train Loss = 678.9611, Val Loss = 1022.5321


                                                                           

Epoch 69/100: Train Loss = 678.9611, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 70/100: Train Loss = 678.9611, Val Loss = 1022.5320


                                                                           

Epoch 71/100: Train Loss = 678.9611, Val Loss = 1022.5320


                                                                           

Epoch 72/100: Train Loss = 678.9610, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 73/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 74/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 75/100: Train Loss = 678.9610, Val Loss = 1022.5319
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5319


                                                                           

Epoch 76/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 77/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 78/100: Train Loss = 678.9610, Val Loss = 1022.5319
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5319


                                                                           

Epoch 79/100: Train Loss = 678.9610, Val Loss = 1022.5319


                                                                           

Epoch 80/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 81/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 82/100: Train Loss = 678.9609, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 83/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 84/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 85/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 86/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 87/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 88/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 89/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 90/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 91/100: Train Loss = 678.9609, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 92/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 93/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 94/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 95/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 96/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 97/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 98/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 99/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 100/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/final_model.pt
Training complete. Best validation loss: 1022.5318


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
train/recon_loss,▂▃▃▂▃▃▄▅▃▁▆▆▄▃▇█▂▄▂▂▅▁▅▁▅▃▃▄▃▄▃▄▅▄▄▄▃▂▃▅
train/total_loss,█▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/total_loss,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
train/recon_loss,9307.76562
train/total_loss,678.96082
val/recon_loss,11025.65723
val/total_loss,1022.5318


Training session complete.


### Normal

In [None]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vae_model import TransformerVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/transformer_vae/normal_transformer_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vae_config = config_parser["Transformer_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVae(transformer_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

## Conv VQ VAE

### Fraud

In [None]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vqvae_model import ConvVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/conv_vqvae/fraud_conv_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vqvae_config = config_parser["Conv_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVQVAE(conv_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log metrics to WandB
        try:
            # For VQVAE, we only log the total losses since we don't have separate recon_loss extraction
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

### Normal


In [None]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vqvae_model import ConvVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger

# Build the config path
config_path = "configs/conv_vqvae/normal_conv_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vqvae_config = config_parser["Conv_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVQVAE(conv_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log metrics to WandB
        try:
            # For VQVAE, we only log the total losses since we don't have separate recon_loss extraction
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

## Transformer VQVAE

### Fraud

In [None]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vqvae_model import TransformerVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger

# Build the config path
config_path = "configs/transformer_vqvae/fraud_transformer_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vqvae_config = config_parser["Transformer_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVQVAE(transformer_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            # For VQVAE, we only log the total losses 
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

### Normal

In [None]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vqvae_model import TransformerVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger

# Build the config path
config_path = "configs/transformer_vqvae/normal_transformer_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vqvae_config = config_parser["Transformer_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVQVAE(transformer_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            # For VQVAE, we only log the total losses 
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")