In [None]:
!pip install einops pytorch_lightning diffusers==0.12.1 kornia librosa accelerate

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import pytorch_lightning as pl

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from src import *

mpl.rcParams['figure.figsize'] = (8, 8)

### Creating datasets

In [None]:
from dataset import SpectrogramDataset

train_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/ins3',
                            condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/mix',
                            return_pair=True,
                            out_channels=1
                     )


valid_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/ins3',
                          condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/mix',
                          return_pair=True,
                          out_channels=1
                     )

test_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/ins3',
                           condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/mix',
                           return_pair=True,
                           out_channels=1
                     )

img1,img2=train_ds[0]
print(img1.shape)

plt.subplot(1,2,1)
plt.imshow(img1.permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(img2.permute(1,2,0))

### Model Setup

In [None]:
autoencoder = AutoencoderKL(in_channels=1,
                            out_channels=1,
                            down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
                            up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
                            block_out_channels=(128,256,512,512),
                            layers_per_block=2,
                            sample_size=256
                            )

In [None]:
model=LatentDiffusionConditional(train_ds,
                                 autoencoder,
                                 valid_dataset=valid_ds,
                                 lr=1e-5,
                                 batch_size=2,
                                 schedule='linear')

### Training

In [None]:
trainer = pl.Trainer(
    max_steps=1,
    callbacks=[EMA(0.9999)],
)

In [None]:
trainer.fit(model)

### Load checkpoint

In [None]:
#model = torch.load('lightning_logs/version_6/checkpoints/epoch=1177-step=37725.ckpt')

### Testing autoencoder

In [None]:
plt.subplot(1,2,1)
plt.imshow(img1.permute(1,2,0))
plt.title('Input')
plt.subplot(1,2,2)
plt.imshow(model.ae(img1.unsqueeze(0))[0].detach().cpu().permute(1,2,0))
plt.title('AutoEncoder Reconstruction')

### Create sample

In [None]:
input,output=test_ds[0]
batch_input=torch.stack(4*[input],0)

#model.cuda()
out=model(batch_input, verbose=True)

In [None]:
plt.subplot(1,2+len(out),1)
plt.imshow(input.permute(1,2,0))
plt.title('Input')
plt.axis('off')
for idx in range(out.shape[0]):
    plt.subplot(1,2+len(out),idx+2)
    plt.imshow(out[idx].detach().cpu().permute(1,2,0))
    plt.axis('off')
plt.subplot(1,2+len(out),2+len(out))
plt.imshow(output.permute(1,2,0))
plt.title('Ground Truth')
plt.axis('off')

In [None]:
phase = test_ds.get_phase(0)
print(phase.shape)
print(out[0].shape)
name = test_ds.files[0]
test_ds.save_audio(out[0], phase, name = name)