In [1]:
import torch
from torch.cuda.amp import autocast
from pathlib import Path
import IPython.display as ipd

from util.util import load_audio, crop_audio
from util.platform import get_torch_device_type
from dance_diffusion.api import RequestHandler, Request, RequestType, ModelType
from diffusion_library.sampler import SamplerType
from diffusion_library.scheduler import SchedulerType

import wandb

import pytorch_lightning as pl
from audio_diffusion.models import DiffusionAttnUnet1D
from train_uncond import DiffusionUncond
from audio_diffusion.audio_lora import AudioLoRAModule, AudioLoRANetwork

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgcpage[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
debug = True
seed = 0
sample_rate = 16000
chunk_size = 65536

device_type_accelerator = get_torch_device_type()
device_accelerator = torch.device(device_type_accelerator)

In [4]:
def log_audio_tensor(run, name, audio_tensor, sample_rate):
    table = wandb.Table(columns=['Index', 'Audio'])
    for index, sample in enumerate(audio_tensor.cpu().numpy()):
        audio = wandb.Audio(sample[0], sample_rate=sample_rate)  # Has to be mono
        table.add_data(index, audio)
    run.log({name: table})

In [18]:
def test_sample(
        model,
        batch_size=1,
        steps=50,
        scheduler=SchedulerType.V_CRASH,
        scheduler_args={
            'sigma_min': 0.1,
            'sigma_max': 50.0,
            'rho': 1.0
        },
        sampler=SamplerType.V_IPLMS,
        sampler_args={'use_tqdm': True},
        callback=None
):
    generator = torch.Generator(device_accelerator)
    generator.manual_seed(seed)
    
    step_list = scheduler.get_step_list(steps, device_accelerator.type, **scheduler_args)#step_list = step_list[:-1] if sampler in [SamplerType.V_PRK, SamplerType.V_PLMS, SamplerType.V_PIE, SamplerType.V_PLMS2, SamplerType.V_IPLMS] else step_list
    
    if SamplerType.is_v_sampler(sampler):
        x_T = torch.randn([batch_size, 2, chunk_size], generator=generator, device=device_accelerator)

    return sampler.sample(
        model,
        x_T,
        step_list,
        callback,
        **sampler_args
    ).float()

def preview_batch(generated):
    for ix, gen_sample in enumerate(generated):
        print(f'sample #{ix + 1}')
        display(ipd.Audio(gen_sample.cpu(), rate=sample_rate))


## Set up model

In [6]:
run = wandb.init(project='loraw_dev')

In [7]:
# Download model
model_name = 'maestro_16000_65536'
model_artifact = run.use_artifact(f'{model_name}:v0', type='model')
checkpoint_path = Path(model_artifact.download()) / f'{model_name}.ckpt'

[34m[1mwandb[0m: Downloading large artifact maestro_16000_65536:v0, 845.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.8


In [23]:
class Object(object):
    pass

args = Object()
args.sample_size = 65536
args.sample_rate = 16000
args.latent_dim = 0
args.seed = 0
args.ema_decay = 0.95

# checkpoint_path = 'models/maestro_16000_65536.ckpt'

model = DiffusionUncond.load_from_checkpoint(checkpoint_path, map_location=device_accelerator, global_args=args)

Lightning automatically upgraded your loaded checkpoint from v1.6.4 to v2.0.4. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file c:\Users\Griffin\Documents\Github\LoRAW\artifacts\maestro_16000_65536-v0\maestro_16000_65536.ckpt`


In [24]:
batch_baseline = test_sample(model.diffusion)
preview_batch(batch_baseline)

sample #1


In [None]:
run.finish()

In [25]:
lora_model = AudioLoRANetwork(model.diffusion)
lora_model.to(device_accelerator)
lora_model.apply_to()


create LoRA for U-Net1D: 72 modules.


In [26]:
batch_empty = test_sample(model.diffusion)
preview_batch(batch_empty)

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor) should be the same