In [None]:
import random

import numpy as np
import torch
import torch.nn as nn
import torchmetrics
import torchvision as tv
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from dataeval._internal.metrics.ber import BER, _knn

In [None]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(0)

In [None]:
from torchmetrics.utilities.data import dim_zero_cat


class BERMetric(torchmetrics.Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("images", default=[], dist_reduce_fx="cat")
        self.add_state("labels", default=[], dist_reduce_fx="cat")

    def update(self, images: torch.Tensor, labels: torch.Tensor) -> None:
        self.images.append(images)
        self.labels.append(labels)

    def compute(self) -> torch.Tensor:
        images = dim_zero_cat(self.images).detach().cpu().numpy()
        labels = dim_zero_cat(self.labels).detach().cpu().numpy()
        return torch.tensor(_knn(images, labels, 1))

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainset = tv.datasets.MNIST(root="./data", train=True, download=True, transform=transform)

num_epochs = 250
batch_size = 128

In [None]:
subset = torch.utils.data.Subset(trainset, range(2000))
dataloader = DataLoader(subset, batch_size=batch_size)

In [None]:
# Define model architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # 28 x 28
            nn.Conv2d(1, 4, kernel_size=5),
            # 4 x 24 x 24
            nn.ReLU(True),
            nn.Conv2d(4, 8, kernel_size=5),
            nn.ReLU(True),
            # 8 x 20 x 20 = 3200
            nn.Flatten(),
            nn.Linear(3200, 10),
            # 10
            nn.Sigmoid(),
        )
        self.decoder = nn.Sequential(
            # 10
            nn.Linear(10, 400),
            # 400
            nn.ReLU(True),
            nn.Linear(400, 4000),
            # 4000
            nn.ReLU(True),
            nn.Unflatten(1, (10, 20, 20)),
            # 10 x 20 x 20
            nn.ConvTranspose2d(10, 10, kernel_size=5),
            # 24 x 24
            nn.ConvTranspose2d(10, 1, kernel_size=5),
            # 28 x 28
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def encode(self, x):
        x = self.encoder(x)
        return x

In [None]:
# Initialize model and train
model = Autoencoder()
distance = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.5)
losses = []
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = torch.tensor(img)
        output = model(img)
        loss = distance(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().numpy())
    print(f"epoch [{epoch + 1}/{num_epochs}], loss: {loss.item():.4f}")

In [114]:
encoding_list = []
labels_list = []
imgs_list = []

ber = BERMetric()

with torch.no_grad():
    for batch in dataloader:
        imgs, labels = batch
        encodings = model.encode(imgs)
        x = ber(encodings, labels)

        encoding_list.append(encodings.detach().cpu())
        labels_list.append(labels.detach().cpu())
        imgs_list.append(imgs.detach().cpu())

metric_ber = ber.compute()
print("Metric BER:", metric_ber)

np_encodings = dim_zero_cat(encoding_list).numpy()
np_labels = dim_zero_cat(labels_list).numpy()
np_imgs = dim_zero_cat(imgs_list).numpy()

b = BER(np_encodings, np_labels)
class_ber = b.evaluate()
print("Class BER:", class_ber)

Metric BER: tensor([0.3010, 0.1658], dtype=torch.float64)
Class BER: {'ber': 0.301, 'ber_lower': 0.165765704968775}
