In [1]:
import torch
import numpy as np


import pyworld.toolkit.tools.torchutils as tu
import pyworld.toolkit.tools.visutils as vu
import pyworld.toolkit.tools.datautils as du
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J

from pyworld.algorithms.optimise.TripletOptimiser import SASTripletOptimiser
from pyworld.toolkit.nn.CNet import CNet2
from pyworld.toolkit.nn.MLP import MLP
from anomapy import load
from anomapy.train import initialise

import datasets


env = "Breakout"
dataset_clean = "aad.clean.{0}".format(env)
dataset_anomaly = "aad.anomaly.{0}".format(env)
dataset_raw = "aad.raw.{0}".format(env)


device = tu.device()
batch_size = 128
epochs = 10
model = None

USING DEVICE: cuda


In [2]:
def transform(episode):
    state = T.to_float(episode['state'][...])
    state = T.CHW(state)
    state = torch.from_numpy(state)
    #action = load.remove_redundant_actions(episode['action'], env) does this have an effect?
    action = episode['action'][...]
    action[-1] = 0
    action = action.astype(np.int64)
    #action = torch.from_numpy(action)
    if 'label' in episode:
        return {'state':state, 'action':action, 'label':episode['label'][...]}
    return {'state':state, 'action':action}

def transform_iterate(episode):
    episode = transform(episode)
    state = episode['state']
    action = episode['action']
    return state[:-1], action[:-1], state[1:]

dataset = datasets.dataset(dataset_clean)
state_shape = tuple(reversed(dataset.meta.state_shape))
action_shape = tuple(dataset.meta.action_shape)
latent_shape = (2,)

episodes = [transform_iterate(e) for e in dataset.load(count=10)]





loading dataset aad.clean.Breakout...


IntProgress(value=0, max=10)

In [3]:
state_model = CNet2(state_shape, latent_shape).to(device)
action_model = MLP(latent_shape[0] * 2 + action_shape[0], latent_shape[0]).to(device)
optim = SASTripletOptimiser(state_model, action_model)
model = optim.model



In [4]:
#wb = initialise.WB(optim.model, model='sassn', env=env)
#with wb:
for epoch in range(epochs):
    for episode in episodes:
        for s1,a,s2 in du.batch_iterator(*episode, batch_size=batch_size, shuffle=True):
            #print(a.dtype)
            optim(s1,a,s2)
            #step = wb.step()
            #if step % 100:
            #    wb(**optim.cma.recent())

        print(epoch, optim.cma())

    

0 [0.25271611]
0 [0.21766914]
0 [0.21140797]
0 [0.20913328]
0 [0.2065443]
0 [0.2042705]
0 [0.19181388]
0 [0.18132863]
0 [0.17496928]
0 [0.16919411]
1 [0.16562073]
1 [0.15931369]
1 [0.15374784]
1 [0.14941942]
1 [0.14564559]
1 [0.14179331]
1 [0.13681149]
1 [0.13187728]
1 [0.12883682]
1 [0.12531925]
2 [0.12335443]
2 [0.11893332]
2 [0.11510748]
2 [0.11266962]
2 [0.11016269]
2 [0.10798824]
2 [0.10458586]
2 [0.10136422]
2 [0.09919926]
2 [0.09688085]
3 [0.09550488]
3 [0.09176607]
3 [0.08905282]
3 [0.0872434]
3 [0.0855225]
3 [0.08398537]
3 [0.08197601]
3 [0.08010833]
3 [0.0787977]
3 [0.07732007]
4 [0.07647365]
4 [0.0741141]
4 [0.07238404]
4 [0.07124362]
4 [0.07015119]
4 [0.06914674]
4 [0.06769544]
4 [0.06628686]
4 [0.06536551]
4 [0.06438845]
5 [0.0638175]
5 [0.06217391]
5 [0.06099653]
5 [0.06021083]
5 [0.05947131]
5 [0.05877028]
5 [0.0577362]
5 [0.05677284]
5 [0.05612787]
5 [0.0554394]
6 [0.05502071]
6 [0.05386502]
6 [0.05300871]
6 [0.05244719]
6 [0.05190485]
6 [0.05140329]
6 [0.05061411]
6 [0

In [12]:

from anomapy.evaluate import score, evaluate

#if model is None:
#    RUN = "sassn-Breakout-20200212142752" #change me to load a specific model!
#    model, kwargs = evaluate.initialise(run=RUN) #load it from the run...



#pyo.init_notebook_mode()
def roc(label, score):
    from sklearn.metrics import roc_curve
    from sklearn.metrics import auc
    
    assert label.shape[0] == score.shape[0]
    fpr, tpr, _ = roc_curve(label, score)
    return fpr, tpr

def plot_latent(state, action):
    z = tu.to_numpy(tu.collect(state_model, state))
    x,y = z[:,0],z[:,1]
    images = vu.transform.HWC(tu.to_numpy(state))
    return vu.jupyter.scatter_image(x,y,images)
    
dataset_a = datasets.dataset(dataset_anomaly)
dataset_r = datasets.dataset(dataset_raw)

episodes_a = {k.split('.')[0]:v for k,v in dataset_a.load(file_names=True)}
episodes_r = {k.split('.')[0]:v for k,v in dataset_r.load(file_names=True)}
meta = dataset_a.meta
rocs = []
legend = [k for k in meta.anomaly]


for k in meta.anomaly:
    print(k)
    episode = transform(episodes_a[meta.anomaly[k][0]])
    #plot_latent(episode['state'], episode['action'])
 
    #print(episode['state'].shape, episode['action'].shape)
    #rocs.append(roc(*score.sassn.score(model, episode)))
    #print(fpr, tpr)

#J.plot([r[0] for r in rocs], [r[1] for r in rocs],legend=legend)

import plotly.offline as pyo
import plotly.graph_objs as go
from IPython.display import display, clear_output

fig = vu.plot.plot([r[0] for r in rocs], [r[1] for r in rocs], mode='lines', legend=legend)
display(go.FigureWidget(fig))





loading dataset aad.anomaly.Breakout...


IntProgress(value=0, max=28)

loading dataset aad.raw.Breakout...


IntProgress(value=0, max=41)

fill
---- subsequences:  (157,)
---- anomaly:  96
---- normal:   61
block
---- subsequences:  (188,)
---- anomaly:  107
---- normal:   81
freeze
---- subsequences:  (223,)
---- anomaly:  148
---- normal:   75
freeze_skip
---- subsequences:  (137,)
---- anomaly:  100
---- normal:   37
split_horizontal
---- subsequences:  (213,)
---- anomaly:  126
---- normal:   87
split_vertical
---- subsequences:  (110,)
---- anomaly:  67
---- normal:   43
action
---- subsequences:  (186,)
---- anomaly:  104
---- normal:   82


FigureWidget({
    'data': [{'mode': 'lines',
              'name': 'fill',
              'type': 'scatter',
 …

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00d\x00\x00\x00d\x08\x02\x00\x00\x00\xff\x80\x02\x0…