In [None]:
# To mount
from google.colab import drive
drive.mount('/content/drive')

# # To initialize the work environment
# %cd /content/drive/My Drive/
# !git clone https://github.com/allnightlight/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance/casestudies -b casestudies

# # To update the work environment
# %cd /content/drive/My Drive/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance
# !git pull

%cd /content/drive/My Drive/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance/casestudies

In [None]:
import os
import sys
sys.path.append("../framework/")
sys.path.append("../sl/")
sys.path.append("../wae/")

import torch
import numpy as np
import matplotlib 
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D  
import pandas as pd
import itertools

from conc_environment_factory import ConcEnvironmentFactory
from conc_build_parameter import ConcBuildParameter
from conc_build_parameter_factory import ConcBuildParameterFactory
from conc_agent_factory import ConcAgentFactory
from wae_trainer_factory import WaeTrainerFactory

from wae_batch_data_agent import WaeBatchDataAgent
from wae_batch_data_environment import WaeBatchDataEnvironment

from builder import Builder
from store import Store
from mylogger import MyLogger

from loader import Loader

In [None]:
dbPath = "training_log.sqlite"

In [None]:
target_casestudy = "cs03a"

In [None]:
def evaluate_err(buildParameter, agent):
    environment = environmentFactory.create(buildParameter)
    trainer = trainerFactory.create(buildParameter, agent, environment)

    dataBatchEnv = environment.getTestData()
    dataBatchAg = agent(dataBatchEnv)

    _err_observable = torch.mean(torch.abs(dataBatchAg._XHat - dataBatchEnv._X))
    _err_latent, _ = trainer.measure_distance(dataBatchAg._Xi, dataBatchAg._XiHat)

    err_observable = _err_observable.data.numpy()
    err_latent = _err_latent.data.numpy()    
    return err_observable, err_latent

## S400: Load trained agents to analyze them

### SS410: initialize a loader of trained agents

In [None]:
agentFactory = ConcAgentFactory()
environmentFactory = ConcEnvironmentFactory()
trainerFactory =  WaeTrainerFactory()

store = Store(dbPath)

buildParameterFactory = ConcBuildParameterFactory()
loader = Loader(agentFactory=agentFactory
                , environmentFactory=environmentFactory
                , buildParameterFactory=buildParameterFactory
                , store = store)

### SS420: evaluate trained agents:

evaluation error is here:
* the representitive errors of observable variables
* and the discrepancy between latent referenced distribution and the one projected by trained encoder

In [None]:
tbl = {
    "criteria": []
    , "score": []
    , "epoch": []
      }
for agent, buildParameter, epoch in loader.load(target_casestudy + "%"):
    representative_error, latent_distribution_discrepancy = evaluate_err(buildParameter, agent)
    
    for score, criteria in [
        (representative_error, "Representative Error")
        , (latent_distribution_discrepancy, "Latent Distributions Discrepancy")]:

        for key in buildParameter.__dict__:
            if not key in tbl:
                tbl[key] = []
            tbl[key].append(buildParameter.__dict__[key])
    
        tbl["epoch"].append(epoch)
        tbl["criteria"].append(criteria)
        tbl["score"].append(score)

tbl = pd.DataFrame(tbl)
tbl.to_csv(target_casestudy +  "_score.csv")