In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import torch
from torch.utils.data import DataLoader
from src.datasets import CodesPtDataset, FSDKaggle2018Dataset, collate_fn_audio
from torch.utils.data import Subset

from torch.nn.utils.rnn import pad_sequence

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = FSDKaggle2018Dataset("../2552860")
dl = DataLoader(Subset(dataset, range(2048)), batch_size=16, shuffle=False, collate_fn=collate_fn_audio)

In [3]:
from src.model import ALMTokenizer

encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6}
decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6}

mae_decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}
mae_encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}

patchify_args = {"device": "cuda"}
unpatchify_args = {"device": "cuda"}

model = ALMTokenizer(
    from_raw_audio=True,
    encoder_args=encoder_args,
    decoder_args=decoder_args,
    mae_decoder_args=mae_decoder_args,
    mae_encoder_args=mae_encoder_args,
    patchify_args=patchify_args,
    unpatchify_args=unpatchify_args,
    window_size=2,
).to(device)

print(model)

  WeightNorm.apply(module, name, dim)


ALMTokenizer(
  (query_encoder): QueryEncoder(
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (pos_encoder): PositionalEncoding()
  )
  (query_decoder): QueryDecoder(
    (transformer): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn):



In [4]:
from src.discriminator import Discriminator
import torch.nn as nn
import torchaudio.transforms as T

# 1) Define mel/log-mel transforms with each hop_length
hop_lengths = [32, 64, 128, 256, 512, 1024]

mel_transforms = nn.ModuleList([
    T.MelSpectrogram(sample_rate=24000, n_fft=1024, hop_length=h, win_length=1024)
    for h in hop_lengths
])

# 2) Instantiate the 6 discriminators
discriminators = nn.ModuleList([
    Discriminator(
        in_channels=128, 
        hidden_dims=[64,128,256,512,512,512], 
        mel_transform=m
        ).to(device)
    for m in mel_transforms
])

In [5]:
from tqdm import trange, tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/alm_tokenizer")

lr_g          = 1e-4
weight_decay  = 1e-2
num_epochs    = 200

import torch.optim as optim
optim_g = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr_g,
    weight_decay=weight_decay
)

from src.losses import compute_generator_loss, compute_discriminator_loss
from itertools import chain

lr_d = 2e-4
betas = (0.5, 0.9)
optim_d = optim.Adam(
    params=chain(*[D.parameters() for D in discriminators]),
    lr=lr_d,
    betas=betas
)

for epoch in trange(num_epochs):
    
    losses = {
        "L_time": 0.0,
        "L_freq": 0.0,
        "L_adv": 0.0,
        "L_feat": 0.0,
        "L_mae": 0.0,
        "L_total": 0.0
        }
    
    for i, codes in enumerate(tqdm(dl)):
        
        codes = codes.to(device)

        res = model(codes)

        x_hat = res["x_hat"]
        x = res["orig_waveform"]
        mae_pred = res["mae_pred"]
        mae_target = res["mae_target"]
        mask_idx = res["mask_indices"]

        if epoch >= 10:
            discriminator_loss = compute_discriminator_loss(discriminators, x, x_hat)
            optim_d.zero_grad()
            discriminator_loss.backward()
            optim_d.step()

        generator_loss = compute_generator_loss(
            x_hat=x_hat,
            x=x,
            discriminators=discriminators,
            mae_pred=mae_pred,
            mae_target=mae_target,
            mask_idx=mask_idx
        )

        for loss_type, loss_value in generator_loss.items():
            losses[loss_type] = losses[loss_type] + loss_value.item()
        
        total_gen_loss = generator_loss["L_total"]
        
        optim_g.zero_grad()
        total_gen_loss.backward()
        optim_g.step()    

    for loss_type, loss_value in losses.items():
        losses[loss_type] /= len(codes)
        writer.add_scalar(f"losses/{loss_type}", losses[loss_type], epoch)
    print(f"Epoch {epoch:2d} | Average Loss: {losses['L_total']:.4f}")


100%|██████████| 128/128 [01:14<00:00,  1.71it/s]
  0%|          | 1/200 [01:14<4:07:59, 74.77s/it]

Epoch  0 | Average Loss: 92.1694


100%|██████████| 128/128 [01:15<00:00,  1.70it/s]
  1%|          | 2/200 [02:30<4:07:56, 75.13s/it]

Epoch  1 | Average Loss: 89.6682


100%|██████████| 128/128 [01:17<00:00,  1.65it/s]
  2%|▏         | 3/200 [03:47<4:10:13, 76.21s/it]

Epoch  2 | Average Loss: 89.0118


100%|██████████| 128/128 [01:18<00:00,  1.64it/s]
  2%|▏         | 4/200 [05:05<4:11:24, 76.96s/it]

Epoch  3 | Average Loss: 88.4635


100%|██████████| 128/128 [01:18<00:00,  1.64it/s]
  2%|▎         | 5/200 [06:23<4:11:29, 77.38s/it]

Epoch  4 | Average Loss: 87.8421


100%|██████████| 128/128 [01:18<00:00,  1.64it/s]
  3%|▎         | 6/200 [07:42<4:11:07, 77.67s/it]

Epoch  5 | Average Loss: 87.3540


100%|██████████| 128/128 [01:18<00:00,  1.63it/s]
  4%|▎         | 7/200 [09:00<4:10:34, 77.90s/it]

Epoch  6 | Average Loss: 86.9583


100%|██████████| 128/128 [01:18<00:00,  1.64it/s]
  4%|▍         | 8/200 [10:18<4:09:27, 77.96s/it]

Epoch  7 | Average Loss: 86.4643


100%|██████████| 128/128 [01:18<00:00,  1.63it/s]
  4%|▍         | 9/200 [11:37<4:08:39, 78.11s/it]

Epoch  8 | Average Loss: 86.0897


100%|██████████| 128/128 [01:19<00:00,  1.62it/s]
  5%|▌         | 10/200 [12:56<4:08:17, 78.41s/it]

Epoch  9 | Average Loss: 85.4141


  4%|▍         | 5/128 [00:03<01:37,  1.27it/s]
  5%|▌         | 10/200 [13:00<4:07:00, 78.00s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 236.00 MiB. GPU 0 has a total capacity of 5.68 GiB of which 98.44 MiB is free. Including non-PyTorch memory, this process has 5.51 GiB memory in use. Of the allocated memory 3.73 GiB is allocated by PyTorch, and 1.63 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
res = model(codes)

res["x_hat"]

import IPython.display as ipd

ipd.display(ipd.Audio(res["x_hat"][0].cpu().numpy(), rate=24000))
ipd.display(ipd.Audio(res["orig_waveform"][0].cpu().numpy(), rate=24000))