In [None]:
# To mount
from google.colab import drive
drive.mount('/content/drive')

# # To initialize the work environment
# %cd /content/drive/My Drive/
# !git clone https://github.com/allnightlight/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance/casestudies -b casestudies

# # To update the work environment
# %cd /content/drive/My Drive/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance
# !git pull

%cd /content/drive/My Drive/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance/casestudies

In [None]:
import os
import sys
sys.path.append("../framework/")
sys.path.append("../sl/")
sys.path.append("../wae/")

import torch
import numpy as np
import matplotlib 
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D  
import pandas as pd
import itertools

from conc_environment_factory import ConcEnvironmentFactory
from conc_build_parameter import ConcBuildParameter
from conc_build_parameter_factory import ConcBuildParameterFactory
from conc_agent_factory import ConcAgentFactory
from wae_trainer_factory import WaeTrainerFactory

from wae_batch_data_agent import WaeBatchDataAgent
from wae_batch_data_environment import WaeBatchDataEnvironment

from builder import Builder
from store import Store
from mylogger import MyLogger

from loader import Loader

In [None]:
dbPath = "training_log.sqlite"

In [None]:
target_casestudy = "cs03b"

In [None]:
def evaluate_err(buildParameter, agent):
    environment = environmentFactory.create(buildParameter)
    trainer = trainerFactory.create(buildParameter, agent, environment)

    dataBatchEnv = environment.getTestData()
    dataBatchAg = agent(dataBatchEnv)

    _err_observable = torch.mean(torch.abs(dataBatchAg._XHat - dataBatchEnv._X))
    _err_latent, _ = trainer.measure_distance(dataBatchAg._Xi, dataBatchAg._XiHat)

    err_observable = _err_observable.data.numpy()
    err_latent = _err_latent.data.numpy()    
    return err_observable, err_latent

In [None]:
def approximate_observable_image(th):
    
    # theta: (...)
    r = 2 + np.cos(3*th) # (...)
    x = r * np.cos(2*th) # (...)
    y = r * np.sin(2*th) # (...)
    z = np.sin(3*th) # (...)

    X = np.stack((x,y,z), axis=-1) # (..., 3)
    
    return X  # (..., 3)

def approximate_latent_image(th):
    xi1 = np.cos(th)
    xi2 = np.sin(th)
    xi3 = np.zeros(th.shape)
    
    Xi = np.stack((xi1, xi2, xi3), axis=-1) # (..., nXi = 3)
    
    return Xi # (..., 3)

In [None]:
def plot_encoder_projection_image(agent, environment, azim):
    
    nTh = 2**10
    Th = np.linspace(0, 1, nTh) * np.pi * 2
    XYZ = approximate_observable_image(Th) # (nTh, 3)
    Xi = approximate_latent_image(Th) # (nTh, 3)

    _XYZ = torch.from_numpy(XYZ.astype(np.float32)) # (nTh, 3)
    _Z = torch.zeros(nTh, environment.nZ)

    dataBatchEnv = WaeBatchDataEnvironment(_XYZ, _Z)
    dataBatchAgent = agent(dataBatchEnv)
    XiHat = dataBatchAgent._XiHat.data.numpy() # (nTheta, 3)

    fig = plt.gcf()
    #
    ax = fig.add_subplot(1,2,1, projection = "3d")
    ax.plot(XYZ[...,0], XYZ[...,1], XYZ[...,2], color = "blue", linewidth = 1.0)
    ax.set_title('Input image \non the observable variables space')
    ax.axis('off')
    ax.view_init(30, azim)
    #
    ax.set_xlim(-3.10,3.10)
    ax.set_ylim(-3.10,3.10)
    ax.set_zlim(-1.10,1.10)
    #
    ax = fig.add_subplot(1,2,2, projection = "3d")
    ax.plot(Xi[...,0], Xi[...,1], Xi[...,2], color='lightgrey', linewidth=.5)
    ax.plot(XiHat[...,0], XiHat[...,1], XiHat[...,2], color="blue", linewidth = 1.0)
    ax.set_title('Ouput image \non the latent variables space')
    ax.axis('off')
    ax.view_init(30, azim)
    #
    ax.set_xlim(-1.10,1.10)
    ax.set_ylim(-1.10,1.10)
    ax.set_zlim(-1.10,1.10)
    fig.tight_layout()

In [None]:
def plot_decoder_projection_image(agent, environment, azim):
    
    nTh = 2**10
    Th = np.linspace(0, 1, nTh) * np.pi * 2
    XYZ = approximate_observable_image(Th) # (nTh, 3)
    Xi = approximate_latent_image(Th) # (nTh, 3)

    _Xi = torch.from_numpy(Xi.astype(np.float32)) # (nTh, 3)
    _XYZHat = agent.dec(_Xi) # (nTh, 3)
    XYZHat = _XYZHat.data.numpy() # (nTh, 3)

    fig = plt.gcf()
    #
    ax = fig.add_subplot(1,2,1, projection = "3d")
    ax.plot(Xi[...,0], Xi[...,1], Xi[...,2], color="red", linewidth=1.)

    ax.set_title('Input image \non the latent variables space')
    ax.axis('off')
    ax.view_init(30, azim)
    #
    ax.set_xlim(-1.10,1.10)
    ax.set_ylim(-1.10,1.10)
    ax.set_zlim(-1.10,1.10)
    #
    ax = fig.add_subplot(1,2,2, projection = "3d")
    ax.plot(XYZ[...,0], XYZ[...,1], XYZ[...,2], color = 'lightgray', linewidth = 1)
    ax.plot(XYZHat[...,0], XYZHat[...,1], XYZHat[...,2], color="red", linewidth = 1)
    ax.set_title('Ouput image \non the observable variables space')
    ax.axis('off')
    ax.view_init(30, azim)
    #
    ax.set_xlim(-3.10,3.10)
    ax.set_ylim(-3.10,3.10)
    ax.set_zlim(-1.10,1.10)
    fig.tight_layout()

## S400: Load trained agents to analyze them

### SS410: initialize a loader of trained agents

In [None]:
agentFactory = ConcAgentFactory()
environmentFactory = ConcEnvironmentFactory()
trainerFactory =  WaeTrainerFactory()

store = Store(dbPath)

buildParameterFactory = ConcBuildParameterFactory()
loader = Loader(agentFactory=agentFactory
                , environmentFactory=environmentFactory
                , buildParameterFactory=buildParameterFactory
                , store = store)

### SS420: evaluate trained agents:

evaluation error is here:
* the representitive errors of observable variables
* and the discrepancy between latent referenced distribution and the one projected by trained encoder

In [None]:
tbl = {
    "criteria": []
    , "score": []
    , "epoch": []
      }
for agent, buildParameter, epoch in loader.load(target_casestudy + "%"):
    representative_error, latent_distribution_discrepancy = evaluate_err(buildParameter, agent)
    
    for score, criteria in [
        (representative_error, "Representative Error")
        , (latent_distribution_discrepancy, "Latent Distributions Discrepancy")]:

        for key in buildParameter.__dict__:
            if not key in tbl:
                tbl[key] = []
            tbl[key].append(buildParameter.__dict__[key])
    
        tbl["epoch"].append(epoch)
        tbl["criteria"].append(criteria)
        tbl["score"].append(score)

tbl = pd.DataFrame(tbl)
tbl.to_csv(target_casestudy +  "_score.csv")

### SS430: Print encoder's images

In [None]:
for agent, buildParameter, epoch in loader.load(target_casestudy + "%", buildParameterKey=None):
    
    if epoch == buildParameter.nEpoch:
        for azim in np.arange(0, 360, 30):
            environment = environmentFactory.create(buildParameter)
            fig = plt.figure(figsize=[12, 6])            
            plot_encoder_projection_image(agent, environment, azim)
            fig.savefig("./img/%s_encoder_projection_%s_azim=%03d.png" % (target_casestudy, buildParameter.key, azim))
            plt.close(fig)
        break

### SS440: Print decoder's images

In [None]:
for agent, buildParameter, epoch in loader.load(target_casestudy + "%", buildParameterKey=None):
    
    if epoch == buildParameter.nEpoch:
        for azim in np.arange(0, 360, 30):
            environment = environmentFactory.create(buildParameter)
            fig = plt.figure(figsize=[12, 6])
            plot_decoder_projection_image(agent, environment, azim)
            fig.savefig("./img/%s_decoder_projection_%s_azim=%03d.png" % (target_casestudy, buildParameter.key, azim))
            plt.close(fig)
        break