In [None]:
import torch
import torchvision

import torch.nn as nn
import torch.nn.functional as F

import tensorflow_datasets as tfds

import numpy as np

import matplotlib.pyplot as plt

from torch.utils.data import TensorDataset, DataLoader
from torch import optim 

device = torch.device("cuda:0")

torch.cuda.set_device(device)

# Loading Data Sets

In [None]:
training_images, training_labels = tfds.as_numpy(tfds.load(
                                        'mnist',
                                        split='train', 
                                        batch_size=-1, 
                                        as_supervised=True,
                                    ))

In [None]:
testing_images, testing_labels = tfds.as_numpy(tfds.load(
                                        'mnist',
                                        split='test', 
                                        batch_size=-1, 
                                        as_supervised=True,
                                    ))

# Initiate a random model
Here we let the model choose the labels. The assumption is that if the data set changes, the model labels will change. This then changes the image frequency rate, leading to a change the detection signal

In [None]:
model = nn.Sequential(
                nn.Conv1d(1,10,kernel_size=(3,3)),
                nn.ReLU(),
                nn.Conv1d(10,100,kernel_size=(3,3)),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(57600,10),
            )

model.to(device)

In [None]:
labels = np.zeros([10,10])

for ii in range(10):
    
    labels[ii,ii] = 1

# Load Data Loaders

In [None]:
training_images_pytorch = torch.Tensor(training_images).transpose(1,3)
training_labels_pytorch = torch.Tensor(training_labels).type(torch.LongTensor)

testing_images_pytorch = torch.Tensor(testing_images).transpose(1,3)
testing_labels_pytorch = torch.Tensor(testing_labels).type(torch.LongTensor)

training_dataset = TensorDataset(training_images_pytorch,training_labels_pytorch)
testing_dataset  = TensorDataset(testing_images_pytorch,testing_labels_pytorch)

training_loader = DataLoader(training_dataset, batch_size=1, shuffle=True, pin_memory=True,drop_last=True)
testing_loader  = DataLoader(testing_dataset, batch_size=1, shuffle=True, pin_memory=True,drop_last=True)

# Function Calls

In [None]:
def calculate_lambda(labels,label_vals):
    
    lambda_vals = np.zeros([len(label_vals)])
    
    for val in range(len(label_vals)):
        
        lambda_vals[val] = (val == labels).sum().item()/len(labels)
        
    return lambda_vals

def LambdaPredictionTransition(x,lambda_vals):
    
    x = x + lambda_vals
    return x

def LambdaFilterTransition(x,y):
    
    x = x + 0.001*(y-x)
    return x

def LambdaObservation(y,label_pred):
    
    lambda_vals = calculate_lambda(label_pred,torch.Tensor(np.arange(0,10)).int())
    y = y + lambda_vals
    return y
    
def Residual(x,y):
    
    r = np.linalg.norm(x-y)
    return r


In [None]:
x = np.zeros([10,len(training_loader)])
n = 0

model_labels = np.zeros([len(training_loader)])

for image_batch, _ in training_loader:
    
    x[:,n] = model(image_batch.to(device)).cpu().data.numpy().squeeze()
    model_labels[n] = x[:,n].argmax()
    n += 1

In [None]:
x = np.zeros([10,len(training_loader)])
n = 0

detection_labels = np.zeros([len(training_loader)])

for image_batch, _ in training_loader:
    
    image_batch = torch.randn(image_batch.size())
    x[:,n] = model(image_batch.to(device)).cpu().data.numpy().squeeze()
    detection_labels[n] = x[:,n].argmax()
    n += 1

In [None]:
lambda_vals_model     = calculate_lambda(model_labels,torch.Tensor(np.arange(0,10)).int())
lambda_vals_detection = calculate_lambda(detection_labels,torch.Tensor(np.arange(0,10)).int())

In [None]:
k = 0

x = np.zeros([len(lambda_vals_model),len(testing_loader)])
y = np.zeros([len(lambda_vals_model),len(testing_loader)])
r = np.zeros([len(testing_loader)])

for image_batch, label_batch in testing_loader:

    model.eval()

    output           = model(image_batch.to(device))

    _, predicted     = torch.max(output.data, 1)
    
    x[:,k+1] = LambdaPredictionTransition(x[:,k],lambda_vals_model)
    y[:,k+1] = LambdaObservation(y[:,k],predicted)
    r[k]     = Residual(x[:,k+1],y[:,k+1])
    x[:,k+1] = LambdaFilterTransition(x[:,k+1],y[:,k+1])

    k += 1
    
    if k+1 == len(testing_loader):
        break

r_test = r

k = 0

x = np.zeros([len(lambda_vals_model),len(testing_loader)])
y = np.zeros([len(lambda_vals_model),len(testing_loader)])
r = np.zeros([len(testing_loader)])

for image_batch, label_batch in testing_loader:

    model.eval()
    
    image_batch += 100.0*torch.randn(image_batch.size())
    
    output           = model(image_batch.to(device))

    _, predicted     = torch.max(output.data, 1)
    
    x[:,k+1] = LambdaPredictionTransition(x[:,k],lambda_vals_model)
    y[:,k+1] = LambdaObservation(y[:,k],predicted)
    r[k]     = Residual(x[:,k+1],y[:,k+1])
    x[:,k+1] = LambdaFilterTransition(x[:,k+1],y[:,k+1])

    k += 1
    
    if k+1 == len(testing_loader):
        break

r_detect = r

# Plot Figures

In [None]:
plt.plot(r_test)
plt.plot(r_detect)

In [None]:
bins = np.linspace(0,70,100)

hist_no_shift, _ = np.histogram(r_test, bins=bins, density=True)
hist_shift, _    = np.histogram(r_detect, bins=bins, density=True)

line_up,   = plt.plot(bins[0:99],hist_no_shift, label='No Label Shift')
line_down, = plt.plot(bins[0:99],hist_shift, label='Label Shift')

plt.legend(handles=[line_up, line_down])