In [1]:
import AutoEncoder
import torch.nn as nn
import torch.nn.init as init
import os
from datetime import datetime

import torch
import torchaudio.transforms as T
import Dataset
from lib.AudioSet.transform import TimeSequenceLengthFixer, SoundTrackSelector
import yaml
import matplotlib.pyplot as plt
import numpy as np
import DataTransform

import lib.MuxkitDeepLearningTools.dataset_tools.CachableDataset as mk_cachedata

In [2]:
class ToDevice(nn.Module):
    def __init__(self, device="cuda"):
        super().__init__()
        self.device = "cuda"

    def forward(self, x):
        return x.to(self.device)

with open("hyperpara.yml", "r") as f:
    hyper_parameter = yaml.safe_load(f)
Device = hyper_parameter["TrainProcessControl"]["device"]

pipeline = nn.Sequential(
    SoundTrackSelector(hyper_parameter["SoundTrackSelector"]['mode']),
    ToDevice(Device),
    T.Resample(**hyper_parameter["Resample"]),
    TimeSequenceLengthFixer(**hyper_parameter["TimeSequenceLengthFixer"]),
    DataTransform.ToLogMelSpectrogram(**hyper_parameter["ToLogMelSpectrogram"]),
).to(Device)

@torch.no_grad()
def data_preprocess(x: torch.Tensor)->torch.Tensor:
    return pipeline(x)
with open("other_configs.yml", "r") as f:
    trainset, evalset = Dataset.AudioSet.from_yaml(yaml.safe_load(f))
trainset.transform=data_preprocess
evalset.transform=data_preprocess

In [3]:
model = AutoEncoder.AutoEncoder(**hyper_parameter["AutoEncoder"]).to(Device)
# from torchinfo import summary
# summary(model, input_size=trainset[0].shape)

In [4]:
optimizer = torch.optim.Adam(model.parameters(),
                       lr=float(hyper_parameter["TrainProcessControl"]["LearningRate"]),
                       weight_decay=float(hyper_parameter["TrainProcessControl"]["Optimizer"]["WeightDecay"]))

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    patience=hyper_parameter["TrainProcessControl"]["Scheduler"]["patience"],
)

loss_function = nn.MSELoss()
train_losses = torch.empty(0).to(Device)
validate_losses = torch.empty(0).to(Device)

from torch.utils.data import random_split
subset_size = 1000
val_test_split = 500

subset = torch.utils.data.Subset(evalset, torch.randperm(len(evalset))[:subset_size].tolist())
validate_set, test_set = random_split(subset, [val_test_split, subset_size - val_test_split])
trainset = mk_cachedata.CacheableDataset(trainset, **hyper_parameter["CacheableDataset"])
validate_set = mk_cachedata.CacheableDataset(validate_set, **hyper_parameter["CacheableDataset"])
test_set = mk_cachedata.CacheableDataset(test_set, **hyper_parameter["CacheableDataset"])


In [6]:
def one_step_loss(x, x_hat):
    return loss_function(x, x_hat)
def one_epoch(dataloader, loss_save_to: torch.Tensor):
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()

        x_hat = model(batch)
        loss = one_step_loss(batch, x_hat)
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())

    # log mean loss for the epoch
    mean_loss = sum(epoch_losses) / len(epoch_losses)
    loss_tensor = torch.tensor([mean_loss], device=loss_save_to.device)
    loss_save_to = torch.cat([loss_save_to, loss_tensor])


In [7]:
def save_checkpoint(model, optimizer, epoch=None, save_dir="./checkpoints"):
    os.makedirs(save_dir, exist_ok=True)
    now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"checkpoint_{now}.pt"
    save_path = os.path.join(save_dir, filename)

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, save_path)

    print(f"[âœ“] Checkpoint saved: {save_path}")