# Main

# Train Fraud

## Conv VAE

### Fraud

In [2]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vae_model import ConvVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/conv_vae/fraud_conv_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vae_config = config_parser["Conv_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVae(conv_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

Initializing WandB run: conv-vae-fraud_20250313_173404 (Project: fraud-classification, Entity: alexkstern)


Loaded configuration from configs/conv_vae/fraud_conv_vae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 62727
Models will be saved to: saved_models/conv_vae/fraud_conv_vae/20250313_173415


                                                                           

Epoch 1/150: Train Loss = 704.4488, Val Loss = 1022.8286
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.8286


                                                                           

Epoch 2/150: Train Loss = 679.0434, Val Loss = 1022.5550
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5550


                                                                           

Epoch 3/150: Train Loss = 678.9731, Val Loss = 1022.5265
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5265


                                                                           

Epoch 4/150: Train Loss = 678.9738, Val Loss = 1022.5521


                                                                           

Epoch 5/150: Train Loss = 678.9614, Val Loss = 1022.5346


                                                                           

Epoch 6/150: Train Loss = 678.9674, Val Loss = 1022.5279


                                                                           

Epoch 7/150: Train Loss = 678.9592, Val Loss = 1022.5274


                                                                           

Epoch 8/150: Train Loss = 678.9613, Val Loss = 1022.5326


                                                                           

Epoch 9/150: Train Loss = 678.9659, Val Loss = 1022.5347


                                                                           

Epoch 10/150: Train Loss = 678.9616, Val Loss = 1022.5378


                                                                           

Epoch 11/150: Train Loss = 678.9632, Val Loss = 1022.5346


                                                                           

Epoch 12/150: Train Loss = 678.9609, Val Loss = 1022.5292


                                                                           

Epoch 13/150: Train Loss = 678.9595, Val Loss = 1022.5272


                                                                           

Epoch 14/150: Train Loss = 678.9634, Val Loss = 1022.5356


                                                                           

Epoch 15/150: Train Loss = 678.9615, Val Loss = 1022.5340


                                                                           

Epoch 16/150: Train Loss = 678.9606, Val Loss = 1022.5270


                                                                           

Epoch 17/150: Train Loss = 678.9582, Val Loss = 1022.5342


                                                                           

Epoch 18/150: Train Loss = 678.9600, Val Loss = 1022.5373


                                                                           

Epoch 19/150: Train Loss = 678.9630, Val Loss = 1022.5355


                                                                           

Epoch 20/150: Train Loss = 678.9585, Val Loss = 1022.5329


                                                                           

Epoch 21/150: Train Loss = 678.9605, Val Loss = 1022.5376


                                                                           

Epoch 22/150: Train Loss = 678.9619, Val Loss = 1022.5338


                                                                           

Epoch 23/150: Train Loss = 678.9579, Val Loss = 1022.5380


                                                                           

Epoch 24/150: Train Loss = 678.9638, Val Loss = 1022.5348


                                                                           

Epoch 25/150: Train Loss = 678.9619, Val Loss = 1022.5290


                                                                           

Epoch 26/150: Train Loss = 678.9622, Val Loss = 1022.5253
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5253


                                                                           

Epoch 27/150: Train Loss = 678.9598, Val Loss = 1022.5322


                                                                           

Epoch 28/150: Train Loss = 678.9614, Val Loss = 1022.5257


                                                                           

Epoch 29/150: Train Loss = 678.9601, Val Loss = 1022.5303


                                                                           

Epoch 30/150: Train Loss = 678.9610, Val Loss = 1022.5331


                                                                           

Epoch 31/150: Train Loss = 678.9617, Val Loss = 1022.5251
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1022.5251


                                                                           

Epoch 32/150: Train Loss = 678.9612, Val Loss = 1022.5330


                                                                           

Epoch 33/150: Train Loss = 678.9610, Val Loss = 1022.5312


                                                                           

Epoch 34/150: Train Loss = 678.9616, Val Loss = 1022.5319


                                                                           

Epoch 35/150: Train Loss = 678.9608, Val Loss = 1022.5320


                                                                           

Epoch 36/150: Train Loss = 678.9605, Val Loss = 1022.5306


                                                                           

Epoch 37/150: Train Loss = 678.9602, Val Loss = 1022.5325


                                                                           

Epoch 38/150: Train Loss = 678.9607, Val Loss = 1022.5325


                                                                           

Epoch 39/150: Train Loss = 678.9605, Val Loss = 1022.5303


                                                                           

Epoch 40/150: Train Loss = 678.9611, Val Loss = 1022.5281


                                                                           

Epoch 41/150: Train Loss = 678.9596, Val Loss = 1022.5290


                                                                           

Epoch 42/150: Train Loss = 678.9612, Val Loss = 1022.5331


                                                                           

Epoch 43/150: Train Loss = 678.9608, Val Loss = 1022.5286


                                                                           

Epoch 44/150: Train Loss = 678.9612, Val Loss = 1022.5280


                                                                           

Epoch 45/150: Train Loss = 678.9613, Val Loss = 1022.5333


                                                                           

Epoch 46/150: Train Loss = 678.9593, Val Loss = 1022.5317


                                                                           

Epoch 47/150: Train Loss = 678.9613, Val Loss = 1022.5313


                                                                           

Epoch 48/150: Train Loss = 678.9618, Val Loss = 1022.5286


                                                                           

Epoch 49/150: Train Loss = 678.9620, Val Loss = 1022.5313


                                                                           

Epoch 50/150: Train Loss = 678.9610, Val Loss = 1022.5315


                                                                           

Epoch 51/150: Train Loss = 678.9607, Val Loss = 1022.5331


                                                                           

Epoch 52/150: Train Loss = 678.9609, Val Loss = 1022.5324


                                                                           

Epoch 53/150: Train Loss = 678.9611, Val Loss = 1022.5312


                                                                           

Epoch 54/150: Train Loss = 678.9603, Val Loss = 1022.5318


                                                                           

Epoch 55/150: Train Loss = 678.9613, Val Loss = 1022.5318


                                                                           

Epoch 56/150: Train Loss = 678.9609, Val Loss = 1022.5315


                                                                           

Epoch 57/150: Train Loss = 678.9604, Val Loss = 1022.5309


                                                                           

Epoch 58/150: Train Loss = 678.9604, Val Loss = 1022.5290


                                                                           

Epoch 59/150: Train Loss = 678.9603, Val Loss = 1022.5314


                                                                           

Epoch 60/150: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 61/150: Train Loss = 678.9609, Val Loss = 1022.5329


                                                                           

Epoch 62/150: Train Loss = 678.9610, Val Loss = 1022.5316


                                                                           

Epoch 63/150: Train Loss = 678.9608, Val Loss = 1022.5316


                                                                           

Epoch 64/150: Train Loss = 678.9608, Val Loss = 1022.5321


                                                                           

Epoch 65/150: Train Loss = 678.9609, Val Loss = 1022.5323


                                                                           

Epoch 66/150: Train Loss = 678.9607, Val Loss = 1022.5312


                                                                           

Epoch 67/150: Train Loss = 678.9605, Val Loss = 1022.5317


                                                                           

Epoch 68/150: Train Loss = 678.9603, Val Loss = 1022.5306


                                                                           

Epoch 69/150: Train Loss = 678.9601, Val Loss = 1022.5305


                                                                           

Epoch 70/150: Train Loss = 678.9615, Val Loss = 1022.5304


                                                                           

Epoch 71/150: Train Loss = 678.9589, Val Loss = 1022.5275


                                                                           

Epoch 72/150: Train Loss = 671.6301, Val Loss = 1006.9674
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/best_model.pt
New best validation loss: 1006.9674


                                                                           

Epoch 73/150: Train Loss = 665.9213, Val Loss = 1006.9708


                                                                           

Epoch 74/150: Train Loss = 665.9213, Val Loss = 1006.9710


                                                                           

Epoch 75/150: Train Loss = 665.9195, Val Loss = 1006.9802


                                                                           

Epoch 76/150: Train Loss = 665.9198, Val Loss = 1006.9707


                                                                           

Epoch 77/150: Train Loss = 665.9196, Val Loss = 1007.1178


                                                                           

Epoch 78/150: Train Loss = 665.9198, Val Loss = 1006.9691


                                                                           

Epoch 79/150: Train Loss = 665.9195, Val Loss = 1006.9674


                                                                           

Epoch 80/150: Train Loss = 665.9188, Val Loss = 1006.9704


                                                                           

Epoch 81/150: Train Loss = 665.9202, Val Loss = 1006.9702


                                                                           

Epoch 82/150: Train Loss = 665.9192, Val Loss = 1006.9699


                                                                           

Epoch 83/150: Train Loss = 665.9195, Val Loss = 1006.9697


                                                                           

Epoch 84/150: Train Loss = 665.9199, Val Loss = 1006.9696


                                                                           

Epoch 85/150: Train Loss = 665.9195, Val Loss = 1006.9710


                                                                           

Epoch 86/150: Train Loss = 665.9356, Val Loss = 1006.9724


                                                                           

Epoch 87/150: Train Loss = 665.9226, Val Loss = 1006.9698


                                                                           

Epoch 88/150: Train Loss = 665.9193, Val Loss = 1006.9711


                                                                           

Epoch 89/150: Train Loss = 665.9196, Val Loss = 1006.9689


                                                                           

Epoch 90/150: Train Loss = 665.9197, Val Loss = 1006.9696


                                                                           

Epoch 91/150: Train Loss = 665.9192, Val Loss = 1006.9713


                                                                           

Epoch 92/150: Train Loss = 665.9197, Val Loss = 1006.9709


                                                                           

Epoch 93/150: Train Loss = 665.9194, Val Loss = 1006.9685


                                                                           

Epoch 94/150: Train Loss = 665.9196, Val Loss = 1006.9702


                                                                           

Epoch 95/150: Train Loss = 665.9197, Val Loss = 1006.9694


                                                                           

Epoch 96/150: Train Loss = 665.9200, Val Loss = 1006.9695


                                                                           

Epoch 97/150: Train Loss = 665.9195, Val Loss = 1006.9704


                                                                           

Epoch 98/150: Train Loss = 665.9193, Val Loss = 1006.9701


                                                                           

Epoch 99/150: Train Loss = 665.9199, Val Loss = 1006.9702


                                                                           

Epoch 100/150: Train Loss = 665.9192, Val Loss = 1006.9706


                                                                           

Epoch 101/150: Train Loss = 665.9195, Val Loss = 1006.9701


                                                                           

Epoch 102/150: Train Loss = 665.9194, Val Loss = 1006.9707


                                                                           

Epoch 103/150: Train Loss = 665.9196, Val Loss = 1006.9698


                                                                           

Epoch 104/150: Train Loss = 665.9193, Val Loss = 1006.9697


                                                                           

Epoch 105/150: Train Loss = 665.9192, Val Loss = 1006.9699


                                                                           

Epoch 106/150: Train Loss = 665.9192, Val Loss = 1006.9694


                                                                           

Epoch 107/150: Train Loss = 665.9197, Val Loss = 1006.9702


                                                                           

Epoch 108/150: Train Loss = 665.9195, Val Loss = 1006.9695


                                                                           

Epoch 109/150: Train Loss = 665.9195, Val Loss = 1006.9697


                                                                           

Epoch 110/150: Train Loss = 665.9197, Val Loss = 1006.9704


                                                                           

Epoch 111/150: Train Loss = 665.9190, Val Loss = 1006.9696


                                                                           

Epoch 112/150: Train Loss = 665.9196, Val Loss = 1006.9696


                                                                           

Epoch 113/150: Train Loss = 665.9193, Val Loss = 1006.9692


                                                                           

Epoch 114/150: Train Loss = 665.9196, Val Loss = 1006.9699


                                                                           

Epoch 115/150: Train Loss = 665.9197, Val Loss = 1006.9702


                                                                           

Epoch 116/150: Train Loss = 665.9192, Val Loss = 1006.9700


                                                                           

Epoch 117/150: Train Loss = 665.9196, Val Loss = 1006.9697


                                                                           

Epoch 118/150: Train Loss = 665.9195, Val Loss = 1006.9698


                                                                           

Epoch 119/150: Train Loss = 665.9194, Val Loss = 1006.9695


                                                                           

Epoch 120/150: Train Loss = 665.9193, Val Loss = 1006.9700


                                                                           

Epoch 121/150: Train Loss = 665.9194, Val Loss = 1006.9697


                                                                           

Epoch 122/150: Train Loss = 665.9189, Val Loss = 1006.9681


                                                                           

Epoch 123/150: Train Loss = 665.9222, Val Loss = 1006.9705


                                                                           

Epoch 124/150: Train Loss = 665.9194, Val Loss = 1006.9702


                                                                           

Epoch 125/150: Train Loss = 665.9191, Val Loss = 1006.9697


                                                                           

Epoch 126/150: Train Loss = 665.9199, Val Loss = 1006.9695


                                                                           

Epoch 127/150: Train Loss = 665.9197, Val Loss = 1006.9706


                                                                           

Epoch 128/150: Train Loss = 665.9198, Val Loss = 1006.9698


                                                                           

Epoch 129/150: Train Loss = 665.9190, Val Loss = 1006.9698


                                                                           

Epoch 130/150: Train Loss = 665.9191, Val Loss = 1006.9701


                                                                           

Epoch 131/150: Train Loss = 665.9199, Val Loss = 1006.9701


                                                                           

Epoch 132/150: Train Loss = 665.9194, Val Loss = 1006.9697


                                                                           

Epoch 133/150: Train Loss = 665.9199, Val Loss = 1006.9691


                                                                           

Epoch 134/150: Train Loss = 665.9196, Val Loss = 1006.9699


                                                                           

Epoch 135/150: Train Loss = 665.9195, Val Loss = 1006.9694


                                                                           

Epoch 136/150: Train Loss = 665.9193, Val Loss = 1006.9699


                                                                           

Epoch 137/150: Train Loss = 665.9192, Val Loss = 1006.9693


                                                                           

Epoch 138/150: Train Loss = 665.9194, Val Loss = 1006.9704


                                                                           

Epoch 139/150: Train Loss = 665.9193, Val Loss = 1006.9701


                                                                           

Epoch 140/150: Train Loss = 665.9196, Val Loss = 1006.9702


                                                                           

Epoch 141/150: Train Loss = 665.9194, Val Loss = 1006.9698


                                                                           

Epoch 142/150: Train Loss = 665.9193, Val Loss = 1006.9702


                                                                           

Epoch 143/150: Train Loss = 665.9194, Val Loss = 1006.9699


                                                                           

Epoch 144/150: Train Loss = 665.9193, Val Loss = 1006.9700


                                                                           

Epoch 145/150: Train Loss = 665.9191, Val Loss = 1006.9692


                                                                           

Epoch 146/150: Train Loss = 665.9194, Val Loss = 1006.9695


                                                                           

Epoch 147/150: Train Loss = 665.9191, Val Loss = 1006.9704


                                                                           

Epoch 148/150: Train Loss = 665.9193, Val Loss = 1006.9692


                                                                           

Epoch 149/150: Train Loss = 665.9197, Val Loss = 1006.9706


                                                                           

Epoch 150/150: Train Loss = 665.9197, Val Loss = 1006.9702
Model saved to saved_models/conv_vae/fraud_conv_vae/20250313_173415/final_model.pt
Training complete. Best validation loss: 1006.9674


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/recon_loss,▄▄▃▄▂█▂▁▃▃▄▄▃▃▃▃▃▇▆▁▁▂▄▅▄▄▂▁▄▄▄▄▃▅▃▅▂▆▆▁
train/total_loss,█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/total_loss,█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,150.0
train/recon_loss,4001.76562
train/total_loss,665.91969
val/recon_loss,10876.7793
val/total_loss,1006.97016


Training session complete.


In [1]:
import os
import torch
import torch.optim as optim
import numpy as np
import random
import configparser

# Import our modules
from models.conv_vae_model import ConvVae, vae_loss_function, print_num_params
from trainer.test_trainer_vae import VAETrainer  # Updated trainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger_lr import WandBLogger
from utils.evaluation_utils import extract_recon_loss
from utils.lr_scheduler import create_scheduler  # New scheduler utility

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Build the config path
config_path = "configs/conv_vae/fraud_conv_vae_test.config"

# Load configuration
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vae_config = config_parser["Conv_VAE"]
train_config = config_parser["Trainer"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVae(conv_vae_config)
print_num_params(model)

# Training parameters
lr = train_config.getfloat("lr")
num_epochs = train_config.getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)

# Create learning rate scheduler
scheduler = create_scheduler(train_config, optimizer)

# Define loss function
loss_fn = vae_loss_function

# Create trainer with scheduler
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer, scheduler)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            
            # Add learning rate to logging
            wandb_logger.log_epoch(
                epoch, train_loss, val_loss, 
                train_recon_loss, val_recon_loss,
                learning_rate=current_lr
            )
        except Exception as e:
            print(f"WandB logging error: {e}")
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress with learning rate
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, LR = {current_lr:.1e}")
        
        # Update scheduler based on validation loss
        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {
                'epoch': epoch, 
                'train_loss': train_loss, 
                'val_loss': val_loss, 
                'learning_rate': current_lr
            }
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {
            'epoch': num_epochs, 
            'train_loss': train_losses[-1], 
            'val_loss': val_losses[-1],
            'learning_rate': optimizer.param_groups[0]['lr']
        }
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    # Save interrupted model logic...

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Initializing WandB run: conv-vae-fraud_20250313_231435 (Project: fraud-classification, Entity: alexkstern)


[34m[1mwandb[0m: Currently logged in as: [33malexkstern[0m ([33malexksternteam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loaded configuration from configs/conv_vae/fraud_conv_vae_test.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 62727




Models will be saved to: saved_models/conv_vae/fraud_conv_vae_test/20250313_231447


                                                                           

Epoch 1/150: Train Loss = 684.0239, Val Loss = 1022.5321, LR = 1.0e-02
Model saved to saved_models/conv_vae/fraud_conv_vae_test/20250313_231447/best_model.pt
New best validation loss: 1022.5321


                                                                           

Epoch 2/150: Train Loss = 678.9608, Val Loss = 1022.5317, LR = 1.0e-02
Model saved to saved_models/conv_vae/fraud_conv_vae_test/20250313_231447/best_model.pt
New best validation loss: 1022.5317


                                                                           

Epoch 3/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-02
Model saved to saved_models/conv_vae/fraud_conv_vae_test/20250313_231447/best_model.pt
New best validation loss: 1022.5317


                                                                           

Epoch 4/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-02


                                                                           

Epoch 5/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-02


                                                                           

Epoch 6/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-02
Model saved to saved_models/conv_vae/fraud_conv_vae_test/20250313_231447/best_model.pt
New best validation loss: 1022.5317


                                                                           

Epoch 7/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-02


                                                                           

Epoch 8/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 5.0e-03


                                                                           

Epoch 9/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 5.0e-03


                                                                           

Epoch 10/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 5.0e-03


                                                                           

Epoch 11/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 5.0e-03


                                                                           

Epoch 12/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 5.0e-03


                                                                           

Epoch 13/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 5.0e-03


                                                                           

Epoch 14/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.5e-03


                                                                           

Epoch 15/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.5e-03


                                                                           

Epoch 16/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.5e-03


                                                                           

Epoch 17/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.5e-03


                                                                           

Epoch 18/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.5e-03


                                                                           

Epoch 19/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.5e-03


                                                                           

Epoch 20/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.3e-03


                                                                           

Epoch 21/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.3e-03


                                                                           

Epoch 22/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.3e-03


                                                                           

Epoch 23/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.3e-03


                                                                           

Epoch 24/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.3e-03


                                                                           

Epoch 25/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.3e-03


                                                                           

Epoch 26/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 6.3e-04


                                                                           

Epoch 27/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 6.3e-04


                                                                           

Epoch 28/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 6.3e-04


                                                                           

Epoch 29/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 6.3e-04


                                                                           

Epoch 30/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 6.3e-04


                                                                           

Epoch 31/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 6.3e-04


                                                                           

Epoch 32/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.1e-04


                                                                           

Epoch 33/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.1e-04


                                                                           

Epoch 34/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.1e-04


                                                                           

Epoch 35/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.1e-04


                                                                           

Epoch 36/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.1e-04


                                                                           

Epoch 37/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.1e-04


                                                                           

Epoch 38/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.6e-04


                                                                           

Epoch 39/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.6e-04


                                                                           

Epoch 40/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.6e-04


                                                                           

Epoch 41/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.6e-04


                                                                           

Epoch 42/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.6e-04


                                                                           

Epoch 43/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.6e-04


                                                                           

Epoch 44/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 7.8e-05


                                                                           

Epoch 45/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 7.8e-05


                                                                           

Epoch 46/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 7.8e-05


                                                                           

Epoch 47/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 7.8e-05


                                                                           

Epoch 48/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 7.8e-05


                                                                           

Epoch 49/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 7.8e-05


                                                                           

Epoch 50/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.9e-05


                                                                           

Epoch 51/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.9e-05


                                                                           

Epoch 52/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.9e-05


                                                                           

Epoch 53/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.9e-05


                                                                           

Epoch 54/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.9e-05


                                                                           

Epoch 55/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 3.9e-05


                                                                           

Epoch 56/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.0e-05


                                                                           

Epoch 57/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.0e-05


                                                                           

Epoch 58/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.0e-05


                                                                           

Epoch 59/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.0e-05


                                                                           

Epoch 60/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.0e-05


                                                                           

Epoch 61/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.0e-05


                                                                           

Epoch 62/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 9.8e-06


                                                                           

Epoch 63/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 9.8e-06


                                                                           

Epoch 64/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 9.8e-06


                                                                           

Epoch 65/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 9.8e-06


                                                                           

Epoch 66/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 9.8e-06


                                                                           

Epoch 67/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 9.8e-06


                                                                           

Epoch 68/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 4.9e-06


                                                                           

Epoch 69/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 4.9e-06


                                                                           

Epoch 70/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 4.9e-06


                                                                           

Epoch 71/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 4.9e-06


                                                                           

Epoch 72/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 4.9e-06


                                                                           

Epoch 73/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 4.9e-06


                                                                           

Epoch 74/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.4e-06


                                                                           

Epoch 75/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.4e-06


                                                                           

Epoch 76/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.4e-06


                                                                           

Epoch 77/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.4e-06


                                                                           

Epoch 78/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.4e-06


                                                                           

Epoch 79/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 2.4e-06


                                                                           

Epoch 80/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.2e-06


                                                                           

Epoch 81/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.2e-06


                                                                           

Epoch 82/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.2e-06


                                                                           

Epoch 83/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.2e-06


                                                                           

Epoch 84/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.2e-06


                                                                           

Epoch 85/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.2e-06


                                                                           

Epoch 86/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 87/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 88/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 89/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 90/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 91/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 92/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 93/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 94/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 95/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 96/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 97/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 98/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 99/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 100/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 101/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 102/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 103/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 104/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 105/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 106/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 107/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 108/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 109/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 110/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 111/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 112/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 113/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 114/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 115/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 116/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 117/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 118/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 119/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 120/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 121/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 122/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 123/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 124/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 125/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 126/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 127/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 128/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 129/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 130/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 131/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 132/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 133/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 134/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 135/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 136/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 137/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 138/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 139/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 140/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 141/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 142/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 143/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 144/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 145/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 146/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 147/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 148/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 149/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06


                                                                           

Epoch 150/150: Train Loss = 678.9607, Val Loss = 1022.5317, LR = 1.0e-06
Training complete. Best validation loss: 1022.5317


0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇█████
learning_rate,██▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/recon_loss,▆▂▃▄▁▄▃▅▃▂▃▅▅▆▇▅▁▂▄▆▄▂█▁▅▄▁▄▃▃▂▆▃▄▃▄▅▅▇▂
train/total_loss,██████████████████████████████▁█████████
val/recon_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/total_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,150.0
learning_rate,0.0
train/recon_loss,4085.08521
train/total_loss,678.96068
val/recon_loss,11025.65723
val/total_loss,1022.53168


Training session complete.


### Normal

In [1]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vae_model import ConvVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/conv_vae/normal_conv_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vae_config = config_parser["Conv_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVae(conv_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Initializing WandB run: conv-vae-normal_20250313_175704 (Project: fraud-classification, Entity: alexkstern)


[34m[1mwandb[0m: Currently logged in as: [33malexkstern[0m ([33malexksternteam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loaded configuration from configs/conv_vae/normal_conv_vae.config
Filtered dataset to class 0: 12000 samples
Normalization statistics (calculated from class 0): {'Time': {'mean': 94364.65358333333, 'std': 47365.815157589255}, 'Amount': {'mean': 87.60478666666666, 'std': 240.59403081682598}}
Filtered dataset to class 0: 12000 samples
Filtered dataset to class 0: 1500 samples
Filtered dataset to class 0: 1500 samples
Total number of trainable parameters: 62727
Models will be saved to: saved_models/conv_vae/normal_conv_vae/20250313_175713


                                                                               

Epoch 1/30: Train Loss = 31.6793, Val Loss = 32.5413
Model saved to saved_models/conv_vae/normal_conv_vae/20250313_175713/best_model.pt
New best validation loss: 32.5413


                                                                               

Epoch 2/30: Train Loss = 31.5506, Val Loss = 32.5407
Model saved to saved_models/conv_vae/normal_conv_vae/20250313_175713/best_model.pt
New best validation loss: 32.5407


                                                                               

Epoch 3/30: Train Loss = 31.5506, Val Loss = 32.5412


                                                                               

Epoch 4/30: Train Loss = 31.5506, Val Loss = 32.5413


                                                                               

Epoch 5/30: Train Loss = 31.5506, Val Loss = 32.5413


                                                                               

Epoch 6/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 7/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 8/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 9/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 10/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 11/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 12/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 13/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 14/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 15/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 16/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 17/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 18/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 19/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 20/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 21/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 22/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 23/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 24/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 25/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 26/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 27/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 28/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 29/30: Train Loss = 31.5505, Val Loss = 32.5413


                                                                               

Epoch 30/30: Train Loss = 31.5505, Val Loss = 32.5413
Model saved to saved_models/conv_vae/normal_conv_vae/20250313_175713/final_model.pt
Training complete. Best validation loss: 32.5407


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train/recon_loss,▃█▂▂▂▂▂▃▃▃▃▂▄▁▂▁▂▂▂▁▁▁▂▁▂▃▂▂▂▂
train/total_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,█▁▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
val/total_loss,█▁▇███████████████████████████

0,1
epoch,30.0
train/recon_loss,170.05019
train/total_loss,31.55054
val/recon_loss,199.38646
val/total_loss,32.54126


Training session complete.


## Transformer VAE

### Fraud

In [1]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vae_model import TransformerVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/transformer_vae/fraud_transformer_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vae_config = config_parser["Transformer_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVae(transformer_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Initializing WandB run: transformer-vae-fraud_20250313_190718 (Project: fraud-classification, Entity: alexkstern)


[34m[1mwandb[0m: Currently logged in as: [33malexkstern[0m ([33malexksternteam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loaded configuration from configs/transformer_vae/fraud_transformer_vae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 147879
Models will be saved to: saved_models/transformer_vae/fraud_transformer_vae/20250313_190737


                                                                           

Epoch 1/100: Train Loss = 682.0477, Val Loss = 1022.8600
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.8600


                                                                           

Epoch 2/100: Train Loss = 679.1918, Val Loss = 1022.7154
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.7154


                                                                           

Epoch 3/100: Train Loss = 679.1025, Val Loss = 1022.6512
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.6512


                                                                           

Epoch 4/100: Train Loss = 679.0561, Val Loss = 1022.6143
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.6143


                                                                           

Epoch 5/100: Train Loss = 679.0293, Val Loss = 1022.5921
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5921


                                                                           

Epoch 6/100: Train Loss = 679.0118, Val Loss = 1022.5776
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5776


                                                                           

Epoch 7/100: Train Loss = 679.0006, Val Loss = 1022.5677
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5677


                                                                           

Epoch 8/100: Train Loss = 678.9925, Val Loss = 1022.5607
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5607


                                                                           

Epoch 9/100: Train Loss = 678.9867, Val Loss = 1022.5558
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5558


                                                                           

Epoch 10/100: Train Loss = 678.9825, Val Loss = 1022.5517
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5517


                                                                           

Epoch 11/100: Train Loss = 678.9791, Val Loss = 1022.5487
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5487


                                                                           

Epoch 12/100: Train Loss = 678.9764, Val Loss = 1022.5463
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5463


                                                                           

Epoch 13/100: Train Loss = 678.9744, Val Loss = 1022.5444
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5444


                                                                           

Epoch 14/100: Train Loss = 678.9726, Val Loss = 1022.5428
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5428


                                                                           

Epoch 15/100: Train Loss = 678.9712, Val Loss = 1022.5415
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5415


                                                                           

Epoch 16/100: Train Loss = 678.9701, Val Loss = 1022.5404
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5404


                                                                           

Epoch 17/100: Train Loss = 678.9690, Val Loss = 1022.5395
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5395


                                                                           

Epoch 18/100: Train Loss = 678.9682, Val Loss = 1022.5387
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5387


                                                                           

Epoch 19/100: Train Loss = 678.9675, Val Loss = 1022.5380
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5380


                                                                           

Epoch 20/100: Train Loss = 678.9669, Val Loss = 1022.5374
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5374


                                                                           

Epoch 21/100: Train Loss = 678.9663, Val Loss = 1022.5368
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5368


                                                                           

Epoch 22/100: Train Loss = 678.9658, Val Loss = 1022.5365
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5365


                                                                           

Epoch 23/100: Train Loss = 678.9654, Val Loss = 1022.5360
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5360


                                                                           

Epoch 24/100: Train Loss = 678.9650, Val Loss = 1022.5357
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5357


                                                                           

Epoch 25/100: Train Loss = 678.9647, Val Loss = 1022.5354
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5354


                                                                           

Epoch 26/100: Train Loss = 678.9644, Val Loss = 1022.5351
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5351


                                                                           

Epoch 27/100: Train Loss = 678.9641, Val Loss = 1022.5349
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5349


                                                                           

Epoch 28/100: Train Loss = 678.9639, Val Loss = 1022.5346
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5346


                                                                           

Epoch 29/100: Train Loss = 678.9637, Val Loss = 1022.5344
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5344


                                                                           

Epoch 30/100: Train Loss = 678.9634, Val Loss = 1022.5342
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5342


                                                                           

Epoch 31/100: Train Loss = 678.9633, Val Loss = 1022.5340
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5340


                                                                           

Epoch 32/100: Train Loss = 678.9631, Val Loss = 1022.5339
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5339


                                                                           

Epoch 33/100: Train Loss = 678.9630, Val Loss = 1022.5337
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5337


                                                                           

Epoch 34/100: Train Loss = 678.9628, Val Loss = 1022.5336
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5336


                                                                           

Epoch 35/100: Train Loss = 678.9627, Val Loss = 1022.5335
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5335


                                                                           

Epoch 36/100: Train Loss = 678.9626, Val Loss = 1022.5334
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5334


                                                                           

Epoch 37/100: Train Loss = 678.9625, Val Loss = 1022.5333
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5333


                                                                           

Epoch 38/100: Train Loss = 678.9624, Val Loss = 1022.5333
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5333


                                                                           

Epoch 39/100: Train Loss = 678.9623, Val Loss = 1022.5331
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5331


                                                                           

Epoch 40/100: Train Loss = 678.9622, Val Loss = 1022.5330
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5330


                                                                           

Epoch 41/100: Train Loss = 678.9621, Val Loss = 1022.5329
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5329


                                                                           

Epoch 42/100: Train Loss = 678.9620, Val Loss = 1022.5330


                                                                           

Epoch 43/100: Train Loss = 678.9620, Val Loss = 1022.5328
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5328


                                                                           

Epoch 44/100: Train Loss = 678.9619, Val Loss = 1022.5328
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5328


                                                                           

Epoch 45/100: Train Loss = 678.9619, Val Loss = 1022.5328


                                                                           

Epoch 46/100: Train Loss = 678.9618, Val Loss = 1022.5327
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5327


                                                                           

Epoch 47/100: Train Loss = 678.9617, Val Loss = 1022.5326
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5326


                                                                           

Epoch 48/100: Train Loss = 678.9617, Val Loss = 1022.5326
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5326


                                                                           

Epoch 49/100: Train Loss = 678.9616, Val Loss = 1022.5326
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5326


                                                                           

Epoch 50/100: Train Loss = 678.9616, Val Loss = 1022.5325
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5325


                                                                           

Epoch 51/100: Train Loss = 678.9615, Val Loss = 1022.5325


                                                                           

Epoch 52/100: Train Loss = 678.9615, Val Loss = 1022.5324
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5324


                                                                           

Epoch 53/100: Train Loss = 678.9615, Val Loss = 1022.5324


                                                                           

Epoch 54/100: Train Loss = 678.9614, Val Loss = 1022.5324


                                                                           

Epoch 55/100: Train Loss = 678.9614, Val Loss = 1022.5324
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5324


                                                                           

Epoch 56/100: Train Loss = 678.9614, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 57/100: Train Loss = 678.9613, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 58/100: Train Loss = 678.9613, Val Loss = 1022.5323


                                                                           

Epoch 59/100: Train Loss = 678.9613, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 60/100: Train Loss = 678.9613, Val Loss = 1022.5322
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5322


                                                                           

Epoch 61/100: Train Loss = 678.9612, Val Loss = 1022.5321
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5321


                                                                           

Epoch 62/100: Train Loss = 678.9612, Val Loss = 1022.5321


                                                                           

Epoch 63/100: Train Loss = 678.9612, Val Loss = 1022.5321
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5321


                                                                           

Epoch 64/100: Train Loss = 678.9612, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 65/100: Train Loss = 678.9612, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 66/100: Train Loss = 678.9611, Val Loss = 1022.5321


                                                                           

Epoch 67/100: Train Loss = 678.9611, Val Loss = 1022.5320


                                                                           

Epoch 68/100: Train Loss = 678.9611, Val Loss = 1022.5321


                                                                           

Epoch 69/100: Train Loss = 678.9611, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 70/100: Train Loss = 678.9611, Val Loss = 1022.5320


                                                                           

Epoch 71/100: Train Loss = 678.9611, Val Loss = 1022.5320


                                                                           

Epoch 72/100: Train Loss = 678.9610, Val Loss = 1022.5320
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5320


                                                                           

Epoch 73/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 74/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 75/100: Train Loss = 678.9610, Val Loss = 1022.5319
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5319


                                                                           

Epoch 76/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 77/100: Train Loss = 678.9610, Val Loss = 1022.5320


                                                                           

Epoch 78/100: Train Loss = 678.9610, Val Loss = 1022.5319
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5319


                                                                           

Epoch 79/100: Train Loss = 678.9610, Val Loss = 1022.5319


                                                                           

Epoch 80/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 81/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 82/100: Train Loss = 678.9609, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 83/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 84/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 85/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 86/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 87/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 88/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 89/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 90/100: Train Loss = 678.9609, Val Loss = 1022.5319


                                                                           

Epoch 91/100: Train Loss = 678.9609, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 92/100: Train Loss = 678.9609, Val Loss = 1022.5318


                                                                           

Epoch 93/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 94/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 95/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 96/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 97/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 98/100: Train Loss = 678.9608, Val Loss = 1022.5318


                                                                           

Epoch 99/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/best_model.pt
New best validation loss: 1022.5318


                                                                           

Epoch 100/100: Train Loss = 678.9608, Val Loss = 1022.5318
Model saved to saved_models/transformer_vae/fraud_transformer_vae/20250313_190737/final_model.pt
Training complete. Best validation loss: 1022.5318


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
train/recon_loss,▂▃▃▂▃▃▄▅▃▁▆▆▄▃▇█▂▄▂▂▅▁▅▁▅▃▃▄▃▄▃▄▅▄▄▄▃▂▃▅
train/total_loss,█▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/total_loss,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
train/recon_loss,9307.76562
train/total_loss,678.96082
val/recon_loss,11025.65723
val/total_loss,1022.5318


Training session complete.


### Normal

In [2]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vae_model import TransformerVae, vae_loss_function, print_num_params
from trainer.trainer_vae import VAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/transformer_vae/normal_transformer_vae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vae_config = config_parser["Transformer_VAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVae(transformer_vae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vae_loss_function

# Create trainer
trainer = VAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            train_batch = next(iter(dataloaders['train']))
            val_batch = next(iter(dataloaders['val']))
            train_recon_loss = extract_recon_loss(model, train_batch, trainer.device)
            val_recon_loss = extract_recon_loss(model, val_batch, trainer.device)
            wandb_logger.log_epoch(epoch, train_loss, val_loss, train_recon_loss, val_recon_loss)
        except Exception as e:
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

Initializing WandB run: transformer-vae-normal_20250313_232840 (Project: fraud-classification, Entity: alexkstern)


Loaded configuration from configs/transformer_vae/normal_transformer_vae.config
Filtered dataset to class 0: 12000 samples
Normalization statistics (calculated from class 0): {'Time': {'mean': 94364.65358333333, 'std': 47365.815157589255}, 'Amount': {'mean': 87.60478666666666, 'std': 240.59403081682598}}
Filtered dataset to class 0: 12000 samples
Filtered dataset to class 0: 1500 samples
Filtered dataset to class 0: 1500 samples
Total number of trainable parameters: 147879
Models will be saved to: saved_models/transformer_vae/normal_transformer_vae/20250313_232852


                                                                               

Epoch 1/30: Train Loss = 31.5605, Val Loss = 32.5408
Model saved to saved_models/transformer_vae/normal_transformer_vae/20250313_232852/best_model.pt
New best validation loss: 32.5408


                                                                               

Epoch 2/30: Train Loss = 31.5503, Val Loss = 32.5409


                                                                               

Epoch 3/30: Train Loss = 31.5504, Val Loss = 32.5407
Model saved to saved_models/transformer_vae/normal_transformer_vae/20250313_232852/best_model.pt
New best validation loss: 32.5407


                                                                               

Epoch 4/30: Train Loss = 31.5504, Val Loss = 32.5409


                                                                               

Epoch 5/30: Train Loss = 31.5503, Val Loss = 32.5410


                                                                               

Epoch 6/30: Train Loss = 31.5505, Val Loss = 32.5410


                                                                               

Epoch 7/30: Train Loss = 31.5503, Val Loss = 32.5407


                                                                               

Epoch 8/30: Train Loss = 31.5504, Val Loss = 32.5411


                                                                               

Epoch 9/30: Train Loss = 31.5504, Val Loss = 32.5411


                                                                               

Epoch 10/30: Train Loss = 31.5505, Val Loss = 32.5411


                                                                               

Epoch 11/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                               

Epoch 12/30: Train Loss = 31.5503, Val Loss = 32.5410


                                                                               

Epoch 13/30: Train Loss = 31.5503, Val Loss = 32.5408


                                                                               

Epoch 14/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                               

Epoch 15/30: Train Loss = 31.5503, Val Loss = 32.5410


                                                                               

Epoch 16/30: Train Loss = 31.5504, Val Loss = 32.5411


                                                                               

Epoch 17/30: Train Loss = 31.5504, Val Loss = 32.5411


                                                                                

Epoch 18/30: Train Loss = 31.5505, Val Loss = 32.5409


                                                                               

Epoch 19/30: Train Loss = 31.5505, Val Loss = 32.5411


                                                                               

Epoch 20/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                               

Epoch 21/30: Train Loss = 31.5504, Val Loss = 32.5409


                                                                               

Epoch 22/30: Train Loss = 31.5503, Val Loss = 32.5410


                                                                               

Epoch 23/30: Train Loss = 31.5504, Val Loss = 32.5409


                                                                               

Epoch 24/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                               

Epoch 25/30: Train Loss = 31.5503, Val Loss = 32.5408


                                                                               

Epoch 26/30: Train Loss = 31.5505, Val Loss = 32.5410


                                                                               

Epoch 27/30: Train Loss = 31.5503, Val Loss = 32.5409


                                                                               

Epoch 28/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                               

Epoch 29/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                               

Epoch 30/30: Train Loss = 31.5502, Val Loss = 32.5410
Model saved to saved_models/transformer_vae/normal_transformer_vae/20250313_232852/final_model.pt
Training complete. Best validation loss: 32.5407


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train/recon_loss,▂▁▂▂▁▁▂▃▇▃▁▁▂▂▁▂▄▂▄█▂▂▃▁▂▂▂▂▂▂
train/total_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/recon_loss,▅▃█▃▁▁▆▁▁▁▂▂▅▂▂▁▁▃▁▂▃▂▃▂▅▂▃▂▁▁
val/total_loss,▂▅▁▅▇▇▁█▇█▆▆▂▆▆██▄█▆▄▆▅▆▂▆▅▆▇▇

0,1
epoch,30.0
train/recon_loss,194.85416
train/total_loss,31.5502
val/recon_loss,199.39651
val/total_loss,32.54104


Training session complete.


## Conv VQ VAE

### Fraud

In [3]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vqvae_model import ConvVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger
from utils.evaluation_utils import extract_recon_loss

# Build the config path
config_path = "configs/conv_vqvae/fraud_conv_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vqvae_config = config_parser["Conv_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVQVAE(conv_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log metrics to WandB
        try:
            # For VQVAE, we only log the total losses since we don't have separate recon_loss extraction
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

Initializing WandB run: conv-vqvae-fraud_20250313_235557 (Project: fraud-classification, Entity: alexkstern)


Loaded configuration from configs/conv_vqvae/fraud_conv_vqvae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 64164
Models will be saved to: saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605


                                                                             

Epoch 1/100: Train Loss = 789.5891, Val Loss = 1094.8434
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 1094.8434


                                                                             

Epoch 2/100: Train Loss = 687.4258, Val Loss = 1024.3413
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 1024.3413


                                                                             

Epoch 3/100: Train Loss = 679.2138, Val Loss = 1022.5116
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 1022.5116


                                                                             

Epoch 4/100: Train Loss = 678.5124, Val Loss = 1014.5305
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 1014.5305


                                                                             

Epoch 5/100: Train Loss = 666.8460, Val Loss = 1006.3647
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 1006.3647


                                                                             

Epoch 6/100: Train Loss = 664.9659, Val Loss = 1006.4459


                                                                             

Epoch 7/100: Train Loss = 664.9075, Val Loss = 1005.8581
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 1005.8581


                                                                             

Epoch 8/100: Train Loss = 659.2394, Val Loss = 998.6946
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 998.6946


                                                                             

Epoch 9/100: Train Loss = 658.1904, Val Loss = 999.5041


                                                                             

Epoch 10/100: Train Loss = 658.1658, Val Loss = 999.1953


                                                                             

Epoch 11/100: Train Loss = 658.2933, Val Loss = 999.6391


                                                                             

Epoch 12/100: Train Loss = 658.2448, Val Loss = 999.2041


                                                                             

Epoch 13/100: Train Loss = 658.3359, Val Loss = 999.6279


                                                                             

Epoch 14/100: Train Loss = 658.1899, Val Loss = 999.3504


                                                                             

Epoch 15/100: Train Loss = 658.2206, Val Loss = 998.9550


                                                                             

Epoch 16/100: Train Loss = 658.1138, Val Loss = 998.9332


                                                                             

Epoch 17/100: Train Loss = 658.1868, Val Loss = 998.9254


                                                                             

Epoch 18/100: Train Loss = 658.1533, Val Loss = 998.9335


                                                                             

Epoch 19/100: Train Loss = 658.0802, Val Loss = 998.9409


                                                                             

Epoch 20/100: Train Loss = 658.1270, Val Loss = 998.8804


                                                                             

Epoch 21/100: Train Loss = 658.1445, Val Loss = 998.9299


                                                                             

Epoch 22/100: Train Loss = 658.1111, Val Loss = 998.6067
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 998.6067


                                                                             

Epoch 23/100: Train Loss = 657.9859, Val Loss = 998.5200
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 998.5200


                                                                             

Epoch 24/100: Train Loss = 658.0306, Val Loss = 998.4305
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 998.4305


                                                                             

Epoch 25/100: Train Loss = 658.0364, Val Loss = 998.5711


                                                                             

Epoch 26/100: Train Loss = 658.0216, Val Loss = 998.6831


                                                                             

Epoch 27/100: Train Loss = 657.9549, Val Loss = 998.5988


                                                                             

Epoch 28/100: Train Loss = 657.9662, Val Loss = 998.4462


                                                                             

Epoch 29/100: Train Loss = 657.9193, Val Loss = 998.5684


                                                                             

Epoch 30/100: Train Loss = 657.9319, Val Loss = 998.6018


                                                                             

Epoch 31/100: Train Loss = 657.9533, Val Loss = 998.5944


                                                                             

Epoch 32/100: Train Loss = 657.9301, Val Loss = 998.5278


                                                                             

Epoch 33/100: Train Loss = 657.9744, Val Loss = 998.6046


                                                                             

Epoch 34/100: Train Loss = 657.9675, Val Loss = 998.5977


                                                                             

Epoch 35/100: Train Loss = 657.9539, Val Loss = 998.5346


                                                                             

Epoch 36/100: Train Loss = 657.9313, Val Loss = 998.4389


                                                                             

Epoch 37/100: Train Loss = 657.9430, Val Loss = 998.4895


                                                                             

Epoch 38/100: Train Loss = 657.9794, Val Loss = 998.5955


                                                                             

Epoch 39/100: Train Loss = 657.9515, Val Loss = 998.5428


                                                                             

Epoch 40/100: Train Loss = 657.9263, Val Loss = 998.5304


                                                                             

Epoch 41/100: Train Loss = 657.9502, Val Loss = 998.4788


                                                                             

Epoch 42/100: Train Loss = 657.9433, Val Loss = 998.5076


                                                                             

Epoch 43/100: Train Loss = 657.9420, Val Loss = 998.5131


                                                                             

Epoch 44/100: Train Loss = 657.9432, Val Loss = 998.5226


                                                                             

Epoch 45/100: Train Loss = 657.9403, Val Loss = 998.4326


                                                                             

Epoch 46/100: Train Loss = 657.9348, Val Loss = 998.5024


                                                                             

Epoch 47/100: Train Loss = 657.9226, Val Loss = 998.5036


                                                                             

Epoch 48/100: Train Loss = 657.9798, Val Loss = 998.4244
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/best_model.pt
New best validation loss: 998.4244


                                                                             

Epoch 49/100: Train Loss = 657.9372, Val Loss = 998.4995


                                                                             

Epoch 50/100: Train Loss = 657.9309, Val Loss = 998.5877


                                                                             

Epoch 51/100: Train Loss = 657.9189, Val Loss = 998.4793


                                                                             

Epoch 52/100: Train Loss = 657.9155, Val Loss = 998.4975


                                                                             

Epoch 53/100: Train Loss = 657.9036, Val Loss = 998.5140


                                                                             

Epoch 54/100: Train Loss = 657.9425, Val Loss = 998.4353


                                                                             

Epoch 55/100: Train Loss = 657.9394, Val Loss = 998.5009


                                                                             

Epoch 56/100: Train Loss = 657.9275, Val Loss = 998.5574


                                                                             

Epoch 57/100: Train Loss = 657.9282, Val Loss = 998.5053


                                                                             

Epoch 58/100: Train Loss = 657.9177, Val Loss = 998.4841


                                                                             

Epoch 59/100: Train Loss = 657.9145, Val Loss = 998.4622


                                                                             

Epoch 60/100: Train Loss = 657.9127, Val Loss = 998.4680


                                                                             

Epoch 61/100: Train Loss = 657.9189, Val Loss = 998.5383


                                                                             

Epoch 62/100: Train Loss = 657.9126, Val Loss = 998.4842


                                                                             

Epoch 63/100: Train Loss = 657.9185, Val Loss = 998.5681


                                                                             

Epoch 64/100: Train Loss = 657.9426, Val Loss = 998.4751


                                                                             

Epoch 65/100: Train Loss = 657.9125, Val Loss = 998.5215


                                                                             

Epoch 66/100: Train Loss = 657.9093, Val Loss = 998.4908


                                                                             

Epoch 67/100: Train Loss = 657.9186, Val Loss = 998.4965


                                                                             

Epoch 68/100: Train Loss = 657.9148, Val Loss = 998.5162


                                                                             

Epoch 69/100: Train Loss = 657.9142, Val Loss = 998.5408


                                                                             

Epoch 70/100: Train Loss = 657.9504, Val Loss = 998.4680


                                                                             

Epoch 71/100: Train Loss = 657.9185, Val Loss = 998.4440


                                                                             

Epoch 72/100: Train Loss = 657.9329, Val Loss = 998.4844


                                                                             

Epoch 73/100: Train Loss = 657.9035, Val Loss = 998.4855


                                                                             

Epoch 74/100: Train Loss = 657.9244, Val Loss = 998.4765


                                                                             

Epoch 75/100: Train Loss = 657.9042, Val Loss = 998.4996


                                                                             

Epoch 76/100: Train Loss = 657.9028, Val Loss = 998.4794


                                                                             

Epoch 77/100: Train Loss = 657.9108, Val Loss = 998.4582


                                                                             

Epoch 78/100: Train Loss = 657.9229, Val Loss = 998.5428


                                                                             

Epoch 79/100: Train Loss = 657.9209, Val Loss = 998.5125


                                                                             

Epoch 80/100: Train Loss = 657.9114, Val Loss = 998.4926


                                                                             

Epoch 81/100: Train Loss = 657.9075, Val Loss = 998.4866


                                                                             

Epoch 82/100: Train Loss = 657.9155, Val Loss = 998.4777


                                                                             

Epoch 83/100: Train Loss = 657.9218, Val Loss = 998.5151


                                                                             

Epoch 84/100: Train Loss = 657.9071, Val Loss = 998.4612


                                                                             

Epoch 85/100: Train Loss = 657.9095, Val Loss = 998.5120


                                                                             

Epoch 86/100: Train Loss = 657.9162, Val Loss = 998.4674


                                                                             

Epoch 87/100: Train Loss = 657.9242, Val Loss = 998.4832


                                                                             

Epoch 88/100: Train Loss = 657.8995, Val Loss = 998.5081


                                                                             

Epoch 89/100: Train Loss = 657.8958, Val Loss = 998.4973


                                                                             

Epoch 90/100: Train Loss = 657.8984, Val Loss = 998.4970


                                                                             

Epoch 91/100: Train Loss = 657.9064, Val Loss = 998.4984


                                                                             

Epoch 92/100: Train Loss = 657.9026, Val Loss = 998.4972


                                                                             

Epoch 93/100: Train Loss = 657.9074, Val Loss = 998.4826


                                                                             

Epoch 94/100: Train Loss = 657.9037, Val Loss = 998.4617


                                                                             

Epoch 95/100: Train Loss = 657.9033, Val Loss = 998.4612


                                                                             

Epoch 96/100: Train Loss = 657.8975, Val Loss = 998.4899


                                                                             

Epoch 97/100: Train Loss = 657.9173, Val Loss = 998.4830


                                                                             

Epoch 98/100: Train Loss = 657.9011, Val Loss = 998.4700


                                                                             

Epoch 99/100: Train Loss = 657.8975, Val Loss = 998.5024


                                                                             

Epoch 100/100: Train Loss = 657.9041, Val Loss = 998.5084
Model saved to saved_models/conv_vqvae/fraud_conv_vqvae/20250313_235605/final_model.pt
Training complete. Best validation loss: 998.4244




0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██
train/total_loss,█▆▆▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/total_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
train/total_loss,657.90409
val/total_loss,998.50837


Training session complete.


### Normal


In [4]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.conv_vqvae_model import ConvVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger

# Build the config path
config_path = "configs/conv_vqvae/normal_conv_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
conv_vqvae_config = config_parser["Conv_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = ConvVQVAE(conv_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log metrics to WandB
        try:
            # For VQVAE, we only log the total losses since we don't have separate recon_loss extraction
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

Initializing WandB run: conv-vqvae-normal_20250313_235814 (Project: fraud-classification, Entity: alexkstern)


Loaded configuration from configs/conv_vqvae/normal_conv_vqvae.config
Filtered dataset to class 0: 12000 samples
Normalization statistics (calculated from class 0): {'Time': {'mean': 94364.65358333333, 'std': 47365.815157589255}, 'Amount': {'mean': 87.60478666666666, 'std': 240.59403081682598}}
Filtered dataset to class 0: 12000 samples
Filtered dataset to class 0: 1500 samples
Filtered dataset to class 0: 1500 samples
Total number of trainable parameters: 64164
Models will be saved to: saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821


                                                                                  

Epoch 1/30: Train Loss = 34.8762, Val Loss = 31.9475
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 31.9475


                                                                                  

Epoch 2/30: Train Loss = 30.2766, Val Loss = 31.5840
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 31.5840


                                                                                 

Epoch 3/30: Train Loss = 30.1573, Val Loss = 30.8868
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 30.8868


                                                                                  

Epoch 4/30: Train Loss = 29.9350, Val Loss = 31.3080


                                                                                 

Epoch 5/30: Train Loss = 29.9361, Val Loss = 30.8306
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 30.8306


                                                                                 

Epoch 6/30: Train Loss = 29.6313, Val Loss = 30.6516
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 30.6516


                                                                                 

Epoch 7/30: Train Loss = 29.8023, Val Loss = 31.1184


                                                                                 

Epoch 8/30: Train Loss = 29.6932, Val Loss = 30.5924
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 30.5924


                                                                                 

Epoch 9/30: Train Loss = 29.4954, Val Loss = 30.5075
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 30.5075


                                                                                 

Epoch 10/30: Train Loss = 29.4116, Val Loss = 30.2268
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 30.2268


                                                                                  

Epoch 11/30: Train Loss = 29.4478, Val Loss = 30.3042


                                                                                 

Epoch 12/30: Train Loss = 29.0861, Val Loss = 29.7262
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 29.7262


                                                                                 

Epoch 13/30: Train Loss = 28.8436, Val Loss = 29.6217
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 29.6217


                                                                                 

Epoch 14/30: Train Loss = 28.8971, Val Loss = 29.5758
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 29.5758


                                                                                 

Epoch 15/30: Train Loss = 29.0012, Val Loss = 29.7518


                                                                                 

Epoch 16/30: Train Loss = 28.9728, Val Loss = 29.4276
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 29.4276


                                                                                  

Epoch 17/30: Train Loss = 28.6322, Val Loss = 29.4860


                                                                                 

Epoch 18/30: Train Loss = 28.2609, Val Loss = 29.1322
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 29.1322


                                                                                  

Epoch 19/30: Train Loss = 28.2884, Val Loss = 29.2113


                                                                                 

Epoch 20/30: Train Loss = 28.3572, Val Loss = 28.9754
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 28.9754


                                                                                  

Epoch 21/30: Train Loss = 28.2050, Val Loss = 29.3145


                                                                                  

Epoch 22/30: Train Loss = 28.3093, Val Loss = 29.2697


                                                                                 

Epoch 23/30: Train Loss = 28.1953, Val Loss = 28.8144
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 28.8144


                                                                                 

Epoch 24/30: Train Loss = 28.0105, Val Loss = 29.0959


                                                                                  

Epoch 25/30: Train Loss = 28.0362, Val Loss = 29.0171


                                                                                 

Epoch 26/30: Train Loss = 28.2812, Val Loss = 29.2156


                                                                                  

Epoch 27/30: Train Loss = 28.3124, Val Loss = 29.2723


                                                                                 

Epoch 28/30: Train Loss = 28.1570, Val Loss = 28.9158


                                                                                 

Epoch 29/30: Train Loss = 28.0012, Val Loss = 28.7007
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/best_model.pt
New best validation loss: 28.7007


                                                                                  

Epoch 30/30: Train Loss = 27.8918, Val Loss = 28.7631
Model saved to saved_models/conv_vqvae/normal_conv_vqvae/20250313_235821/final_model.pt
Training complete. Best validation loss: 28.7007


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train/total_loss,█▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val/total_loss,█▇▆▇▆▅▆▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▁▂▂▂▂▁▁▁

0,1
epoch,30.0
train/total_loss,27.89178
val/total_loss,28.76313


Training session complete.


## Transformer VQVAE

### Fraud

In [5]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vqvae_model import TransformerVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger

# Build the config path
config_path = "configs/transformer_vqvae/fraud_transformer_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vqvae_config = config_parser["Transformer_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVQVAE(transformer_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            # For VQVAE, we only log the total losses 
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

Initializing WandB run: transformer-vqvae-fraud_20250314_001319 (Project: fraud-classification, Entity: alexkstern)


Loaded configuration from configs/transformer_vqvae/fraud_transformer_vqvae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 170689
Models will be saved to: saved_models/transformer_vqvae/fraud_transformer_vqvae/20250314_001329


                                                                             

Epoch 1/100: Train Loss = 654.7753, Val Loss = 984.5373
Model saved to saved_models/transformer_vqvae/fraud_transformer_vqvae/20250314_001329/best_model.pt
New best validation loss: 984.5373


                                                                             

Epoch 2/100: Train Loss = 647.1583, Val Loss = 984.1514
Model saved to saved_models/transformer_vqvae/fraud_transformer_vqvae/20250314_001329/best_model.pt
New best validation loss: 984.1514


                                                                             

Epoch 3/100: Train Loss = 646.9973, Val Loss = 984.0501
Model saved to saved_models/transformer_vqvae/fraud_transformer_vqvae/20250314_001329/best_model.pt
New best validation loss: 984.0501


                                                                             

Epoch 4/100: Train Loss = 646.9363, Val Loss = 983.9916
Model saved to saved_models/transformer_vqvae/fraud_transformer_vqvae/20250314_001329/best_model.pt
New best validation loss: 983.9916


                                                                             

Epoch 5/100: Train Loss = 646.8360, Val Loss = 983.7335
Model saved to saved_models/transformer_vqvae/fraud_transformer_vqvae/20250314_001329/best_model.pt
New best validation loss: 983.7335


                                                                             

Epoch 6/100: Train Loss = 646.6988, Val Loss = 983.9249


                                                                             

Epoch 7/100: Train Loss = 646.8555, Val Loss = 984.0608


                                                                             

Epoch 8/100: Train Loss = 646.9518, Val Loss = 984.1817


                                                                             

Epoch 9/100: Train Loss = 647.0312, Val Loss = 984.2847


                                                                             

Epoch 10/100: Train Loss = 647.0945, Val Loss = 984.2881


                                                                             

Epoch 11/100: Train Loss = 647.1388, Val Loss = 984.3364


                                                                             

Epoch 12/100: Train Loss = 647.1827, Val Loss = 984.4300


                                                                             

Epoch 13/100: Train Loss = 647.2211, Val Loss = 984.4509


                                                                             

Epoch 14/100: Train Loss = 647.2443, Val Loss = 984.4628


                                                                             

Epoch 15/100: Train Loss = 647.2726, Val Loss = 984.4706


                                                                             

Epoch 16/100: Train Loss = 647.2916, Val Loss = 984.4808


                                                                             

Epoch 17/100: Train Loss = 647.3152, Val Loss = 984.4765


                                                                             

Epoch 18/100: Train Loss = 647.2978, Val Loss = 984.4869


                                                                             

Epoch 19/100: Train Loss = 647.2891, Val Loss = 984.4532


                                                                             

Epoch 20/100: Train Loss = 647.3108, Val Loss = 984.4898


                                                                             

Epoch 21/100: Train Loss = 647.3312, Val Loss = 984.4534


                                                                             

Epoch 22/100: Train Loss = 647.3248, Val Loss = 984.5269


                                                                             

Epoch 23/100: Train Loss = 647.3703, Val Loss = 984.5367


                                                                             

Epoch 24/100: Train Loss = 647.3537, Val Loss = 984.5304


                                                                             

Epoch 25/100: Train Loss = 647.3904, Val Loss = 984.5603


                                                                             

Epoch 26/100: Train Loss = 647.3546, Val Loss = 984.5444


                                                                             

Epoch 27/100: Train Loss = 647.3570, Val Loss = 984.5488


                                                                             

Epoch 28/100: Train Loss = 647.4122, Val Loss = 984.6010


                                                                             

Epoch 29/100: Train Loss = 647.5025, Val Loss = 984.7829


                                                                             

Epoch 30/100: Train Loss = 647.5091, Val Loss = 984.8086


                                                                             

Epoch 31/100: Train Loss = 647.5106, Val Loss = 984.7037


                                                                             

Epoch 32/100: Train Loss = 647.5011, Val Loss = 984.7728


                                                                             

Epoch 33/100: Train Loss = 647.3886, Val Loss = 984.5510


                                                                             

Epoch 34/100: Train Loss = 647.3741, Val Loss = 984.5634


                                                                             

Epoch 35/100: Train Loss = 647.4152, Val Loss = 984.6036


                                                                             

Epoch 36/100: Train Loss = 647.4515, Val Loss = 984.6848


                                                                             

Epoch 37/100: Train Loss = 647.4683, Val Loss = 984.6235


                                                                             

Epoch 38/100: Train Loss = 647.4429, Val Loss = 984.6227


                                                                             

Epoch 39/100: Train Loss = 647.4850, Val Loss = 984.6780


                                                                             

Epoch 40/100: Train Loss = 647.5229, Val Loss = 984.7837


                                                                             

Epoch 41/100: Train Loss = 647.5624, Val Loss = 984.8693


                                                                             

Epoch 42/100: Train Loss = 647.5925, Val Loss = 984.8953


                                                                             

Epoch 43/100: Train Loss = 647.5966, Val Loss = 984.8786


                                                                             

Epoch 44/100: Train Loss = 647.6100, Val Loss = 984.9418


                                                                             

Epoch 45/100: Train Loss = 647.6170, Val Loss = 984.8947


                                                                             

Epoch 46/100: Train Loss = 647.6655, Val Loss = 984.9535


                                                                             

Epoch 47/100: Train Loss = 647.6545, Val Loss = 984.9999


                                                                             

Epoch 48/100: Train Loss = 647.6301, Val Loss = 984.9526


                                                                             

Epoch 49/100: Train Loss = 647.7078, Val Loss = 985.0489


                                                                             

Epoch 50/100: Train Loss = 647.6576, Val Loss = 984.8740


                                                                             

Epoch 51/100: Train Loss = 647.6471, Val Loss = 984.9807


                                                                             

Epoch 52/100: Train Loss = 647.7202, Val Loss = 985.0622


                                                                             

Epoch 53/100: Train Loss = 647.7299, Val Loss = 985.0538


                                                                             

Epoch 54/100: Train Loss = 647.7250, Val Loss = 985.0193


                                                                             

Epoch 55/100: Train Loss = 647.7294, Val Loss = 985.0361


                                                                             

Epoch 56/100: Train Loss = 647.7391, Val Loss = 985.1041


                                                                             

Epoch 57/100: Train Loss = 647.8014, Val Loss = 985.1157


                                                                             

Epoch 58/100: Train Loss = 647.7987, Val Loss = 985.1155


                                                                             

Epoch 59/100: Train Loss = 647.7890, Val Loss = 985.1153


                                                                             

Epoch 60/100: Train Loss = 647.7987, Val Loss = 985.0801


                                                                             

Epoch 61/100: Train Loss = 647.7859, Val Loss = 985.1276


                                                                             

Epoch 62/100: Train Loss = 647.8094, Val Loss = 985.1379


                                                                             

Epoch 63/100: Train Loss = 647.8624, Val Loss = 985.1551


                                                                             

Epoch 64/100: Train Loss = 647.8359, Val Loss = 985.1650


                                                                             

Epoch 65/100: Train Loss = 647.8286, Val Loss = 985.1376


                                                                             

Epoch 66/100: Train Loss = 647.8608, Val Loss = 985.1638


                                                                             

Epoch 67/100: Train Loss = 647.8817, Val Loss = 985.1634


                                                                             

Epoch 68/100: Train Loss = 647.9167, Val Loss = 985.1789


                                                                             

Epoch 69/100: Train Loss = 647.9231, Val Loss = 985.1785


                                                                             

Epoch 70/100: Train Loss = 647.9257, Val Loss = 985.2220


                                                                             

Epoch 71/100: Train Loss = 647.9984, Val Loss = 985.2522


                                                                             

Epoch 72/100: Train Loss = 647.9709, Val Loss = 985.2521


                                                                             

Epoch 73/100: Train Loss = 647.9783, Val Loss = 985.2340


                                                                             

Epoch 74/100: Train Loss = 647.9733, Val Loss = 985.2522


                                                                             

Epoch 75/100: Train Loss = 647.9710, Val Loss = 985.2337


                                                                             

Epoch 76/100: Train Loss = 647.9814, Val Loss = 985.2520


                                                                             

Epoch 77/100: Train Loss = 648.0142, Val Loss = 985.2517


                                                                             

Epoch 78/100: Train Loss = 648.0582, Val Loss = 985.2520


                                                                             

Epoch 79/100: Train Loss = 648.0715, Val Loss = 985.2692


                                                                             

Epoch 80/100: Train Loss = 648.0738, Val Loss = 985.2877


                                                                             

Epoch 81/100: Train Loss = 648.1231, Val Loss = 985.2874


                                                                             

Epoch 82/100: Train Loss = 648.1520, Val Loss = 985.2874


                                                                             

Epoch 83/100: Train Loss = 648.1538, Val Loss = 985.2879


                                                                             

Epoch 84/100: Train Loss = 648.1421, Val Loss = 985.2984


                                                                             

Epoch 85/100: Train Loss = 648.1871, Val Loss = 985.2981


                                                                             

Epoch 86/100: Train Loss = 648.1646, Val Loss = 985.3175


                                                                             

Epoch 87/100: Train Loss = 648.1789, Val Loss = 985.2870


                                                                             

Epoch 88/100: Train Loss = 648.1181, Val Loss = 985.2873


                                                                             

Epoch 89/100: Train Loss = 648.1001, Val Loss = 985.2499


                                                                             

Epoch 90/100: Train Loss = 647.9334, Val Loss = 985.1591


                                                                             

Epoch 91/100: Train Loss = 647.9225, Val Loss = 985.1574


                                                                             

Epoch 92/100: Train Loss = 647.7942, Val Loss = 985.0831


                                                                             

Epoch 93/100: Train Loss = 647.7746, Val Loss = 985.0066


                                                                             

Epoch 94/100: Train Loss = 647.5898, Val Loss = 984.7108


                                                                             

Epoch 95/100: Train Loss = 647.4860, Val Loss = 984.6573


                                                                             

Epoch 96/100: Train Loss = 647.5037, Val Loss = 984.8962


                                                                             

Epoch 97/100: Train Loss = 647.6290, Val Loss = 984.8023


                                                                             

Epoch 98/100: Train Loss = 647.5746, Val Loss = 984.8035


                                                                             

Epoch 99/100: Train Loss = 647.4827, Val Loss = 984.8172


                                                                             

Epoch 100/100: Train Loss = 647.6183, Val Loss = 984.8314
Model saved to saved_models/transformer_vqvae/fraud_transformer_vqvae/20250314_001329/final_model.pt
Training complete. Best validation loss: 983.7335




0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train/total_loss,▃▁▁▂▃▃▃▄▃▄▅▅▅▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████▇▇▆▅▅
val/total_loss,▂▁▁▂▂▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇█████████▇▅▆▆

0,1
epoch,100.0
train/total_loss,647.61834
val/total_loss,984.83138


Training session complete.


### Normal

In [6]:
import os
import torch
import torch.optim as optim
import numpy as np
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import required modules
from models.transformer_vqvae_model import TransformerVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config
from utils.model_saver import save_model, get_save_directory
from utils.wandb_logger import WandBLogger

# Build the config path
config_path = "configs/transformer_vqvae/normal_transformer_vqvae.config"

# Load configuration
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)
transformer_vqvae_config = config_parser["Transformer_VQVAE"]

# Create config dictionary for WandB
config_dict = {}
for section in config_parser.sections():
    config_dict[section] = dict(config_parser[section])

# Initialize WandB logger
wandb_logger = WandBLogger(config_dict)

# Load data
config_dict = load_config(config_path)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Create model
model = TransformerVQVAE(transformer_vqvae_config)
print_num_params(model)

# Training parameters
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create trainer
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Create save directory
save_dir = get_save_directory(config_path)
print(f"Models will be saved to: {save_dir}")

# Track best validation loss
best_val_loss = float('inf')
train_losses, val_losses = [], []

try:
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = trainer.train_epoch()
        val_loss = trainer.validate_epoch()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Log to WandB (if enabled)
        try:
            # For VQVAE, we only log the total losses 
            wandb_logger.log_epoch(epoch, train_loss, val_loss)
        except Exception as e:
            print(f"Error logging to WandB: {e}")
        
        # Print progress
        print(f"Epoch {epoch}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            metadata = {'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss}
            model_path = save_model(model, save_dir, 'best_model.pt', metadata)
            print(f"New best validation loss: {val_loss:.4f}")
            wandb_logger.log_model(model_path, metadata)

    # Save final model if different from best
    if val_losses[-1] > best_val_loss:
        metadata = {'epoch': num_epochs, 'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}
        save_model(model, save_dir, 'final_model.pt', metadata)

    print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

except KeyboardInterrupt:
    print("\nTraining interrupted by user. Saving current state...")
    if len(train_losses) > 0:
        metadata = {
            'interrupted': True,
            'epoch': len(train_losses), 
            'train_loss': train_losses[-1],
            'val_loss': val_losses[-1] if len(val_losses) > 0 else None
        }
        save_model(model, save_dir, 'interrupted_model.pt', metadata)
        print(f"Interrupted model saved to {save_dir}")

except Exception as e:
    print(f"\nTraining failed with error: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    # Clean up WandB
    try:
        wandb_logger.finish()
    except:
        pass
    print("Training session complete.")

Initializing WandB run: transformer-vqvae-normal_20250314_001744 (Project: fraud-classification, Entity: alexkstern)


Loaded configuration from configs/transformer_vqvae/normal_transformer_vqvae.config
Filtered dataset to class 0: 12000 samples
Normalization statistics (calculated from class 0): {'Time': {'mean': 94364.65358333333, 'std': 47365.815157589255}, 'Amount': {'mean': 87.60478666666666, 'std': 240.59403081682598}}
Filtered dataset to class 0: 12000 samples
Filtered dataset to class 0: 1500 samples
Filtered dataset to class 0: 1500 samples
Total number of trainable parameters: 170689
Models will be saved to: saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750


                                                                                 

Epoch 1/30: Train Loss = 22.8098, Val Loss = 23.9801
Model saved to saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750/best_model.pt
New best validation loss: 23.9801


                                                                                 

Epoch 2/30: Train Loss = 23.4542, Val Loss = 24.0075


                                                                                 

Epoch 3/30: Train Loss = 22.9211, Val Loss = 23.4038
Model saved to saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750/best_model.pt
New best validation loss: 23.4038


                                                                                 

Epoch 4/30: Train Loss = 22.7282, Val Loss = 23.3471
Model saved to saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750/best_model.pt
New best validation loss: 23.3471


                                                                                 

Epoch 5/30: Train Loss = 22.6406, Val Loss = 23.5659


                                                                                 

Epoch 6/30: Train Loss = 22.8517, Val Loss = 23.1709
Model saved to saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750/best_model.pt
New best validation loss: 23.1709


                                                                                 

Epoch 7/30: Train Loss = 24.1452, Val Loss = 22.7688
Model saved to saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750/best_model.pt
New best validation loss: 22.7688


                                                                                 

Epoch 8/30: Train Loss = 22.1396, Val Loss = 22.7759


                                                                                 

Epoch 9/30: Train Loss = 22.1022, Val Loss = 22.7421
Model saved to saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750/best_model.pt
New best validation loss: 22.7421


                                                                                 

Epoch 10/30: Train Loss = 22.4386, Val Loss = 23.1618


                                                                                 

Epoch 11/30: Train Loss = 30.4875, Val Loss = 32.5416


                                                                                 

Epoch 12/30: Train Loss = 31.5614, Val Loss = 32.5415


                                                                                 

Epoch 13/30: Train Loss = 31.5584, Val Loss = 32.5414


                                                                                 

Epoch 14/30: Train Loss = 31.5556, Val Loss = 32.5413


                                                                                 

Epoch 15/30: Train Loss = 31.5529, Val Loss = 32.5413


                                                                                 

Epoch 16/30: Train Loss = 31.5513, Val Loss = 32.5413


                                                                                 

Epoch 17/30: Train Loss = 31.5509, Val Loss = 32.5412


                                                                                 

Epoch 18/30: Train Loss = 31.5506, Val Loss = 32.5410


                                                                                 

Epoch 19/30: Train Loss = 31.5507, Val Loss = 32.5412


                                                                                 

Epoch 20/30: Train Loss = 31.5505, Val Loss = 32.5411


                                                                                 

Epoch 21/30: Train Loss = 31.5505, Val Loss = 32.5410


                                                                                 

Epoch 22/30: Train Loss = 31.5505, Val Loss = 32.5411


                                                                                 

Epoch 23/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                                 

Epoch 24/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                                 

Epoch 25/30: Train Loss = 31.5504, Val Loss = 32.5407


                                                                                 

Epoch 26/30: Train Loss = 31.5505, Val Loss = 32.5411


                                                                                 

Epoch 27/30: Train Loss = 31.5504, Val Loss = 32.5407


                                                                                 

Epoch 28/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                                 

Epoch 29/30: Train Loss = 31.5504, Val Loss = 32.5410


                                                                                 

Epoch 30/30: Train Loss = 31.5504, Val Loss = 32.5410
Model saved to saved_models/transformer_vqvae/normal_transformer_vqvae/20250314_001750/final_model.pt
Training complete. Best validation loss: 22.7421


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train/total_loss,▂▂▂▁▁▂▃▁▁▁▇███████████████████
val/total_loss,▂▂▁▁▂▁▁▁▁▁████████████████████

0,1
epoch,30.0
train/total_loss,31.55042
val/total_loss,32.541


Training session complete.
