In [35]:
import torch
from pathlib import Path
from util.util import load_audio, crop_audio
from util.platform import get_torch_device_type
from dance_diffusion.api import RequestHandler, Request, RequestType, ModelType
from diffusion_library.sampler import SamplerType
from diffusion_library.scheduler import SchedulerType

import wandb

In [44]:
def log_audio_tensor(run, name, audio_tensor, sample_rate):
    table = wandb.Table(columns=['Index', 'Audio'])
    for index, sample in enumerate(audio_tensor.cpu().numpy()):
        audio = wandb.Audio(sample[0], sample_rate=sample_rate)  # Has to be mono
        table.add_data(index, audio)
    run.log({name: table})

## Set up model

In [37]:
run = wandb.init(project='loraw_dev')

In [38]:
debug = True
seed = 1234
sample_rate = 16000
chunk_size = 65536

In [39]:
# Download model
model_name = 'maestro_16000_65536'
model_artifact = run.use_artifact(f'{model_name}:v0', type='model')
checkpoint = Path(model_artifact.download()) / f'{model_name}.ckpt'

[34m[1mwandb[0m: Downloading large artifact maestro_16000_65536:v0, 845.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:6.8


In [40]:
# Init request handler
device_type_accelerator = get_torch_device_type()
device_accelerator = torch.device(device_type_accelerator)

request_handler = RequestHandler(device_accelerator, optimize_memory_use=False, use_autocast=True)

# Test generation
if debug:
    request = Request(
        request_type=RequestType.Generation,
        model_type=ModelType.DD,
        model_path=checkpoint,
        model_chunk_size=chunk_size,
        model_sample_rate=sample_rate,

        sampler_type=SamplerType.V_IPLMS,
        sampler_args={'use_tqdm': True},

        scheduler_type=SchedulerType.V_CRASH,
        scheduler_args={
            'sigma_min': 0.1,
            'sigma_max': 50.0,
            'rho': 1.0
        },
        
        seed=seed,
        batch_size=8,
        steps=50
    )
    output = request_handler.process_request(request).result
    log_audio_tensor(run, 'model_test', output, sample_rate)

Model file artifacts\maestro_16000_65536-v0\maestro_16000_65536.ckpt is invalid. Please run the conversion script.
 - Default model config will be used, which may be inaccurate.


In [46]:
run.finish()

In [None]:
from lora_diffusion import monkeypatch_replace_lora, monkeypatch_lora, tune_lora_scale
