# Train VAE

### Imports

In [1]:
#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.VAE import *
from Libraries.Utils import *
from Conf import *

### Config

In [2]:
batch_size: int = 32
epochs: int = 100
learning_rate: float = 1e-5
lr_decay: int = 40
lr_gamma: float = 0.1
reprod_loss_weight: float = 20000
logging_level: int = logging.INFO
model_name: str = "conv_VAE_v2"
model_path: str = f"{MODEL_PATH}/{model_name}.pth"
checkpoint_freq: int = 5 #0 for no checkpoint saving
training_data_name: str = "training_1280"


logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

### Data Loading

In [3]:
file = load_training_data(f"{DATA_PATH}/{training_data_name}.npy")

In [4]:
data_loader = create_dataloader(Audio_Data(file), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-02-23 20:30:33,221 - INFO - Data loaded with shape: (1280, 1024, 672)


### Model Creation

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = VAE(in_channels=1, latent_dim=256, device=device,input_shape=[0,0, file.shape[-2], file.shape[-1]], n_conv_blocks=1, n_starting_filters=64, lin_bottleneck=False).to(device)
print(model)
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    logger.info(f"Model {model_name} loaded with {count_parameters(model)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(model)} Parameters")

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay, gamma=lr_gamma)

2025-02-23 20:30:34,433 - INFO - Model conv_VAE_v2 loaded with 363393 Parameters


VAE(
  (activation): LeakyReLU(negative_slope=0.3)
  (encoder): Sequential(
    (0): ConvDown(
      (conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.3)
        (2): MaxPool2d(kernel_size=4, stride=2, padding=1, dilation=1, ceil_mode=False)
      )
    )
    (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): LeakyReLU(negative_slope=0.3)
  )
  (bottleneck_encoder): ConvBottleneckEncoder(
    (conv_mean): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
    (conv_logvar): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (bottleneck_decoder): ConvBottleneckDecoder(
    (expand): ConvTranspose2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
  )
  (decoder): Sequential(
    (0): ConvUp(
      (conv): Sequential(
        (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.3)
        (2): Upsampl

### Training

In [6]:
x = train_VAE(model, data_loader, optimizer, loss_VAE, epochs=epochs, device=device, reprod_loss_weight=reprod_loss_weight, checkpoint_freq=checkpoint_freq, model_path=model_path)
scatter_plot(x)
torch.save(model.state_dict(), model_path)
logger.info("Model saved successfully.")

2025-02-23 20:31:01,378 - INFO - Training started on cpu


KeyboardInterrupt: 