
### Dataset

In [None]:
env = 'Breakout'
num_episodes = 25 #use them all?

In [None]:
import anomapy.train.sssn as sssn
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.torchutils as tu

def distance(model, episode):
    z, d = sssn.distance(model, episode)
    return tu.to_numpy(z), tu.to_numpy(d)
    
def plot_latent(model, episode):
    z = tu.to_numpy(sssn.encode(model, episode))
    images = T.HWC(tu.to_numpy(episode))
    return J.scatter_image(z[:,0], z[:,1], images, scatter_colour='blue', line_colour='#b9d1fa', scale=1.5)

def order_by(model, episode):
    d, z = distance(model, episode)
    images = T.HWC(tu.to_numpy(episode))
    i = np.argsort(-d)
    return d[i], images[i]

In [None]:
import datasets
from pprint import pprint
import numpy as np

dataset = datasets.dataset('aad.raw.{0}'.format(env))
dataset.state.transform.to_float().CHW().torch()
episodes = [x for x in dataset.state.load(num_episodes)]
for i,e in enumerate(episodes):
    print("episode:", i, e.shape)
    
print(np.sum([e.shape[0] for e in episodes]))
    
episode_test = episodes[-1]
episodes = episodes[:-1]


# Train


In [None]:
from types import SimpleNamespace

import pyworld.toolkit.tools.wbutils as wbu
import pyworld.toolkit.tools.fileutils as fu
import pyworld.toolkit.tools.torchutils as tu
import pyworld.toolkit.tools.visutils.jupyter as J

import anomapy.train.sssn as sssn

CONFIG = sssn.default_config()
CONFIG.update(dataset.meta.to_dict())
CONFIG['latent_shape'] = 256
CONFIG['epochs'] = 12
CONFIG['batch_size'] = 128
CONFIG['total_states'] = sum([e.shape[0] for e in episodes])
CONFIG['episodes'] = len(episodes)
CONFIG = SimpleNamespace(**CONFIG)

print(CONFIG)

DRYRUN = True
PROJECT = "anomaly-detection"
RUN_ID = "{0}-{1}-{2}".format(CONFIG.model, CONFIG.name, fu.file_datetime())
RUN_TAGS = [CONFIG.model, CONFIG.name]

wbu.dryrun(DRYRUN)

optimiser = sssn.new(dryrun=DRYRUN, **CONFIG.__dict__)
model = optimiser.model

wb = wbu.WB(PROJECT, model, id=RUN_ID, tags=RUN_TAGS, config=CONFIG.__dict__)
#loss_plot = J.dynamic_plot(update_after=10)

#z = tu.to_numpy(sssn.encode(model, episodes[0]))
#plot = J.plot(z[:,0], z[:,1], mode=J.line_mode.both)

with wb:
    for i in range(CONFIG.epochs):
        print("---- epoch: ", i)
        for episode in episodes:
            for loss in sssn.epoch(optimiser, episode, CONFIG.batch_size):
                pass #loss_plot.update(None, loss['loss'])
            #z = tu.to_numpy(sssn.encode(model, episodes[0]))
            #plot.set_data(z[:,0], z[:,1])
            print("loss:", optimiser.cma())


# Load Model

In [None]:
from types import SimpleNamespace
import anomapy.train.sssn as sssn
import pyworld.toolkit.tools.wbutils as wbu


def distance(model, episode):
    z, d = sssn.distance(model, episode)
    return tu.to_numpy(z), tu.to_numpy(d)


env = "Breakout"
dryruns = sorted([r for r in wbu.dryruns() if env in r])
print(dryruns[-1])

models, config = wbu.load(dryruns[-1]) #load the most recently trained model
model = models['model.pt'].load(sssn.model(**config))
config = SimpleNamespace(**config)

# Load Anomalies


In [None]:
import datasets

dataset = datasets.dataset('aad.anomaly.{0}'.format(env))
dataset.state.transform.to_float().CHW().torch()
anoms = [(a,e[0]) for a,e in dataset.meta.anomaly.items()]
a_episodes = [d for d in dataset.state.load_files(*[e[1] for e in anoms])]
a_labels = [d for d in dataset.label.load_files(*[e[1] for e in anoms])]

In [None]:
import numpy as np
a_tlabels = [np.logical_or(l[:-1], l[1:]).astype(np.uint8) for l in a_labels] #transition labels

# ROC

In [None]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc


import plotly
import plotly.graph_objects as go

import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.visutils.plot as vplot
import pyworld.toolkit.tools.torchutils as tu
import pyworld.toolkit.tools.fileutils as fu

import numpy as n
from pprint import pprint

def roc(label, score):
    assert label.shape[0] == score.shape[0]
    fpr, tpr, _ = roc_curve(label, score)
   
    return fpr, tpr

fprs = []
tprs = []
aucs = {}
legend = []
for i in range(len(a_episodes[:-1])):
    episode = a_episodes[i]
    anomaly = anoms[i][0]
    label = a_tlabels[i]
    _, score = distance(model, episode)
    
    fpr, tpr, _ = roc_curve(label, score)
    aucs[anomaly] = auc(fpr, tpr)
    fprs.append(fpr)
    tprs.append(tpr)
    legend.append(anomaly)

pprint(aucs)
legend[0] = 'flicker'
legend[1] = 'visual artefact'
legend = [l.replace('_', ' ') for l in legend]
print(legend)

plot = J.plot(fprs, tprs, legend=legend, show=False)
plot.fig.update_layout(showlegend=True, autosize=False, width=300, height=280, margin=dict(l=5,b=5,r=5,t=5))
plot.fig.update_layout(dict(legend=dict(xanchor='center', x=0.5, orientation='h')))

#path = "/home/ben/Downloads/rocs/"
#plot.fig.write_image(path + "{0}.png".format(env))
#fu.save(path + "{0}.json".format(env), aucs)



# Histograms

In [None]:
anoms[0] = ('flicker', anoms[1])
anoms[1] = ('visual artefact', anoms[1])
anoms = [(a[0].replace('_',' '), a[1]) for a in anoms]
for a in anoms:
    print(a)

In [None]:
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.visutils.plot as vplot
import pyworld.toolkit.tools.torchutils as tu
import pyworld.toolkit.tools.datautils as du
import numpy as np


import plotly
import plotly.graph_objects as go


subplot = plotly.subplots.make_subplots(rows=3,cols=2, vertical_spacing = 0.08, horizontal_spacing = 0.05, subplot_titles=[a[0] for a in anoms[:-1]])



showlegend = True

def histogram(model, episode, labels, title, i, j, bins=50):
    z, d = distance(model, episode)
    d = du.normalise(d)
    
    
    d_a = d[labels==1]
    d_n = d[labels==0]
    print(d.shape, d_a.shape, d_n.shape)
    binsize = (np.max(d) - np.min(d)) / bins
    print("anomaly count: ", d_a.shape[0], "normal count: ", d_n.shape[0])
    subplot.add_trace(go.Histogram(x=d_a, marker=dict(color='red'), name='anomaly', showlegend=showlegend, xbins=dict(size=binsize)), row=i, col=j)
    subplot.add_trace(go.Histogram(x=d_n, marker=dict(color='blue'), name='normal', showlegend=showlegend, xbins=dict(size=binsize)), row=i, col=j)

histogram(model, a_episodes[0], a_tlabels[0], anoms[0], 1, 1)
showlegend=False
histogram(model, a_episodes[1], a_tlabels[1], anoms[1], 1, 2)
histogram(model, a_episodes[2], a_tlabels[2], anoms[2], 2, 1)
histogram(model, a_episodes[3], a_tlabels[3], anoms[3], 2, 2)
histogram(model, a_episodes[4], a_tlabels[4], anoms[4], 3, 1)
histogram(model, a_episodes[5], a_tlabels[5], anoms[5], 3, 2)

subplot.update_layout(yaxis_type="log")
subplot.update_layout(yaxis2_type="log")
subplot.update_layout(yaxis3_type="log")
subplot.update_layout(yaxis4_type="log")
subplot.update_layout(yaxis5_type="log")
subplot.update_layout(yaxis6_type="log")
#print(subplot)
subplot.update_layout(margin=dict(l=5,b=0,r=5,t=20))
subplot.update_layout(dict(legend=dict(xanchor='center', y = -0.03, x=0.5, orientation='h')))


# Visualise Latent Space

In [None]:
assert config.latent_shape == 2 #otherwise... hmmm

# Distance Graphs
Distance graphs are interactive and show the score (naturaly ordered) of each transition. 

In [None]:
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.torchutils as tu
import numpy as np
import datasets

def plot_distance(model, episode):
    z, d = distance(model, episode)
    images = T.HWC(tu.to_numpy(episode))

    plot = J.SimplePlot(np.arange(len(d)), d)
    image1 = J.SimpleImage(images[0])
    image2 = J.SimpleImage(images[1])

    def on_hover(trace, points, state):
        i = points.point_inds[0]
        image1.set_image(images[i])
        image2.set_image(images[i+1])

    plot.on_hover(on_hover)
    d_images = J.layout_horizontal(image1.fig, image2.fig)
    d_plot = J.layout_horizontal(plot.fig)
    J.display(J.layout_vertical(d_plot, d_images))

env = "Pong"
dataset = datasets.dataset('aad.raw.{0}'.format(env))
dataset.state.transform.to_float().CHW().torch()
episode = [x for x in dataset.state.load(1)][0]
plot_distance(model, episode)

"""
for i in range(len(a_episodes)):
    a = anoms[i][0]
    print("========================== {0} ==========================".format(a))
    plot_distance(model, a_episodes[i]) 
""" 

# Visualise Anomalies

In [None]:
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.torchutils as tu
import numpy as np

for i in range(len(a_episodes)):
    a = anoms[i][0]
    print(a)
    J.images(T.HWC(tu.to_numpy(a_episodes[i])))
    #J.scatter(np.arange(len(a_labels[i])), a_labels[i])
    

#### Load and show old 2D visualisation

In [None]:
from pprint import pprint
from types import SimpleNamespace
import anomapy.train.sssn as sssn
import pyworld.toolkit.tools.wbutils as wbu
import datasets
import numpy as np

#load model
env = "Breakout"
run = "benedict-wilkins/anomapy/sssn-Breakout-20200209131829"
models, config = wbu.load(run)
config['state'] = dict(shape=config['state_shape'])
model = models['model.pt'].load(sssn.model(**config))
config = SimpleNamespace(**config)

#load data
num_episodes = 1
dataset = datasets.dataset('aad.raw.{0}'.format(env))
dataset.state.transform.to_float().CHW().torch()
episodes = [x for x in dataset.state.load(num_episodes)]



In [None]:
#visualise
import anomapy.train.sssn as sssn
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.torchutils as tu

a = 5
episode = a_episodes[a]
labels = a_labels[a].astype(np.uint8)
print(labels)

def plot_latent(model, episode):
    z = tu.to_numpy(sssn.encode(model, episode))
    images = T.HWC(tu.to_numpy(episode))
    return J.scatter_image(z[:,0], z[:,1], images, scatter_colour=np.array(['blue','red'])[labels], line_colour='#b9d1fa', scale=1.5)

from ipywidgets import Image, Layout, VBox, HBox, interact, IntSlider, IntProgress, HTML, Output
import ipywidgets as widgets

fig, image = plot_latent(model, episode)

fig.update_layout(width=600, height=500)
#box_layout = widgets.Layout(display='flex',flex_flow='row',align_items='center',width='100%',height='100%')
#display(HBox([fig], layout=box_layout))

#box_layout = widgets.Layout(display='flex',flex_flow='row',align_items='center',width='100%',height='100%')
#display(HBox([fig, image_widget], layout=box_layout)) #basically... this needs to be done in jupyter..?!]


In [None]:
image

-----------------
-----------------
-----------------
-----------------

# OTHER DEMOS


#### dynamic plot example

In [None]:
import pyworld.toolkit.tools.visutils.jupyter as J
import time
import numpy as np

loss_plot = J.dynamic_plot(update_after=100)

for i in range(1000):
    time.sleep(0.01)
    loss_plot.update(i, np.sin(i/(np.pi*2)))
    

In [None]:
z = tu.to_numpy(optimiser.encode(episode))
images = T.HWC(tu.to_numpy(episode))

#### Interactive plot example

In [None]:
import pyworld.toolkit.tools.visutils.jupyter as J
import numpy as np

x = np.arange(1000)
y = np.sin(x)

plot = J.SimplePlot(x,y)
def on_hover(trace, points, state):
    ind = points.point_inds[0]
    print(ind)
    
plot.on_hover(on_hover)
plot.display()

#### Interactive plot with images

In [None]:
import pyworld.toolkit.tools.visutils.jupyter as J
import numpy as np

N = 100
x = np.arange(N)
y = np.sin(x)
images = np.random.randint(0,255,size=(N, 100, 100, 3))
image = J.SimpleImage(images[0])

plot = J.SimplePlot(x,y)
def on_hover(trace, points, state):
    i = points.point_inds[0]
    image.set_image(images[i])
    
plot.on_hover(on_hover)

J.display(J.layout_horizontal(plot.fig, image.fig))



##### Sliding Histogram

In [None]:
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.visutils.plot as vplot
import numpy as np

N = 100
x = np.arange(N)
y = np.sin(x)
yc = np.random.uniform(0,1,size=y.shape[0])

fig = J.histogram([y,yc], show=False)
vplot.histogram_slider(fig, sizes=np.arange(0.01,0.5,0.01))



#### subplots

In [None]:
import numpy as np

N = 100
x = np.arange(N)
y = np.sin(x)
yc = np.random.uniform(0,1,size=y.shape[0])

import plotly
import plotly.graph_objects as go

def hist(x, color, name, showlegend=False):
    return go.Histogram(x=x, marker=dict(color=color), name=name, showlegend=showlegend)

c1 = 'red'
c2 = 'blue'


subplot = plotly.subplots.make_subplots(rows=3, cols=2)
trace1 = subplot.add_trace(hist(x, c1, 'anomaly', showlegend=True), row=1, col=1)
trace2 = subplot.add_trace(hist(y, c2, 'normal', showlegend=True), row=1, col=1)
trace3 = subplot.add_trace(hist(yc, c1, 'anomaly'), row=1, col=2)

subplot


### change scatter colour

In [None]:
import pyworld.toolkit.tools.visutils.jupyter as J
import numpy as np

n = 10
x = np.arange(n)
y = np.sin(x)

plot = J.SimplePlot(x,y,mode=J.line_mode.marker)
def on_hover(trace, points, state):
    ind = points.point_inds[0]
    print(ind)
    
plot.on_hover(on_hover)
#print(plot.fig)
plot.fig.data[0]['marker'] = dict(color=np.random.randint(0,3,size=n).tolist())
plot.display()

