In [1]:
import pennylane as qml
from pennylane import numpy as np

import torch

In [None]:
clayer = torch.nn.Linear(2, 2)
softmax = torch.nn.Softmax(dim=1)

In [None]:
n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev)
def qnode(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.RX(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

In [None]:
weight_shapes = {"weights": (n_qubits,)}
qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)

In [None]:
layers = [clayer, qlayer, softmax]
model = torch.nn.Sequential(*layers)

In [None]:
num_points = 50_000
dummy_data = torch.rand(num_points, 2)

# Pytorch needs to see a leading batch dimension
dummy_data_no_broadcast = [
    dummy_data[i].reshape(1, -1) for i in range(num_points)
]

#print(dummy_data[:5])
#print(model(dummy_data[:5]))

In [None]:
import time

no_broadcast_start = time.process_time()

for data_point in dummy_data_no_broadcast:
    model(data_point)

no_broadcast_time = time.process_time() - no_broadcast_start

In [None]:
batch_sizes = [10, 20, 30, 40, 50, 100]
broadcast_times = []

num_batches = [num_points // batch_sizes[i] for i in range(len(batch_sizes))]

In [None]:
for n, batch_size in zip(num_batches, batch_sizes):
    start = time.process_time()

    for batch_num in range(n):
        batch_start = batch_num*batch_size
        batch_end = (batch_num + 1)*batch_size
        model(dummy_data[batch_start : batch_end])

    broadcast_times.append(time.process_time() - start)

In [None]:
import matplotlib.pyplot as plt

plt.style.use("pennylane.drawer.plot")

In [None]:
batch_sizes.append(1)
broadcast_times.append(no_broadcast_time)

plt.scatter(batch_sizes, broadcast_times)
plt.ylabel('Time (s)')
plt.xlabel('Batch size')
plt.title("Time taken for one full pass through data")
plt.show()