# Train diffusion model

In [None]:
!pip install -q matplotlib
!pip install -q datasets
!pip uninstall -q pytorch_lightning -y
!pip install -q pytorch_lightning==1.7.0

!pip uninstall -q torchmetrics -y
!pip install -q torchmetrics==0.7.0
!pip install -q torchinfo
!pip install -q transformers

from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/path/to/root/

import torch
import json

device = "cuda" if torch.cuda.is_available() else "cpu"

## Load data

In [None]:
from data_generation.nsynth import get_nsynth_dataloader

BATCH_SIZE = 8
training_dataset_path = f'data/NSynth/nsynth-STFT-train-52.hdf5'  # Make sure to use your actual path

training_dataloader = get_nsynth_dataloader(training_dataset_path, batch_size=BATCH_SIZE, shuffle=True,
                                            get_latent_representation=True, with_meta_data=True, with_timbre_emb=False, task="STFT")

## Load models

In [None]:
from model.multimodal_model import get_multi_modal_model
from model.timbre_encoder_pretrain import get_timbre_encoder
from model.VQGAN import get_VQGAN
from transformers import AutoTokenizer, ClapModel

VAE_model_name = "VQ-GAN_name"
modelConfig = {"in_channels": 3, "hidden_channels": [80, 160], "embedding_dim": 4, "out_channels": 3, "block_depth": 2,
               "attn_pos":  [80, 160], "attn_with_skip": True,
            "num_embeddings": 8192, "commitment_cost": 0.25, "decay": 0.99,
            "norm_type": "groupnorm", "act_type": "swish", "num_groups": 16}
VAE = get_VQGAN(modelConfig, load_pretrain=True, model_name=VAE_model_name, device=device)

CLAP = ClapModel.from_pretrained("laion/clap-htsat-unfused")  # 153,492,890
CLAP_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")

timbre_encoder_name = "timbre_encoder_name"
timbre_encoder_Config = {"input_dim": 512, "feature_dim": 512, "hidden_dim": 1024, "num_instrument_classes": 1006, "num_instrument_family_classes": 11, "num_velocity_classes": 128, "num_qualities": 10, "num_layers": 3}
timbre_encoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, device="cpu")
mmm_name = "multimodal_model_name"
MMM_config = {"text_feature_dim": 512, "spectrogram_feature_dim": 1024, "multi_modal_emb_dim": 512, "num_projection_layers": 2,
              "temperature": 1.0, "dropout": 0.1, "freeze_text_encoder": False, "freeze_spectrogram_encoder": False}
dataset_text_encoder = get_multi_modal_model(timbre_encoder, CLAP, MMM_config, load_pretrain=True, model_name=mmm_name, device="cpu")
timbre_encoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, device=device)

with open(f"data/NSynth/GPT/encodes2embeddings_mapping_TE_STFT.json", "r") as f:
    encodes2embeddings_mapping = json.load(f)

for key in encodes2embeddings_mapping.keys():
    encodes2embeddings_mapping[key] = torch.tensor(encodes2embeddings_mapping[key]).to(device)

## Train

In [None]:
from model.diffusion import train_diffusion_model

unconditional_condition = dataset_text_encoder.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to(device)

# Specify model name
model_name = "your_model_name"
init_model_name = "history/init_model_name"
# init_model_name = model_name    # uncomment this line for training from scratch
unetConfig = {"in_dim": 4, "down_dims": [96, 96, 192, 384], "up_dims": [384, 384, 192, 96], "attn_type": "linear_add", "condition_type": "natural_language_prompt", "label_emb_dim": 512}

save_steps = 100000
model, optimizer = train_diffusion_model(VAE, dataset_text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps=1000, lr=1e-4, uncondition_rate=0.1,
                                             max_iter=40000, iterator=training_dataloader, encodes2embeddings_mapping=encodes2embeddings_mapping,
                                            load_pretrain=False, save_steps=save_steps, unconditional_condition=unconditional_condition, init_loss=0.5, save_model_name="28_1_2024_TE_STFT", n_IS_batches=200)