In [None]:
import torch
import numpy as np

import pyworld.toolkit.tools.torchutils as tu
import pyworld.toolkit.tools.visutils as vu
import pyworld.toolkit.tools.datautils as du
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J

from pyworld.algorithms.optimise.TripletOptimiser import SASTripletOptimiser
from pyworld.toolkit.nn.CNet import CNet2
from pyworld.toolkit.nn.MLP import MLP
from anomapy import load
from anomapy.evaluate import evaluate

import datasets

import sys


In [None]:
sys.argv = ["foo", "-run", "sssn-Breakout-2D-VIS"]

env = "Breakout"
dataset_name = "aad.raw.{0}".format(env)

model, kwargs = evaluate.initialise()

def transform(episode):
    state = T.to_float(episode['state'][...])
    state = T.CHW(state)
    state = torch.from_numpy(state)
    return state

dataset = datasets.dataset(dataset_name)
state_shape = tuple(reversed(dataset.meta.state_shape))
latent_shape = (2,)

states = [transform(e) for e in dataset.load(1)]

state = sorted(states, key=lambda x: x.shape[0])[-1]

#au = load.anomaly_utils(**kwargs)
#a_episode, n_episode = au.to_torch(au.load_both(au.meta()[0]))


In [None]:
dataset_anomaly = "aad.anomaly.{0}".format(env)
dataset_raw = "aad.raw.{0}".format(env)

def transform(episode):
    state = T.to_float(episode['state'][...])
    state = T.CHW(state)
    state = torch.from_numpy(state)
    #action = load.remove_redundant_actions(episode['action'], env) does this have an effect?
    action = episode['action'][...]
    action[-1] = 0
    action = action.astype(np.int64)
    #action = torch.from_numpy(action)
    if 'label' in episode:
        return {'state':state, 'action':action, 'label':episode['label'][...]}
    return {'state':state, 'action':action}

def plot_latent(state, colour=None):
    z = tu.to_numpy(tu.collect(model, state))
    x,y = z[:,0],z[:,1]
    images = vu.transform.HWC(tu.to_numpy(state))
    return vu.jupyter.scatter_image(x,y,images, scatter_colour=colour, line_colour='#b9d1fa', scale=1.5)
    
dataset_a = datasets.dataset(dataset_anomaly)
dataset_r = datasets.dataset(dataset_raw)

episodes_a = {k.split('.')[0]:v for k,v in dataset_a.load(file_names=True)}
episodes_r = {k.split('.')[0]:v for k,v in dataset_r.load(file_names=True)}
meta = dataset_a.meta
anomalies = [k for k in meta.anomaly]

def plot(anom):
    print(anom)
    episode = transform(episodes_a[meta.anomaly[anom][0]])
    colour = None
    if 'label' in episode:
        colour = np.array(['#636efa','#ef553b'])[episode['label'].astype(np.uint8)]
        
        
    fig, _ = plot_latent(episode['state'], colour=colour)

plot('fill')
plot('block')
plot('freeze')
plot('freeze_skip')
plot('split_horizontal')
plot('split_vertical')



In [None]:
z_n = tu.to_numpy(tu.collect(model, state))
x = z_n[:,0]
y = z_n[:,1]
print(x.shape, y.shape)

images = vu.transform.HWC(tu.to_numpy(state))
print("\n\n\n\n\n\n\n\n\n\n\n")

vu.jupyter.scatter_image(x,y,images,scale=2)