In [1]:
import os
import configparser
import torch
import torch.optim as optim

# Import our model and trainer for VQ-VAE
from models.conv_vqvae_model import ConvVQVAE, vqvae_loss_function, print_num_params
from trainer.trainer_vqvae import VQVAETrainer
from dataloader.dataloader import load_fraud_data, load_config

# Build the config path relative to this notebook's location.
# If __file__ is not defined in your notebook, replace os.path.dirname(__file__) with the appropriate base path.
#config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "fraud_vae", "vae_test.config")
config_path = "configs/conv_vqvae/fraud_conv_vqvae.config"

# Load the configuration
config_parser = configparser.ConfigParser()
config_parser.read(config_path)

# Extract the model configuration from the "Conv_VQVAE" section
conv_vqvae_config = config_parser["Conv_VQVAE"]

# Load data (this will override defaults with settings from the config's DataLoader section)
data = load_fraud_data(config_path=config_path)
dataloaders = data['dataloaders']
input_dim = data['input_dim']

# Instantiate the VQ-VAE model using the config section
model = ConvVQVAE(conv_vqvae_config)
print_num_params(model)

# Retrieve training hyperparameters from the Trainer section (including learning rate)
lr = config_parser["Trainer"].getfloat("lr")
num_epochs = config_parser["Trainer"].getint("num_epochs")

# Define the optimizer and loss function for the VQ-VAE
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = vqvae_loss_function

# Create the trainer instance
trainer = VQVAETrainer(model, dataloaders, loss_fn, optimizer)

# Train the model (this will use tqdm for progress reporting)
train_losses, val_losses = trainer.train(num_epochs=num_epochs)


  from .autonotebook import tqdm as notebook_tqdm


Loaded configuration from configs/conv_vqvae/fraud_conv_vqvae.config
Filtered dataset to class 1: 378 samples
Normalization statistics (calculated from class 1): {'Time': {'mean': 80790.48148148147, 'std': 48332.5139872635}, 'Amount': {'mean': 133.6764814814815, 'std': 276.3532237447719}}
Filtered dataset to class 1: 378 samples
Filtered dataset to class 1: 47 samples
Filtered dataset to class 1: 48 samples
Total number of trainable parameters: 64164


                                                      

RuntimeError: Given groups=1, weight of size [16, 1, 3], expected input[1, 8, 30] to have 1 channels, but got 8 channels instead