# MNIST Digits (Quantum)

This notebooks trains and evaluates quantum vision transformers for the MNIST Digits classification task.
You can find information about the dataset at https://www.tensorflow.org/datasets/catalog/mnist.

In [1]:
import jax

from quantum_transformers.utils import plot_image
from quantum_transformers.datasets import get_mnist_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import VisionTransformer
from quantum_transformers.quantum_layer import get_circuit

data_dir = '/global/cfs/cdirs/m4392/salcc/data'

Please first ``pip install -U qiskit`` to enable related functionality in translation module


The models are trained using the following devices:

In [2]:
for d in jax.devices():
    print(d, d.device_kind)

TFRT_CPU_0 cpu


Let's check how many samples the dataset has, the shape of the input data, and how one sample looks like.

In [3]:
mnist_train_dataloader, mnist_valid_dataloader, mnist_test_dataloader = get_mnist_dataloaders(batch_size=64, data_dir=data_dir)
first_image = next(iter(mnist_train_dataloader))[0][0]
print(first_image.shape)
plot_image(first_image)

2024-02-19 18:16:01.092563: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /global/cfs/cdirs/m4392/salcc/data/mnist/3.0.1...[0m


OSError: [Errno 30] Read-only file system: '/global'

Now let's train the quantum vision transformer on the best hyperparameters found using random hyperparameter search.

In [None]:
model = VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3,
                          quantum_attn_circuit=get_circuit(), quantum_mlp_circuit=get_circuit())
train_and_evaluate(model, mnist_train_dataloader, mnist_valid_dataloader, mnist_test_dataloader, num_classes=10, num_epochs=30)

2023-10-09 15:09:25.510151: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Number of parameters = 1672


Epoch   1/30: 100%|██████████| 843/843 [00:46<00:00, 18.32batch/s, Loss = 1.8906, AUC = 79.17%] 
Epoch   2/30: 100%|██████████| 843/843 [00:10<00:00, 83.93batch/s, Loss = 1.4022, AUC = 92.00%] 
Epoch   3/30: 100%|██████████| 843/843 [00:09<00:00, 85.68batch/s, Loss = 0.9424, AUC = 95.62%] 
Epoch   4/30: 100%|██████████| 843/843 [00:10<00:00, 82.76batch/s, Loss = 0.7218, AUC = 96.88%] 
Epoch   5/30: 100%|██████████| 843/843 [00:10<00:00, 83.92batch/s, Loss = 0.6324, AUC = 97.40%] 
Epoch   6/30: 100%|██████████| 843/843 [00:10<00:00, 82.51batch/s, Loss = 0.5749, AUC = 97.80%] 
Epoch   7/30: 100%|██████████| 843/843 [00:10<00:00, 82.99batch/s, Loss = 0.5245, AUC = 98.17%] 
Epoch   8/30: 100%|██████████| 843/843 [00:09<00:00, 85.52batch/s, Loss = 0.4905, AUC = 98.38%] 
Epoch   9/30: 100%|██████████| 843/843 [00:10<00:00, 84.11batch/s, Loss = 0.4669, AUC = 98.53%] 
Epoch  10/30: 100%|██████████| 843/843 [00:10<00:00, 84.09batch/s, Loss = 0.4496, AUC = 98.60%] 
Epoch  11/30: 100%|██████████|

Total training time = 338.45s, best validation AUC = 98.92% at epoch 29


Testing: 100%|██████████| 156/156 [00:04<00:00, 38.56batch/s, Loss = 0.3841, AUC = 98.94%] 


(Array(0.38410008, dtype=float32), 98.94037686108435, [], [])