In [1]:
import unittest
import ml_collections
import copy
import numpy as np
import torch
import wandb
import yaml
import os
import sys
from src.baselines.models import get_trainer
from unit_tests.test_utils import *
from src.utils import loader_to_tensor, set_seed
from src.datasets.dataloader import get_dataset
from src.baselines.networks.TimeVAE import VariationalAutoencoderConvInterpretable
from src.baselines.TimeVAE import TimeVAETrainer
from src.evaluations.loss import get_standard_test_metrics
from src.evaluations.evaluations import full_evaluation
from src.evaluations.summary import full_evaluation_latest
from unit_tests.test_utils import test_init
from src.utils import combine_dls
from src.utils import to_numpy
from src.evaluations.loss import CrossCorrelLoss
from src.evaluations.summary import EvaluationComponent, EvaluationSummary


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = test_init()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [3]:
config['Evaluation']

TestMetrics:
  acf_loss:
    stationary: true
  cov_loss: None
  cross_corr: None
  discriminative_score:
    dscore_batch_size: 128
    dscore_epochs: 10
    dscore_hidden_size: 32
    dscore_num_layers: 1
  hist_loss:
    n_bins: 50
  permutation_test:
    n_permutation: 5
  predictive_score:
    pscore_batch_size: 128
    pscore_epochs: 10
    pscore_hidden_size: 32
    pscore_num_layers: 2
  sig_mmd:
    depth: 5
  sigw1_loss:
    depth: 2
batch_size: 256
metrics_enabled:
- cross_corr
n_eval: 5
sample_size: 1000
test_ratio: 0.2

## Full evaluation latest

#### Step 1: Load data and model

In [46]:
fn = lambda filename: pt.join(config.data_dir, filename)

# load pre-trained model
vae = torch.load(fn('vae_model_state_dict.pt'))
vae.encoder.load_state_dict(torch.load(fn('vae_encoder_state_dict.pt')), strict=True)
vae.decoder.load_state_dict(torch.load(fn('vae_decoder_state_dict.pt')), strict=True)
vae.eval()

### data
real_train_dl = torch.load(pt.join(config.data_dir, 'X_train.pt'))
real_test_dl = torch.load(pt.join(config.data_dir, 'X_test.pt'))

#### Step 2: call full_evaluation_latest -> EvaluationSummary result class

In [47]:
eval_summary = full_evaluation_latest(vae, real_train_dl, real_test_dl, config, algo='TimeVAE')       
print(eval_summary)

---- evaluation metric = cross_corr in group = stylized_fact_scores ----


100%|████████████████████████████████████████████| 5/5 [00:00<00:00, 15.23it/s]

 No metrics enabled in group = implicit_scores
 No metrics enabled in group = sig_scores
 No metrics enabled in group = permutation_test
EvaluationSummary(cross_corr_mean=0.22479782, cross_corr_std=0.03149462, hist_loss_mean=nan, hist_loss_std=nan, cov_loss_mean=nan, cov_loss_std=nan, acf_loss_mean=nan, acf_loss_std=nan, sigw1_mean=nan, sigw1_std=nan, sig_mmd_mean=nan, sig_mmd_std=nan, discriminative_score_mean=nan, discriminative_score_std=nan, predictive_score_mean=nan, predictive_score_std=nan, permutation_test_power=nan, permutation_test_type1_error=nan)





#### Zoom in `full_evaluation_latest`

In [11]:
# step 1: 
ec = EvaluationComponent(config, vae, real_train_dl, real_test_dl, **{'algo':'TimeVAE'})
summary_dict = ec.eval_summary()
# step 2: 
eval_summary = EvaluationSummary()
eval_summary.set_values(summary_dict)
eval_summary

---- evaluation metric = cross_corr in group = stylized_fact_scores ----


100%|████████████████████████████████████████████| 5/5 [00:00<00:00, 28.78it/s]

 No metrics enabled in group = implicit_scores
 No metrics enabled in group = sig_scores
 No metrics enabled in group = permutation_test





EvaluationSummary(cross_corr_mean=0.24386013, cross_corr_std=0.020164223, hist_loss_mean=nan, hist_loss_std=nan, cov_loss_mean=nan, cov_loss_std=nan, acf_loss_mean=nan, acf_loss_std=nan, sigw1_mean=nan, sigw1_std=nan, sig_mmd_mean=nan, sig_mmd_std=nan, discriminative_score_mean=nan, discriminative_score_std=nan, predictive_score_mean=nan, predictive_score_std=nan, permutation_test_power=nan, permutation_test_type1_error=nan)

In [49]:
# print result
for k in eval_summary.get_attrs():
    print(k,getattr(eval_summary,k))

cross_corr_mean 0.22479782
cross_corr_std 0.03149462
hist_loss_mean nan
hist_loss_std nan
cov_loss_mean nan
cov_loss_std nan
acf_loss_mean nan
acf_loss_std nan
sigw1_mean nan
sigw1_std nan
sig_mmd_mean nan
sig_mmd_std nan
discriminative_score_mean nan
discriminative_score_std nan
predictive_score_mean nan
predictive_score_std nan
permutation_test_power nan
permutation_test_type1_error nan


## Compute Metrics Details in `EvaluationComponent`

In [51]:
# Crreate evaluation component from (config, model, train_data, test_data)
eval_comp = EvaluationComponent(config, vae, real_train_dl, real_test_dl, **{'algo':'TimeVAE'})

# data
data_map = eval_comp.data_set[0]
print('data includes ',data_map.keys())
real = combine_dls([data_map['real_train_dl'],data_map['real_test_dl']])
fake = combine_dls([data_map['fake_train_dl'],data_map['fake_test_dl']]) 

#### Metrics

## method 1
metric = 'cross_corr'
eval_func = getattr(eval_comp, metric)
score = eval_func(real,fake)
print(f'{metric} = {score}')


## method 2
scores2 = eval_comp.cross_corr(real,fake)
print(f'{metric} = {scores2}')

## method 2
scores3 = to_numpy(CrossCorrelLoss(real, name='cross_corr')(fake))
print(f'{metric} = {scores3}')


data includes  dict_keys(['real_train_dl', 'real_test_dl', 'fake_train_dl', 'fake_test_dl'])
cross_corr = 0.25324398279190063
cross_corr = 0.25324398279190063
cross_corr = 0.25324398279190063


In [53]:
'''
class CrossCorrelLoss(Loss):
    def __init__(self, x_real, max_lag=64, **kwargs):
        super(CrossCorrelLoss, self).__init__(norm_foo=cc_diff, **kwargs)
        self.lags = max_lag
        self.metric = CrossCorrelationMetric(self.transform)
        self.cross_correl_real = self.metric.measure(x_real,self.lags).mean(0)[0]
        self.max_lag = max_lag

    def compute(self, x_fake):
        cross_correl_fake = self.metric.measure(x_fake,lags=self.lags).mean(0)[0]
        loss = self.norm_foo(
            cross_correl_fake - self.cross_correl_real.to(x_fake.device)).unsqueeze(0)
        return loss
'''
# from src.evaluations.metrics import CrossCorrelationMetric
# m = CrossCorrelationMetric()
# cc_real = m.measure(real,lags=64).mean(0)[0]
# cc_fake = m.measure(fake,lags=64).mean(0)[0]
# cc = torch.abs((cc_fake - cc_real.to(cc_fake.device))).sum(0).unsqueeze(0)
# cc