In [1]:
#@title Mount Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title Get & Import Requirements
!git clone https://github.com/Bikecicle/LoRAW
%cd LoRAW
!pip install git+https://github.com/diontimmer/sample-diffusion-lib
!pip install pytorch_lightning
!pip install prefigure
import torch
from torch.utils import data
import pytorch_lightning as pl

from pathlib import Path
import IPython.display as ipd

from sample_diffusion.util.platform import get_torch_device_type
from sample_diffusion.diffusion_library.sampler import SamplerType
from sample_diffusion.diffusion_library.scheduler import SchedulerType

from train_uncond_lora import DiffusionUncondLora, ExceptionCallback, DemoCallback
from dataset.dataset import SampleDataset

class Object(object):
    pass


In [7]:
#@title Settings
run_name = "Rank 16" #@param {type:"string"}
project_name = 'loraw_test' #@param {type:"string"}
checkpoint_every = 100 #@param {type:"integer"}
demo_every = 25 #@param {type:"integer"}
num_demos = 1 #@param {type:"integer"}
demo_samples = 65536 #@param {type:"integer"}
demo_steps = 50 #@param {type:"integer"}
accum_batches = 1 #@param {type:"integer"}
training_dir = '/content/drive/MyDrive/AI/datasets/dub_neuro' #@param {type:"string"}
base_model_path = '/content/drive/MyDrive/AI/models/DanceDiffusion/dd/base_models/jmann-large-580k.ckpt' #@param {type:"string"}

sample_size = 65536 #@param {type:"integer"}
sample_rate = 48000 #@param {type:"integer"}
latent_dim = 0 #@param {type:"integer"}
seed = 0 #@param {type:"integer"}
batch_size = 4 #@param {type:"integer"}
max_epochs = 1000000 #@param {type:"integer"}
lora_rank = 16 #@param {type:"integer"}

ema_decay = 0.95 #@param {type:"number"}
random_crop = False #@param {type:"boolean"}
num_gpus = 1 #@param {type:"integer"}
cache_training_data = False #@{type:"boolean"}

args = Object()

args.run_name = run_name
args.project_name = project_name
args.checkpoint_every = checkpoint_every
args.demo_every = demo_every
args.num_demos = num_demos
args.demo_samples = demo_samples
args.demo_steps = demo_steps
args.accum_batches = accum_batches
args.training_dir = training_dir
args.base_model_path = base_model_path

args.sample_size = sample_size
args.sample_rate = sample_rate
args.latent_dim = latent_dim
args.seed = seed
args.batch_size = batch_size
args.max_epochs = max_epochs
args.lora_rank = lora_rank
args.ema_decay = ema_decay
args.random_crop = random_crop
args.num_gpus = num_gpus
args.cache_training_data = cache_training_data

device_type_accelerator = get_torch_device_type()
device_accelerator = torch.device(device_type_accelerator)



wandb_logger = pl.loggers.WandbLogger(project=args.project_name, name=args.run_name)
torch.manual_seed(seed=args.seed)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


<torch._C.Generator at 0x7f7b701f2af0>

In [None]:
#@title Load Model & Inject
model = DiffusionUncondLora.load_from_checkpoint(args.base_model_path, map_location=device_accelerator, global_args=args, strict=False)
model.inject_new_lora(lora_dim=args.lora_rank)
model.to(device_accelerator)


In [None]:
#@title Start Training

# Load dataset
train_set = SampleDataset([args.training_dir], args)
train_dl = data.DataLoader(
    train_set,
    args.batch_size,
    shuffle=True,
    num_workers=0
)


from typing import Any

class HijackedModelCheckpoint(pl.callbacks.ModelCheckpoint):
    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        trainer.model.lora.save_weights(filepath, device_accelerator)

        self._last_global_step_saved = trainer.global_step

        # notify loggers
        if trainer.is_global_zero:
            for logger in trainer.loggers:
                logger.after_save_checkpoint(proxy(self))

exc_callback = ExceptionCallback()
ckpt_callback = HijackedModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1, dirpath='output')
demo_callback = DemoCallback(args)

wandb_logger.watch(model)
wandb_logger.config = args

diffusion_trainer = pl.Trainer(
    devices=args.num_gpus,
    accelerator="gpu",
    num_nodes = args.num_nodes,
    strategy='ddp',
    precision=16,
    accumulate_grad_batches=args.accum_batches,
    callbacks=[ckpt_callback, demo_callback, exc_callback],
    logger=wandb_logger,
    log_every_n_steps=1,
    max_epochs=args.max_epochs,

) if args.num_gpus > 1 else pl.Trainer(
    devices=1,
    accelerator="gpu",
    precision=16,
    accumulate_grad_batches=args.accum_batches,
    callbacks=[ckpt_callback, demo_callback, exc_callback],
    logger=wandb_logger,
    log_every_n_steps=1,
    max_epochs=args.max_epochs,
)

diffusion_trainer.fit(model, train_dl)
import wandb
wandb.finish()

