In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import pennylane as qml

In [None]:
# Load CIFAR-10 sample
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
images, labels = next(iter(trainloader))

In [None]:
# Define quantum device
dev = qml.device("braket.aws.qubit", 
                 device_arn="arn:aws:braket:::device/quantum-simulator/amazon/sv1", 
                 wires=4)

In [None]:
# Quantum circuit
@qml.qnode(dev, interface="torch")
def quantum_circuit(inputs, weights):
    for i in range(4):
        qml.RY(inputs[i], wires=i)
    for i in range(4):
        qml.RZ(weights[i], wires=i)
    for i in range(3):
        qml.CNOT(wires=[i, i+1])
    for i in range(4):
        qml.RY(weights[i+4], wires=i)
    return [qml.expval(qml.PauliZ(i)) for i in range(4)]

In [None]:
# Hybrid CNN + Quantum model
class HybridModel(nn.Module):
    def __init__(self):
        super(HybridModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 4)
        self.quantum_layer = qml.qnn.TorchLayer(quantum_circuit, {"weights": (8,)})
        self.fc2 = nn.Linear(4, 10)
    
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.tanh(self.fc1(x))
        x = self.quantum_layer(x)
        x = self.fc2(x)
        return x

In [None]:
# Forward pass
model = HybridModel()
output = model(images)

print(f"Input shape: {images.shape}")
print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")