In [None]:
import os
import sys
sys.path.append("../framework/")
sys.path.append("../sl/")
sys.path.append("../wae/")

import numpy as np
import matplotlib.pylab as plt
import pandas as pd

from conc_environment import ConcEnvironment
from conc_environment_factory import ConcEnvironmentFactory
from conc_build_parameter import ConcBuildParameter
from conc_build_parameter_factory import ConcBuildParameterFactory

from builder import Builder
from store import Store
from wae_agent_factory import WaeAgentFactory
from wae_trainer_factory import WaeTrainerFactory
from mylogger import MyLogger

from loader import Loader

In [None]:
# storeDbPath = "training_history.sqlite"
# if os.path.exists(storeDbPath):
#     os.remove(storeDbPath)

## S400: Load trained agents to analyze them

In [None]:
def count_error(agent, environment):

    cnts = {}
    for segment, dataEnv in [
            ("Train", environment.getDataTrain())
            , ("Test", environment.getDataTest())
        ]:

        dataAgent = agent(dataEnv)

        XiHat = dataAgent._XiHat.data.numpy()

        threshold = agent.cluster_interval/2

        Z = dataEnv._Z.data.numpy() # (*, nZ)

        Zhat = np.zeros((XiHat.shape[0], environment.nZ)) # (*, nZ)
        Zhat[XiHat[:,0] < threshold,0] = 1
        Zhat[XiHat[:,0] >= threshold,1] = 1

        cnt = []
        for k1 in range(environment.nZ):
            row = (np.sum(Zhat[Z[:,k1] == 1,k1] == 1)
                   , np.sum(Zhat[Z[:,k1] == 1,k1] == 0))
            cnt.append(row)
            
        # cnt[0] = (
        #     numbers of Z[0] = 1 and Zhat[0] = 1
        #     ,numbers of Z[0] = 1 and Zhat[0] = 0)
        #  ...
        # cnt[i] = (
        #     numbers of Z[i] = 1 and Zhat[i] = 1
        #     ,numbers of Z[i] = 1 and Zhat[i] = 0)
        # i in 0:nZ

        cnts[segment] = cnt

    return cnts

In [None]:
def show_distribution(agent, environment):
    for segment, dataEnv in [
            ("Train", environment.getDataTrain())
            , ("Test", environment.getDataTest())
        ]:

        dataAgent = agent(dataEnv)

        XiHat = dataAgent._XiHat.data.numpy()

        Z = dataEnv._Z.data.numpy() # (*, nZ)

        fig = plt.figure(figsize=[10/2.57, 10/2.57])
        for i in range(environment.nZ):
            XiGivenZ = XiHat[Z[:,i] == 1,:]

            theta = np.linspace(0,1,2**5)*np.pi*2
            x = np.cos(theta) * 2
            y = np.sin(theta) * 2    

            color = {0: "r", 1: "b"}
            ax = fig.add_subplot(environment.nZ, 1, i+1)
            ax.plot(XiGivenZ[:,0], XiGivenZ[:,1], '.', color = color[i])        
            for j in range(environment.nZ):
                ax.plot(x + agent.cluster_interval * j ,y, '--', color=color[j])
            ax.grid()
            ax.set_xlim((-2,2+agent.cluster_interval*(environment.nZ-1)))
            ax.set_ylim((-2,2))
            ax.set_ylabel('Z = %d@%s' % (i, segment))
        plt.tight_layout()

### SS410: define iterators of agents to be evaluated:

In [None]:
storeDbPath = "training_history.sqlite"

agentFactory = WaeAgentFactory()
environmentFactory = ConcEnvironmentFactory()
store = Store(storeDbPath)

buildParameterFactory = ConcBuildParameterFactory()
loader = Loader(agentFactory=agentFactory
                , environmentFactory=environmentFactory
                , buildParameterFactory=buildParameterFactory
                , store = store)

In [None]:
def genTrainedAgents():
    search_label = "test case #1%"
    for agent, buildParameter, epoch in loader.load(search_label):
        yield agent, buildParameter, epoch

### SS420: check the discrimation error:

In [None]:
tbl = {
    "build.key": []
    , "build.label": []
    , "epoch": []
    , "segment": []
    , "TN": []
    , "FP": []
    , "FN": []
    , "TP": []}
for agent, buildParameter, epoch in genTrainedAgents():
    environment = environmentFactory.create(buildParameter)

    cnts = count_error(agent, environment)
    
    for segment in ["Train", "Test"]:
        tbl["build.label"].append(buildParameter.label)
        tbl["epoch"].append(epoch)
        tbl["build.key"].append(buildParameter.key)
        tbl["segment"].append(segment)
        tbl["TN"].append(cnts[segment][0][0]) # z = (1,0) and zhat = (1,0)
        tbl["FP"].append(cnts[segment][0][1]) # z = (1,0) and zhat = (0,1)
        tbl["FN"].append(cnts[segment][1][1]) # z = (0,1) and zhat = (1,0)
        tbl["TP"].append(cnts[segment][1][0]) # z = (0,1) and zhat = (0,1)

pd.DataFrame(tbl)

### SS430: show the distribution on latent space:

In [None]:
for agent, buildParameter, epoch in genTrainedAgents():
    environment = environmentFactory.create(buildParameter)
    show_distribution(agent, environment)
    break