In [None]:
%%capture
%pip install einops pytorch_lightning diffusers==0.12.1 kornia librosa accelerate ipympl

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import pytorch_lightning as pl

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from src import *

mpl.rcParams['figure.figsize'] = (8, 8)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Creating datasets

In [None]:
from dataset import SpectrogramDataset

train_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/ins3',
                            condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/mix',
                            return_pair=True,
                            out_channels=3
                     )


valid_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/ins3',
                          condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/mix',
                          return_pair=True,
                          out_channels=3
                     )

test_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/ins3',
                           condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/mix',
                           return_pair=True,
                           out_channels=3
                     )

img1,img2=train_ds[0]

plt.subplot(1,2,1)
plt.imshow(img1.permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(img2.permute(1,2,0))

### Testing autoencoder

In [None]:
autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
pl_ae_model = Autoencoder(autoencoder).to(device)

In [None]:
img, cond = test_ds[0]
phase = test_ds.get_phase(0)

plt.subplot(1,3,1)
plt.imshow(cond.permute(1,2,0))
plt.title('Input')
plt.subplot(1,3,2)
prod_img = pl_ae_model(cond.unsqueeze(0).to(device))
plt.imshow(prod_img[0].detach().cpu().permute(1,2,0))
plt.title('AutoEncoder Reconstruction')
plt.subplot(1,3,3)
prod_img2 = pl_ae_model.decode(pl_ae_model.encode(cond.unsqueeze(0).to(device)))
plt.imshow(prod_img2[0].detach().cpu().permute(1,2,0))
plt.title('AutoEncoder Reconstruction encode/decode')
print(pl_ae_model.encode(cond.unsqueeze(0).to(device)).shape)

print(prod_img.shape, prod_img2.shape)


### Model setup

In [None]:
model=LatentDiffusionConditional(test_ds,
                                 autoencoder=pl_ae_model,
                                 valid_dataset=valid_ds,
                                 lr=1e-4,
                                 batch_size=8,
                                 schedule='linear',
                                 warm_up_steps=10000,
                                 num_timesteps=1000,
                                 loss_fn=torch.nn.functional.l1_loss)

### Training

In [None]:
trainer = pl.Trainer(
    default_root_dir="trained_models/latent_diffusion/test",
    max_epochs=3000,
    callbacks=[EMA(0.9999)],
    accelerator='auto'
)

In [None]:
trainer.fit(model)

### Load checkpoint

In [None]:
trainer.fit(model, ckpt_path='trained_models/latent_diffusion/lightning_logs/version_3/checkpoints/epoch=613-step=12000.ckpt')

In [None]:
model = LatentDiffusionConditional.load_from_checkpoint('trained_models/latent_diffusion/lightning_logs/version_0/checkpoints/epoch=477-step=7648.ckpt',
                                                        train_dataset=train_ds,
                                                        autoencoder=pl_ae_model,
                                                        valid_dataset=valid_ds,
                                                        lr=1e-5,
                                                        batch_size=16,
                                                        schedule='linear')

print(model.lr)

### Create sample

In [None]:
input,output=test_ds[0]
batch_input=torch.stack(4*[input],0)

#model.cuda()
out=model(batch_input, verbose=True)

In [None]:
plt.subplot(1,2+len(out),1)
plt.imshow(input.permute(1,2,0))
plt.title('Input')
plt.axis('off')
for idx in range(out.shape[0]):
    plt.subplot(1,2+len(out),idx+2)
    plt.imshow(out[idx].detach().cpu().permute(1,2,0))
    plt.axis('off')
plt.subplot(1,2+len(out),2+len(out))
plt.imshow(output.permute(1,2,0))
plt.title('Ground Truth')
plt.axis('off')

In [None]:
phase = test_ds.get_phase(0)
print(phase.shape)
print(out[0].shape)
name = test_ds.files[0]
test_ds.save_audio(out[0], phase, name = name)