In [None]:
import os
import sys
sys.path.append("../framework/")
sys.path.append("../sl/")
sys.path.append("../wae/")

import numpy as np
import matplotlib.pylab as plt
from wae_environment import WaeEnvironment
from wae_environment_factory import WaeEnvironmentFactory
from wae_build_parameter import WaeBuildParameter
from wae_build_parameter_factory import WaeBuildParameterFactory
from builder import Builder
from store import Store
from wae_agent_factory import WaeAgentFactory
from wae_trainer_factory import WaeTrainerFactory
from mylogger import MyLogger

from loader import Loader

In [None]:
dbPath = "testDb.sqlite"
if os.path.exists(dbPath):
    os.remove(dbPath)

## S100: Define Conc. Env. to be modeled by deriving WaeEnv.

### SS110: define ConcEnv

In [None]:
class ConcEnv(WaeEnvironment):
    nX = 2
    nZ = 2
        
    def loadData(self):
        nSample = 2**10
        nX = 2
        nZ = 2
        X0 = np.random.randn(nSample//2, nX) + [2,2] # (nSample//2, nX)
        X1 = np.random.randn(nSample//2, nX) + [-2,-2] # (nSample//2, nX)
        X = np.concatenate((X0, X1), axis=0) # (nSample, nX)
        Z = np.zeros((nSample, nZ))
        Z[:nSample//2, 0] = 1
        Z[nSample//2:, 1] = 1        
        
        self.dataX = X
        self.dataZ = Z

### SS120: define ConcEnvFactory

In [None]:
class ConcEnvFactory(WaeEnvironmentFactory):
    def create(self, buildParameter):
        return ConcEnv(buildParameter.nBatch)

## S200: Define Conc. Build Param. 

### SS210: define ConcBuildParameter

In [None]:
class ConcBuildParameter(WaeBuildParameter):
    pass

### SS220: define ConcBuildParameterFactory

In [None]:
class ConcBuildParameterFactory(WaeBuildParameterFactory):
    def create(self):
        return ConcBuildParameter()

## S300: Run learning agents

### SS310: define instances of Conc. Build Params

In [None]:
buildParameter = ConcBuildParameter(nEpoch=2**7, reg_param=1.0, label = "test case #1")

### SS320: initialize an instance of builder

In [None]:
agentFactory = WaeAgentFactory()
environmentFactory = ConcEnvFactory()
trainerFactory = WaeTrainerFactory()

logger = MyLogger(console_print=True)
store = Store("testDb.sqlite")

builder = Builder(agentFactory=agentFactory
                  , environmentFactory=environmentFactory
                  , trainerFactory=trainerFactory
                  , store=store
                  , logger = logger)

### SS330: run build

In [None]:
builder.build(buildParameter)

## S400: Load trained agents to analyze them

### SS410: choose a trained agent

In [None]:
buildParameterFactory = ConcBuildParameterFactory()
loader = Loader(agentFactory=agentFactory
                , environmentFactory=environmentFactory
                , buildParameterFactory=buildParameterFactory
                , store = store)

In [None]:
agent, buildParameter, epoch = [*loader.load("test case #1%")][-1]
print("build parameter label = ", buildParameter.label)
print("epoch = %d" % epoch)

### SS420: check the distribution of the latent and obverseved variables, respectively.

In [None]:
environment = environmentFactory.create(buildParameter)

In [None]:
fig1 = plt.figure()
fig2 = plt.figure()

ax1 = fig1.add_subplot()
ax2 = fig2.add_subplot()

ax1.set_title("Observed varible distribution")
ax2.set_title("Latent varible distribution")

for dataBatchEnv in environment.generateBatchDataIterator():
    dataBatchAg = agent(dataBatchEnv)

    XHat = dataBatchAg._XHat.data.numpy() # (*, nX)
    XiHat = dataBatchAg._XiHat.data.numpy() # (*, nXi)
    Xi = dataBatchAg._Xi.data.numpy() # (*, nXi)

    X = dataBatchEnv._X.data.numpy() # (*, nX)
    Z = dataBatchEnv._Z.data.numpy() # (*, nZ)

    markertype = {0: "o", 1: "^"}
    for k1 in range(environment.nZ):
        idx = Z[:,k1] == 1
        ax1.plot(X[idx,0], X[idx,1], 'b', markerfacecolor = "None", marker=markertype[k1], linestyle="")
        ax1.plot(XHat[idx,0], XHat[idx,1], 'r', markerfacecolor = "None", marker=markertype[k1], linestyle="")
        
        ax2.plot(Xi[idx,0], Xi[idx,1], 'b', markerfacecolor = "None", marker=markertype[k1], linestyle="")
        ax2.plot(XiHat[idx,0], XiHat[idx,1], 'r', markerfacecolor = "None", marker=markertype[k1], linestyle="")