In [None]:
import random

import numpy as np
import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as transforms
from drift_metric import DriftMetric
from torch.utils.data import DataLoader
from torchmetrics.utilities.data import dim_zero_cat

from daml.detectors import DriftKS

In [None]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(0)

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainset = tv.datasets.MNIST(root="./data", train=True, download=True, transform=transform)

In [None]:
num_epochs = 100
batch_size = 128

subset = torch.utils.data.Subset(trainset, range(2000))
dataloader = DataLoader(subset, batch_size=batch_size)

In [None]:
# Define model architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # 28 x 28
            nn.Conv2d(1, 4, kernel_size=5),
            # 4 x 24 x 24
            nn.ReLU(True),
            nn.Conv2d(4, 8, kernel_size=5),
            nn.ReLU(True),
            # 8 x 20 x 20 = 3200
            nn.Flatten(),
            nn.Linear(3200, 10),
            # 10
            nn.Sigmoid(),
        )
        self.decoder = nn.Sequential(
            # 10
            nn.Linear(10, 400),
            # 400
            nn.ReLU(True),
            nn.Linear(400, 4000),
            # 4000
            nn.ReLU(True),
            nn.Unflatten(1, (10, 20, 20)),
            # 10 x 20 x 20
            nn.ConvTranspose2d(10, 10, kernel_size=5),
            # 24 x 24
            nn.ConvTranspose2d(10, 1, kernel_size=5),
            # 28 x 28
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def encode(self, x):
        x = self.encoder(x)
        return x

In [None]:
# Initialize model and train
model = Autoencoder()
distance = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.5)
losses = []
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = torch.tensor(img)
        output = model(img)
        loss = distance(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().numpy())
    print(f"epoch [{epoch + 1}/{num_epochs}], loss: {loss.item():.4f}")

In [None]:
mnist_encodings = []
corr_encodings = []


with torch.no_grad():
    model.eval()
    # First, preprocess MNIST
    for batch in dataloader:
        imgs, _ = batch
        encodings = model.encode(imgs)
        mnist_encodings.append(encodings.detach().cpu())
mnist_encodings = dim_zero_cat(mnist_encodings).numpy()
mnist_encodings = mnist_encodings

drift = DriftMetric(mnist_encodings)
with torch.no_grad():
    model.eval()
    # Then, preprocess corrupt MNIST
    for batch in dataloader:
        imgs, _ = batch
        encodings = model.encode(imgs * torch.rand_like(imgs))
        corr_encodings.append(encodings.detach().cpu())
        drift.update(encodings)


metric_ber = drift.compute()
print("Metric BER:", metric_ber)

np_corr_encodings = dim_zero_cat(corr_encodings).numpy()

ksdrift = DriftKS(mnist_encodings)
x = ksdrift.predict(np_corr_encodings)


print("Class KSDrift:", x)