In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
from src.datasets import FSDKaggle2018Dataset, collate_fn_audio
from torch.utils.data import Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = FSDKaggle2018Dataset("../2552860")
dl = DataLoader(Subset(dataset, range(2048)), batch_size=16, shuffle=False, collate_fn=collate_fn_audio)

In [3]:
from src.model import ALMTokenizer

encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6}
decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6}

mae_decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}
mae_encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}

patchify_args = {"device": "cuda"}
unpatchify_args = {"device": "cuda"}

model = ALMTokenizer(
    from_raw_audio=True,
    encoder_args=encoder_args,
    decoder_args=decoder_args,
    mae_decoder_args=mae_decoder_args,
    mae_encoder_args=mae_encoder_args,
    patchify_args=patchify_args,
    unpatchify_args=unpatchify_args,
    window_size=2,
).to(device)

print(model)

  WeightNorm.apply(module, name, dim)


ALMTokenizer(
  (query_encoder): QueryEncoder(
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (pos_encoder): PositionalEncoding()
  )
  (query_decoder): QueryDecoder(
    (transformer): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn):



In [4]:
from src.discriminator import Discriminator
import torch.nn as nn
import torchaudio.transforms as T

# 1) Define mel/log-mel transforms with each hop_length
hop_lengths = [32, 64, 128, 256, 512, 1024]

mel_transforms = nn.ModuleList([
    T.MelSpectrogram(sample_rate=24000, n_fft=1024, hop_length=h, win_length=1024)
    for h in hop_lengths
])

# 2) Instantiate the 6 discriminators
discriminators = nn.ModuleList([
    Discriminator(
        in_channels=128, 
        hidden_dims=[64,128,256,512,512,512], 
        mel_transform=m
        ).to(device)
    for m in mel_transforms
])

In [None]:
from tqdm import trange, tqdm
from torch.utils.tensorboard.writer import SummaryWriter
import os

writer_dir = "runs/alm_tokenizer"
checkpoint_dir = "checkpoints/alm_tokenizer"
checkpoint_freq = 10

os.makedirs(writer_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

writer = SummaryWriter(log_dir="runs/alm_tokenizer")

lr_g          = 1e-4
weight_decay  = 1e-2
num_epochs    = 200

import torch.optim as optim
optim_g = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr_g,
    weight_decay=weight_decay
)

from src.losses import compute_generator_loss, compute_discriminator_loss
from itertools import chain

lr_d = 2e-4
betas = (0.5, 0.9)
optim_d = optim.Adam(
    params=chain(*[D.parameters() for D in discriminators]),
    lr=lr_d,
    betas=betas
)

for epoch in trange(num_epochs):
    
    losses = {
        "L_time": 0.0,
        "L_freq": 0.0,
        "L_adv": 0.0,
        "L_feat": 0.0,
        "L_mae": 0.0,
        "L_total": 0.0
        }
    
    for i, wavs in enumerate(dl):
        
        wavs = wavs.to(device)

        res = model(wavs)

        x_hat = res["x_hat"]
        x = res["orig_waveform"]
        mae_pred = res["mae_pred"]
        mae_target = res["mae_target"]
        mask_idx = res["mask_indices"]

        # Late discriminator training
        if epoch >= 10:
            discriminator_loss = compute_discriminator_loss(discriminators, x, x_hat)
            optim_d.zero_grad()
            discriminator_loss.backward()
            optim_d.step()

        # Generator training
        generator_loss = compute_generator_loss(
            x_hat=x_hat,
            x=x,
            discriminators=discriminators,
            mae_pred=mae_pred,
            mae_target=mae_target,
            mask_idx=mask_idx
        )

        for loss_type, loss_value in generator_loss.items():
            losses[loss_type] = losses[loss_type] + loss_value.item()
        
        total_gen_loss = generator_loss["L_total"]
        
        optim_g.zero_grad()
        total_gen_loss.backward()
        optim_g.step()

    # Save the model every n epochs
    if epoch % checkpoint_freq == 0:
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"alm_tokenizer_epoch_{epoch}.pth"))
        print(f"Model saved at epoch {epoch}")    

    # Log losses
    for loss_type, loss_value in losses.items():
        losses[loss_type] /= len(wavs)
        writer.add_scalar(f"losses/{loss_type}", losses[loss_type], epoch)
    
    print(f"Epoch {epoch:2d} | Average Loss: {losses['L_total']:.4f}")


  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
res = model(wavs)

import IPython.display as ipd

ipd.display(ipd.Audio(res["x_hat"][0].cpu().numpy(), rate=24000))
ipd.display(ipd.Audio(res["orig_waveform"][0].cpu().numpy(), rate=24000))