In [100]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [101]:
%autoreload 2

In [158]:
import pennylane as qml
from quanvolution.quanv import TorchQuanvLayer
import torch
from torch import nn
import pennylane as qml
from pennylane import numpy as np
from pennylane.templates import RandomLayers
import torchvision
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torchvision import transforms
import time
from torch.utils.data import Subset


In [168]:
n_qubits = 4
dev2 = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev2)
def parameterised_qnode(inputs, weights):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

In [169]:
n_layers=1
weight_shapes = {"weights": (n_layers, n_qubits)}
qnode_layer = qml.qnn.TorchLayer(parameterised_qnode, weight_shapes=weight_shapes)

In [170]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=transform)

print(len(dataset))

train_set, test_set = torch.utils.data.random_split(dataset, [50000, 10000])

train_indices = torch.randperm(len(train_set))[:50]
test_indices = torch.randperm(len(test_set))[:10]

train_set = Subset(train_set, train_indices)
test_set = Subset(test_set, test_indices)

train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=4)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=4)

60000


In [171]:
model = torch.nn.Sequential(
    TorchQuanvLayer(qnode_layer, kernel_size=(2,2), stride=2, out_channels=4),
    TorchQuanvLayer(qnode_layer, kernel_size=(2,2), stride=2, out_channels=16),
    torch.nn.Flatten(),
    torch.nn.Linear(in_features=7*7*4*4, out_features=10)
)

In [172]:
optimiser = torch.optim.Adam(params=model.parameters(), lr=0.01)
loss_function = torch.nn.CrossEntropyLoss()

In [173]:
def run_experiment(model, optimiser, loss_function):
    
    start = time.time()

    for epoch in range(10):
        # initiliase epoch loss and predictions for accuracy
        cumulative_loss = 0.0
        correct_preds = 0
        model.train()
        for (x, y) in tqdm(train_loader, total=len(train_loader)):

            # do the business
            optimiser.zero_grad()
            outputs = model(x)
            # make predictions
            _, preds = torch.max(outputs, -1)
            loss = loss_function(outputs, y)  
            loss.backward()
            optimiser.step()

            # update loss and predictions
            cumulative_loss += loss.item() * x.size(0)
            correct_preds += torch.sum(preds == y.data).item()  
            
        train_epoch_loss = cumulative_loss / len(train_loader.dataset)
        train_epoch_acc = correct_preds / len(train_loader.dataset)
        
        print(f"Epoch {epoch}: Training Loss {train_epoch_loss} \n Training Accuracy {train_epoch_acc} \n Total time elapsed {time.time()-start}")
        
        model.eval()
        cumulative_val_loss = 0.
        correct_val_preds = 0
        for (x, y) in tqdm(test_loader, total=len(test_loader)):

            # do the business
            outputs = model(x)
            # make predictions
            _, preds = torch.max(outputs, -1)
            loss = loss_function(outputs, y)  

            # update loss and predictions
            cumulative_val_loss += loss.item() * x.size(0)
            correct_val_preds += torch.sum(preds == y.data).item()  
            
        val_epoch_loss = cumulative_val_loss / len(test_loader.dataset)
        val_epoch_acc = correct_val_preds / len(test_loader.dataset)
        
        print(f"Epoch {epoch}: Validation Loss {val_epoch_loss} \n Validation Accuracy {val_epoch_acc} \n Total time elapsed {time.time()-start}")

    return model

In [174]:
run_experiment(model, optimiser, loss_function)

  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 0: Training Loss 2.305353031158447 
 Training Accuracy 0.12 
 Total time elapsed 291.1890470981598


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 0: Validation Loss 2.298046112060547 
 Validation Accuracy 0.1 
 Total time elapsed 301.5173215866089


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1: Training Loss 2.0745471572875975 
 Training Accuracy 0.58 
 Total time elapsed 597.7529082298279


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1: Validation Loss 2.2632441997528074 
 Validation Accuracy 0.1 
 Total time elapsed 605.9923114776611


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 2: Training Loss 1.8062255477905274 
 Training Accuracy 0.54 
 Total time elapsed 902.9067921638489


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2: Validation Loss 2.3081400871276854 
 Validation Accuracy 0.1 
 Total time elapsed 912.8048276901245


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 3: Training Loss 1.509759669303894 
 Training Accuracy 0.74 
 Total time elapsed 1199.2184793949127


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3: Validation Loss 2.358138847351074 
 Validation Accuracy 0.1 
 Total time elapsed 1207.6711556911469


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 4: Training Loss 1.166794981956482 
 Training Accuracy 0.8 
 Total time elapsed 1485.3968822956085


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4: Validation Loss 2.3496920585632326 
 Validation Accuracy 0.3 
 Total time elapsed 1493.4524700641632


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 5: Training Loss 0.7364475274085999 
 Training Accuracy 0.96 
 Total time elapsed 1758.165236234665


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5: Validation Loss 2.450683069229126 
 Validation Accuracy 0.3 
 Total time elapsed 1766.171458721161


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 6: Training Loss 0.3752254854142666 
 Training Accuracy 1.0 
 Total time elapsed 2028.7741787433624


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6: Validation Loss 2.672013854980469 
 Validation Accuracy 0.3 
 Total time elapsed 2036.6593129634857


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 7: Training Loss 0.17588809423148632 
 Training Accuracy 1.0 
 Total time elapsed 2299.609493494034


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7: Validation Loss 2.7957467555999758 
 Validation Accuracy 0.3 
 Total time elapsed 2307.567653656006


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 8: Training Loss 0.09078230263665318 
 Training Accuracy 1.0 
 Total time elapsed 2570.0426392555237


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8: Validation Loss 2.926401424407959 
 Validation Accuracy 0.4 
 Total time elapsed 2577.973256111145


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch 9: Training Loss 0.057003190610557794 
 Training Accuracy 1.0 
 Total time elapsed 2865.6566450595856


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9: Validation Loss 3.0200959205627442 
 Validation Accuracy 0.4 
 Total time elapsed 2874.017060995102


Sequential(
  (0): TorchQuanvLayer(
    (qnode): <Quantum Torch Layer: func=parameterised_qnode>
  )
  (1): TorchQuanvLayer(
    (qnode): <Quantum Torch Layer: func=parameterised_qnode>
  )
  (2): Flatten(start_dim=1, end_dim=-1)
  (3): Linear(in_features=784, out_features=10, bias=True)
)