In [1]:
%matplotlib notebook
#%pprint
import notebook
import os
import torchsummary
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader
from types import SimpleNamespace
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import sklearn
import sklearn.metrics

from tqdm.auto import tqdm

import utils
from SSIM import SSIM # structural similarity loss...

import jnu as J

config = dict(
    device = "cuda:0",
    state_shape = (3,84,84),
    latent_shape = (256,),
)
config = SimpleNamespace(**config)

model = utils.AE(config.state_shape, config.latent_shape)
model.load_state_dict(torch.load("./AE-SSIM-256-3.pt"))
model = model.to(config.device)
#criterion = nn.BCEWithLogitsLoss(reduction='none')
criterion = SSIM(size_average=False)

label_thresholds = list(reversed([0, 10, 50, 100, 200, 400]))


In [None]:
fig, axes = plt.subplots(5,4, sharex=True, sharey=True, figsize=(10,10), 
                         gridspec_kw=dict(wspace=-0.2, hspace=0.2, left=0.05, right=.95, top=.95, bottom=0.05))
import matplotlib as mpl
cmap = mpl.colormaps['viridis']


#fig.tight_layout()
axes = axes.ravel()
for ax in axes:
    ax.set_aspect('equal')
    ax.set_prop_cycle('color',cmap(1 - np.linspace(0,1,len(label_thresholds))))
    
font = {'fontsize':10}
plt.set_cmap("viridis")

metrics = dict(auc_pr={"Precision/Recall AUC":[], "τ":label_thresholds}, 
               auc_roc={"ROC AUC":[], "τ":label_thresholds})

for i, file in enumerate(utils.FILES_TEST):
    name = os.path.splitext(os.path.split(file)[1])[0].replace("Geometry", "Geom").replace("Texture", "Tex")
    print(name)
    ys, ys_ = [], []
    i = 2 * i
    for x, y in utils.load(file, ['observation', 'bugmask']):
        x = torch.from_numpy(x).to(config.device)
        x_ = utils.forward(model, x, device=config.device)
        ys_.append(criterion(x, x_).cpu().numpy()) # score (1D)
        y = (y[...].sum(1, keepdims=True) > 0.)
        ys.append(y.sum(-1).sum(-1).sum(-1)) # label (1D)
    y, y_ = np.concatenate(ys), np.concatenate(ys_)
    
    # compute metrics, ROC
    plt.set_cmap("viridis")
    thresholds, fprtpr, aucs = utils.roc(y, y_, label_thresholds)
    for lt, (fpr, tpr), auc in zip(thresholds, fprtpr, aucs):  
        axes[i].plot(fpr, tpr, alpha=0.8, label=f"τ={lt}")    
    axes[i].set_title(name + " (ROC)", fontdict=font)
    metrics['auc_roc'][name] = aucs
    
    # Precision recall
    thresholds, rp, aucs = utils.pr(y, y_, label_thresholds)
    for lt, (r, p), auc in zip(thresholds, rp, aucs):  
        axes[i + 1].plot(r, p, alpha=0.8, label=f"τ={lt}")
    axes[i + 1].set_title(name + " (PR)", fontdict=font)
    metrics['auc_pr'][name] = aucs
    
axes[-1].legend()

def format_x(x):
    if isinstance(x, int):
        return str(x)
    elif isinstance(x, float):
        return "{0:0.3f}".format(x)
    raise ValueError()

print("Auto-Encoder (SSIM)")
for k,v in metrics.items():
    for n, x in v.items():
        print(" & ".join([n] + list(reversed([format_x(i) for i in x]))) + " \\\\")

In [None]:
plt.savefig("./AE-SIM-256-ROC-PR.png", dpi=200)

In [None]:
# EXPLORE SINGLE
file_id = 6 #2 # 5
file = utils.FILES_TEST[file_id]

name = os.path.splitext(os.path.split(file)[1])[0]
print(name)
ys, ys_ = [], []
i = 2 * i
for x, y in utils.load(file, ['observation', 'bugmask']):
    x = torch.from_numpy(x).to(config.device)
    x_ = utils.forward(model, x, device=config.device)
    
    J.images(np.concatenate([x.cpu().numpy(), x_.cpu().numpy(), y], axis=3), scale=3)
        
    
    ys_.append(criterion(x, x_).cpu().numpy()) # score (1D)
    y = (y[...].sum(1, keepdims=True) > 0.)
    ys.append(y.sum(-1).sum(-1).sum(-1)) # label (1D)
    break
    
y, y_ = np.concatenate(ys), np.concatenate(ys_)
y = (y > 0).astype(np.float32)
fig = plt.figure(figsize=(10,5))
plt.plot(np.arange(y.shape[0]), y, alpha=0.8)
plt.plot(np.arange(y_.shape[0]), y_, alpha=0.8)


In [None]:
def model_forward(model, obs):
    with torch.no_grad():
        loader = DataLoader(obs, batch_size=512)
        result = [model(x) for x in loader]
        return torch.clip(torch.cat(result),0,1)


    
def metrics(label, score, thresholds=np.arange(1,20) * 0.05, normalize=None, beta=1.0):
    assert label.shape[0] == score.shape[0]
    assert len(label.shape) == len(score.shape)
    score = np.interp(score, (score.min(), score.max()), (0, 1)).astype(np.float32)
    #label = np.interp(label, (label.min(), label.max()), (0, 1))
    label = (label > 0).astype(np.float32)
    
    prf = []
    cms = []
    
    
    for threshold in thresholds:
        _score = (score > threshold).astype(np.float32)
        cms.append(sklearn.metrics.confusion_matrix(label, _score, normalize=normalize).ravel())
        prf.append(sklearn.metrics.precision_recall_fscore_support(label, _score, beta=beta, pos_label=1, average='binary'))
        print(threshold, prf[-1])
        
        #plt.figure()
        #plt.plot(label, alpha=0.6, label="label")
        #plt.plot(_score, alpha=0.6, label="score")
        #plt.legend()
    cms = np.stack(cms)
    fig = plt.figure()
    for i, l in enumerate(["TN", "FP", "FN", "TP"]):
        plt.plot(thresholds, cms[:,i], label=l)
    plt.legend()
    
    #tn, fp, fn, tp
    
def roc(label, score, label_thresholds=[0, 10, 50, 100, 200], alpha=0.5, ax=None):
    assert label.shape[0] == score.shape[0]
    assert len(label.shape) == len(score.shape)
    #score = np.interp(score, (score.min(), score.max()), (0, 1)).astype(np.float32)
    results = []
    print(label.min(), label.max())
    for lt in label_thresholds:
        _label = (label > lt).astype(np.float32)
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(_label, score) 
        auc = sklearn.metrics.auc(fpr, tpr)
        results.append((fpr, tpr, thresholds))
        if ax is not None:
            ax.plot(fpr, tpr, alpha=alpha, label=f"auc={auc:.3f}, τ={lt}")
            ax.set_xlabel("false positive rate")
            ax.set_ylabel("true positive rate")
    if ax is not None:
        ax.legend()
    return results, label_thresholds


def pr(label, score, label_thresholds=[0, 10, 50, 100, 200], alpha=0.5, ax=None):
    assert label.shape[0] == score.shape[0]
    assert len(label.shape) == len(score.shape)
    #score = np.interp(score, (score.min(), score.max()), (0, 1)).astype(np.float32)
    results = []
    print(label.min(), label.max())
    for lt in label_thresholds:
        _label = (label > lt).astype(np.float32)
        p, r, thresholds = sklearn.metrics.precision_recall_curve(_label, score) 
        auc = sklearn.metrics.auc(r, p)
        results.append((p, r, thresholds))
        if ax is not None:
            ax.plot(r, p, alpha=alpha, label=f"auc={auc:.3f}, τ={lt}")
            ax.set_xlabel("Recall")
            ax.set_ylabel("Precision")
    if ax is not None:
        ax.legend()
    return results, label_thresholds
    
select = [0]
plot = True
LABEL_THRESHOLDS = [0,10,50,100,200,400]

for bug, file in FILES:
    print(bug)
    for i, (obs, label) in enumerate(load(file=file, keys=["observation", "bugmask"])):
        if not i in select:
            continue
        #obs, label = obs[:100], label[:100]
        obs, label = obs[...], label[...]
        obs = torch.from_numpy(obs[...]).to(config.device)
        pred = model_forward(model, obs)
        label = (label[...].sum(1, keepdims=True) > 0.)
        
        J.images(np.concatenate([obs.cpu().numpy(), pred.cpu().numpy(), label.repeat(3, axis=1)], axis=3), scale=3)
        
        score = criterion(pred, obs).cpu().numpy()
        label = label.sum(-1).sum(-1).sum(-1)  
        assert label.sum() > 0 # there must be some bugs present in the episode!
        
            
        if plot:
            _score = np.interp(score, (score.min(), score.max()), (0, 1)).astype(np.float32)
            _label = np.interp(label, (label.min(), label.max()), (0, 1)).astype(np.float32)
            fig = plt.figure(figsize=(10,3))
            plt.plot(np.arange(_score.shape[0]), _score, label="score")
            plt.plot(np.arange(_label.shape[0]), _label, label="label", alpha=0.5)
            plt.legend()
            
        #metrics(label, score)
        fig, _ = plt.subplots(1, 2, figsize=(8, 4))
        plt.tight_layout(pad=2)
        roc(label, score, ax=fig.axes[0], label_thresholds=LABEL_THRESHOLDS)
        pr(label, score, ax=fig.axes[1], label_thresholds=LABEL_THRESHOLDS)
        #
            
        
        break
        
    print("----------------------------------------------------------------------")

In [None]:
data = dict()
for (obs, label) in test_data:
    print(np.unique(label))

In [None]:
model = AE(config.state_shape, config.latent_shape)
model.load_state_dict(torch.load("./AE-SSIM-256.pt"))
#criterion = nn.BCEWithLogitsLoss(reduction='none')
criterion = SSIM(size_average=False)

In [None]:
with torch.no_grad():
    n = 1024
    model = model.cpu()
    for obs, mask in test_data:
        obs, mask = obs[:n], mask[:n]
        obs = obs.cpu()
        pred = model(obs).cpu()
        score = criterion(pred, obs)
        print(score.shape)
        
        score = score.reshape(score.shape[0],-1).sum(-1).cpu().numpy()
        score = np.interp(score, (score.min(), score.max()), (0, +1))
        
        label = mask.reshape(mask.shape[0],-1).sum(-1).cpu().numpy()
        label = np.interp(label, (label.min(), label.max()), (0, +1))
        
        J.images(torch.cat([torch.clip(pred,0,1), obs.cpu(), mask], dim=3))
        fig = plt.figure(figsize=(10,5))
        plt.plot(np.arange(score.shape[0]), score, label="score")
        plt.plot(np.arange(score.shape[0]), label, label="label")
        plt.legend()