In [None]:
# To mount
from google.colab import drive
drive.mount('/content/drive')

# # To initialize the work environment
# %cd /content/drive/My Drive/
# !git clone https://github.com/allnightlight/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance/casestudies -b casestudies

# # To update the work environment
# %cd /content/drive/My Drive/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance
# !git pull

%cd /content/drive/My Drive/ConditionalWassersteinAutoencoderPoweredBySinkhornDistance/casestudies

In [None]:
import os
import sys
sys.path.append("../framework/")
sys.path.append("../sl/")
sys.path.append("../wae/")

import torch
import numpy as np
import matplotlib.pylab as plt
import pandas as pd

from conc_environment_factory import ConcEnvironmentFactory
from conc_build_parameter import ConcBuildParameter
from conc_build_parameter_factory import ConcBuildParameterFactory
from conc_agent_factory import ConcAgentFactory
from wae_trainer_factory import WaeTrainerFactory

from wae_batch_data_agent import WaeBatchDataAgent
from wae_batch_data_environment import WaeBatchDataEnvironment

from builder import Builder
from store import Store
from mylogger import MyLogger

from loader import Loader

In [None]:
dbPath = "training_log.sqlite"

In [None]:
target_casestudy = "cs01c"

In [None]:
def evaluate_err(buildParameter, agent):
    environment = environmentFactory.create(buildParameter)
    trainer = trainerFactory.create(buildParameter, agent, environment)

    err_latent = []
    err_observable = []
    for dataBatchEnv in environment.generateBatchDataIterator():
        dataBatchAg = agent(dataBatchEnv)
        
        _err_latent, _ = trainer.measure_distance(dataBatchAg._Xi, dataBatchAg._XiHat)
        _err_observable = torch.mean(torch.abs(dataBatchAg._XHat - dataBatchEnv._X))

        err_latent.append(_err_latent.data.numpy())
        err_observable.append(_err_observable.data.numpy())
    err_latent = np.mean(err_latent)
    err_observable = np.mean(err_observable)
    
    return err_observable, err_latent

In [None]:
import itertools

In [None]:
def get_pair_of_xi_and_xihat(buildParameter, agent, nBatches):
    environment = environmentFactory.create(buildParameter)
    trainer = trainerFactory.create(buildParameter, agent, environment)

    Xi = []
    XiHat = []
    for dataBatchEnv in itertools.islice(environment.generateBatchDataIterator(), nBatches):
        dataBatchAg = agent(dataBatchEnv)

        Xi.append(dataBatchAg._Xi.data.numpy())
        XiHat.append(dataBatchAg._XiHat.data.numpy())
    Xi = np.concatenate(Xi, axis=0) # (*, nXi)
    XiHat = np.concatenate(XiHat, axis=0) # (*, nXi)
    
    return Xi, XiHat

In [None]:
def plot_encoder_projection_image(agent, environment):
    d_out = environment.d_out
    d_in = environment.d_in
    
    nX = 2**4
    nY = 2**4
    x = np.linspace(0, 1, nX) # (nX)
    y = np.linspace(0, 1, nY) # (nY)

    X, Y = np.meshgrid(x, y) # (nY, nX)
    Xi1 = np.stack((X, Y), axis=2) # (nY, nX, 2)
    Xi2 = np.stack((X.T, Y.T), axis=2) # (nX, nY, 2)

    nR = 2**2
    nTheta = 2**6
    r = np.linspace(d_in/2, d_out/2, nR) # (nR)
    theta = np.linspace(0, 2*np.pi, nTheta) # (nTheta)

    R, Theta = np.meshgrid(r, theta) # (nTheta, nR)
    X1 = np.stack((R * np.cos(Theta), R * np.sin(Theta)), axis=2) # (nTheta, nR, 2)
    X2 = np.stack((R.T * np.cos(Theta.T), R.T * np.sin(Theta.T)), axis=2) # (nR, nTheta, 2)
    _X1 = torch.from_numpy(X1.astype(np.float32).reshape(-1,2)) # (nTheta * nR, 2)
    _X2 = torch.from_numpy(X2.astype(np.float32).reshape(-1,2)) # (nR * nTheta, 2)

    _Z = torch.ones(nR * nTheta).reshape(-1,1) # (*, nZ = 1)

    dataBatchEnv = WaeBatchDataEnvironment(_X1, _Z)
    dataBatchAgent = agent(dataBatchEnv)
    XiHat1 = dataBatchAgent._XiHat.data.numpy().reshape(nTheta, nR, -1) # (nTheta, nR, 2)

    dataBatchEnv = WaeBatchDataEnvironment(_X2, _Z)
    dataBatchAgent = agent(dataBatchEnv)
    XiHat2 = dataBatchAgent._XiHat.data.numpy().reshape(nR, nTheta, -1) # (nR, nTheta, 2)

    fig = plt.gcf()
    #
    ax = fig.add_subplot(1,2,1)
    ax.plot(X1[...,0], X1[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.plot(X2[...,0], X2[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.contourf(X1[...,0], X1[...,1], R, cmap = plt.get_cmap('Blues'))
    ax.set_title('Input image \non the observable variables space')
    ax.axis('off')
    ax.set_aspect('equal', 'datalim')
    ax.set_xlim(-1,1)
    ax.set_ylim(-1,1)
    plt.tight_layout()
    #
    ax = fig.add_subplot(1,2,2)
    ax.plot(Xi1[...,0], Xi1[...,1], '-', color = "gray", linewidth = 0.5)
    ax.plot(Xi2[...,0], Xi2[...,1], '-', color = "gray", linewidth = 0.5)
    ax.plot(XiHat1[...,0], XiHat1[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.plot(XiHat2[...,0], XiHat2[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.contourf(XiHat1[...,0], XiHat1[...,1], R, cmap = plt.get_cmap('Blues'))
    ax.set_title('Projected image\non the latent variables space')    
    ax.set_aspect('equal', 'datalim')
    ax.set_xlim(-0.10,1.10)
    ax.set_ylim(-0.10,1.10)
    ax.axis('off')
    #
    fig.tight_layout()

In [None]:
def plot_decoder_projection_image(agent, environment):

    d_out = environment.d_out
    d_in = environment.d_in

    nR = 2**2
    nTheta = 2**6
    r = np.linspace(d_in/2, d_out/2, nR) # (nR)
    theta = np.linspace(0, 2*np.pi, nTheta) # (nTheta)

    R, Theta = np.meshgrid(r, theta) # (nTheta, nR)
    X1 = np.stack((R * np.cos(Theta), R * np.sin(Theta)), axis=2) # (nTheta, nR, 2)
    X2 = np.stack((R.T * np.cos(Theta.T), R.T * np.sin(Theta.T)), axis=2) # (nR, nTheta, 2)

    nX = 2**4
    nY = 2**4
    x = np.linspace(0, 1, nX) # (nX)
    y = np.linspace(0, 1, nY) # (nY)

    X, Y = np.meshgrid(x, y) # (nY, nX)
    Xi1 = np.stack((X, Y), axis=2) # (nY, nX, 2)
    Xi2 = np.stack((X.T, Y.T), axis=2) # (nX, nY, 2)
    _Xi1 = torch.from_numpy(Xi1.astype(np.float32).reshape(-1,2)) # (nY * nX, 2)
    _Xi2 = torch.from_numpy(Xi2.astype(np.float32).reshape(-1,2)) # (nX * nY, 2)

    _XHat1 = agent.dec(_Xi1) # (nY * nX, 2)
    _XHat2 = agent.dec(_Xi2) # (nY * nX, 2)

    XHat1 = _XHat1.data.numpy().reshape(nY, nX, 2) # (nY, nX, 2)
    XHat2 = _XHat2.data.numpy().reshape(nX, nY, 2) # (nX, nY, 2)

    fig = plt.gcf()
    #
    ax = fig.add_subplot(1,2,1)
    ax.plot(Xi1[...,0], Xi1[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.plot(Xi2[...,0], Xi2[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.contourf(Xi1[...,0], Xi1[...,1], Xi1[..., 0], cmap = plt.get_cmap('Reds'))
    ax.set_title('Input image \non the latent variables space')
    ax.axis('off')
    ax.set_aspect('equal', 'datalim')
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    plt.tight_layout()
    #
    ax = fig.add_subplot(1,2,2)
    ax.plot(X1[...,0], X1[...,1], '-', color = "gray", linewidth = 0.5)
    ax.plot(X2[...,0], X2[...,1], '-', color = "gray", linewidth = 0.5)

    ax.plot(XHat1[...,0], XHat1[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.plot(XHat2[...,0], XHat2[...,1], '-', color = "lightgray", linewidth = 0.5)
    ax.contourf(XHat1[...,0], XHat1[...,1], Xi1[..., 0], cmap = plt.get_cmap('Reds'))

    ax.set_title('Projected image \non the observable variables space')
    ax.set_aspect('equal', 'datalim')
    ax.set_xlim(-1.10,1.10)
    ax.set_ylim(-1.10,1.10)
    ax.axis('off')
    #
    fig.tight_layout()

## S400: Load trained agents to analyze them

### SS410: initialize a loader of trained agents

In [None]:
agentFactory = ConcAgentFactory()
environmentFactory = ConcEnvironmentFactory()
trainerFactory =  WaeTrainerFactory()

store = Store(dbPath)

buildParameterFactory = ConcBuildParameterFactory()
loader = Loader(agentFactory=agentFactory
                , environmentFactory=environmentFactory
                , buildParameterFactory=buildParameterFactory
                , store = store)

### SS420: evaluate trained agents:

evaluation error is here:
* the representitive errors of observable variables
* and the discrepancy between latent referenced distribution and the one projected by trained encoder

In [None]:
tbl = {
    "representative_error": []
    , "latent_distribution_discrepancy": []
    , "epoch": []
      }
for agent, buildParameter, epoch in loader.load(target_casestudy + "%"):
    
    for key in buildParameter.__dict__:
        if not key in tbl:
            tbl[key] = []
        tbl[key].append(buildParameter.__dict__[key])
    
    err_observable, err_latent = evaluate_err(buildParameter, agent)
    
    tbl["epoch"].append(epoch)
    tbl["representative_error"].append(err_observable)
    tbl["latent_distribution_discrepancy"].append(err_latent)
    
tbl = pd.DataFrame(tbl)
tbl.to_csv(target_casestudy +  "_error.csv")

buildParameterBest = tbl.iloc[tbl["latent_distribution_discrepancy"].idxmin(),]["key"]

In [None]:
tbl = {
    "var": []
    , "x": []
    , "y": []
    }
for agent, buildParameter, epoch in loader.load(target_casestudy + "%"):
    
    if epoch == buildParameter.nEpoch:
            
        Xi, XiHat = get_pair_of_xi_and_xihat(buildParameter, agent, nBatches = 8)

        for label, var in [ 
            ("Xi", Xi)
            , ("XiHat", XiHat)
            ]:
            for x, y in var:
                
                for key in buildParameter.__dict__:
                    if not key in tbl:
                        tbl[key] = []
                    tbl[key].append(buildParameter.__dict__[key])                
                
                tbl["var"].append(label)
                tbl["x"].append(x)
                tbl["y"].append(y)

tbl = pd.DataFrame(tbl)
tbl.to_csv(target_casestudy +  "_xi_and_xihat.csv")

### SS430: plot encoder's projection

In [None]:
for agent, buildParameter, epoch in loader.load(target_casestudy + "%", buildParameterKey=buildParameterBest):
    
    if epoch == buildParameter.nEpoch:
        environment = environmentFactory.create(buildParameter)
        fig = plt.figure(figsize=[8, 6])
        plot_encoder_projection_image(agent, environment)
        fig.savefig("./img/encoder_projection_%s.png" % buildParameter.key)
        plt.close(fig)

### SS440: plot decoder's projection

In [None]:
for agent, buildParameter, epoch in loader.load(target_casestudy + "%", buildParameterKey=buildParameterBest):
    
    if epoch == buildParameter.nEpoch:
        environment = environmentFactory.create(buildParameter)
        fig = plt.figure(figsize=[8, 6])
        plot_decoder_projection_image(agent, environment)
        fig.savefig("./img/deccoder_projection_%s.png" % buildParameter.key)
        plt.close(fig)