In [1]:
%load_ext autoreload
%autoreload 2

### Define the DataLoader

In [2]:
import torch
from torch.utils.data import DataLoader
from src.datasets import FSDKaggle2018Dataset, collate_fn_audio
from torch.utils.data import Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = FSDKaggle2018Dataset("../2552860")
dl = DataLoader(Subset(dataset, range(len(dataset))), batch_size=32, shuffle=False, collate_fn=collate_fn_audio)

### Define the Model

In [3]:
from src.model import ALMTokenizer
expansion_factor = 4

encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6, "dim_feedforward": 128*expansion_factor}
decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6, "dim_feedforward": 128*expansion_factor}

mae_decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}
mae_encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}

patchify_args = {"device": device}
unpatchify_args = {"device": device}

model = ALMTokenizer(
    from_raw_audio=True,
    encoder_args=encoder_args,
    decoder_args=decoder_args,
    mae_decoder_args=mae_decoder_args,
    mae_encoder_args=mae_encoder_args,
    patchify_args=patchify_args,
    unpatchify_args=unpatchify_args,
    window_size=6,
).to(device)

for p in model.patchify.parameters: p.requires_grad = False
for p in model.unpatchify.parameters: p.requires_grad = False

print(model)

  WeightNorm.apply(module, name, dim)


ALMTokenizer(
  (query_encoder): QueryEncoder(
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (pos_encoder): PositionalEncoding()
  )
  (query_decoder): QueryDecoder(
    (transformer): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn): M

### Define the Discriminators

In [None]:
from src.discriminator import Discriminator
import torch.nn as nn
import torchaudio.transforms as T

# 1) Define mel/log-mel transforms with each hop_length
hop_lengths = [32, 64, 128, 256, 512, 1024]

class MelLogMel(nn.Module):
    """
    Compute mel- and log-mel- spectrograms and stack them as two channels.
    """

    def __init__(
        self,
        sample_rate: int = 24000,
        n_fft: int = 1024,
        hop_length: int = 128,
        win_length: int = 1024,
        n_mels: int = 128,
        top_db: float = 80.0,
    ):
        super().__init__()
        # Mel spectrogram: power spectrogram -> mel bins
        self.mel_spec = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            n_mels=n_mels,
            power=2.0,                # power spectrogram
        )
        # Convert power spectrogram to decibels
        self.to_db = T.AmplitudeToDB(stype="power", top_db=top_db)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, 1, N) or (B, N) waveform in [-1, 1]
        returns: (B, 2, n_mels, T) where channel 0 = mel, 1 = log-mel
        """
        # collapse channel if present
        if x.dim() == 3 and x.size(1) == 1:
            x = x.squeeze(1)  # -> (B, N)
        # compute mel
        mel = self.mel_spec(x)       # (B, n_mels, T)
        log_mel = self.to_db(mel)    # (B, n_mels, T)
        # stack as two-channel feature
        return torch.stack([mel, log_mel], dim=1)

mel_transforms = nn.ModuleList([
    T.MelSpectrogram(sample_rate=24000, n_fft=1024, hop_length=h, win_length=1024)
    for h in hop_lengths
])

# 2) Instantiate the 6 discriminators
discriminators = nn.ModuleList([
    Discriminator(
        in_channels=128, 
        hidden_dims=[64,128,256,512,512,512], 
        mel_transform=m
        ).to(device)
    for m in mel_transforms
])

### Train the Model

In [None]:
# Change the weighting of each loss
lambdas = {
    "L_time": 1.0,
    "L_freq": 1.0,
    "L_adv": 1.0,
    "L_feat": 1.0,
    "L_mae": 1.0,
}


model.train_model(
    discriminators=discriminators,
    dl=dl,
    lambdas=lambdas,
    num_epochs=1000,
    writer_dir="runs/alm_tokenizer",
    checkpoint_dir="checkpoints/alm_tokenizer_30_expfactor_4"
)

  1%|          | 10/1000 [33:16<54:07:31, 196.82s/it]

Model saved at epoch 10


  2%|▏         | 20/1000 [1:05:06<51:47:14, 190.24s/it]

Model saved at epoch 20


  3%|▎         | 30/1000 [1:37:08<51:22:44, 190.69s/it]

Model saved at epoch 30


  4%|▎         | 37/1000 [2:00:14<52:46:10, 197.27s/it]

In [18]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(trainable_params)

4628736


In [None]:
from tqdm import trange, tqdm
from torch.utils.tensorboard.writer import SummaryWriter
import os
from itertools import chain


writer_dir = "runs/patchify_unpatchify"
checkpoint_dir = "checkpoints/patchify_unpatchify"
checkpoint_freq = 10

os.makedirs(writer_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

writer = SummaryWriter(log_dir="runs/alm_tokenizer")

lr_p          = 1e-4
weight_decay  = 1e-5
num_epochs    = 200

import torch.optim as optim
optim_p = optim.AdamW(
    params=chain(*[model.patchify.parameters, model.unpatchify.parameters]),
    lr=lr_p,
    weight_decay=weight_decay
)

from src.losses import compute_mae_loss

for epoch in trange(num_epochs):
    
    mae_loss = 0.0
    
    for wavs in dl:
        
        wavs = wavs.to(device)

        res = model(wavs)

        x_hat = res["x_hat"]
        x = res["orig_waveform"]
        mae_pred = res["mae_pred"]
        mae_target = res["mae_target"]
        mask_idx = res["mask_indices"]


        # Generator training
        loss = compute_mae_loss(
            mae_pred=mae_pred,
            mae_target=mae_target,
            mask_idx=mask_idx,
        )

        mae_loss += loss.item()

        optim_p.zero_grad()
        loss.backward()
        optim_p.step()
    
    # Save the model every n epochs
    if (epoch + 1) % checkpoint_freq == 0:
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"patchify_unpatchify_epoch_{epoch + 1}.pth"))
        print(f"Model saved at epoch {epoch + 1}")    

    mae_loss /= len(dl)
    writer.add_scalar(f"patchify/mae_loss", mae_loss, epoch)
    
    torch.cuda.empty_cache()



  0%|          | 1/200 [02:38<8:44:10, 158.04s/it]


KeyboardInterrupt: 

In [None]:
from typing import Iterable
import torch.optim as optim
from src.losses import compute_generator_loss, compute_discriminator_loss
from itertools import chain
from typing import Optional

def load_model(self, path):
    self.load_state_dict(torch.load(path, map_location=self.device))
    return model

def train_model(
        self,
        lambdas: dict,
        num_epochs: int = 200,
        checkpoint_freq: int = 10,
        start_checkpoint: Optional[int] = None,
        discriminator_train_freq: int = 30,
        writer_dir: str = "writer",
        checkpoint_dir: str = "checkpoints",
        lr_g: float = 1e-4,
        weight_decay: float = 1e-5,
        lr_d: float = 2e-5,
        betas: Iterable[float] = (0.5, 0.9),
        ):
    
    os.makedirs(writer_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    writer = SummaryWriter(log_dir="runs/alm_tokenizer")

    optim_g = optim.AdamW(
        filter(lambda p: p.requires_grad, self.parameters()),
        lr=lr_g,
        weight_decay=weight_decay
    )

    optim_d = optim.Adam(
        params=chain(*[D.parameters() for D in discriminators]),
        lr=lr_d,
        betas=betas
    )

    if start_checkpoint:
        self.load_model(os.path.join(checkpoint_dir, f"alm_tokenizer_epoch_{start_checkpoint}.pth"))

    self.train()

    for epoch in trange(num_epochs):
        
        losses = {
            "L_time": 0.0,
            "L_freq": 0.0,
            "L_adv": 0.0,
            "L_feat": 0.0,
            "L_mae": 0.0,
            "L_total": 0.0
            }
        
        for wavs in dl:
            
            wavs = wavs.to(device)

            res = self(wavs)

            x_hat = res["x_hat"]
            x = res["orig_waveform"]
            mae_pred = res["mae_pred"]
            mae_target = res["mae_target"]
            mask_idx = res["mask_indices"]


            # Generator training
            generator_loss = compute_generator_loss(
                x_hat=x_hat,
                x=x,
                discriminators=discriminators,
                mae_pred=mae_pred,
                mae_target=mae_target,
                lambdas = lambdas,
                mask_idx=mask_idx
            )

            for loss_type, loss_value in generator_loss.items():
                losses[loss_type] = losses[loss_type] + loss_value.item()
            
            total_gen_loss = generator_loss["L_total"]

            optim_g.zero_grad()
            total_gen_loss.backward()
            optim_g.step()
        
            if epoch % discriminator_train_freq == 0:
                x_disc = x.detach()
                x_hat_disc = x_hat.detach()
                discriminator_loss = compute_discriminator_loss(discriminators, x_disc, x_hat_disc)

                optim_d.zero_grad()
                discriminator_loss.backward()
                optim_d.step()

                writer.add_scalar(f"losses/discriminators", discriminator_loss, epoch)
        
        # Save the model every n epochs
        if (epoch + 1) % checkpoint_freq == 0:
            torch.save(self.state_dict(), os.path.join(checkpoint_dir, f"alm_tokenizer_epoch_{epoch + 1}.pth"))
            print(f"Model saved at epoch {epoch + 1}")    

        # Log losses
        for loss_type, loss_value in losses.items():
            losses[loss_type] /= len(dl)
            writer.add_scalar(f"losses/{loss_type}", losses[loss_type], epoch)

        torch.cuda.empty_cache()


In [5]:
for p in model.patchify.parameters: p.requires_grad = False
for p in model.unpatchify.parameters: p.requires_grad = False

from tqdm import trange, tqdm
from torch.utils.tensorboard.writer import SummaryWriter
import os

writer_dir = "runs/alm_tokenizer"
checkpoint_dir = "checkpoints/alm_tokenizer_30"
checkpoint_freq = 10
discriminator_train_freq = 30

os.makedirs(writer_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

writer = SummaryWriter(log_dir="runs/alm_tokenizer")

lr_g          = 1e-4
weight_decay  = 1e-5
num_epochs    = 1000

import torch.optim as optim
optim_g = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr_g,
    weight_decay=weight_decay
)

from src.losses import compute_generator_loss, compute_discriminator_loss
from itertools import chain

lr_d = 2e-5
betas = (0.5, 0.9)
optim_d = optim.Adam(
    params=chain(*[D.parameters() for D in discriminators]),
    lr=lr_d,
    betas=betas
)

lambdas = {
    "L_time": 1.0,
    "L_freq": 1.0,
    "L_adv": 1.0,
    "L_feat": 1.0,
    "L_mae": 1.0,
}

model.train()

for epoch in trange(num_epochs):
    
    losses = {
        "L_time": 0.0,
        "L_freq": 0.0,
        "L_adv": 0.0,
        "L_feat": 0.0,
        "L_mae": 0.0,
        "L_total": 0.0
        }
    
    for wavs in dl:
        
        wavs = wavs.to(device)

        res = model(wavs)

        x_hat = res["x_hat"]
        x = res["orig_waveform"]
        mae_pred = res["mae_pred"]
        mae_target = res["mae_target"]
        mask_idx = res["mask_indices"]


        # Generator training
        generator_loss = compute_generator_loss(
            x_hat=x_hat,
            x=x,
            discriminators=discriminators,
            mae_pred=mae_pred,
            mae_target=mae_target,
            lambdas = lambdas,
            mask_idx=mask_idx
        )

        for loss_type, loss_value in generator_loss.items():
            losses[loss_type] = losses[loss_type] + loss_value.item()
        
        total_gen_loss = generator_loss["L_total"]

        optim_g.zero_grad()
        total_gen_loss.backward()
        optim_g.step()
    
        if epoch % discriminator_train_freq == 0:
            x_disc = x.detach()
            x_hat_disc = x_hat.detach()
            discriminator_loss = compute_discriminator_loss(discriminators, x_disc, x_hat_disc)

            optim_d.zero_grad()
            discriminator_loss.backward()
            optim_d.step()

            writer.add_scalar(f"losses/discriminators", discriminator_loss, epoch)
    
    # Save the model every n epochs
    if (epoch + 1) % checkpoint_freq == 0:
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"alm_tokenizer_epoch_{epoch + 1}.pth"))
        print(f"Model saved at epoch {epoch + 1}")    

    # Log losses
    for loss_type, loss_value in losses.items():
        losses[loss_type] /= len(dl)
        writer.add_scalar(f"losses/{loss_type}", losses[loss_type], epoch)

    torch.cuda.empty_cache()


  1%|          | 10/1000 [30:57<49:46:22, 180.99s/it]

Model saved at epoch 10


  2%|▏         | 20/1000 [1:00:47<48:46:40, 179.18s/it]

Model saved at epoch 20


  3%|▎         | 30/1000 [1:30:36<48:10:59, 178.82s/it]

Model saved at epoch 30


  4%|▍         | 40/1000 [2:01:15<47:55:20, 179.71s/it]

Model saved at epoch 40


  5%|▌         | 50/1000 [2:31:09<47:15:54, 179.11s/it]

Model saved at epoch 50


  6%|▌         | 60/1000 [3:01:05<46:49:29, 179.33s/it]

Model saved at epoch 60


  7%|▋         | 70/1000 [3:31:52<46:25:43, 179.72s/it]

Model saved at epoch 70


  8%|▊         | 80/1000 [4:01:54<45:49:22, 179.31s/it]

Model saved at epoch 80


  9%|▉         | 90/1000 [4:31:38<45:07:14, 178.50s/it]

Model saved at epoch 90


 10%|█         | 100/1000 [5:02:10<44:48:05, 179.21s/it]

Model saved at epoch 100


 11%|█         | 110/1000 [5:32:00<44:25:45, 179.71s/it]

Model saved at epoch 110


 12%|█▏        | 120/1000 [6:01:53<43:40:46, 178.69s/it]

Model saved at epoch 120


 13%|█▎        | 130/1000 [6:32:36<43:21:49, 179.44s/it]

Model saved at epoch 130


 14%|█▍        | 140/1000 [7:02:34<43:26:13, 181.83s/it]

Model saved at epoch 140


 15%|█▌        | 150/1000 [7:32:24<42:15:58, 179.01s/it]

Model saved at epoch 150


 15%|█▌        | 151/1000 [7:36:19<42:45:44, 181.32s/it]


KeyboardInterrupt: 

In [20]:
from torch.nn.functional import l1_loss

l1_loss(x_hat, x, reduction="sum") / x_hat.shape[2]
l1_loss(x_hat, x, reduction="mean")

tensor(0.0605, device='cuda:0', grad_fn=<MeanBackward0>)

In [29]:
for audio in x:
    print(audio.max(), audio.min())

tensor(0.1898, device='cuda:0') tensor(-0.1893, device='cuda:0')
tensor(0.4501, device='cuda:0') tensor(-0.4693, device='cuda:0')


In [None]:



chkpt = 100
model = load_model(model, os.path.join(checkpoint_dir, f"alm_tokenizer_epoch_{chkpt}.pth"))

for wavs in dl:
    model = model.eval().to(device)
    res = model(wavs)

    import IPython.display as ipd
    for n in range(len(wavs)):
        ipd.display(ipd.Audio(res[n].cpu().detach().numpy(), rate=24000))
        ipd.display(ipd.Audio(wavs[n].cpu().numpy(), rate=24000))

In [16]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [7]:
def load_model(model, path):
    model.load_state_dict(torch.load(path, map_location=device))
    return model

model = load_model(model, "checkpoints/alm_tokenizer/alm_tokenizer_epoch_80.pth")
model.train()
model = model.cpu()

In [24]:
for batch in dl:
    wavs = batch.to("cpu")
    break
res = model(wavs)
compute_mae_loss(res["mae_pred"], res["mae_target"], res["mask_indices"])

tensor(14.4472, device='cuda:0', grad_fn=<MseLossBackward0>)

In [None]:
from src.utils import interleave_cls_tokens, mask_frames, retrieve_cls_tokens

frames = model.patchify.encode(wavs)
frames = frames.permute(0, 2, 1)  # (B, D, T) -> (B, T, D)
x, positions = interleave_cls_tokens(frames, cls_token=model.cls_token)
x = model.query_encoder(x)
h = retrieve_cls_tokens(x, positions)


In [25]:

masked_frames, masked_idx = mask_frames(frames, mask_rate=0.3, mask_token=model.mask_token)
mae_enc = model.mae_encoder(masked_frames)
mae_dec = model.mae_decoder(mae_enc, frames)
mae_pred = mae_dec[:, masked_idx, :]
mae_target = frames[:, masked_idx, :]
compute_mae_loss(mae_pred, mae_target, masked_idx)

tensor(13.9098, device='cuda:0', grad_fn=<MseLossBackward0>)

In [20]:
mse_loss(res["mae_pred"], mae_pred)

tensor(0.0902, device='cuda:0', grad_fn=<MseLossBackward0>)

In [22]:
mae_pred

tensor([[[-2.5894e-02,  2.8888e+00, -8.8254e-01,  ..., -1.5188e-01,
          -5.1809e-01,  7.2524e-01],
         [ 1.3230e-01,  2.9165e+00, -8.5840e-01,  ..., -3.0685e-01,
          -3.8989e-01,  5.5002e-01],
         [ 9.8439e-02,  2.8141e+00, -9.4399e-01,  ...,  2.0244e-01,
          -5.2359e-01,  5.5056e-01],
         ...,
         [ 1.5142e-01,  2.8619e+00, -8.7963e-01,  ..., -7.3022e-02,
          -4.3664e-01,  4.9169e-01],
         [ 2.4628e-01,  1.4577e+00, -4.2145e-01,  ...,  1.8048e-01,
          -4.3068e-01,  5.8467e-01],
         [ 1.0014e-01,  2.7176e+00, -8.6006e-01,  ...,  3.8413e-01,
          -4.9463e-01,  5.3098e-01]],

        [[ 2.1442e-01,  2.5414e+00, -7.0353e-01,  ...,  4.8555e-02,
          -3.6524e-01,  6.4000e-01],
         [ 2.3384e-01,  1.4476e+00, -4.0368e-01,  ...,  1.0151e-01,
          -5.9046e-01,  6.5376e-01],
         [ 2.4063e-01,  2.8415e+00, -8.5182e-01,  ..., -1.4841e-01,
          -4.1789e-01,  6.0152e-01],
         ...,
         [ 1.4505e-01,  2

In [21]:
res["mae_pred"]

tensor([[[-0.0457,  2.6762, -0.4075,  ...,  0.0211, -0.4550,  0.6094],
         [ 0.1423,  2.2095, -1.0006,  ...,  0.0520, -0.3873,  0.1313],
         [ 0.2777,  3.0498, -0.8560,  ...,  0.0858, -0.6472,  0.5685],
         ...,
         [ 0.1339,  2.6303, -1.0071,  ..., -0.1534, -0.2888,  0.5270],
         [ 0.2377,  2.3136, -0.8446,  ...,  0.1463, -0.4880,  0.6480],
         [ 0.2188,  2.2330, -0.9494,  ...,  0.0840, -0.4270,  0.4441]],

        [[ 0.3146,  2.8417, -0.9893,  ...,  0.1220, -0.4503,  0.7498],
         [ 0.2332,  2.7906, -0.8471,  ...,  0.1076, -0.4237,  0.8211],
         [ 0.1685,  2.6217, -0.8271,  ...,  0.0181, -0.4133,  0.5113],
         ...,
         [ 0.2381,  1.0129, -0.7738,  ...,  0.0347, -0.3107,  0.6183],
         [ 0.2162,  2.6900, -0.3379,  ...,  0.1344, -0.5644,  0.5185],
         [ 0.2872,  2.7244, -0.7717,  ..., -0.0604, -0.5694,  0.5678]],

        [[ 0.5386,  2.6072, -0.6916,  ...,  0.0949, -0.1852,  0.7429],
         [ 0.4679,  2.5086, -0.5822,  ..., -0

In [29]:
model(wavs)

tensor([[[ 3.8272e-03, -1.5208e-03, -1.0253e-03,  ...,  2.0808e-04,
           9.3603e-05,  8.8284e-05]],

        [[ 3.8272e-03, -1.5208e-03, -1.0253e-03,  ...,  2.0808e-04,
           9.3603e-05,  8.8284e-05]],

        [[ 3.8272e-03, -1.5208e-03, -1.0253e-03,  ...,  2.0808e-04,
           9.3603e-05,  8.8284e-05]],

        ...,

        [[ 3.8272e-03, -1.5208e-03, -1.0253e-03,  ...,  2.0808e-04,
           9.3603e-05,  8.8284e-05]],

        [[ 3.8272e-03, -1.5208e-03, -1.0253e-03,  ...,  2.0808e-04,
           9.3603e-05,  8.8284e-05]],

        [[ 3.8272e-03, -1.5208e-03, -1.0253e-03,  ...,  2.0808e-04,
           9.3603e-05,  8.8284e-05]]], device='cuda:0',
       grad_fn=<SliceBackward0>)