In [None]:
%%capture
%pip install einops pytorch_lightning diffusers==0.12.1 kornia librosa accelerate ipympl nussl pandas==1.5.2 accelerate

In [None]:
import torch
import matplotlib.pyplot as plt
import nussl
import nussl.evaluation as ne
import json
import glob

from src import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Dataset

In [None]:
train_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/ins3',
                            condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/train/mix',
                            return_pair=True
                     )


valid_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/ins3',
                          condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/val/mix',
                          return_pair=True
                     )

test_ds=SpectrogramDataset(target_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/ins3',
                           condition_dir='datasets/randomMIDI/PianoViolin11025/WAV/test/mix',
                           return_pair=True
                     )

condition, target = test_ds[0]

plt.subplot(1,2,1)
plt.imshow(condition.permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(target.permute(1,2,0))

### SSDM

In [None]:
train_ds.out_channels, valid_ds.out_channels, test_ds.out_channels = (1,1,1)
train_ds.return_mask, valid_ds.return_mask, test_ds.return_mask = (False, False, False)

model_path = 'trained_models/ssdm/lightning_logs/version_0/'
model = PixelDiffusionConditional.load_from_checkpoint(model_path+'checkpoints/epoch=1999-step=126620.ckpt', train_dataset = test_ds).to(device)

### SLDM

In [None]:
train_ds.out_channels, valid_ds.out_channels, test_ds.out_channels = (3,3,3)
train_ds.return_mask, valid_ds.return_mask, test_ds.return_mask = (False, False, False)

autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
pl_ae_model = Autoencoder(autoencoder).to(device)
model_path = 'trained_models/sldm/lightning_logs/version_0/'
model = LatentDiffusionConditional.load_from_checkpoint(model_path + 'checkpoints/epoch=1999-step=126000.ckpt', train_dataset = test_ds, autoencoder = pl_ae_model).to(device)

### MLDM

In [None]:
train_ds.out_channels, valid_ds.out_channels, test_ds.out_channels = (3,3,3)
train_ds.return_mask, valid_ds.return_mask, test_ds.return_mask = (True, True, True)

autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
pl_ae_model = Autoencoder(autoencoder).to(device)
model_path = 'trained_models/mldm/lightning_logs/version_0/'
model = LatentDiffusionConditional.load_from_checkpoint(model_path + 'checkpoints/epoch=1999-step=126000.ckpt', train_dataset = test_ds, autoencoder = pl_ae_model).to(device)

### Simple UNet

In [None]:
train_ds.out_channels, valid_ds.out_channels, test_ds.out_channels = (1,1,1)
train_ds.return_mask, valid_ds.return_mask, test_ds.return_mask = (True, True, True)

model_path = 'trained_models/unet/lightning_logs/version_0/'
model = plUnet.load_from_checkpoint(model_path+'checkpoints/epoch=1999-step=126000.ckpt', train_dataset = test_ds).to(device)

### Evaluation

In [None]:
ds = test_ds
scores_folder = 'scores/test/'
true_source_folder = 'datasets/randomMIDI/PianoViolin11025/WAV/test/'

for i in range(len(ds)):
    condition, target = ds[i]
    out = model(condition.to(device).unsqueeze(0), verbose=True).detach().cpu()[0]
    phase = ds.get_phase(i)
    name = ds.files[i]

    if test_ds.return_mask:
        out = target * condition
    
    estimated_source = test_ds.to_audio(out, phase)
    estimated_source = nussl.AudioSignal(audio_data_array=estimated_source, sample_rate=11025)

    target_source = nussl.AudioSignal(true_source_folder + 'ins3/' + name, sample_rate = 11025).truncate_seconds(estimated_source.signal_duration)
    mix = nussl.AudioSignal(true_source_folder + 'mix/' + name, sample_rate = 11025).truncate_seconds(estimated_source.signal_duration)

    target_rest = mix - target_source
    estimated_rest = mix - estimated_source

    estimates = [estimated_source, estimated_rest]
    targets = [target_source, target_rest]

    evaluator = ne.BSSEvalScale(targets, estimates, ['ins3', 'rest'])
    scores = evaluator.evaluate()

    os.makedirs(model_path + scores_folder, exist_ok=True)
    output_file = model_path + scores_folder + name.replace('wav', 'json')
    with open(output_file, 'w') as f:
        json.dump(scores, f, indent=4)

In [None]:
folder = 'scores/test/'
json_files = glob.glob(str(model_path) + folder + '*.json')

df = ne.aggregate_score_files(json_files, aggregator=np.nanmedian)
report_card = ne.report_card(df, report_each_source=True)
print(report_card)
with open(model_path + 'report_card.json', 'w') as f:
    json.dump(report_card, f, indent=4)