In [1]:
import sys

from lightning_fabric import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

sys.path.append('../src/')

from dataset.synth_datamodule import ModularSynthDataModule
from model.lit_module import LitModularSynth
from main import configure_experiment
from utils.train_utils import get_project_root

# Train a sound matching model, using DiffMoog differential synthesizer and DDSP

In [2]:
# Configure experiment

root = get_project_root()  # Verify correct project root or insert manually

exp_name = "basic_flow_example_experiment"
dataset_name = "example_basic_flow_dataset"        # Point to the data created using 'create_dataset.ipynb'
cfg_name = 'baseline.yaml'                         # Training configuration yaml file. See sample configurations for more

exp_dir = root.joinpath('experiments', 'current', exp_name)
data_dir = root.joinpath('data', dataset_name)
cfg_path = root.joinpath('configs', 'sample_configs', cfg_name)

device = 'cuda:0'
random_seed = 42

In [None]:
# Init and run training. This can also be run directly from 'src/main.py'
# Checkpoints and tensorboard logs will be located in exp_dir

cfg = configure_experiment(exp_dir, data_dir, cfg_path, True)

datamodule = ModularSynthDataModule(cfg.data_dir, cfg.model.batch_size, cfg.model.num_workers, cfg.loss.in_domain_epochs,
                                    added_noise_std=cfg.synth.added_noise_std)
datamodule.setup()

lit_module = LitModularSynth(cfg, device)

callbacks = [LearningRateMonitor(logging_interval='step'),
             ModelCheckpoint(cfg.ckpts_dir, monitor='in_domain_validation_metrics/pearson_stft/dataloader_idx_0', save_top_k=2)]

tb_logger = TensorBoardLogger(cfg.logs_dir, name=exp_name)
lit_module.tb_logger = tb_logger.experiment

if len(datamodule.train_dataset.params) < 50:
    log_every_n_steps = len(datamodule.train_dataset.params)
else:
    log_every_n_steps = 50

seed_everything(random_seed, workers=True)

trainer = Trainer(devices=1,
                  logger=tb_logger,
                  callbacks=callbacks,
                  max_epochs=cfg.model.num_epochs,
                  accelerator="gpu",
                  detect_anomaly=True,
                  log_every_n_steps=log_every_n_steps,
                  check_val_every_n_epoch=1,
                  accumulate_grad_batches=4,
                  reload_dataloaders_every_n_epochs=1)

trainer.fit(lit_module, datamodule=datamodule)

No OOD train data found. Running in-domain training only...
NSynth dataloader found 0 wav files in /home/ubuntu/almogelharar/aisynth/examples/../data/example_basic_flow_dataset/val_nsynth


Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type           | Params
----------------------------------------------------
0 | synth            | SynthModular   | 0     
1 | synth_net        | SynthNetwork   | 11.5 M
2 | signal_transform | MelSpectrogram | 0     
3 | params_loss      | ParametersLoss | 0     
----------------------------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.005   

No OOD train data found. Running in-domain training only...


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Missing amp param in Oscillator module lfo. Assuming fixed amp. Please check Synth structure if this is unexpected.
Missing amp param in Oscillator module lfo. Assuming fixed amp. Please check Synth structure if this is unexpected.
Missing amp param in Oscillator module lfo. Assuming fixed amp. Please check Synth structure if this is unexpected.




Loading in-domain dataloader


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader


  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



Loading in-domain dataloader
