In [None]:
import numpy as np
import torch
from metadata_utils import InstanceMNIST
from metadata_utils import collate_fn_2 as collate_fn

from torch.utils.data import DataLoader
from dataeval.utils.data.datasets import MNIST, ShipDataset
from dataeval.utils._array import as_numpy

import matplotlib.pyplot as plt


from vae_models import VAEcf, vae_loss, ConvVAE
from vae_models import normdot

from ood_detector import OODdetector
from torchvision import transforms

import torchvision.transforms.functional as F
from PIL import Image

from ood_detector import prepare_naruto


device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

# MNIST


In [None]:
corruption_list = [
    "identity",
    "translate",
    "shot_noise",
    "motion_blur",
    "scale",
    "shear",
    "rotate",
]

mnist = InstanceMNIST(corruption_list, download=True, size=8000)
mnist_val = InstanceMNIST("identity", train=False, size=8000)
mnist_all = InstanceMNIST("identity", size=60000)


refdata = mnist.identity
valdata = mnist_val.identity
shiftdata = mnist.translate
spikydata = mnist.shot_noise
blurdata = mnist.motion_blur
scaledata = mnist.scale
sheardata = mnist.shear
rotatedata = mnist.rotate

alltrain = mnist_all.identity

In [None]:
# Load in the training mnist dataset and use the first 2000
train_ds = MNIST(
    root="./data/",
    train=True,
    download=True,
    size=2000,
    unit_interval=True,
    channels="channels_first",
)
val_ds = MNIST(
    root="./data/",
    train=False,
    download=True,
    size=2000,
    unit_interval=True,
    channels="channels_first",
)

# Split out the images and labels
images, labels = train_ds._data, train_ds._targets
val_images, val_labels = val_ds._data, val_ds._targets

input_shape = images[0].shape

In [None]:
# from metadata_tools import ks_compare

big_batch_size = 2000
collate_fn = collate_fn

refbb = DataLoader(refdata, collate_fn=collate_fn, batch_size=big_batch_size)

valbb = DataLoader(valdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr1bb = DataLoader(shiftdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr2bb = DataLoader(spikydata, collate_fn=collate_fn, batch_size=big_batch_size)
corr3bb = DataLoader(blurdata, collate_fn=collate_fn, batch_size=big_batch_size)
corr4bb = DataLoader(scaledata, collate_fn=collate_fn, batch_size=big_batch_size)
corr5bb = DataLoader(sheardata, collate_fn=collate_fn, batch_size=big_batch_size)
corr6bb = DataLoader(rotatedata, collate_fn=collate_fn, batch_size=big_batch_size)

trainload = DataLoader(alltrain, collate_fn=collate_fn, batch_size=60000)

grab = lambda ds : next(iter(ds))[0]

xtrain = grab(refbb)
xval = grab(valbb)
xcorr = grab(corr1bb)
alltrain = grab(trainload)

xnoise = torch.rand_like(xcorr)


In [None]:
pick = np.random.randint(0, 2000)
plt.imshow(
    torch.concatenate(
        (xtrain[pick, 0, :, :], xcorr[pick, 0, :, :], xnoise[pick, 0, :, :]), 1
    )
)
plt.title(pick)

In [None]:
# newdet.show(xcorr)
from vae_models import VAE

olddet = OODdetector(VAE(latent_dim=20), xtrain, xval)

olddet.load_VAE()

methods = [
    olddet.manifold_distance,
    olddet.manifold_distance_full_svd,
    olddet.manifold_distance_normalized,
    olddet.manifold_distance_weighted,
]
for method in methods:
    olddet.show_dists(xcorr, method=method)
    plt.title(plt.gca().get_title())
    plt.show()

In [None]:
olddet.show_interpolated_images(xcorr[0], use_self=False)

In [None]:
olddet.train()


In [None]:
mnmodel = ConvVAE(nx=28, ny=28, latent_dim=20, img_channels=1).cuda()
mndet = OODdetector(mnmodel, refbb, valbb, batch_size=1000)


In [None]:
mndet.train2()

In [None]:
test_image = grab(valbb)[0:10]

mndet.show_interpolated_images(test_image, use_self=True)

In [None]:
# norm=False
# plt.ecdf(psc:=olddet.p_xy(xcorr, normalize=norm))
# plt.ecdf(psv:=olddet.p_xy(xval, normalize=norm))
# xall, yall = plt.xlim(), plt.ylim()

# sens=0.9
# alpha = 0.02
# plt.plot(xall, [sens, sens])
# xpick = np.interp(sens,  np.linspace(0,1,len(xval)), np.sort(psc))
# plt.plot([xpick, xpick], yall)
# plt.title(f'val: {np.mean(psv < alpha):.2f}, test: {np.mean(psc < alpha):.2f}');

In [None]:
# norm= True
# plt.ecdf(psv:=olddet.one_nn_ecdf(xval, normalized=norm))
# plt.ecdf(psc:=olddet.one_nn_ecdf(xcorr, normalized=norm))
# xall, yall = plt.xlim(), plt.ylim()

# sens=0.9
# alpha = 0.1
# plt.plot(xall, [sens, sens])
# xpick = np.interp(sens,  np.linspace(0,1,len(xval)), np.sort(psc))
# plt.plot([xpick, xpick], yall)
# plt.title(f'val: {np.mean(psv < alpha):.2f}, test: {np.mean(psc < alpha):.2f}');

In [None]:
# norm=False
# plt.ecdf(psv:=newdet.one_nn_ecdf(xval, normalized=norm))
# plt.ecdf(psc:=newdet.one_nn_ecdf(xcorr, normalized=norm))
# xall, yall = plt.xlim(), plt.ylim()

# sens=0.9
# alpha = 0.1
# plt.plot(xall, [sens, sens])
# xpick = np.interp(sens,  np.linspace(0,1,len(xval)), np.sort(psc))
# plt.plot([xpick, xpick], yall)
# plt.title(f'val: {np.mean(psv < alpha):.2f}, test: {np.mean(psc < alpha):.2f}');

In [None]:
#

In [None]:
# k=4
# sens = 0.9
# alpha = 0.01
# plt.ecdf(psv:=olddet.p_xy(xval,k=k))
# plt.ecdf(psc:=olddet.p_xy(xcorr, k=k))
# xall, yall = plt.xlim(), plt.ylim()
# plt.plot(xall, [sens, sens])
# xpick = np.interp(sens,  np.linspace(0,1,2000), np.sort(psc))
# plt.plot([xpick, xpick], yall)
# plt.title(f'val: {np.mean(psv < alpha):.2f}, test: {np.mean(psc < alpha):.2f}')

In [None]:
# newdet = OODdetector(VAE(latent_dim=50), xtrain, xval)
# newdet.train()

In [None]:
# fulldet = OODdetector(VAE(latent_dim=20), training_data=alltrain, validation_data=xval)

In [None]:
# fulldet.train(epochs=100)

In [None]:
# fulldet.show_interpolated_images(xtrain[0:10])

In [None]:
# methods = [fulldet.manifold_distance, fulldet.manifold_distance_full_svd, fulldet.manifold_distance_normalized, fulldet.manifold_distance_weighted]
# for method in methods:
#     fulldet.show_dists(xcorr, method=method)
#     plt.title(plt.gca().get_title())
#     plt.show()

In [None]:
from itertools import combinations


def test_ood(detector, FPfrac=None):
    # Here are my null hypothesis tests for OOD
    keys1 = ["spread", "xy", "ID", "1NN", "normdist"]
    meth1 = [
        detector.p_spread,
        detector.p_xy,
        detector.p_ID,
        detector.one_nn_distance_pval,
        detector.p_dist_normalized,
    ]
    val1results = [meth(detector.test_example) for meth in meth1]
    val1dict = dict(zip(keys1, val1results))

    FPfrac = 0.1 if FPfrac is None else FPfrac

    corrnames = ["shift", "spiky", "blur", "scale", "shear", "rotate"]
    corrlist = [corr1bb, corr2bb, corr3bb, corr4bb, corr5bb, corr6bb]

    for icorr, corrloader in enumerate(corrlist):
        for xcorr, _, _ in corrloader:
            break
        print(f"==============={corrnames[icorr]}=========================")
        corr1results = [meth(xcorr) for meth in meth1]
        corr1dict = dict(zip(keys1, corr1results))
        rlist = []
        for ncombo in range(1, len(keys1)):
            klist = [k for k in combinations(keys1, ncombo)]
            vlist = [v for v in combinations(val1results, ncombo)]
            results = dict({})
            for kc, vc in zip(klist, vlist):
                pv, pc = np.ones(len(xval)), np.ones(len(xcorr))
                rkey = ""
                for k, v in zip(kc, vc):
                    rkey = rkey + " " + k
                    pv *= val1dict[k]
                    pc *= corr1dict[k]
                cutoff = np.interp(FPfrac, np.linspace(0, 1, len(pv)), np.sort(pv))
                tp05 = np.interp(cutoff, np.sort(pc), np.linspace(0, 1, len(pc)))
                results.update({rkey: tp05})
            rlist.append(results)

        for i, d in enumerate(rlist):
            print(f"combination of {i+1} tests:")
            for k, v in d.items():
                print(f"{k}: {v:.2f}")
            print("")

In [None]:
test_ood(mndet, FPfrac=0.01)

In [None]:
# test_ood(fulldet, FPfrac=0.01)

In [None]:
# allnoise = torch.rand(size=alltrain.shape)
# valnoise = torch.rand(size=xval.shape)
# noisedet = OODdetector(VAE(latent_dim=20), training_data=allnoise, validation_data=valnoise)


<!-- 32605902. -->


In [None]:
# noisedet.train(epochs=5000)

In [None]:
# noisedet.show_interpolated_images(xval[np.random.randint(0, 1000, 10)], use_self=True)

In [None]:
# pv = fulldet.p_dist_normalized(xval) *fulldet.p_xy(xval) # *olddet.p_ID(xval) #*olddet.p_xy(xval) #* olddet.p_xy(xval, k=k)* olddet.one_nn_ecdf(xval, normalized=True)
# pc = fulldet.p_dist_normalized(xcorr)* fulldet.p_xy(xcorr)# *olddet.p_ID(xcorr) #*olddet.p_ID(xcorr)#* olddet.p_xy(xcorr, k=k)* olddet.one_nn_ecdf(xcorr, normalized=True)


In [None]:
# # k=5
# vline = plt.ecdf(pv)
# cline = plt.ecdf(pc)
# plt.legend([vline, cline], ['val', 'corr'])
# maxdiff, alpha = fulldet.cdf_max_diff(pc, pv)
# plt.title(f'val: {np.mean(pv < alpha):.2f}, test: {np.mean(pc < alpha):.2f}');
# yall = plt.ylim()
# plt.plot([alpha, alpha], yall)
# print(maxdiff)

# FPfrac = 0.01
# cutoff = np.interp(FPfrac, np.linspace(0, 1, len(pv)), np.sort(pv) )
# tp05 = np.interp(cutoff, np.sort(pc), np.linspace(0, 1, len(pc)))
# print(f'TP@FP01: {tp05:.2f}')
# # plt.xlim([0,2*alpha])

In [None]:
# p0 = np.expand_dims(olddet.p_spread(xcorr), 1)
# p1 = np.expand_dims(olddet.p_xy(xcorr), 1)
# p2 = np.expand_dims(olddet.p_ID(xcorr), 1)
# p3 = np.expand_dims(olddet.one_nn_ecdf(xcorr), 1)
# p4 = np.expand_dims(olddet.p_dist_normalized(xcorr), 1)

# pmat = np.concatenate((p0, p1, p2, p3, p4), axis=1)
# pmat0 = np.mean(pmat, axis=0, keepdims=True)
# pmat -= pmat0
# pmat.shape
# cov = pmat.T @ pmat/(len(pmat)-1)

# std_vec = np.sqrt(np.diag(cov))

# std_inv = np.diag(1 / std_vec)
# corr_matrix = std_inv @ cov @ std_inv

# corr_matrix

In [None]:
# pvalsv = olddet.one_nn_distance_pval(xval)
# pvalsc = olddet.one_nn_distance_pval(xcorr)
# plt.hist((np.log10(pvalsv), np.log10(pvalsc)), 100, density=True, cumulative= True );
# alpha = 0.0005
# print(np.mean(pvalsv < alpha), np.mean(pvalsc < alpha))
# print(olddet.cdf_max_diff(pvalsv, pvalsc))

In [None]:
# k = 20
# plt.hist((olddet.p_ID(xval, k=k), olddet.p_ID(xcorr, k=k)), 50, cumulative=True, density=True);
# alpha = 0.01
# (np.mean(olddet.p_ID(xval, k=k)<alpha), np.mean(olddet.p_ID(xcorr, k=k) < alpha))

In [None]:
#

In [None]:
# for k in range(2, 20):
#     plt.ecdf(np.log10(olddet.p_ID(xval, k=k)))
#     plt.ecdf(np.log10(olddet.p_ID(xcorr, k=k)))
#     plt.title(str(k))
#     plt.show()


In [None]:
# muv, logvarv = olddet.model.encode(xval[0].view((1,-1)))
# stdv = torch.exp(0.5 * logvarv)
# mun, logvarn = olddet.model.encode(naruto[0].view((1,-1)))
# stdn = torch.exp(0.5 * logvarn)


In [None]:
# plt.plot(stdv.flatten().detach())
# plt.plot(stdn.flatten().detach())


In [None]:
# newdet.set_learning_rate(1e-6)
# newdet.train()

In [None]:
# olddet.show_rk(xcorr, 19)

In [None]:
# olddet.show_interpolated_images(xcorr[230:234])
# # newdet.show_interpolated_images(xcorr[230:231], use_self=True)

In [None]:
# for g in newdet.optimizer.param_groups:
#         print(g['lr'])


In [None]:
# rktrain = newdet.ood_knn(xtrain)
# rkval = newdet.ood_knn(xval)
# rkcorr = newdet.ood_knn(xcorr)

In [None]:
newdet.set_learning_rate(1e-5)
newdet.train()

In [None]:
# newdet.load_normalized_knn_model(20)

In [None]:
# K=9
# plt.hist((rktrain[:,K], rkval[:,K], rkcorr[:,K]), 50, cumulative=True, density=True);

In [None]:
# K=8
# plt.hist((rktrain[:,K], rkval[:,K], rkcorr[:,K]), 50, cumulative=True, density=True);

In [None]:
naruto = plt.imread('naruto_mnist.png')
plt.imshow(naruto)
plt.show()
naruto = torch.tensor(naruto.reshape((1, 1, 28, 28)))



In [None]:
# fulldet.show_interpolated_images(naruto, use_self=True)
# fulldet.show_interpolated_images(naruto) # calls plt.show() which clears previous imshow
# rkn = newdet.ood_knn(naruto)
# print(rkn)

In [None]:
# newdet.show_interpolated_images(naruto, use_self=True)
# newdet.show_interpolated_images(naruto) # calls plt.show() which clears previous imshow


In [None]:
# newdet.p_ID(naruto)

In [None]:
# np.mean(newdet.detect_OOD(xcorr))

In [None]:
# plt.plot(rkv[0])
# plt.plot(rkn[0])

In [None]:
# rkv = newdet.ood_knn(xval[0:1])

In [None]:
# plt.plot(rkv.squeeze(), rkn.squeeze())
# xall = plt.xlim()
# plt.plot(xall, xall)

<!-- # END OF MNIST -->


In [None]:
# noisedet1 = OODdetector(VAE(latent_dim=50), torch.concatenate((xtrain, xnoise)), xval)
# noisedet1.train()

<!-- # START CIFAR -->


# CIFAR-10


In [None]:
import torch
import torchvision
from torchvision import transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

transform = transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=1000, shuffle=False, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=200, shuffle=False, num_workers=2
)

for Xcf_train, _ in trainloader:
    # Xcf_train = (Xcf_train - Xcf_train.min())/(Xcf_train.max() - Xcf_train.min())
    Xcf_train = Xcf_train.to(device)
    break

for Xcf_test, _ in testloader:
    # Xcf_test = (Xcf_test - Xcf_test.min())/(Xcf_test.max() - Xcf_test.min())
    Xcf_test = Xcf_test.to(device)
    break


In [None]:
X = np.array(Xcf_train.permute((0, 2, 3, 1)).cpu())

print(X.shape)
plt.imshow(X[333, :, :, :])

In [None]:
cfmodel = VAEcf(latent_dim=8).to(device)

In [None]:
print(cfmodel)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from vae_models import recon_loss_cf, reg_loss_sphere

cfdet = OODdetector(
    cfmodel, Xcf_train, Xcf_test, criterion=(recon_loss_cf, reg_loss_sphere), beta=0.0
)

In [None]:
cfdet.beta = 0.1
cfdet.set_learning_rate(1e-5)
cfdet.train(epochs=5000)

In [None]:
cfmodel.latent_dim

In [None]:
# torch.save(cfdet, 'first_hsphere_LD8.pth')

62051
60994
59995
59206
58566

39609
37746
36433
33719
32037
31105
30370
29946
29564
29282
29116


In [None]:
recon = cfdet.model(Xcf_train)
print(type(recon[0]), recon[0].device)
out = recon[0].permute((0, 2, 3, 1)).detach().cpu()
orig = Xcf_train.permute((0, 2, 3, 1)).detach().cpu()

outscale = (out - out.min()) / (out.max() - out.min())
out = torch.abs(out)
pick = 545  # np.random.randint(0, len(orig))
# pick =  np.random.randint(0, len(orig))
plt.imshow(np.concatenate((orig[pick], out[pick]), axis=1))
plt.title(pick)
plt.show()

plt.plot(orig[pick].flatten(), out[pick].flatten(), ".")
plt.gca().set_aspect(1.0)
xall = plt.xlim()
plt.plot(xall, xall, "k:")
plt.xlabel("original")
plt.ylabel("recon")
plt.show()


plt.hist2d(orig[pick].flatten(), out[pick].flatten(), 50)
plt.gca().set_aspect(1.0)
plt.plot(xall, xall, "w:")

In [None]:
plt.hist2d(orig.flatten(), out.flatten(), 30)
plt.gca().set_aspect(1.0)
xall = plt.xlim()
plt.plot(xall, xall, "w:")

In [None]:
i0 = np.random.randint(len(Xcf_test))
cfdet.show_interpolated_images(Xcf_test[i0 : i0 + 1])

In [None]:
print(orig[pick].shape)
plt.ecdf(orig[pick].flatten())
plt.ecdf(out[pick].flatten())

In [None]:
from vae_models import ConvVAE

optmodel = ConvVAE(nx=32, ny=32, latent_dim=200).cuda()
optdet = OODdetector(optmodel, trainset, testset, batch_size=1000)#  criterion=optmodel.loss_function)
# optdet.load_VAE("cfVAE20250225_35e.pth")

In [None]:
nr = optdet.prepare_naruto()

In [None]:
plt.imshow(nr.permute(1,2,0))

In [None]:
optdet.training_example.shape

In [None]:
optdet.set_learning_rate(1e-3)
optdet.train2(num_epochs=35)

In [None]:
test_ood(optdet, FPfrac=0.01)

In [None]:
p_ID = optdet.p_ID(Xcf_test)
p_naruto = optdet.p_ID(optdet.prepare_naruto())
plt.ecdf(p_ID)
yall = plt.ylim()
plt.plot([p_naruto, p_naruto], yall)
plt.title(p_naruto)

In [None]:
plt.ecdf(optdet.one_nn_ecdf(Xcf_test))
print(optdet.one_nn_ecdf(prepare_naruto((32,32))))

In [None]:
optdet.show_dists(optdet.prepare_naruto())

In [None]:
optdet.prepare_naruto().shape

In [None]:
plt.imshow(optdet.prepare_naruto().permute(1,2,0))

In [None]:
optdet.show_rk(Xcf_test[0:10])

In [None]:
optdet.show_dists(Xcf_test[0:10])

In [None]:
plt.ecdf(p_ID)

In [None]:
data2use = Xcf_test
recon = optdet.model(data2use)
print(type(recon[0]), recon[0].device)
out = recon[0].permute((0, 2, 3, 1)).detach().cpu()
orig = data2use.permute((0, 2, 3, 1)).detach().cpu()

outscale = (out - out.min()) / (out.max() - out.min())
out = torch.abs(out)
pick = np.random.randint(0, len(orig))
# pick =  np.random.randint(0, len(orig))
plt.imshow(np.concatenate((orig[pick], out[pick]), axis=1))
plt.title(pick)
plt.show()

plt.plot(orig[pick].flatten(), out[pick].flatten(), ".")
plt.gca().set_aspect(1.0)
xall = plt.xlim()
plt.plot(xall, xall, "k:")
plt.xlabel("original")
plt.ylabel("recon")
plt.show()


plt.hist2d(orig[pick].flatten(), out[pick].flatten(), 20)
plt.gca().set_aspect(1.0)
plt.plot(xall, xall, "w:")

In [None]:
optdet.show_interpolated_images(Xcf_test[0:1])

In [None]:
optdet.detect_OOD(Xcf_test[0])

In [None]:
optdet.show_dists(Xcf_test.cpu())

In [None]:
nar = optdet.prepare_naruto()

In [None]:
optdet.manifold_distance(Xcf_test[33:39])

In [None]:
optdet.show_dists(Xcf_test[33:39])

In [None]:
optdet.show_interpolated_images(Xcf_train[pick : pick + 1])

# Ships


First cell is minimal example showing duplicate images bug


In [None]:
import torch
import numpy as np

from torch.utils.data import DataLoader
from dataeval.utils.data.datasets import ShipDataset

# train_ship = ShipDataset(root="./data/", download=True, size=1800, unit_interval=True, channels="channels_first")
# test_ship = ShipDataset(root="./data/", download=True, size=200, slice_back=True, unit_interval=True, channels="channels_first")

# workaround from Ryan 7_Mar-2025
train_ship = ShipDataset(root="./data/", download=True, size=1800, unit_interval=True, channels="channels_first")
test_ship = ShipDataset(root="./data/", download=True, size=2000, slice_back=True, unit_interval=True, channels="channels_first", balance=True)
test_ship._data = test_ship._data[-200:]
test_ship._targets = test_ship._targets[-200:]

trainbatch = next(iter(torch.utils.data.DataLoader(train_ship, batch_size=100)))[0]
testbatch = next(iter(torch.utils.data.DataLoader(test_ship, batch_size=100)))[0]
print(f'All odd-indexed image pairs in train and test are two copies of same image: {np.array([(trainbatch[2*i+1]==testbatch[2*i+1]).all() for i in range(50)]).all()}')
print(f'Some odd-indexed image pairs in train and test are two copies of same image: {np.array([(trainbatch[2*i+1]==testbatch[2*i+1]).all() for i in range(50)]).any()}')

In [None]:
shipmodel = ConvVAE(nx=80, ny=80, latent_dim=50).cuda()
shipdet = OODdetector(shipmodel, train_ship, test_ship, batch_size=100)

In [None]:
shipdet.set_learning_rate(1e-3)
shipdet.train2(60)

In [None]:
shipdet.manifold_distance(shipdet.prepare_naruto())

In [None]:
shipdet.show_dists(shipdet.prepare_naruto())

In [None]:
shipdet.show_dists(shipdet.training_example)

In [None]:
data2use = torch.as_tensor(next(iter(shipdet.test_loader))[0]).to(torch.float).to(shipdet.device)
recon = shipdet.model(data2use)
print(type(recon[0]), recon[0].device)
out = recon[0].permute((0, 2, 3, 1)).detach().cpu()
orig = data2use.permute((0, 2, 3, 1)).detach().cpu()

outscale = (out - out.min()) / (out.max() - out.min())
out = torch.abs(out)
pick = np.random.randint(0, len(orig))
# pick =  np.random.randint(0, len(orig))
plt.imshow(np.concatenate((orig[pick], out[pick]), axis=1))
plt.title(pick)
plt.show()

plt.hist2d(orig[pick].flatten(), out[pick].flatten(), 20)
plt.gca().set_aspect(1.0)
xall = plt.xlim()
plt.plot(xall, xall, "w:")
plt.show()

plt.plot(orig[pick].flatten(), out[pick].flatten(), ".")
plt.gca().set_aspect(1.0)
plt.plot(xall, xall, "k:")
plt.xlabel("original")
plt.ylabel("recon")
plt.show()




In [None]:
shipdet.show_interpolated_images(prepare_naruto((80, 80)), use_self=True)


In [None]:
optdet.show_interpolated_images(prepare_naruto((32, 32)), use_self=True)


In [None]:
optdet.detect_OOD(prepare_naruto((32,32)))

In [None]:
from ood_detector import prepare_naruto
shipdet.detect_OOD(prepare_naruto((80,80)))

In [None]:
np.sum(optdet.detect_OOD(optdet.training_example))

In [None]:
pick = 545  # np.random.randint(0, len(orig))
plt.imshow(np.concatenate((orig[pick], out[pick]), axis=1))
plt.title(pick)

In [None]:
cfdet.show_interpolated_images(Xcf_train[pick : pick + 1], use_self=False)

In [None]:
dists = cfdet.manifold_distance(Xcf_test)

In [None]:
test = plt.ecdf(cfdet.manifold_distance(Xcf_test))
rand = plt.ecdf(cfdet.manifold_distance(torch.randn_like(Xcf_test) + Xcf_test))
plt.legend([test, rand], ["test", "rando"])

In [None]:
noisy = torch.randn_like(Xcf_test) * 0.1 + Xcf_test

In [None]:
plt.imshow(noisy[0].cpu().permute((1, 2, 0)))
plt.title(str(cfdet.manifold_distance(noisy[0:1])))

In [None]:
# ov = OutlierVAE(Xcf_train[0:5], latent_dim=400)

In [None]:
# cfdet = OODdetector(ov, Xcf_train, Xcf_test, criterion=vae_loss2)

In [None]:
# cfdet.train(epochs=30)

In [None]:
# cfdet.set_learning_rate(1e-5)

In [None]:
# cfdet.train(epochs=60)

In [None]:
# idx = 8
# X = Xcf_train[idx].reshape(1, 32, 32, 3)
# X = torch.permute(X, (0,3,1,2))
# X_recon, mu, logvar = cfdet.model(X)


In [None]:
# X_recon.max()

In [None]:
# torch.permute(X_recon*3.7, (0,2,3,1)).max()

In [None]:
# plt.imshow(torch.permute(X_recon, (0,2,3,1)).detach().squeeze().numpy())

In [None]:
cfdet.show_dists(Xcf_test)

In [None]:
# Xcf_test.shape

In [None]:
# cfdet.show_interpolated_images(Xcf_test[0:1]*3.7)

In [None]:
# plt.plot(ov.embed(Xcf_train).detach())

In [None]:
# embed = ov.embed(Xcf_train)

# # intrinsic_dimension(embed)

In [None]:
# print(ov(Xcf_train)[0].shape)

<!-- Probably should handle convtranspose2d artifacts....forget how atm -->


In [None]:
# recon = ov(Xcf_train)

In [None]:
# img = recon[0][11,:,:, :]
# img = torch.permute(img, (1,2,0)).detach().numpy()
# plt.imshow(img)
# print(recon[0].min(), recon[0].max())

In [None]:
# x = np.random.uniform(0, 100, size=10000)
# (np.floor(x) + 1.0 == np.ceil(x)).all()

In [None]:
# Xcf_train.shape

In [None]:
for Xtrain, _, _ in valbb:
    break

try:
    assert torch.sum(np.abs(Xtrain - xtrain)) > 0.0
except AssertionError:
    print("x and xval are the same, what happened?")

xval_pred, _, _ = model(xval)  # type: ignore
recon = normdot(xval, xval_pred)  # type: ignore

for xcorr, _, _ in corr1bb:
    break
xcorr_pred, _, _ = model(xcorr)  # type: ignore
corr_recon = normdot(xcorr_pred, xcorr)  # type: ignore

plt.hist((recon, corr_recon), 40, density=True, cumulative=True)
plt.legend(["control", "OOD"])
plt.plot(plt.xlim(), [0.9, 0.9], plt.xlim(), [0.01, 0.01])
plt.title("VAE reconstruction error - normdot")

cutoff = 0.6

for i in range(10):
    is_ood_test = corr_recon < cutoff
    is_ood_ctrl = recon < cutoff
    print(
        f"cutoff: {cutoff:0.2f}, TP: {np.mean(is_ood_test): 0.3f}, FP: {np.mean(is_ood_ctrl):.3f}"
    )
    cutoff += 0.02


In [None]:
#

In [None]:
for xval, _, _ in valbb:
    break

try:
    assert torch.sum(np.abs(xval - xtrain)) > 0.0
except AssertionError:
    print("x and xval are the same, what happened?")

xval_pred, _, _ = model(xval)  # type: ignore
recon = normdot(xval, xval_pred)  # type: ignore


def detect_OOD(xcheck, xctrl, model):
    xctrl_pred, _, _ = model(xctrl)  # type: ignore
    recon = normdot(xctrl, xctrl_pred)  # type: ignore

    xcheck_pred, _, _ = model(xcheck)  # type: ignore
    check_recon = normdot(xcheck_pred, xcheck)  # type: ignore

    min_nd = np.percentile(recon, 1)
    is_ood = check_recon < min_nd

    return is_ood

    plt.hist((recon, check_recon), 40, density=True, cumulative=True)
    plt.legend(["control", "OOD"])
    plt.plot(plt.xlim(), [0.9, 0.9], plt.xlim(), [0.01, 0.01])
    plt.title("VAE reconstruction error - normdot")


cutoff = 0.6

for i in range(10):
    is_ood_test = corr_recon < cutoff
    is_ood_ctrl = recon < cutoff
    print(
        f"cutoff: {cutoff:0.2f}, TP: {np.mean(is_ood_test): 0.3f}, FP: {np.mean(is_ood_ctrl):.3f}"
    )
    cutoff += 0.02


In [None]:
model.eval()
np.mean(detect_OOD(xcorr, xval, model))

In [None]:
np.mean(detect_OOD(xval, xval, model))

In [None]:
verbose_loss(xtrain, *model(xtrain))

In [None]:
# ov = OutlierVAE(xval, image_channels=1)

In [None]:
#

In [None]:
# cfdet = OutlierVAE()

In [None]:
#

In [None]:
# newdet.show_dists(xcorr, method=newdet.manifold_distance_normalized)

In [None]:
#

In [None]:
# moddet = OODdetector(VAEmod(784, 400, 20), xtrain, xval)

In [None]:
# moddet.manifold_distance(xtrain)

In [None]:
# moddet.manifold_dimension()

In [None]:
# moddet.train()

In [None]:
#

In [None]:
# newdet.model.embed(xtrain)

In [None]:
# newdet.manifold_dimension()

In [None]:
# all_zeros = newdet.manifold_distance(newdet.training_data)

In [None]:
# plt.plot(all_zeros)

In [None]:
#

In [None]:
# newdet.show_dists(newdet.training_data)

In [None]:
# # newdet.load_knn_model(8)
# newdet.show_dists(xcorr)

In [None]:
# recon = newdet.model(xcorr[0])[0].detach().numpy().reshape((28,28))
# plt.imshow(np.concatenate((recon, xcorr[0].squeeze()), axis=1))

In [None]:
# xcorr[0].shape

In [None]:
# # mdim = 10
# # data4knn = newdet.model.embed(newdet.validation_data).detach().reshape((newdet.validation_data.shape[0], -1))
# # newdet.neighbors = NearestNeighbors(n_neighbors=mdim, algorithm='ball_tree').fit(data4knn)

# dtrain = newdet.manifold_distance(xtrain)
# dcorr = newdet.manifold_distance(xcorr)
# dval = newdet.manifold_distance(xval)

# plt.hist((dtrain, dval,  dcorr), 50, density=True, cumulative=True);


In [None]:
# dval =

In [None]:
# newdet.manifold_dimension()

In [None]:
# def ROC_curve(xref, xcorr):
#     xref, cref = ecdf(xref)
#     xcorr, ccorr = ecdf(xcorr)

#     x_all = np.sort(np.concatenate((xref, xcorr)))

#     FP = 1.0 - np.interp(x_all, xref, cref)
#     TP = 1.0 - np.interp(x_all, xcorr, ccorr)
#     iord = np.argsort(FP)

#     return FP[iord], TP[iord]

In [None]:
# FP, TP = ROC_curve(recon, corr_recon)

In [None]:
# FP[0:10]

In [None]:
# plt.plot(*ROC_curve(recon, corr_recon))
# plt.plot([0, 1], [0, 1], 'k:')
# plt.gca().set_aspect(1.0)