In [None]:
%%capture
%pip install einops pytorch_lightning diffusers==0.12.1 kornia librosa accelerate torchvision pandas==1.5.2 accelerate

In [None]:
import torch
import pytorch_lightning as pl

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from torchvision.utils import save_image

from src import *

mpl.rcParams['figure.figsize'] = (8, 8)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Datasets

In [None]:
train_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/ins3',
                            condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/mix',
                            return_pair=True,
                            out_channels=1,
                            return_mask=True
                     )


valid_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/ins3',
                          condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/mix',
                          return_pair=True,
                          out_channels=1,
                          return_mask=True
                     )

test_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/ins3',
                           condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/mix',
                           return_pair=True,
                           out_channels=1,
                           return_mask=True
                     )

img1,img2=train_ds[0]
print(img1.shape)

plt.subplot(1,2,1)
plt.imshow(img1.permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(img2.permute(1,2,0))

### Model setup

In [None]:
model = plUnet(train_dataset=train_ds,
               valid_dataset=valid_ds,
               batch_size=8)

### Training

In [None]:
trainer = pl.Trainer(
    default_root_dir="trained_models/unet/",
    max_epochs=1000,
    callbacks=[EMA(0.9999)],
    accelerator='auto'
)

In [None]:
trainer.fit(model)

### Sampling

In [None]:
song_num = 0
condition, target=test_ds[song_num]
batch_input=torch.stack(1*[condition],0).to(device)

model.to(device)
out=model(batch_input)

if test_ds.return_mask:
        out = out * condition

In [None]:
plt.subplot(1,2+len(out),1)
plt.imshow(condition.permute(1,2,0))
plt.title('Input')
plt.axis('off')
for idx in range(out.shape[0]):
    plt.subplot(1,2+len(out),idx+2)
    plt.imshow(out[idx].detach().cpu().permute(1,2,0))
    plt.axis('off')
plt.subplot(1,2+len(out),2+len(out))
plt.imshow(target.permute(1,2,0))
plt.title('Ground Truth')
plt.axis('off')

In [None]:
phase = test_ds.get_phase(song_num)

name = test_ds.files[song_num]
test_ds.save_audio(out[0].detach().cpu(), phase, name = 'produced_' +  name)
test_ds.save_audio(target, phase, name = 'target_' + name)
test_ds.save_audio(condition, phase, name = 'condition_' + name)
save_image(out[0], 'results/produced_' + name.replace('wav','png'))
save_image(condition, 'results/condition_' + name.replace('wav','png'))
save_image(target, 'results/target_' + name.replace('wav','png'))