In [None]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt

from discriminative_metrics import discriminative_score_metrics
from predictive_metrics import predictive_score_metrics
from context_fid import Context_FID
from cross_correlation import CrossCorrelLoss
from metric_utils import display_scores
from dtw import dtw_js_divergence_distance

### Load Real and Generated Data

In [None]:
experiment_name = "default_experiment"
real_data_path = f"../../outputs/{experiment_name}/real_samples.npy"
gen_data_path = f"../../outputs/{experiment_name}/ddpm_samples.npy"
real_data = np.load(real_data_path)
generated_data = np.load(gen_data_path)
real_data.shape, generated_data.shape

In [None]:
num_samples = min(real_data.shape[0], generated_data.shape[0])
if real_data.shape[0] > num_samples:
    print(f"WARNING: Generated data only has {generated_data.shape[0]} samples, less than real data's {real_data.shape[0]} samples. Using all {num_samples} generated samples for evaluation.")
else:
    print(f"number of samples: {num_samples}")

random_indices = np.random.choice(len(real_data), num_samples, replace=False)
real_data = real_data[random_indices]
random_indices = np.random.choice(len(generated_data), num_samples, replace=False)
generated_data = generated_data[random_indices]

In [None]:
# minmax scale the inputs for fair comparison
data_min = np.min(real_data, axis=(0,1), keepdims=True)
data_max = np.max(real_data, axis=(0,1), keepdims=True)

real_data = (real_data - data_min) / (data_max - data_min)
generated_data = (generated_data - data_min) / (data_max - data_min)

### Discriminative Score

In [None]:
iterations = 5
discriminative_score = []

for i in range(iterations):
    temp_disc, fake_acc, real_acc = discriminative_score_metrics(real_data, generated_data)
    discriminative_score.append(temp_disc)
    print(f'Iter {i}: ', temp_disc, '\n')
      
display_scores(discriminative_score)
print()

### Predictive Score

In [None]:
iterations = 5
predictive_score = []
for i in range(iterations):
    temp_pred = predictive_score_metrics(real_data, generated_data)
    predictive_score.append(temp_pred)
    print(i, ' epoch: ', temp_pred, '\n')
      
display_scores(predictive_score)
print()

### Context-FID Score

In [None]:
context_fid_score = []

for i in range(iterations):
    context_fid = Context_FID(real_data, generated_data)
    context_fid_score.append(context_fid)
    print(f'Iter {i}: ', 'context-fid =', context_fid, '\n')
      
display_scores(context_fid_score)

### Correlational Score

In [None]:
def random_choice(size, num_select=100):
    select_idx = np.random.randint(low=0, high=size, size=(num_select,))
    return select_idx

x_real = torch.from_numpy(real_data)
x_fake = torch.from_numpy(generated_data)

correlational_score = []
# size = int(x_real.shape[0] / iterations)
size = 1000

for i in range(iterations):
    real_idx = random_choice(x_real.shape[0], size)
    fake_idx = random_choice(x_fake.shape[0], size)
    corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')
    loss = corr.compute(x_fake[fake_idx, :, :])
    correlational_score.append(loss.item())
    print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\n')

display_scores(correlational_score)

### DTW distance

In [None]:
iterations = 5
js_results = []
for i in range(iterations):
    js_dist = dtw_js_divergence_distance(real_data, generated_data, n_samples=100)['js_divergence']
    print("js_dist: ", round(js_dist, 4))
    js_results.append(js_dist)
display_scores(js_results)