In [5]:
import os
import torch
import torch.optim as optim

# Import our model and trainer for VQ-VAE
from models.conv_vqvae_model import ConvVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config

# Build the config path relative to this notebook's location
config_path = "configs/conv_vqvae/fraud_conv_vqvae.config"

# Use configparser to maintain compatibility with your model code
import configparser
config_parser = configparser.ConfigParser()
config_parser.read(config_path)

# Extract the model configuration from the "Conv_VQVAE" section
conv_vqvae_config = config_parser["Conv_VQVAE"]

# Use load_config for the dataloader (it expects the dictionary format)
config_dict = load_config(config_path)

# Load data (this will override defaults with settings from the config's DataLoader section)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Instantiate the VQ-VAE model using the config section
model = ConvVQVAE(conv_vqvae_config)
print_num_params(model)

# Retrieve training hyperparameters from the Trainer section (including learning rate)
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define the optimizer and loss function for the VQ-VAE
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create the trainer instance
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Train the model (this will use tqdm for progress reporting)
train_losses, val_losses = trainer.train(num_epochs=num_epochs)

Loaded configuration from configs/conv_vqvae/fraud_conv_vqvae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 64164


                                                                             

Epoch 1/50: Train Loss = 696.5127, Val Loss = 1022.2650


                                                                             

Epoch 2/50: Train Loss = 675.0035, Val Loss = 1008.6584


                                                                             

Epoch 3/50: Train Loss = 690.4520, Val Loss = 1009.6366


                                                                             

Epoch 4/50: Train Loss = 663.4775, Val Loss = 1032.1196


                                                                             

Epoch 5/50: Train Loss = 675.1273, Val Loss = 1011.3932


                                                                             

Epoch 6/50: Train Loss = 673.3805, Val Loss = 1134.0840


                                                                             

Epoch 7/50: Train Loss = 670.4873, Val Loss = 998.8870


                                                                             

Epoch 8/50: Train Loss = 658.5239, Val Loss = 1000.7985


                                                                             

Epoch 9/50: Train Loss = 659.0288, Val Loss = 1000.8221


                                                                             

Epoch 10/50: Train Loss = 660.3430, Val Loss = 1006.8452


                                                                             

Epoch 11/50: Train Loss = 667.7788, Val Loss = 1011.1822


                                                                             

Epoch 12/50: Train Loss = 668.6575, Val Loss = 1013.8548


                                                                             

Epoch 13/50: Train Loss = 682.2960, Val Loss = 1109.6201


                                                                             

Epoch 14/50: Train Loss = 741.2791, Val Loss = 1136.1420


                                                                             

Epoch 15/50: Train Loss = 750.5163, Val Loss = 1274.8167


                                                                             

Epoch 16/50: Train Loss = 765.5583, Val Loss = 1245.9328


                                                                             

Epoch 17/50: Train Loss = 700.6611, Val Loss = 1139.6880


                                                                             

Epoch 18/50: Train Loss = 705.6671, Val Loss = 1358.4336


                                                                             

Epoch 19/50: Train Loss = 784.7133, Val Loss = 1279.7272


                                                                             

Epoch 20/50: Train Loss = 707.7576, Val Loss = 1103.3652


                                                                             

Epoch 21/50: Train Loss = 685.8796, Val Loss = 1083.3895


                                                                             

Epoch 22/50: Train Loss = 686.0165, Val Loss = 1105.2827


                                                                             

Epoch 23/50: Train Loss = 685.1847, Val Loss = 1145.2068


                                                                             

Epoch 24/50: Train Loss = 696.6067, Val Loss = 1088.5982


                                                                             

Epoch 25/50: Train Loss = 676.6622, Val Loss = 1074.6069


                                                                             

Epoch 26/50: Train Loss = 675.9409, Val Loss = 1047.9624


                                                                             

Epoch 27/50: Train Loss = 669.7487, Val Loss = 1038.3445


                                                                             

Epoch 28/50: Train Loss = 667.7861, Val Loss = 1029.1154


                                                                             

Epoch 29/50: Train Loss = 666.1185, Val Loss = 1024.2532


                                                                             

Epoch 30/50: Train Loss = 665.5925, Val Loss = 1024.3694


                                                                             

Epoch 31/50: Train Loss = 666.2888, Val Loss = 1028.3457


                                                                             

Epoch 32/50: Train Loss = 664.5187, Val Loss = 1032.0549


                                                                             

Epoch 33/50: Train Loss = 668.0398, Val Loss = 1021.9585


                                                                             

Epoch 34/50: Train Loss = 662.6735, Val Loss = 1012.7257


                                                                             

Epoch 35/50: Train Loss = 661.0667, Val Loss = 1008.5010


                                                                             

Epoch 36/50: Train Loss = 659.8139, Val Loss = 1007.2373


                                                                             

Epoch 37/50: Train Loss = 659.4984, Val Loss = 1006.1728


                                                                             

Epoch 38/50: Train Loss = 658.5691, Val Loss = 1002.9739


                                                                             

Epoch 39/50: Train Loss = 658.3240, Val Loss = 1002.8470


                                                                             

Epoch 40/50: Train Loss = 658.1094, Val Loss = 1002.9852


                                                                             

Epoch 41/50: Train Loss = 658.2869, Val Loss = 1005.1620


                                                                             

Epoch 42/50: Train Loss = 659.5203, Val Loss = 1019.6483


                                                                             

Epoch 43/50: Train Loss = 661.0569, Val Loss = 1013.0277


                                                                             

Epoch 44/50: Train Loss = 662.3330, Val Loss = 1043.4733


                                                                             

Epoch 45/50: Train Loss = 669.1828, Val Loss = 1031.4887


                                                                             

Epoch 46/50: Train Loss = 668.4558, Val Loss = 1052.4761


                                                                             

Epoch 47/50: Train Loss = 667.5610, Val Loss = 1023.2524


                                                                             

Epoch 48/50: Train Loss = 661.3401, Val Loss = 1006.8302


                                                                             

Epoch 49/50: Train Loss = 659.1544, Val Loss = 1011.3263


                                                                             

Epoch 50/50: Train Loss = 662.4793, Val Loss = 1026.8474


