# 🧠 Vision Transformer vs Quantum-Inspired Vision Transformer
This Colab notebook compares a small regular ViT and a simulated quantum-inspired ViT using PennyLane and PyTorch on a subset of CIFAR-10.

In [None]:
# 📦 Install required packages
!pip install -q pennylane torchvision timm

In [None]:
# 🔍 Import libraries
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import pennylane as qml
import numpy as np
from tqdm import tqdm

## 📥 Load CIFAR-10 (small subset for fast training)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform)

# Use a subset for speed
train_subset = Subset(train_dataset, list(range(1000)))
test_subset = Subset(test_dataset, list(range(200)))

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

## ✅ Simple ViT (Tiny Custom Model)

In [None]:
class SimpleViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(3*32*32, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

## ⚛️ Quantum Layer for Quantum-Inspired ViT

In [None]:
n_qubits = 4
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch")
def quantum_circuit(inputs, weights):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(w)) for w in range(n_qubits)]

In [None]:
class QuantumLayer(nn.Module):
    def __init__(self):
        super().__init__()
        weight_shapes = {"weights": (1, n_qubits)}
        self.q_layer = qml.qnn.TorchLayer(quantum_circuit, weight_shapes)

    def forward(self, x):
        return self.q_layer(x)

## 🔗 Hybrid Quantum ViT

In [None]:
class HybridViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(3*32*32, n_qubits)
        self.q_layer = QuantumLayer()
        self.fc2 = nn.Linear(n_qubits, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.tanh(self.fc1(x))
        x = self.q_layer(x)
        return self.fc2(x)