# Training Script

### Imports

In [9]:
#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.VAE import *
from Libraries.Utils import *
from Conf import *

### Config

In [None]:
batch_size: int = 16
epochs: int = 100
learning_rate: float = 1e-5
lr_decay: int = 40
lr_gamma: float = 0.1
reprod_loss_weight: float = 20000
logging_level: int = logging.INFO
model_name: str = "conv_VAE_v1"
model_path: str = f"{MODEL_PATH}/{model_name}.pth"
checkpoint_freq: int = 5 #0 for no checkpoint saving
training_data_name: str = "training_v2"


logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

### Data Loading

In [None]:
file = load_training_data(f"{DATA_PATH}/{training_data_name}.npy")

In [14]:
data_loader = create_dataloader(Audio_Data(file[:16,...]), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-02-21 21:15:34,035 - INFO - Data loaded with shape: (4251, 1024, 672)


### Model Creation

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = VAE(in_channels=1, latent_dim=128, device=device,input_shape=[0,0, file.shape[-2], file.shape[-1]], n_conv_blocks=1, n_starting_filters=32, lin_bottleneck=False).to(device)
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, weights_only=False))
    logger.info(f"Model {model_name} loaded with {count_parameters(model)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(model)} Parameters")

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay, gamma=lr_gamma)

2025-02-21 21:16:09,212 - INFO - Model conv_VAE_v1 loaded with 91585 Parameters


### Training

In [18]:
x = train_VAE(model, data_loader, optimizer, loss_VAE, epochs=epochs, device=device, reprod_loss_weight=reprod_loss_weight, checkpoint_freq=checkpoint_freq, model_path=model_path)
scatter_plot(x)
torch.save(model.state_dict(), model_path)
logger.info("Model saved successfully.")

2025-02-21 21:16:13,940 - INFO - Training started on cpu
2025-02-21 21:16:26,401 - INFO - Epoch 01: Avg. Loss: 3.24241e+05 Avg. Reprod: 3.05681e+05 Avg. KL: 1.85600e+04 Remaining Time: 00h 20min 33s
2025-02-21 21:16:40,806 - INFO - Epoch 02: Avg. Loss: 3.24100e+05 Avg. Reprod: 3.05540e+05 Avg. KL: 1.85600e+04 Remaining Time: 00h 21min 56s
2025-02-21 21:16:54,475 - INFO - Epoch 03: Avg. Loss: 3.24327e+05 Avg. Reprod: 3.05639e+05 Avg. KL: 1.86880e+04 Remaining Time: 00h 21min 50s
2025-02-21 21:17:10,031 - INFO - Epoch 04: Avg. Loss: 3.24045e+05 Avg. Reprod: 3.05357e+05 Avg. KL: 1.86880e+04 Remaining Time: 00h 22min 26s
2025-02-21 21:17:32,163 - INFO - Epoch 05: Avg. Loss: 3.24086e+05 Avg. Reprod: 3.05270e+05 Avg. KL: 1.88160e+04 Remaining Time: 00h 24min 46s
2025-02-21 21:17:53,968 - INFO - Epoch 06: Avg. Loss: 3.24034e+05 Avg. Reprod: 3.05218e+05 Avg. KL: 1.88160e+04 Remaining Time: 00h 26min 06s
2025-02-21 21:18:06,763 - INFO - Epoch 07: Avg. Loss: 3.23890e+05 Avg. Reprod: 3.05074e+05 

KeyboardInterrupt: 