In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    # change the version (==vX.XX.X) in the statement below in order to get the latest version of dataeval.
    %pip install -q dataeval[torch]==v0.67.0
except Exception:
    pass

from pytest import approx

In [2]:
import torch
import torchvision.transforms.v2 as v2
from torch.utils.data import Subset
from torchvision.datasets import MNIST

In [3]:
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
training_dataset = MNIST(root="./data/", train=True, transform=to_tensor, download=True)
testing_dataset = MNIST(root="./data/", train=False, transform=to_tensor, download=True)

In [4]:
print("Training data size:", training_dataset.data.shape)
print("Training labels size:", training_dataset.targets.shape)

Training data size: torch.Size([60000, 28, 28])
Training labels size: torch.Size([60000])


In [5]:
from dataeval.torch.models import AriaAutoencoder
from dataeval.torch.trainer import AETrainer

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AriaAutoencoder(channels=1)
trainer = AETrainer(model, device=device, batch_size=32)

In [7]:
training_subset = Subset(training_dataset, range(6000))
training_loss = trainer.train(training_subset, epochs=10)
print(training_loss[-1])

0.11283737020765214


In [8]:
eval_loss = trainer.eval(testing_dataset)
print(eval_loss)

0.1140080429018496


In [9]:
### TEST ASSERTION ###
print(training_loss[-1])
print(eval_loss)
assert training_loss[-1] == approx(0.112837, abs=1e-4)
assert eval_loss == approx(0.114008, abs=1e-4)

0.11283737020765214
0.1140080429018496


In [10]:
embeddings = trainer.encode(training_subset)

In [11]:
### TEST ASSERTION ###
print(embeddings.shape)
assert embeddings.shape == torch.Size([6000, 64, 6, 6])

torch.Size([6000, 64, 6, 6])


In [12]:
print("Embedded image shape:", embeddings.shape)

Embedded image shape: torch.Size([6000, 64, 6, 6])
