In [1]:
%load_ext autoreload
%autoreload 2

### Load the Config File

In [2]:
import yaml
# Load the YAML
with open("config_medium.yaml", "r") as f:
    cfg = yaml.safe_load(f)

### Define the DataLoader

In [3]:
import torch
from torch.utils.data import DataLoader
from src.datasets import FSD50K, collate_fn_audio
from torch.utils.data import Subset

dl_cfg = cfg["dataloader"]
batch_size = dl_cfg["batch_size"]
num_workers = dl_cfg["num_workers"]
nsecs = dl_cfg["nsecs"]
shuffle = dl_cfg["shuffle"]
train_subset_size = dl_cfg["train_subset_size"]
test_subset_size = dl_cfg["test_subset_size"]
dataset_path = dl_cfg["dataset_path"]

train_dataset = FSD50K(dataset_path, split="train")
test_dataset = FSD50K(dataset_path, split="test")

if train_subset_size is None:
    train_subset_size = int(len(train_dataset))

if test_subset_size is None:
    test_subset_size = int(len(test_dataset))


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(123)

train_dl = DataLoader(Subset(train_dataset, range(train_subset_size)), 
                batch_size=batch_size, 
                shuffle=shuffle,
                num_workers=num_workers,
                collate_fn=lambda x: collate_fn_audio(x, nsecs=nsecs))

test_dl = DataLoader(Subset(test_dataset, range(test_subset_size)), 
                batch_size=batch_size, 
                shuffle=shuffle,
                num_workers=num_workers,
                collate_fn=lambda x: collate_fn_audio(x, nsecs=nsecs))

### Define the Model and Discriminators

In [None]:
import yaml
from src.model import ALMTokenizer
from encodec.msstftd import MultiScaleSTFTDiscriminator



def load_model_from_config(cfg):

    device = torch.device(cfg["device"])

    encoder_args      = cfg["model"]["base_args"]
    decoder_args      = cfg["model"]["base_args"]
    mae_decoder_args  = cfg["model"]["mae_args"]
    patchify_args     = cfg["model"]["patchify_args"]
    unpatchify_args   = cfg["model"]["unpatchify_args"]
    
    model = ALMTokenizer(
        from_raw_audio   = True,
        encoder_args     = encoder_args,
        decoder_args     = decoder_args,
        mae_decoder_args = mae_decoder_args,
        patchify_args    = patchify_args,
        unpatchify_args  = unpatchify_args,
        window_size      = cfg["model"]["window_size"],
    ).to(device)

    cfg_disc = cfg["discriminator"]
    hop_lengths = cfg_disc["hop_lengths"]
    n_fft = cfg_disc["n_fft"]
    win_lengths = cfg_disc["win_lengths"]
    n_mels = cfg_disc["n_mels"]

    discriminators = MultiScaleSTFTDiscriminator(
        filters = n_mels,
        n_ffts = n_fft,
        hop_lengths = hop_lengths,
        win_lengths = win_lengths
    ).to(device)

    return model, discriminators

model, discriminators = load_model_from_config(cfg)

model_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {model_params:,}")

disc_params = sum(p.numel() for p in discriminators.parameters())
print(f"Discriminators parameters: {disc_params:,}")

print(f"Total parameters: {model_params + disc_params:,}")

  WeightNorm.apply(module, name, dim)


Model parameters: 6,346,240
Discriminators parameters: 1,493,000
Total parameters: 7,839,240
Estimated memory: 7.48 GB


### Train the Model

In [None]:
training_cfg = cfg["training"]
num_epochs = training_cfg["num_epochs"]
discriminator_train_freq = training_cfg["discriminator_train_freq"]
d_train_prob = training_cfg["d_train_prob"]
checkpoint_freq = training_cfg["checkpoint_freq"]
eval_freq = training_cfg["eval_freq"]
start_checkpoint = training_cfg["start_checkpoint"]
writer_dir = training_cfg["writer_dir"]
checkpoint_dir = training_cfg["checkpoint_dir"]
checkpoint_dir = training_cfg["checkpoint_dir"]
lr_generator = training_cfg["lr_generator"]
lr_discriminator = training_cfg["lr_discriminator"]
weight_decay = training_cfg["weight_decay"]
betas = training_cfg["betas"]
lambdas = training_cfg["lambdas"]

model.train_model(train_dl=train_dl,
                    test_dl=test_dl,
                    discriminators=discriminators,
                    num_epochs=num_epochs,
                    discriminator_train_freq=discriminator_train_freq,
                    d_train_prob=d_train_prob,
                    checkpoint_freq=checkpoint_freq,
                    start_checkpoint=start_checkpoint,
                    eval_freq=eval_freq,
                    writer_dir=writer_dir,
                    checkpoint_dir=checkpoint_dir,
                    lr_g=lr_generator,
                    weight_decay=weight_decay,
                    lr_d=lr_discriminator,
                    betas=betas,
                    lambdas=lambdas)

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch progress:   0%|          | 0/32 [00:00<?, ?it/s]

In [None]:
[p for p in discriminators.discriminators[1].parameters()]

[Parameter containing:
 tensor([[[[ 7.8257e-02,  1.2198e-02, -2.6733e-02,  ..., -5.7011e-02,
            -1.8401e-01,  4.2589e-02],
           [ 7.6581e-02,  1.2796e-01,  4.1820e-02,  ...,  9.5402e-02,
            -4.7279e-02, -8.6919e-02],
           [ 1.1125e-02, -3.1267e-02, -7.4786e-02,  ...,  1.5737e-02,
             9.6646e-02,  5.0586e-02]],
 
          [[ 1.4459e-01, -7.7240e-05,  1.3664e-01,  ..., -4.1520e-02,
            -3.9226e-02,  1.5649e-01],
           [ 1.6049e-02,  4.7256e-02,  1.4906e-01,  ...,  6.7259e-02,
            -9.1659e-02, -7.9693e-02],
           [-2.0297e-02, -1.5751e-01,  1.9621e-02,  ..., -3.7993e-02,
            -1.7449e-01,  6.5505e-02]]],
 
 
         [[[-8.9079e-02,  9.5457e-02,  6.0435e-02,  ..., -1.0784e-01,
             5.1653e-03,  9.5588e-02],
           [-1.7102e-01, -1.1771e-01, -5.3930e-02,  ...,  3.9789e-02,
            -1.9824e-01,  7.8612e-02],
           [-1.6027e-01,  1.0421e-01,  8.7587e-02,  ..., -1.4430e-01,
             9.8414e-02,  

In [None]:
discriminators.parameters()

<generator object Module.parameters at 0x710338b1f530>