In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
from src.datasets import FSDKaggle2018Dataset, collate_fn_audio
from torch.utils.data import Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = FSDKaggle2018Dataset("../2552860")
dl = DataLoader(Subset(dataset, range(2048)), batch_size=12, shuffle=False, collate_fn=collate_fn_audio)

In [3]:
from src.model import ALMTokenizer

encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6}
decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 6}

mae_decoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}
mae_encoder_args = {"embed_dim": 128, "n_heads": 8, "n_layers": 4}

patchify_args = {"device": "cuda"}
unpatchify_args = {"device": "cuda"}

model = ALMTokenizer(
    from_raw_audio=True,
    encoder_args=encoder_args,
    decoder_args=decoder_args,
    mae_decoder_args=mae_decoder_args,
    mae_encoder_args=mae_encoder_args,
    patchify_args=patchify_args,
    unpatchify_args=unpatchify_args,
    window_size=2,
).to(device)

print(model)

  WeightNorm.apply(module, name, dim)


ALMTokenizer(
  (query_encoder): QueryEncoder(
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (pos_encoder): PositionalEncoding()
  )
  (query_decoder): QueryDecoder(
    (transformer): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn):



In [4]:
from src.discriminator import Discriminator
import torch.nn as nn
import torchaudio.transforms as T

# 1) Define mel/log-mel transforms with each hop_length
hop_lengths = [32, 64, 128, 256, 512, 1024]

mel_transforms = nn.ModuleList([
    T.MelSpectrogram(sample_rate=24000, n_fft=1024, hop_length=h, win_length=1024)
    for h in hop_lengths
])

# 2) Instantiate the 6 discriminators
discriminators = nn.ModuleList([
    Discriminator(
        in_channels=128, 
        hidden_dims=[64,128,256,512,512,512], 
        mel_transform=m
        ).to(device)
    for m in mel_transforms
])

In [5]:
from tqdm import trange, tqdm
from torch.utils.tensorboard.writer import SummaryWriter
import os

writer_dir = "runs/alm_tokenizer"
checkpoint_dir = "checkpoints/alm_tokenizer"
checkpoint_freq = 10

os.makedirs(writer_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

writer = SummaryWriter(log_dir="runs/alm_tokenizer")

lr_g          = 1e-4
weight_decay  = 1e-2
num_epochs    = 200

import torch.optim as optim
optim_g = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr_g,
    weight_decay=weight_decay
)

from src.losses import compute_generator_loss, compute_discriminator_loss
from itertools import chain

lr_d = 2e-4
betas = (0.5, 0.9)
optim_d = optim.Adam(
    params=chain(*[D.parameters() for D in discriminators]),
    lr=lr_d,
    betas=betas
)

for epoch in trange(num_epochs):
    
    losses = {
        "L_time": 0.0,
        "L_freq": 0.0,
        "L_adv": 0.0,
        "L_feat": 0.0,
        "L_mae": 0.0,
        "L_total": 0.0
        }
    
    for wavs in dl:
        
        wavs = wavs.to(device)

        res = model(wavs)

        x_hat = res["x_hat"]
        x = res["orig_waveform"]
        mae_pred = res["mae_pred"]
        mae_target = res["mae_target"]
        mask_idx = res["mask_indices"]

        # Late discriminator training
        if epoch >= 50:
            discriminator_loss = compute_discriminator_loss(discriminators, x, x_hat)
            optim_d.zero_grad()
            discriminator_loss.backward()
            optim_d.step()

        # Generator training
        generator_loss = compute_generator_loss(
            x_hat=x_hat,
            x=x,
            discriminators=discriminators,
            mae_pred=mae_pred,
            mae_target=mae_target,
            mask_idx=mask_idx
        )

        for loss_type, loss_value in generator_loss.items():
            losses[loss_type] = losses[loss_type] + loss_value.item()
        
        total_gen_loss = generator_loss["L_total"]
        
        optim_g.zero_grad()
        total_gen_loss.backward()
        optim_g.step()

    # Save the model every n epochs
    if epoch % checkpoint_freq == 0:
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"alm_tokenizer_epoch_{epoch}.pth"))
        print(f"Model saved at epoch {epoch}")    

    # Log losses
    for loss_type, loss_value in losses.items():
        losses[loss_type] /= len(wavs)
        writer.add_scalar(f"losses/{loss_type}", losses[loss_type], epoch)
    
    print(f"Epoch {epoch + 1:2d} | Average Loss: {losses['L_total']:.4f}")
    torch.cuda.empty_cache()


  0%|          | 1/200 [01:20<4:28:04, 80.83s/it]

Model saved at epoch 0
Epoch  1 | Average Loss: 248.4803


  1%|          | 2/200 [02:42<4:29:04, 81.54s/it]

Epoch  2 | Average Loss: 242.0003


  2%|▏         | 3/200 [04:05<4:29:50, 82.19s/it]

Epoch  3 | Average Loss: 239.5903


  2%|▏         | 4/200 [05:29<4:30:03, 82.67s/it]

Epoch  4 | Average Loss: 237.7300


  2%|▎         | 5/200 [06:52<4:29:15, 82.85s/it]

Epoch  5 | Average Loss: 236.2550


  3%|▎         | 6/200 [08:15<4:28:23, 83.01s/it]

Epoch  6 | Average Loss: 234.7717


  4%|▎         | 7/200 [09:38<4:27:11, 83.06s/it]

Epoch  7 | Average Loss: 232.8170


  4%|▍         | 8/200 [11:02<4:25:59, 83.12s/it]

Epoch  8 | Average Loss: 231.1425


  4%|▍         | 9/200 [12:25<4:24:30, 83.09s/it]

Epoch  9 | Average Loss: 229.2674


  5%|▌         | 10/200 [13:46<4:21:02, 82.43s/it]

Epoch 10 | Average Loss: 227.8231


  6%|▌         | 11/200 [15:06<4:17:41, 81.81s/it]

Model saved at epoch 10
Epoch 11 | Average Loss: 226.0542


  6%|▌         | 12/200 [16:26<4:14:58, 81.37s/it]

Epoch 12 | Average Loss: 224.8255


  6%|▋         | 13/200 [17:49<4:14:54, 81.79s/it]

Epoch 13 | Average Loss: 223.4280


  7%|▋         | 14/200 [19:12<4:14:21, 82.05s/it]

Epoch 14 | Average Loss: 221.4942


  8%|▊         | 15/200 [20:35<4:13:41, 82.28s/it]

Epoch 15 | Average Loss: 220.3755


  8%|▊         | 16/200 [21:57<4:12:52, 82.46s/it]

Epoch 16 | Average Loss: 218.8735


  8%|▊         | 17/200 [23:19<4:10:12, 82.03s/it]

Epoch 17 | Average Loss: 217.1961


  9%|▉         | 18/200 [24:39<4:07:27, 81.58s/it]

Epoch 18 | Average Loss: 215.7541


 10%|▉         | 19/200 [26:00<4:05:20, 81.33s/it]

Epoch 19 | Average Loss: 214.1405


 10%|█         | 20/200 [27:20<4:03:15, 81.09s/it]

Epoch 20 | Average Loss: 213.2747


 10%|█         | 21/200 [28:41<4:01:39, 81.00s/it]

Model saved at epoch 20
Epoch 21 | Average Loss: 211.9050


 11%|█         | 22/200 [30:02<3:59:55, 80.87s/it]

Epoch 22 | Average Loss: 210.2128


 12%|█▏        | 23/200 [31:22<3:58:25, 80.82s/it]

Epoch 23 | Average Loss: 209.2810


 12%|█▏        | 24/200 [32:43<3:56:52, 80.75s/it]

Epoch 24 | Average Loss: 207.5802


 12%|█▎        | 25/200 [34:04<3:55:22, 80.70s/it]

Epoch 25 | Average Loss: 206.3331


 13%|█▎        | 26/200 [35:24<3:54:05, 80.72s/it]

Epoch 26 | Average Loss: 205.0960


 14%|█▎        | 27/200 [36:45<3:52:32, 80.65s/it]

Epoch 27 | Average Loss: 204.1463


 14%|█▍        | 28/200 [38:05<3:51:09, 80.64s/it]

Epoch 28 | Average Loss: 202.6756


 14%|█▍        | 29/200 [39:26<3:49:49, 80.64s/it]

Epoch 29 | Average Loss: 201.4112


 15%|█▌        | 30/200 [40:47<3:48:22, 80.60s/it]

Epoch 30 | Average Loss: 200.2220


 16%|█▌        | 31/200 [42:07<3:47:00, 80.59s/it]

Model saved at epoch 30
Epoch 31 | Average Loss: 199.3158


 16%|█▌        | 32/200 [43:28<3:45:43, 80.61s/it]

Epoch 32 | Average Loss: 198.2685


 16%|█▋        | 33/200 [44:48<3:44:24, 80.62s/it]

Epoch 33 | Average Loss: 196.6678


 17%|█▋        | 34/200 [46:09<3:42:53, 80.56s/it]

Epoch 34 | Average Loss: 195.4832


 18%|█▊        | 35/200 [47:29<3:41:30, 80.55s/it]

Epoch 35 | Average Loss: 194.6501


 18%|█▊        | 36/200 [48:50<3:40:11, 80.56s/it]

Epoch 36 | Average Loss: 193.7098


 18%|█▊        | 37/200 [50:11<3:38:56, 80.59s/it]

Epoch 37 | Average Loss: 192.3709


 19%|█▉        | 38/200 [51:31<3:37:36, 80.60s/it]

Epoch 38 | Average Loss: 191.3959


 20%|█▉        | 39/200 [52:52<3:36:21, 80.63s/it]

Epoch 39 | Average Loss: 190.4375


 20%|██        | 40/200 [54:13<3:34:56, 80.61s/it]

Epoch 40 | Average Loss: 189.3899


 20%|██        | 41/200 [55:33<3:33:42, 80.64s/it]

Model saved at epoch 40
Epoch 41 | Average Loss: 188.5487


 21%|██        | 42/200 [56:54<3:32:25, 80.67s/it]

Epoch 42 | Average Loss: 187.4033


 22%|██▏       | 43/200 [58:15<3:31:06, 80.68s/it]

Epoch 43 | Average Loss: 186.3815


 22%|██▏       | 44/200 [59:35<3:29:43, 80.67s/it]

Epoch 44 | Average Loss: 185.4175


 22%|██▎       | 45/200 [1:00:56<3:28:14, 80.61s/it]

Epoch 45 | Average Loss: 184.2711


 23%|██▎       | 46/200 [1:02:17<3:27:00, 80.65s/it]

Epoch 46 | Average Loss: 183.7648


 24%|██▎       | 47/200 [1:03:37<3:25:36, 80.63s/it]

Epoch 47 | Average Loss: 182.5276


 24%|██▍       | 48/200 [1:04:58<3:24:13, 80.62s/it]

Epoch 48 | Average Loss: 182.0026


 24%|██▍       | 49/200 [1:06:19<3:23:03, 80.68s/it]

Epoch 49 | Average Loss: 180.5740


 25%|██▌       | 50/200 [1:07:39<3:21:41, 80.68s/it]

Epoch 50 | Average Loss: 179.9082


 26%|██▌       | 51/200 [1:09:13<3:30:19, 84.70s/it]

Model saved at epoch 50
Epoch 51 | Average Loss: 565.9153


 26%|██▌       | 52/200 [1:10:47<3:35:56, 87.54s/it]

Epoch 52 | Average Loss: 736.0941


 26%|██▋       | 53/200 [1:12:24<3:41:25, 90.38s/it]

Epoch 53 | Average Loss: 788.3795


 27%|██▋       | 54/200 [1:14:01<3:44:43, 92.35s/it]

Epoch 54 | Average Loss: 842.9887


 28%|██▊       | 55/200 [1:15:38<3:46:21, 93.66s/it]

Epoch 55 | Average Loss: 930.0933


 28%|██▊       | 56/200 [1:17:15<3:46:48, 94.50s/it]

Epoch 56 | Average Loss: 956.8872


 28%|██▊       | 57/200 [1:18:49<3:45:19, 94.54s/it]

Epoch 57 | Average Loss: 915.5637


 29%|██▉       | 58/200 [1:20:23<3:43:04, 94.26s/it]

Epoch 58 | Average Loss: 1044.6233


 30%|██▉       | 59/200 [1:21:57<3:41:05, 94.08s/it]

Epoch 59 | Average Loss: 1043.5514


 30%|███       | 60/200 [1:23:30<3:39:19, 94.00s/it]

Epoch 60 | Average Loss: 1142.4351


 30%|███       | 61/200 [1:25:04<3:37:42, 93.98s/it]

Model saved at epoch 60
Epoch 61 | Average Loss: 1156.7800


 31%|███       | 62/200 [1:26:38<3:35:54, 93.87s/it]

Epoch 62 | Average Loss: 1152.5017


 32%|███▏      | 63/200 [1:28:11<3:34:09, 93.79s/it]

Epoch 63 | Average Loss: 1220.4083


 32%|███▏      | 64/200 [1:29:45<3:32:33, 93.77s/it]

Epoch 64 | Average Loss: 1283.2130


 32%|███▎      | 65/200 [1:31:19<3:30:52, 93.72s/it]

Epoch 65 | Average Loss: 1318.1096


 33%|███▎      | 66/200 [1:32:53<3:29:18, 93.72s/it]

Epoch 66 | Average Loss: 1302.7643


 34%|███▎      | 67/200 [1:34:26<3:27:41, 93.69s/it]

Epoch 67 | Average Loss: 1305.3916


 34%|███▍      | 68/200 [1:36:00<3:26:14, 93.74s/it]

Epoch 68 | Average Loss: 1292.2196


 34%|███▍      | 69/200 [1:37:34<3:24:40, 93.74s/it]

Epoch 69 | Average Loss: 1307.7562


 35%|███▌      | 70/200 [1:39:07<3:22:58, 93.68s/it]

Epoch 70 | Average Loss: 1318.2578


 36%|███▌      | 71/200 [1:40:41<3:21:23, 93.67s/it]

Model saved at epoch 70
Epoch 71 | Average Loss: 1429.3736


 36%|███▌      | 72/200 [1:42:15<3:19:47, 93.66s/it]

Epoch 72 | Average Loss: 1350.4330


 36%|███▋      | 73/200 [1:43:48<3:18:15, 93.66s/it]

Epoch 73 | Average Loss: 1410.6902


 37%|███▋      | 74/200 [1:45:22<3:16:29, 93.57s/it]

Epoch 74 | Average Loss: 1265.7533


 38%|███▊      | 75/200 [1:46:55<3:15:05, 93.65s/it]

Epoch 75 | Average Loss: 1313.2156


 38%|███▊      | 76/200 [1:48:29<3:13:27, 93.61s/it]

Epoch 76 | Average Loss: 1318.1702


 38%|███▊      | 77/200 [1:50:03<3:11:57, 93.64s/it]

Epoch 77 | Average Loss: 1373.8905


 39%|███▉      | 78/200 [1:51:36<3:10:21, 93.62s/it]

Epoch 78 | Average Loss: 1404.7575


 40%|███▉      | 79/200 [1:53:10<3:08:54, 93.68s/it]

Epoch 79 | Average Loss: 1386.2251


 40%|████      | 80/200 [1:54:43<3:07:12, 93.61s/it]

Epoch 80 | Average Loss: 1388.6422


 40%|████      | 81/200 [1:56:17<3:05:41, 93.63s/it]

Model saved at epoch 80
Epoch 81 | Average Loss: 1394.7897


 41%|████      | 82/200 [1:57:51<3:04:05, 93.61s/it]

Epoch 82 | Average Loss: 1467.5144


 42%|████▏     | 83/200 [1:59:24<3:02:26, 93.56s/it]

Epoch 83 | Average Loss: 1415.4720


 42%|████▏     | 84/200 [2:00:58<3:00:51, 93.55s/it]

Epoch 84 | Average Loss: 1442.8066


 42%|████▎     | 85/200 [2:02:31<2:59:14, 93.52s/it]

Epoch 85 | Average Loss: 1503.2080


 43%|████▎     | 86/200 [2:04:05<2:57:37, 93.48s/it]

Epoch 86 | Average Loss: 1505.8797


 44%|████▎     | 87/200 [2:05:38<2:56:01, 93.47s/it]

Epoch 87 | Average Loss: 1505.1325


 44%|████▍     | 88/200 [2:07:11<2:54:30, 93.48s/it]

Epoch 88 | Average Loss: 1498.8279


 44%|████▍     | 89/200 [2:08:45<2:52:58, 93.50s/it]

Epoch 89 | Average Loss: 1531.6924


 45%|████▌     | 90/200 [2:10:18<2:51:24, 93.49s/it]

Epoch 90 | Average Loss: 1557.5906


 46%|████▌     | 91/200 [2:11:52<2:49:54, 93.53s/it]

Model saved at epoch 90
Epoch 91 | Average Loss: 1558.4831


 46%|████▌     | 92/200 [2:13:26<2:48:20, 93.52s/it]

Epoch 92 | Average Loss: 1480.4445


 46%|████▋     | 93/200 [2:14:59<2:46:46, 93.52s/it]

Epoch 93 | Average Loss: 1465.0996


 47%|████▋     | 94/200 [2:16:33<2:45:21, 93.60s/it]

Epoch 94 | Average Loss: 1465.4296


 48%|████▊     | 95/200 [2:18:07<2:43:53, 93.66s/it]

Epoch 95 | Average Loss: 1465.0234


 48%|████▊     | 96/200 [2:19:40<2:42:12, 93.58s/it]

Epoch 96 | Average Loss: 1464.3908


 48%|████▊     | 97/200 [2:21:14<2:40:39, 93.59s/it]

Epoch 97 | Average Loss: 1464.4853


 49%|████▉     | 98/200 [2:22:47<2:39:05, 93.59s/it]

Epoch 98 | Average Loss: 1463.6292


 50%|████▉     | 99/200 [2:24:21<2:37:29, 93.56s/it]

Epoch 99 | Average Loss: 1463.3696


 50%|█████     | 100/200 [2:25:54<2:35:53, 93.54s/it]

Epoch 100 | Average Loss: 1482.7472


 50%|█████     | 101/200 [2:27:28<2:34:25, 93.59s/it]

Model saved at epoch 100
Epoch 101 | Average Loss: 1543.0617


 51%|█████     | 102/200 [2:29:02<2:32:53, 93.60s/it]

Epoch 102 | Average Loss: 1506.1310


 52%|█████▏    | 103/200 [2:30:35<2:31:10, 93.51s/it]

Epoch 103 | Average Loss: 1554.7788


 52%|█████▏    | 104/200 [2:32:08<2:29:37, 93.51s/it]

Epoch 104 | Average Loss: 1574.0610


 52%|█████▎    | 105/200 [2:33:42<2:28:04, 93.52s/it]

Epoch 105 | Average Loss: 1574.2756


 53%|█████▎    | 106/200 [2:35:15<2:26:31, 93.52s/it]

Epoch 106 | Average Loss: 1573.8990


 54%|█████▎    | 107/200 [2:36:49<2:24:52, 93.47s/it]

Epoch 107 | Average Loss: 1573.8309


 54%|█████▍    | 108/200 [2:38:22<2:23:24, 93.53s/it]

Epoch 108 | Average Loss: 1573.5232


 55%|█████▍    | 109/200 [2:39:56<2:21:49, 93.51s/it]

Epoch 109 | Average Loss: 1573.3015


 55%|█████▌    | 110/200 [2:41:29<2:20:14, 93.49s/it]

Epoch 110 | Average Loss: 1573.6285


 56%|█████▌    | 111/200 [2:43:03<2:18:46, 93.56s/it]

Model saved at epoch 110
Epoch 111 | Average Loss: 1572.9313


 56%|█████▌    | 112/200 [2:44:37<2:17:09, 93.52s/it]

Epoch 112 | Average Loss: 1573.7790


 56%|█████▋    | 113/200 [2:46:10<2:15:39, 93.56s/it]

Epoch 113 | Average Loss: 1573.2284


 57%|█████▋    | 114/200 [2:47:46<2:14:52, 94.10s/it]

Epoch 114 | Average Loss: 1572.9739


 57%|█████▊    | 115/200 [2:49:20<2:13:22, 94.14s/it]

Epoch 115 | Average Loss: 1573.1067


 58%|█████▊    | 116/200 [2:50:53<2:11:28, 93.91s/it]

Epoch 116 | Average Loss: 1572.7399


 58%|█████▊    | 117/200 [2:52:27<2:09:41, 93.75s/it]

Epoch 117 | Average Loss: 1572.4314


 59%|█████▉    | 118/200 [2:54:00<2:08:05, 93.72s/it]

Epoch 118 | Average Loss: 1573.2501


 60%|█████▉    | 119/200 [2:55:34<2:06:22, 93.61s/it]

Epoch 119 | Average Loss: 1572.9399


 60%|██████    | 120/200 [2:57:07<2:04:47, 93.60s/it]

Epoch 120 | Average Loss: 1572.5766


 60%|██████    | 121/200 [2:58:41<2:03:15, 93.61s/it]

Model saved at epoch 120
Epoch 121 | Average Loss: 1572.2053


 61%|██████    | 122/200 [3:00:14<2:01:39, 93.58s/it]

Epoch 122 | Average Loss: 1572.5490


 62%|██████▏   | 123/200 [3:01:48<2:00:03, 93.55s/it]

Epoch 123 | Average Loss: 1572.5165


 62%|██████▏   | 124/200 [3:03:21<1:58:29, 93.54s/it]

Epoch 124 | Average Loss: 1572.5742


 62%|██████▎   | 125/200 [3:04:55<1:56:49, 93.46s/it]

Epoch 125 | Average Loss: 1571.9718


 63%|██████▎   | 126/200 [3:06:28<1:55:13, 93.42s/it]

Epoch 126 | Average Loss: 1572.2810


 64%|██████▎   | 127/200 [3:08:01<1:53:40, 93.43s/it]

Epoch 127 | Average Loss: 1572.6063


 64%|██████▍   | 128/200 [3:09:35<1:52:06, 93.43s/it]

Epoch 128 | Average Loss: 1572.5586


 64%|██████▍   | 129/200 [3:11:08<1:50:33, 93.43s/it]

Epoch 129 | Average Loss: 1571.6485


 65%|██████▌   | 130/200 [3:12:42<1:49:00, 93.44s/it]

Epoch 130 | Average Loss: 1572.0832


 66%|██████▌   | 131/200 [3:14:15<1:47:28, 93.46s/it]

Model saved at epoch 130
Epoch 131 | Average Loss: 1572.8929


 66%|██████▌   | 132/200 [3:15:48<1:45:51, 93.41s/it]

Epoch 132 | Average Loss: 1572.4347


 66%|██████▋   | 133/200 [3:17:22<1:44:19, 93.43s/it]

Epoch 133 | Average Loss: 1571.2622


 67%|██████▋   | 134/200 [3:18:55<1:42:45, 93.41s/it]

Epoch 134 | Average Loss: 1572.1882


 68%|██████▊   | 135/200 [3:20:29<1:41:12, 93.43s/it]

Epoch 135 | Average Loss: 1572.1657


 68%|██████▊   | 136/200 [3:22:02<1:39:36, 93.39s/it]

Epoch 136 | Average Loss: 1571.6199


 68%|██████▊   | 137/200 [3:23:35<1:38:02, 93.37s/it]

Epoch 137 | Average Loss: 1571.9658


 69%|██████▉   | 138/200 [3:25:09<1:36:32, 93.43s/it]

Epoch 138 | Average Loss: 1571.9262


 70%|██████▉   | 139/200 [3:26:42<1:35:00, 93.44s/it]

Epoch 139 | Average Loss: 1572.1290


 70%|███████   | 140/200 [3:28:16<1:33:26, 93.44s/it]

Epoch 140 | Average Loss: 1571.9877


 70%|███████   | 141/200 [3:29:49<1:31:53, 93.45s/it]

Model saved at epoch 140
Epoch 141 | Average Loss: 1571.9681


 71%|███████   | 142/200 [3:31:23<1:30:17, 93.41s/it]

Epoch 142 | Average Loss: 1571.4222


 72%|███████▏  | 143/200 [3:32:56<1:28:44, 93.41s/it]

Epoch 143 | Average Loss: 1571.7735


 72%|███████▏  | 144/200 [3:34:30<1:27:13, 93.46s/it]

Epoch 144 | Average Loss: 1571.7178


 72%|███████▎  | 145/200 [3:36:03<1:25:40, 93.46s/it]

Epoch 145 | Average Loss: 1571.6758


 73%|███████▎  | 146/200 [3:37:37<1:24:06, 93.45s/it]

Epoch 146 | Average Loss: 1571.8399


 74%|███████▎  | 147/200 [3:39:10<1:22:30, 93.40s/it]

Epoch 147 | Average Loss: 1571.0325


 74%|███████▍  | 148/200 [3:40:43<1:20:57, 93.41s/it]

Epoch 148 | Average Loss: 1571.8946


 74%|███████▍  | 149/200 [3:42:17<1:19:23, 93.40s/it]

Epoch 149 | Average Loss: 1571.6789


 75%|███████▌  | 150/200 [3:43:50<1:17:45, 93.30s/it]

Epoch 150 | Average Loss: 1571.4135


 76%|███████▌  | 151/200 [3:45:23<1:16:12, 93.32s/it]

Model saved at epoch 150
Epoch 151 | Average Loss: 1571.7591


 76%|███████▌  | 152/200 [3:46:57<1:14:41, 93.36s/it]

Epoch 152 | Average Loss: 1571.3932


 76%|███████▋  | 153/200 [3:48:30<1:13:06, 93.33s/it]

Epoch 153 | Average Loss: 1571.4144


 77%|███████▋  | 154/200 [3:50:03<1:11:34, 93.37s/it]

Epoch 154 | Average Loss: 1571.6619


 78%|███████▊  | 155/200 [3:51:36<1:09:59, 93.33s/it]

Epoch 155 | Average Loss: 1571.5706


 78%|███████▊  | 156/200 [3:53:10<1:08:28, 93.38s/it]

Epoch 156 | Average Loss: 1570.8925


 78%|███████▊  | 157/200 [3:54:44<1:06:57, 93.43s/it]

Epoch 157 | Average Loss: 1571.8042


 79%|███████▉  | 158/200 [3:56:17<1:05:23, 93.41s/it]

Epoch 158 | Average Loss: 1570.9850


 80%|███████▉  | 159/200 [3:57:50<1:03:48, 93.39s/it]

Epoch 159 | Average Loss: 1571.9297


 80%|████████  | 160/200 [3:59:24<1:02:17, 93.44s/it]

Epoch 160 | Average Loss: 1571.3001


 80%|████████  | 161/200 [4:00:57<1:00:46, 93.50s/it]

Model saved at epoch 160
Epoch 161 | Average Loss: 1572.2995


 81%|████████  | 162/200 [4:02:31<59:16, 93.59s/it]  

Epoch 162 | Average Loss: 1571.5287


 82%|████████▏ | 163/200 [4:04:05<57:40, 93.52s/it]

Epoch 163 | Average Loss: 1571.7740


 82%|████████▏ | 164/200 [4:05:38<56:06, 93.52s/it]

Epoch 164 | Average Loss: 1571.6414


 82%|████████▎ | 165/200 [4:07:11<54:30, 93.45s/it]

Epoch 165 | Average Loss: 1571.8380


 83%|████████▎ | 166/200 [4:08:45<52:57, 93.47s/it]

Epoch 166 | Average Loss: 1570.8040


 84%|████████▎ | 167/200 [4:10:18<51:23, 93.45s/it]

Epoch 167 | Average Loss: 1571.6221


 84%|████████▍ | 168/200 [4:11:51<49:46, 93.34s/it]

Epoch 168 | Average Loss: 1571.8563


 84%|████████▍ | 169/200 [4:13:25<48:13, 93.35s/it]

Epoch 169 | Average Loss: 1571.9095


 85%|████████▌ | 170/200 [4:14:58<46:43, 93.46s/it]

Epoch 170 | Average Loss: 1571.6281


 86%|████████▌ | 171/200 [4:16:32<45:09, 93.44s/it]

Model saved at epoch 170
Epoch 171 | Average Loss: 1571.7866


 86%|████████▌ | 172/200 [4:18:05<43:35, 93.39s/it]

Epoch 172 | Average Loss: 1571.8843


 86%|████████▋ | 173/200 [4:19:39<42:01, 93.39s/it]

Epoch 173 | Average Loss: 1571.4476


 87%|████████▋ | 174/200 [4:21:12<40:27, 93.36s/it]

Epoch 174 | Average Loss: 1571.0516


 88%|████████▊ | 175/200 [4:22:45<38:54, 93.39s/it]

Epoch 175 | Average Loss: 1571.1705


 88%|████████▊ | 176/200 [4:24:19<37:20, 93.36s/it]

Epoch 176 | Average Loss: 1571.6656


 88%|████████▊ | 177/200 [4:25:52<35:47, 93.35s/it]

Epoch 177 | Average Loss: 1570.6174


 89%|████████▉ | 178/200 [4:27:25<34:13, 93.35s/it]

Epoch 178 | Average Loss: 1571.6634


 90%|████████▉ | 179/200 [4:28:59<32:41, 93.41s/it]

Epoch 179 | Average Loss: 1572.0333


 90%|█████████ | 180/200 [4:30:32<31:07, 93.36s/it]

Epoch 180 | Average Loss: 1571.2805


 90%|█████████ | 181/200 [4:32:06<29:34, 93.41s/it]

Model saved at epoch 180
Epoch 181 | Average Loss: 1571.2829


 91%|█████████ | 182/200 [4:33:39<28:01, 93.41s/it]

Epoch 182 | Average Loss: 1571.2317


 92%|█████████▏| 183/200 [4:35:12<26:27, 93.40s/it]

Epoch 183 | Average Loss: 1571.8544


 92%|█████████▏| 184/200 [4:36:46<24:54, 93.40s/it]

Epoch 184 | Average Loss: 1570.6277


 92%|█████████▎| 185/200 [4:38:19<23:20, 93.39s/it]

Epoch 185 | Average Loss: 1571.2521


 93%|█████████▎| 186/200 [4:39:52<21:47, 93.36s/it]

Epoch 186 | Average Loss: 1571.1527


 94%|█████████▎| 187/200 [4:41:26<20:13, 93.32s/it]

Epoch 187 | Average Loss: 1571.5033


 94%|█████████▍| 188/200 [4:42:59<18:39, 93.30s/it]

Epoch 188 | Average Loss: 1571.3259


 94%|█████████▍| 189/200 [4:44:32<17:06, 93.32s/it]

Epoch 189 | Average Loss: 1571.0480


 95%|█████████▌| 190/200 [4:46:06<15:33, 93.33s/it]

Epoch 190 | Average Loss: 1571.2595


 96%|█████████▌| 191/200 [4:47:39<14:00, 93.37s/it]

Model saved at epoch 190
Epoch 191 | Average Loss: 1571.5148


 96%|█████████▌| 192/200 [4:49:13<12:27, 93.38s/it]

Epoch 192 | Average Loss: 1571.3367


 96%|█████████▋| 193/200 [4:50:46<10:53, 93.37s/it]

Epoch 193 | Average Loss: 1571.3053


 97%|█████████▋| 194/200 [4:52:19<09:20, 93.37s/it]

Epoch 194 | Average Loss: 1571.7974


 98%|█████████▊| 195/200 [4:53:52<07:46, 93.34s/it]

Epoch 195 | Average Loss: 1571.8360


 98%|█████████▊| 196/200 [4:55:26<06:13, 93.31s/it]

Epoch 196 | Average Loss: 1571.0288


 98%|█████████▊| 197/200 [4:56:59<04:39, 93.28s/it]

Epoch 197 | Average Loss: 1571.5688


 99%|█████████▉| 198/200 [4:58:32<03:06, 93.34s/it]

Epoch 198 | Average Loss: 1571.5719


100%|█████████▉| 199/200 [5:00:06<01:33, 93.32s/it]

Epoch 199 | Average Loss: 1571.1002


100%|██████████| 200/200 [5:01:39<00:00, 90.50s/it]

Epoch 200 | Average Loss: 1571.2805





In [7]:
res = model(wavs)

import IPython.display as ipd
for n in range(8):
    ipd.display(ipd.Audio(res["x_hat"][n].cpu().numpy(), rate=24000))
    ipd.display(ipd.Audio(res["orig_waveform"][n].cpu().numpy(), rate=24000))