In [None]:
import torch
from torch.utils import data
import pytorch_lightning as pl

from pathlib import Path
import IPython.display as ipd

from util.platform import get_torch_device_type
from diffusion_library.sampler import SamplerType
from diffusion_library.scheduler import SchedulerType

from train_uncond_lora import DiffusionUncondLora, ExceptionCallback, DemoCallback
from dataset.dataset import SampleDataset

In [None]:
debug = False

class Object(object):
    pass

args = Object()
args.name = 'loraw_dev'
args.checkpoint_every = 400
args.demo_every = 25
args.num_demos = 1
args.demo_samples = 65536
args.demo_steps = 50
args.accum_batches = 4

args.sample_size = 32768
args.sample_rate = 16000
args.latent_dim = 0
args.seed = 0
args.batch_size = 1
args.max_epochs = 10
args.lora_rank = 4

args.ema_decay = 0.95
args.random_crop = False
args.num_gpus = 1
args.cache_training_data = False

device_type_accelerator = get_torch_device_type()
device_accelerator = torch.device(device_type_accelerator)

In [None]:
wandb_logger = pl.loggers.WandbLogger(project=args.name)
torch.manual_seed(seed=args.seed)

In [None]:
def test_generate(
        model,
        batch_size=1,
        steps=50,
        scheduler=SchedulerType.V_CRASH,
        scheduler_args={
            'sigma_min': 0.1,
            'sigma_max': 50.0,
            'rho': 1.0
        },
        sampler=SamplerType.V_IPLMS,
        sampler_args={'use_tqdm': True},
        callback=None
):
    generator = torch.Generator(device_accelerator)
    generator.manual_seed(args.seed)
    
    step_list = scheduler.get_step_list(steps, device_accelerator.type, **scheduler_args)
    
    if SamplerType.is_v_sampler(sampler):
        x_T = torch.randn([batch_size, 2, args.sample_size], generator=generator, device=device_accelerator)

    return sampler.sample(
        model,
        x_T,
        step_list,
        callback,
        **sampler_args
    ).float()

def preview_batch(generated):
    for ix, gen_sample in enumerate(generated):
        print(f'sample #{ix + 1}')
        display(ipd.Audio(gen_sample.cpu(), rate=args.sample_rate))


## Set up model

In [None]:
# Download model
model_name = 'maestro_16000_65536'
model_artifact = wandb_logger.use_artifact(f'{model_name}:v0')
checkpoint_path = Path(model_artifact.download()) / f'{model_name}.ckpt'

In [None]:
# checkpoint_path = 'models/maestro_16000_65536.ckpt'

model = DiffusionUncondLora.load_from_checkpoint(checkpoint_path, map_location=device_accelerator, global_args=args)

In [None]:
print(model)

In [None]:
# Generate baseline samples
if debug:
    batch_baseline = test_generate(model.diffusion, batch_size=2)
    preview_batch(batch_baseline)

## Inject LoRA

In [None]:
model.inject_new_lora(lora_dim=args.lora_rank)

In [None]:
# Generate with lora (should theoretically sound the same with blank lora)
if debug:
    batch_empty = test_generate(model.diffusion, batch_size=2)
    preview_batch(batch_empty)

## Train

In [None]:
# Load dataset
training_dir = 'input/ivq_16000_65536'
train_set = SampleDataset([training_dir], args)
train_dl = data.DataLoader(
    train_set,
    args.batch_size,
    shuffle=True,
    num_workers=0
)

In [None]:
exc_callback = ExceptionCallback()
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1, dirpath='output')
demo_callback = DemoCallback(args)

wandb_logger.watch(model)
wandb_logger.config = args

In [None]:
diffusion_trainer = pl.Trainer(
    devices=args.num_gpus,
    accelerator="gpu",
    num_nodes = args.num_nodes,
    strategy='ddp',
    precision=16,
    accumulate_grad_batches=args.accum_batches, 
    callbacks=[ckpt_callback, demo_callback, exc_callback],
    logger=wandb_logger,
    log_every_n_steps=1,
    max_epochs=args.max_epochs,

) if args.num_gpus > 1 else pl.Trainer(
    devices=1,
    accelerator="gpu",
    precision=16,
    accumulate_grad_batches=args.accum_batches,
    callbacks=[ckpt_callback, demo_callback, exc_callback],
    logger=wandb_logger,
    log_every_n_steps=1,
    max_epochs=args.max_epochs,
)

diffusion_trainer.fit(model, train_dl)