In [1]:
# Google Colab Only
try:
    import google.colab  # noqa: F401

    %pip install -q daml[torch]
except Exception:
    pass

import os

from pytest import approx

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
import torch
import torchvision.transforms.v2 as v2
from torch.utils.data import Subset
from torchvision.datasets import MNIST

In [3]:
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
training_dataset = MNIST(root="./data/", train=True, transform=to_tensor, download=True)
testing_dataset = MNIST(root="./data/", train=False, transform=to_tensor, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz


Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

  1%|          | 98304/9912422 [00:00<00:12, 784982.68it/s]

  3%|▎         | 294912/9912422 [00:00<00:07, 1236859.60it/s]

 13%|█▎        | 1277952/9912422 [00:00<00:02, 4216254.79it/s]

 53%|█████▎    | 5242880/9912422 [00:00<00:00, 14866376.73it/s]

100%|██████████| 9912422/9912422 [00:00<00:00, 17273781.21it/s]




Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw



Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz


Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

100%|██████████| 28881/28881 [00:00<00:00, 454021.63it/s]




Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz


Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

  4%|▍         | 65536/1648877 [00:00<00:03, 508140.15it/s]

 16%|█▌        | 262144/1648877 [00:00<00:01, 1100983.61it/s]

 62%|██████▏   | 1015808/1648877 [00:00<00:00, 3318150.90it/s]

100%|██████████| 1648877/1648877 [00:00<00:00, 4224288.70it/s]




Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz


Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

100%|██████████| 4542/4542 [00:00<00:00, 14111502.79it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [4]:
print("Training data size:", training_dataset.data.shape)
print("Training labels size:", training_dataset.targets.shape)

Training data size: torch.Size([60000, 28, 28])
Training labels size: torch.Size([60000])


In [5]:
from daml.models.ae import AETrainer, AriaAutoencoder

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AriaAutoencoder(channels=1)
trainer = AETrainer(model, device=device, batch_size=32)

In [7]:
training_subset = Subset(training_dataset, range(6000))
training_loss = trainer.train(training_subset, epochs=10)
print(training_loss[-1])

0.11283737020765214


In [8]:
eval_loss = trainer.eval(testing_dataset)
print(eval_loss)

0.1140080429018496


In [9]:
### TEST ASSERTION ###
print(training_loss[-1])
print(eval_loss)
assert training_loss[-1] == approx(0.112837, abs=1e-4)
assert eval_loss == approx(0.114008, abs=1e-4)

0.11283737020765214
0.1140080429018496


In [10]:
embeddings = trainer.encode(training_subset)

In [11]:
### TEST ASSERTION ###
print(embeddings.shape)
assert embeddings.shape == torch.Size([6000, 64, 6, 6])

torch.Size([6000, 64, 6, 6])


In [12]:
print("Embedded image shape:", embeddings.shape)

Embedded image shape: torch.Size([6000, 64, 6, 6])
