In [1]:
import matplotlib.pyplot as plt
from icecream import ic
import time
from collections import OrderedDict

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from copy import deepcopy

from src.utils.audio_utils import playAudio
from src.model import get_magnet_model, MAGNET
from src.preprocess_ops import PreProOps
from src.music_bench import (
    MAX_SEC, split_ds,
    shuffle_preserve_order,
    QCODING_LEN,
)
from train import MagnetTrainer
from src.utils.lr_scheduler import CosineDecayWithWarmup
from src.music_bench import AUDIO_TXT_PATH, ioPathTextDs, PreProDataset

torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from train import tonfig
DEVICE = torch.device("cuda")

In [2]:
ctx = (
        torch.autocast(
                device_type="cuda" if "cuda" in DEVICE.type else "cpu",
                dtype={"bfloat16": torch.bfloat16,
                       "float16" : torch.float16, "float32": torch.float32}["float32"]
        )
    )

In [3]:
preprocess_ops = PreProOps(
    max_sec=QCODING_LEN,
    print_info=True,
    device=DEVICE.type,
    compile=False,
    autocast=ctx
)




Number of Parameters in Encodec Model: 14.85181 Million Parameters


Number of Parameters in T5 Model (google-t5/t5-small): 35.330816 Million Parameters



In [4]:
import random as r
iterator = lambda X, y, split: iter(
        PreProDataset(
            split=split,
            randgen=r.Random(tonfig.seed),
            audio_pad_id=tonfig.mask_id,
            qcoding_len=tonfig.seqlen,
            device=DEVICE.type,
            pre_computed_tensors_dirpath=tonfig.PRECOMPUTED_TENSORS_DIRPATH,
            online=True,
            wav_paths=X, texts=y,
            preprocess_ops=preprocess_ops,
        ).iter_batches()
    )

paths, texts = ioPathTextDs(
    save_path=AUDIO_TXT_PATH,
    batch_size=64,
    split_float=0.9,
    return_ds=True
)
train_iterator = iterator(paths, texts, "train")

In [5]:
magnet_model:MAGNET = get_magnet_model(compile=False).to(DEVICE)
debug_input = next(train_iterator)

# Test Training:

# Try Overfitting lil batch

In [6]:
debug_input[0]["qcode"].shape, debug_input[1].shape, 

(torch.Size([64, 4, 750]), torch.Size([64, 196, 512]))

In [7]:
debug_input = ({"qcode": debug_input[0]["qcode"][:2].to(DEVICE), "mask": debug_input[0]["mask"][:2].to(DEVICE)}
               , debug_input[1][:2].to(DEVICE))

In [8]:
debug_input[0]["qcode"].shape, debug_input[1].shape, 

(torch.Size([2, 4, 750]), torch.Size([2, 196, 512]))

In [9]:
print(sum(p.numel() for p in magnet_model.parameters() if p.requires_grad)/1e6, "Million Parameters")

37.833728 Million Parameters


In [10]:
magnet_trainer = MagnetTrainer(
    magnet_model=magnet_model,
    config=tonfig
)

get_lr = CosineDecayWithWarmup(
    warmup_steps=tonfig.warmup_steps,
    max_learning_rate=tonfig.max_learning_rate,
    decay_steps=tonfig.decay_steps,
    min_learning_rate=tonfig.min_learning_rate
)

scaler = torch.cuda.amp.GradScaler(enabled=(tonfig.dtype=="float16"))

optimizer = magnet_model.configure_optimizers(
    weight_decay=tonfig.weight_decay,
    learning_rate=5e-4,
    betas=(tonfig.beta1, tonfig.beta2),
    device_type="cuda" if "cuda" in DEVICE.type else "cpu"
)

@torch.no_grad()
def update_ema(ema_model:MAGNET, model:MAGNET, decay:float):
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # ema = decay*ema + (1-decay)*no_ema
        ema_params[name].mul_(decay).add_(param.data, alpha=1-decay)

def requires_grad(model:nn.Module, requires_grad:bool):
    for param in model.parameters():
        param.requires_grad = requires_grad

In [11]:
ema = deepcopy(magnet_model).to(DEVICE) # sampling with ema model
requires_grad(ema, requires_grad=False)
magnet_model.train()
ema.eval()
update_ema(ema, magnet_model, decay=0.0) # ema_model weights are in sync with magnet_model  

In [None]:
class cl:
    def iterrr():
        yield {"qcode": debug_input[0]["qcode"].to(DEVICE), "mask": debug_input[0]["mask"].to(DEVICE)}, debug_input[1].to(DEVICE)

In [12]:
losses, accuracies = [], []
def test_train():
    audio_input, cond_text = debug_input

    print("Training about to start...")
    t0 = time.time()
    for step in range(0, 2000):
        # set learning rate for all params
        lr = 5e-4
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        # gradient accumulation step
        for mini_step in range(tonfig.num_grad_accumalation_steps):
            with ctx:
                loss, accuracy = magnet_trainer.mini_train_step(
                    audio_input=audio_input, cond_tensor=cond_text
                )
                loss /= tonfig.num_grad_accumalation_steps
                # async prefetch immediately
                audio_input, cond_text = debug_input

            # keeps on scaling and adrequires_gradrd()

        if tonfig.clipnorm is not None:
            # unscale the gradients
            scaler.unscale_(optimizer)
            # clips gradients in-place to grad norm
            grad_norm = nn.utils.clip_grad_norm_(magnet_model.parameters(), max_norm=tonfig.clipnorm)

        # calls unscale to the optimizer unless already called, checks for infs and nans as a part of unscale_
        # calls optimizer.step on unscaled grads if no infs and nans else optimizer.step is skipped
        scaler.step(optimizer)
        # Update the scale factor
        scaler.update()

        # flush grads to save memory
        optimizer.zero_grad(set_to_none=True)

        # update ema model
        update_ema(ema, magnet_model, decay=tonfig.ema_momentum)

        # some logging
        t1 = time.time()
        dt = t1-t0
        t0 = t1
        if step % tonfig.log_interval == 0:
            # multiply as loss was scaled for gradient accumulation
            lossf = loss.item() * tonfig.num_grad_accumalation_steps
            print(
                f"| Step: {step} || Loss: {lossf:.4f} || Masked Accuracy: {accuracy[1]:.4f} | Accuracy: {accuracy[0]:.4f} |"
                f"| LR: {lr:e} || dt: {dt*1000:.2f}ms || Norm: {grad_norm} ||"
            )
            losses.append(lossf); accuracies.append(accuracy)
    return losses, accuracies

losses, accuracies = test_train() # Clear Output

Training about to start...


RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 512 n 1500 k 512 mat1_ld 512 mat2_ld 512 result_ld 512 abcType 0 computeType 77 scaleType 0

In [None]:
plt.plot(losses, label="Loss")
plt.plot(accuracies, label="Accuracy \ Masked Accuracy")
plt.grid(True)
plt.ylabel("Loss/Accuracy")
plt.xlabel("Steps")
plt.legend()
plt.show()

In [None]:
min(losses), max(accuracies), losses.index(min(losses)), accuracies.index(max(accuracies)), losses[1000:].index(max(losses[1000:]))+1000, max(losses[1000:])

# Test Generate

In [None]:
debug_input = next(train_iterator)
debug_input

In [None]:
magnet_model.eval()
gen_tok = magnet_model.generate(
    prompt=debug_input[1],
    preprocess_ops=preprocess_ops,
    device=DEVICE,
    top_p=0.9,
    decoding_steps=[20, 10, 10, 10]
) # (2, 4, 750)
gen_wav = preprocess_ops.getAudioFromCodings(gen_tok) # (2, 1, 240000)
playAudio(tensor=gen_wav[0].squeeze())
playAudio(tensor=gen_wav[1].squeeze())

In [None]:
ema.eval()
gen_tok = ema.generate(
    prompt=debug_input[1],
    preprocess_ops=preprocess_ops,
    device=DEVICE,
    top_p=0.9,
    decoding_steps=[20, 10, 10, 10]
)
gen_wav = preprocess_ops.getAudioFromCodings(gen_tok)
playAudio(tensor=gen_wav[0].squeeze())
playAudio(tensor=gen_wav[1].squeeze())

In [None]:
# original, from dataset
gen_tok = debug_input[0]["qcode"][1][None]
gen_wav = preprocess_ops.getAudioFromCodings(gen_tok)
playAudio(tensor=gen_wav[0].squeeze())

# Nothing...

In [None]:
import matplotlib.pyplot as plt
import math

In [None]:
def plot(num_decoding_steps):
    func = lambda t: math.cos((math.pi * t)/(2*num_decoding_steps))
    plt.plot(list(range(num_decoding_steps)), [func(t) for t in range(num_decoding_steps)], label="mask_p")
    plt.xlabel("Decoding Steps")
    plt.ylabel("Mask Probability (in inference)")
    plt.grid(True)
    plt.legend()
    plt.show()

In [None]:
plot(20)

In [None]:
plot(10)

In [None]:
import torch
import typing as tp
import glob
import os
import random as r

class PreProDataset:
    def __init__(
        self, *,
        split:str, # 'train' or 'val'
        audio_pad_id:int,
        qcoding_len:int,
        device:str,
        randgen:r.Random,
        pre_computed_tensors_dirpath:tp.Optional[str]=None,
        online:bool=False,
        # Optionally None if online is False
        wav_paths:tp.Optional[list[str]]=None,
        texts:tp.Optional[list[str]]=None,
        preprocess_ops:tp.Optional[tp.Any]=None
    ):
        assert split in ["train", "val"]
        self.online = online

        if self.online:
            assert all([wav_paths is not None, texts is not None, preprocess_ops is not None])
            self.wav_paths = wav_paths # (N//B, B)
            self.texts = texts # (N//B, B)
            self.preprocess_ops = preprocess_ops

            self.wav_paths, self.texts = split_ds(self.wav_paths, self.texts, split_float=0.9)[split]
        else:
            assert pre_computed_tensors_dirpath is not None
            shard_filenames = sorted(glob.glob(os.path.join(pre_computed_tensors_dirpath, "musicbench*.pt")))
            assert len(shard_filenames) > 0
            self.shard_filenames = split_ds(shard_filenames, None, split_float=0.9)[split]

        self.audio_pad_id = audio_pad_id
        self.qcoding_len = qcoding_len
        self.randgen = randgen
        self.device = torch.device(device)
    
    def iter_batches(self):
        while True:
            if self.online:
                self.wav_paths, self.texts = shuffle_preserve_order(self.wav_paths, self.texts, randgen=self.randgen)
                for batched_wavpath, batched_text_str in zip(self.wav_paths, self.texts):
                    qcodings = self.preprocess_ops.get_qcodings(
                        batched_wavpath, qcoding_len=self.qcoding_len
                    )
                    cond_tensor = self.preprocess_ops.get_cond_tensor(batched_text_str)
                    yield qcodings, cond_tensor
            else:
                self.randgen.shuffle(self.shard_filenames)
                for shard_filename in self.shard_filenames:
                    qcodings, cond_tensor = torch.load(shard_filename)
                    yield qcodings, cond_tensor
