In [1]:
from types import SimpleNamespace
from pprint import pprint

import numpy as np

import anomapy.train.sssn as sssn
import pyworld.toolkit.tools.wbutils as wbu
import pyworld.toolkit.tools.torchutils as tu
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J

import datasets

envs = ['BeamRider','Breakout','Enduro','Pong','Qbert','Seaquest','SpaceInvaders']
envs = ['Breakout','Enduro','Pong','Qbert','Seaquest','SpaceInvaders']



def load_anomaly(env):
    dataset = datasets.dataset('aad.anomaly.{0}'.format(env))
    dataset.state.transform.to_float().CHW().torch()
    #anoms = [(a,e) for a,e in dataset.meta.anomaly.items()]
    for a, files in dataset.meta.anomaly.items():
        yield a, [d for d in dataset.load_files(*files, close=True)]

def load_clean(env):
    dataset = datasets.dataset('aad.clean.{0}'.format(env))
    dataset.state.transform.to_float().CHW().torch()
    files = dataset.files()
    for i,d in enumerate(dataset.load_files(*files, close=True)):
        #print(files[i])
        yield d

def distance(model, episode):
    z, d = sssn.distance(model, episode)
    return tu.to_numpy(z), tu.to_numpy(d)

def encode(model, episode):
    return tu.to_numpy(sssn.encode(model, episode))

def load_model(env):
    dryruns = sorted([r for r in wbu.dryruns() if env in r])
    models, config = wbu.load(dryruns[-1]) #load the most recently trained model
    model = models['model.pt'].load(sssn.model(**config))
    config = SimpleNamespace(**config)
    print(" \n------- INFO ------- ")
    print(" env: {0}, latent_shape: {1}, margin: {2}".format(env, config.latent_shape, config.optim_margin))
    return model, config

USING DEVICE: cuda


# AUC

In [2]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

ENV_AUCS = {}
AUCS = {}
FT = {}

for env in envs:
    model, config = load_model(env)
    print("\n-------------- {0} --------------".format(env))
    for a, episodes in load_anomaly(env):
        fprs, tprs, aucs = [], [], []
        for episode in episodes:
            _, score = distance(model, episode.state)
            fpr, tpr, _ = roc_curve(episode.tlabel, score)
            #fprs.append(fpr)
            #tprs.append(tpr)
            aucs.append(auc(fpr, tpr))
        #FT[a] = (fprs, tprs)
        AUCS[a] = np.mean(aucs)
        print("{0}: auc {1}".format(a, np.mean(aucs)))
        print(AUCS)
    pprint(AUCS)
    ENV_AUCS = {env:AUCS}

        
        
        
        

 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200512_131359-sssn-Breakout-256-20200512141356
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Breakout, latent_shape: 256, margin: 0.2

-------------- Breakout --------------
loading dataset aad.anomaly.Breakout...


IntProgress(value=0, max=17)

block: auc 0.9884270882835728
{'block': 0.9884270882835728}
loading dataset aad.anomaly.Breakout...


IntProgress(value=0, max=17)

flicker: auc 1.0
{'block': 0.9884270882835728, 'flicker': 1.0}
loading dataset aad.anomaly.Breakout...


IntProgress(value=0, max=17)

freeze: auc 0.0019053132305935593
{'block': 0.9884270882835728, 'flicker': 1.0, 'freeze': 0.0019053132305935593}
loading dataset aad.anomaly.Breakout...


IntProgress(value=0, max=17)

freeze_skip: auc 0.9646884002356922
{'block': 0.9884270882835728, 'flicker': 1.0, 'freeze': 0.0019053132305935593, 'freeze_skip': 0.9646884002356922}
loading dataset aad.anomaly.Breakout...


IntProgress(value=0, max=17)

split_horizontal: auc 0.9791326165873362
{'block': 0.9884270882835728, 'flicker': 1.0, 'freeze': 0.0019053132305935593, 'freeze_skip': 0.9646884002356922, 'split_horizontal': 0.9791326165873362}
loading dataset aad.anomaly.Breakout...


IntProgress(value=0, max=17)

split_vertical: auc 0.9907644390307093
{'block': 0.9884270882835728, 'flicker': 1.0, 'freeze': 0.0019053132305935593, 'freeze_skip': 0.9646884002356922, 'split_horizontal': 0.9791326165873362, 'split_vertical': 0.9907644390307093}
{'block': 0.9884270882835728,
 'flicker': 1.0,
 'freeze': 0.0019053132305935593,
 'freeze_skip': 0.9646884002356922,
 'split_horizontal': 0.9791326165873362,
 'split_vertical': 0.9907644390307093}
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200504_153755-sssn-Enduro-256-20200504163751
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Enduro, latent_shape: 256, margin: 0.2

-------------- Enduro --------------
loading dataset aad.anomaly.Enduro...


IntProgress(value=0, max=5)

block: auc 0.8536652109555745
{'block': 0.8536652109555745, 'flicker': 1.0, 'freeze': 0.0019053132305935593, 'freeze_skip': 0.9646884002356922, 'split_horizontal': 0.9791326165873362, 'split_vertical': 0.9907644390307093}
loading dataset aad.anomaly.Enduro...


IntProgress(value=0, max=5)

flicker: auc 1.0
{'block': 0.8536652109555745, 'flicker': 1.0, 'freeze': 0.0019053132305935593, 'freeze_skip': 0.9646884002356922, 'split_horizontal': 0.9791326165873362, 'split_vertical': 0.9907644390307093}
loading dataset aad.anomaly.Enduro...


IntProgress(value=0, max=5)

freeze: auc 9.033521987300225e-05
{'block': 0.8536652109555745, 'flicker': 1.0, 'freeze': 9.033521987300225e-05, 'freeze_skip': 0.9646884002356922, 'split_horizontal': 0.9791326165873362, 'split_vertical': 0.9907644390307093}
loading dataset aad.anomaly.Enduro...


IntProgress(value=0, max=5)

freeze_skip: auc 0.9986305961970208
{'block': 0.8536652109555745, 'flicker': 1.0, 'freeze': 9.033521987300225e-05, 'freeze_skip': 0.9986305961970208, 'split_horizontal': 0.9791326165873362, 'split_vertical': 0.9907644390307093}
loading dataset aad.anomaly.Enduro...


IntProgress(value=0, max=5)

split_horizontal: auc 0.9828059695012723
{'block': 0.8536652109555745, 'flicker': 1.0, 'freeze': 9.033521987300225e-05, 'freeze_skip': 0.9986305961970208, 'split_horizontal': 0.9828059695012723, 'split_vertical': 0.9907644390307093}
loading dataset aad.anomaly.Enduro...


IntProgress(value=0, max=5)

split_vertical: auc 0.9826143164224351
{'block': 0.8536652109555745, 'flicker': 1.0, 'freeze': 9.033521987300225e-05, 'freeze_skip': 0.9986305961970208, 'split_horizontal': 0.9828059695012723, 'split_vertical': 0.9826143164224351}
{'block': 0.8536652109555745,
 'flicker': 1.0,
 'freeze': 9.033521987300225e-05,
 'freeze_skip': 0.9986305961970208,
 'split_horizontal': 0.9828059695012723,
 'split_vertical': 0.9826143164224351}
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200512_122614-sssn-Pong-256-20200512132611
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Pong, latent_shape: 256, margin: 0.2

-------------- Pong --------------
loading dataset aad.anomaly.Pong...


IntProgress(value=0, max=6)

block: auc 0.9914347378564624
{'block': 0.9914347378564624, 'flicker': 1.0, 'freeze': 9.033521987300225e-05, 'freeze_skip': 0.9986305961970208, 'split_horizontal': 0.9828059695012723, 'split_vertical': 0.9826143164224351}
loading dataset aad.anomaly.Pong...


IntProgress(value=0, max=6)

flicker: auc 1.0
{'block': 0.9914347378564624, 'flicker': 1.0, 'freeze': 9.033521987300225e-05, 'freeze_skip': 0.9986305961970208, 'split_horizontal': 0.9828059695012723, 'split_vertical': 0.9826143164224351}
loading dataset aad.anomaly.Pong...


IntProgress(value=0, max=6)

freeze: auc 0.004387562430288742
{'block': 0.9914347378564624, 'flicker': 1.0, 'freeze': 0.004387562430288742, 'freeze_skip': 0.9986305961970208, 'split_horizontal': 0.9828059695012723, 'split_vertical': 0.9826143164224351}
loading dataset aad.anomaly.Pong...


IntProgress(value=0, max=6)

freeze_skip: auc 0.9381244737345188
{'block': 0.9914347378564624, 'flicker': 1.0, 'freeze': 0.004387562430288742, 'freeze_skip': 0.9381244737345188, 'split_horizontal': 0.9828059695012723, 'split_vertical': 0.9826143164224351}
loading dataset aad.anomaly.Pong...


IntProgress(value=0, max=6)

split_horizontal: auc 0.967066867595678
{'block': 0.9914347378564624, 'flicker': 1.0, 'freeze': 0.004387562430288742, 'freeze_skip': 0.9381244737345188, 'split_horizontal': 0.967066867595678, 'split_vertical': 0.9826143164224351}
loading dataset aad.anomaly.Pong...


IntProgress(value=0, max=6)

split_vertical: auc 0.9697092877596257
{'block': 0.9914347378564624, 'flicker': 1.0, 'freeze': 0.004387562430288742, 'freeze_skip': 0.9381244737345188, 'split_horizontal': 0.967066867595678, 'split_vertical': 0.9697092877596257}
{'block': 0.9914347378564624,
 'flicker': 1.0,
 'freeze': 0.004387562430288742,
 'freeze_skip': 0.9381244737345188,
 'split_horizontal': 0.967066867595678,
 'split_vertical': 0.9697092877596257}
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200429_171611-sssn-Qbert-20200429181610
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Qbert, latent_shape: 64, margin: 0.2

-------------- Qbert --------------
loading dataset aad.anomaly.Qbert...


IntProgress(value=0, max=29)

block: auc 0.9313483070667725
{'block': 0.9313483070667725, 'flicker': 1.0, 'freeze': 0.004387562430288742, 'freeze_skip': 0.9381244737345188, 'split_horizontal': 0.967066867595678, 'split_vertical': 0.9697092877596257}
loading dataset aad.anomaly.Qbert...


IntProgress(value=0, max=29)

flicker: auc 0.9999761928158771
{'block': 0.9313483070667725, 'flicker': 0.9999761928158771, 'freeze': 0.004387562430288742, 'freeze_skip': 0.9381244737345188, 'split_horizontal': 0.967066867595678, 'split_vertical': 0.9697092877596257}
loading dataset aad.anomaly.Qbert...


IntProgress(value=0, max=29)

freeze: auc 0.004766654536971042
{'block': 0.9313483070667725, 'flicker': 0.9999761928158771, 'freeze': 0.004766654536971042, 'freeze_skip': 0.9381244737345188, 'split_horizontal': 0.967066867595678, 'split_vertical': 0.9697092877596257}
loading dataset aad.anomaly.Qbert...


IntProgress(value=0, max=29)

freeze_skip: auc 0.9847884228845487
{'block': 0.9313483070667725, 'flicker': 0.9999761928158771, 'freeze': 0.004766654536971042, 'freeze_skip': 0.9847884228845487, 'split_horizontal': 0.967066867595678, 'split_vertical': 0.9697092877596257}
loading dataset aad.anomaly.Qbert...


IntProgress(value=0, max=29)

split_horizontal: auc 0.9900019621221009
{'block': 0.9313483070667725, 'flicker': 0.9999761928158771, 'freeze': 0.004766654536971042, 'freeze_skip': 0.9847884228845487, 'split_horizontal': 0.9900019621221009, 'split_vertical': 0.9697092877596257}
loading dataset aad.anomaly.Qbert...


IntProgress(value=0, max=29)

split_vertical: auc 0.981957648466138
{'block': 0.9313483070667725, 'flicker': 0.9999761928158771, 'freeze': 0.004766654536971042, 'freeze_skip': 0.9847884228845487, 'split_horizontal': 0.9900019621221009, 'split_vertical': 0.981957648466138}
{'block': 0.9313483070667725,
 'flicker': 0.9999761928158771,
 'freeze': 0.004766654536971042,
 'freeze_skip': 0.9847884228845487,
 'split_horizontal': 0.9900019621221009,
 'split_vertical': 0.981957648466138}
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200429_174221-sssn-Seaquest-20200429184220
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Seaquest, latent_shape: 64, margin: 0.2

-------------- Seaquest --------------
loading dataset aad.anomaly.Seaquest...


IntProgress(value=0, max=10)

block: auc 0.9682516634889227
{'block': 0.9682516634889227, 'flicker': 0.9999761928158771, 'freeze': 0.004766654536971042, 'freeze_skip': 0.9847884228845487, 'split_horizontal': 0.9900019621221009, 'split_vertical': 0.981957648466138}
loading dataset aad.anomaly.Seaquest...


IntProgress(value=0, max=10)

flicker: auc 1.0
{'block': 0.9682516634889227, 'flicker': 1.0, 'freeze': 0.004766654536971042, 'freeze_skip': 0.9847884228845487, 'split_horizontal': 0.9900019621221009, 'split_vertical': 0.981957648466138}
loading dataset aad.anomaly.Seaquest...


IntProgress(value=0, max=10)

freeze: auc 0.0005194014878721607
{'block': 0.9682516634889227, 'flicker': 1.0, 'freeze': 0.0005194014878721607, 'freeze_skip': 0.9847884228845487, 'split_horizontal': 0.9900019621221009, 'split_vertical': 0.981957648466138}
loading dataset aad.anomaly.Seaquest...


IntProgress(value=0, max=10)

freeze_skip: auc 0.9961548392087088
{'block': 0.9682516634889227, 'flicker': 1.0, 'freeze': 0.0005194014878721607, 'freeze_skip': 0.9961548392087088, 'split_horizontal': 0.9900019621221009, 'split_vertical': 0.981957648466138}
loading dataset aad.anomaly.Seaquest...


IntProgress(value=0, max=10)

split_horizontal: auc 0.9928991755116001
{'block': 0.9682516634889227, 'flicker': 1.0, 'freeze': 0.0005194014878721607, 'freeze_skip': 0.9961548392087088, 'split_horizontal': 0.9928991755116001, 'split_vertical': 0.981957648466138}
loading dataset aad.anomaly.Seaquest...


IntProgress(value=0, max=10)

split_vertical: auc 0.9949183888365146
{'block': 0.9682516634889227, 'flicker': 1.0, 'freeze': 0.0005194014878721607, 'freeze_skip': 0.9961548392087088, 'split_horizontal': 0.9928991755116001, 'split_vertical': 0.9949183888365146}
{'block': 0.9682516634889227,
 'flicker': 1.0,
 'freeze': 0.0005194014878721607,
 'freeze_skip': 0.9961548392087088,
 'split_horizontal': 0.9928991755116001,
 'split_vertical': 0.9949183888365146}
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200501_093836-sssn-SpaceInvaders-20200501103833
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: SpaceInvaders, latent_shape: 64, margin: 0.2

-------------- SpaceInvaders --------------
loading dataset aad.anomaly.SpaceInvaders...


IntProgress(value=0, max=17)

block: auc 0.9834330422596319
{'block': 0.9834330422596319, 'flicker': 1.0, 'freeze': 0.0005194014878721607, 'freeze_skip': 0.9961548392087088, 'split_horizontal': 0.9928991755116001, 'split_vertical': 0.9949183888365146}
loading dataset aad.anomaly.SpaceInvaders...


IntProgress(value=0, max=17)

flicker: auc 1.0
{'block': 0.9834330422596319, 'flicker': 1.0, 'freeze': 0.0005194014878721607, 'freeze_skip': 0.9961548392087088, 'split_horizontal': 0.9928991755116001, 'split_vertical': 0.9949183888365146}
loading dataset aad.anomaly.SpaceInvaders...


IntProgress(value=0, max=17)

freeze: auc 0.017912028700705917
{'block': 0.9834330422596319, 'flicker': 1.0, 'freeze': 0.017912028700705917, 'freeze_skip': 0.9961548392087088, 'split_horizontal': 0.9928991755116001, 'split_vertical': 0.9949183888365146}
loading dataset aad.anomaly.SpaceInvaders...


IntProgress(value=0, max=17)

freeze_skip: auc 0.974968887435008
{'block': 0.9834330422596319, 'flicker': 1.0, 'freeze': 0.017912028700705917, 'freeze_skip': 0.974968887435008, 'split_horizontal': 0.9928991755116001, 'split_vertical': 0.9949183888365146}
loading dataset aad.anomaly.SpaceInvaders...


IntProgress(value=0, max=17)

split_horizontal: auc 0.9949124105997612
{'block': 0.9834330422596319, 'flicker': 1.0, 'freeze': 0.017912028700705917, 'freeze_skip': 0.974968887435008, 'split_horizontal': 0.9949124105997612, 'split_vertical': 0.9949183888365146}
loading dataset aad.anomaly.SpaceInvaders...


IntProgress(value=0, max=17)

split_vertical: auc 0.9950918768872543
{'block': 0.9834330422596319, 'flicker': 1.0, 'freeze': 0.017912028700705917, 'freeze_skip': 0.974968887435008, 'split_horizontal': 0.9949124105997612, 'split_vertical': 0.9950918768872543}
{'block': 0.9834330422596319,
 'flicker': 1.0,
 'freeze': 0.017912028700705917,
 'freeze_skip': 0.974968887435008,
 'split_horizontal': 0.9949124105997612,
 'split_vertical': 0.9950918768872543}


In [4]:
pprint(ENV_AUCS)

# AUC RESULTS
# BEAM RIDER
{'block': 0.9346528465961829,
 'flicker': 0.9996786591135681,
 'freeze': 0.004777733739095263,
 'freeze_skip': 0.9877708653393901,
 'split_horizontal': 0.9904982172081528,
 'split_vertical': 0.992672783259697}

#BREAKOUT
{'block': 0.9884270882835728,
 'flicker': 1.0,
 'freeze': 0.0019053132305935593,
 'freeze_skip': 0.9646884002356922,
 'split_horizontal': 0.9791326165873362,
 'split_vertical': 0.9907644390307093}

#ENDURO
{'block': 0.8536652109555745,
 'flicker': 1.0,
 'freeze': 9.033521987300225e-05,
 'freeze_skip': 0.9986305961970208,
 'split_horizontal': 0.9828059695012723,
 'split_vertical': 0.9826143164224351}

#PONG
{'block': 0.9914347378564624,
 'flicker': 1.0,
 'freeze': 0.004387562430288742,
 'freeze_skip': 0.9381244737345188,
 'split_horizontal': 0.967066867595678,
 'split_vertical': 0.9697092877596257}

#QBERT
{'block': 0.9313483070667725,
 'flicker': 0.9999761928158771,
 'freeze': 0.004766654536971042,
 'freeze_skip': 0.9847884228845487,
 'split_horizontal': 0.9900019621221009,
 'split_vertical': 0.981957648466138}

#SEAQUEST
{'block': 0.9682516634889227,
 'flicker': 1.0,
 'freeze': 0.0005194014878721607,
 'freeze_skip': 0.9961548392087088,
 'split_horizontal': 0.9928991755116001,
 'split_vertical': 0.9949183888365146}

#SPACE INVADERS
{'block': 0.9834330422596319,
 'flicker': 1.0,
 'freeze': 0.017912028700705917,
 'freeze_skip': 0.974968887435008,
 'split_horizontal': 0.9949124105997612,
 'split_vertical': 0.9950918768872543}


{'SpaceInvaders': {'block': 0.9834330422596319,
                   'flicker': 1.0,
                   'freeze': 0.017912028700705917,
                   'freeze_skip': 0.974968887435008,
                   'split_horizontal': 0.9949124105997612,
                   'split_vertical': 0.9950918768872543}}


{'block': 0.9884270882835728,
 'flicker': 1.0,
 'freeze': 0.0019053132305935593,
 'freeze_skip': 0.9646884002356922,
 'split_horizontal': 0.9791326165873362,
 'split_vertical': 0.9907644390307093}

In [5]:
def displacement_summary(z): # TODO complexity can be cut in half... 
    # count of D^t_t+1 >= D^t_j j > 1
    z = z[...,np.newaxis]  
    d = ((z.T - z)**2).sum(1)
    print(d.shape)
    s1 = np.diagonal(d,offset=1)[:,np.newaxis]
    sj = d[:-1,1:]
    d = np.triu(np.greater(s1, sj).astype(np.uint8), k=1)
    ds = d.sum(1) # number of non-conforming distances for each state
    return ds, ds.sum() * (2 / ((z.shape[0]-1)*(z.shape[0]-2))) #statistic

def real_displacement_summary(z):
    z = z[...,np.newaxis]  
    d = ((z.T - z)**2).sum(1)
    print(d.shape)
    s1 = np.diagonal(d,offset=1)[:,np.newaxis]
    sj = d[:-1,1:]
    d = np.triu(np.maximum(s1 - sj, 0), k=1)
    
    # mean displacement
    dmu = d.sum(1)
    dmu[:-1] = dmu[:-1] / np.arange(dmu.size-1,0,-1) 
    #max displacment
    dmax = np.max(d, axis=1)
    #min displacement
    dmin = np.min(d, axis=1)
    
    return dmu, dmax, dmin


for env in envs:
    model, config = load_model(env)
    print("\n-------------- {0} --------------".format(env))
    for data in load_clean(env):
        z = encode(model, data.state)
        NDS = displacement_summary(z)
        print(NDS[1])
        break



 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200512_131359-sssn-Breakout-256-20200512141356
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Breakout, latent_shape: 256, margin: 0.2

-------------- Breakout --------------
loading dataset aad.clean.Breakout...


IntProgress(value=0, max=102)

(604, 604)
0.0019173236806003205
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200504_153755-sssn-Enduro-256-20200504163751
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Enduro, latent_shape: 256, margin: 0.2

-------------- Enduro --------------
loading dataset aad.clean.Enduro...


IntProgress(value=0, max=30)

(3326, 3326)
0.0001074889389538829
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200512_122614-sssn-Pong-256-20200512132611
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Pong, latent_shape: 256, margin: 0.2

-------------- Pong --------------
loading dataset aad.clean.Pong...


IntProgress(value=0, max=38)

(2662, 2662)
0.004190577910390407
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200429_171611-sssn-Qbert-20200429181610
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Qbert, latent_shape: 64, margin: 0.2

-------------- Qbert --------------
loading dataset aad.clean.Qbert...


IntProgress(value=0, max=174)

(1081, 1081)
0.0002540074829231456
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200429_174221-sssn-Seaquest-20200429184220
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: Seaquest, latent_shape: 64, margin: 0.2

-------------- Seaquest --------------
loading dataset aad.clean.Seaquest...


IntProgress(value=0, max=61)

(1357, 1357)
8.816902328315319e-05
 -- found local run at /home/ben/Documents/repos/anomaly-detection/wandb/dryrun-20200501_093836-sssn-SpaceInvaders-20200501103833
 -- found config file: 
 -- found 1 model(s): 
 ---- model.pt
 
------- INFO ------- 
 env: SpaceInvaders, latent_shape: 64, margin: 0.2

-------------- SpaceInvaders --------------
loading dataset aad.clean.SpaceInvaders...


IntProgress(value=0, max=104)

(1353, 1353)
0.0005726636854576272


## Uniform Distance Statistic

In [None]:
global i
i = 0

def stats(x):
    _min, _max, _mean, _var = np.min(x), np.max(x), np.mean(x), np.var(x)
    global i
    print(i, "min: {0:8.5f}, max: {1:8.5f}, mean: {2:8.5f}, var: {3:8.5f}".format(_min, _max, _mean, _var))
    i += 1
    return _min, _max, _mean, _var

UDS = {env:[] for env in envs}

for env in envs:
    model, config = load_model(env)
    print("\n-------------- {0} --------------".format(env))
    for data in load_clean(env):
        _, d = distance(model, data.state) #data.state[:1] for pong - there is a problem with the dataset
        rd = np.maximum(d - config.optim_margin, 0)
        _, _, _, var = stats(rd)
        UDS[env].append(var)
        
_UDS = {k:np.mean(v) for k,v in UDS.items()}
pprint(_UDS)

In [None]:
env = "Pong"
model, config = load_model(env)
print("\n-------------- {0} --------------".format(env))
for data in load_clean(env):
    _, d = distance(model, data.state)
    J.plot(np.arange(0,d.shape[0]), d)
    break



In [None]:

#UDS results
{'BeamRider(64)': 0.0038250058,
 'Breakout(256)': 10.5363865,
 'Enduro(256)': 0.00034870935,
 'Pong(256)': 0.0020768184,
 'Qbert(64)': 0.0039075827,
 'Seaquest(64)': 0.0019275966,
 'SpaceInvaders(64)': 0.00080786174}


In [None]:
#FIX PONG FIRST STATE


import pyworld.toolkit.tools.fileutils as fu
import pyworld.toolkit.tools.visutils.transform as T
import pyworld.toolkit.tools.visutils.jupyter as J
import pyworld.toolkit.tools.torchutils as tu
import numpy as np
import datasets
import os

env = "Pong"
dataset = datasets.dataset('aad.clean.{0}'.format(env))
files = dataset.files()
print(files)
files = [files[12],files[29]]


from types import SimpleNamespace
import anomapy.train.sssn as sssn
import pyworld.toolkit.tools.wbutils as wbu


def distance(model, episode):
    z, d = sssn.distance(model, episode)
    return tu.to_numpy(z), tu.to_numpy(d)

dryruns = sorted([r for r in wbu.dryruns() if env in r])
print(dryruns[-1])

models, config = wbu.load(dryruns[-1]) #load the most recently trained model
model = models['model.pt'].load(sssn.model(**config))
config = SimpleNamespace(**config)

def plot_distance(model, episode):
    z, d = distance(model, episode)
    images = T.HWC(tu.to_numpy(episode))

    plot = J.SimplePlot(np.arange(len(d)), d)
    image1 = J.SimpleImage(images[0])
    image2 = J.SimpleImage(images[1])

    def on_hover(trace, points, state):
        i = points.point_inds[0]
        image1.set_image(images[i])
        image2.set_image(images[i+1])

    plot.on_hover(on_hover)
    d_images = J.layout_horizontal(image1.fig, image2.fig)
    d_plot = J.layout_horizontal(plot.fig)
    J.display(J.layout_vertical(d_plot, d_images))

env = "Pong"
dataset = datasets.dataset('aad.clean.{0}'.format(env))
#dataset.state.transform.to_float().CHW().torch()
episodes = [x for x in dataset.load_files(*files)]
for i, episode in enumerate(episodes):
    #print(episode.state.dtype, episode.state.shape)
    file = os.path.join(dataset.path, "raw/" + files[i])
    #print(episode.state.shape[0])
    #episode.state = episode.state[1:]
    print(episode.action.shape[0])
    episode.action = episode.action[1:]
    
    print(episode.action.shape[0], episode.state.shape[0])
    #print(episode.state.shape[0])
    fu.save(file, episode.__dict__, overwrite=True)
    
    #plot_distance(model, episode.state)
