# IMDb Reviews (Quantum)

This notebook trains and evaluates a quantum transformer for the IMDb Reviews sentiment classification task. Note that this is a text classification task.
You can find information about the dataset at https://www.tensorflow.org/datasets/catalog/imdb_reviews.

In [1]:
import jax

from quantum_transformers.datasets import get_imdb_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit

data_dir = '/global/cfs/cdirs/m4392/salcc/data'

2025-08-19 10:01:17.976755: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755568880.555112    2408 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755568880.920186    2408 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755568886.159207    2408 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755568886.159255    2408 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755568886.159257    2408 computation_placer.cc:177] computation placer alr

The models are trained using the following devices:

In [2]:
for d in jax.devices():
    print(d, d.device_kind)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


TFRT_CPU_0 cpu


Let's check how big is the vocabulary, and see an example of one example review (both in tokenized and raw form).

In [3]:
# Define the parameters
batch_size = 32
max_vocab_size = 20_000
max_seq_len = 512

# Load the data
(train_dataloader, val_dataloader, test_dataloader), vocab, tokenizer = get_imdb_dataloaders(
    data_dir='./data', 
    batch_size=batch_size, 
    max_vocab_size=max_vocab_size, 
    max_seq_len=max_seq_len
)

# Print the results
print(f"Vocabulary size: {len(vocab)}")
first_batch = next(iter(train_dataloader))
print(first_batch[0][0])
print(' '.join(map(bytes.decode, tokenizer.detokenize(first_batch[0])[0].numpy().tolist())))

Cardinalities (train, val, test): 22500 2500 25000
Vocabulary size: 19769
[  150   905  1336  2244   105    42  1114   369   106   163   106    42
  1114  2058    17   191   121    15   146   263   246  3020   142   167
    42   224   133   298   102   120   284    15    96   104   105   277
   165  2244    10    60 10135 10476   138  1554  4074   575   113   361
    17    31   100    18    33    31   100    18    33    50   327   104
   106    95  4751    10    60  4788   113    15    95   584    97    95
   215    10    60   198   156   742    98  5364    17    95  1230  5861
  1342   288   105   256   439    97   358    98  7101   108  2715  1564
    10    60  2438 11147   288   153   478   135   122    95   262   148
    17   471   149    42   668  1674  3898    15   110   256   156  1034
   107  1336   408   225   243    17   117  1047   203   102 12557    15
   117   318 12061   135    17    17    17  2244  1582   143   129    42
 12230    98    95  3026    10    60    17    95  

In [4]:
model = Transformer(num_tokens=len(vocab), max_seq_len=512, num_classes=2, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3,
                    quantum_attn_circuit=get_circuit(), quantum_mlp_circuit=get_circuit())
train_and_evaluate(model, train_dataloader, val_dataloader, test_dataloader, num_classes=2, num_epochs=1) # change num_epochs as needed

Number of parameters = 122096


Epoch   1/1: 100%|██████████| 703/703 [1:06:03<00:00,  5.64s/batch, Loss = 0.6951, AUC = 0.496, Train time = 2901.19s]


Best validation AUC = 0.496 at epoch 1
Total training time = 2901.19s, total time (including evaluations) = 3915.11s


Testing: 100%|██████████| 781/781 [18:30<00:00,  1.42s/batch, Loss = 0.6973, AUC = 0.485]


{'train_losses': Array([0.6963565], dtype=float32),
 'val_losses': Array([0.6950757], dtype=float32),
 'train_aucs': Array([0.4909439], dtype=float32),
 'val_aucs': Array([0.49571323], dtype=float32),
 'test_loss': Array(0.6972679, dtype=float32),
 'test_auc': 0.48519720574853364,
 'test_fpr': array([0.00000000e+00, 8.00128020e-05, 5.60089614e-04, ...,
        9.99679949e-01, 9.99679949e-01, 1.00000000e+00]),
 'test_tpr': array([0.        , 0.        , 0.        , ..., 0.99983992, 1.        ,
        1.        ])}