In [None]:
"""
MNIST CNN + Dimensionality Reduction (PCA vs Random Projection vs QEB) + MLP
Train / Validation / Test feature extraction 
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms
import numpy as np
from sklearn.decomposition import PCA
import pennylane as qml
from pennylane import numpy as np

# ========== Config ==========
class Config:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    DATA_PATH = "./data"
    BATCH_SIZE = 64
    TEST_BATCH_SIZE = 1000
    CNN_HIDDEN_DIM = 128
    PROJECTION_DIM = 7
    MLP_HIDDEN_DIM = 64
    NUM_CLASSES = 10
    LEARNING_RATE = 0.001
    CNN_EPOCHS = 2
    MLP_EPOCHS = 5
    LOG_INTERVAL = 100
    RANDOM_SEED = 42
    VAL_RATIO = 0.1   # validation split ratio

# Models
class CNN(nn.Module):
    def __init__(self, hidden_dim=128, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, return_features=False):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        features = self.fc1(x)
        x = F.relu(features)
        output = self.fc2(x)
        
        if return_features:
            return output, features
        return output
    
    def extract_features(self, x):
        _, features = self.forward(x, return_features=True)
        return features

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_classes=10):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x):
        return self.layers(x)


# Data
def get_mnist_loaders(config):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_train_dataset = datasets.MNIST(config.DATA_PATH, train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(config.DATA_PATH, train=False, download=True, transform=transform)

    # train/validation split
    val_size = int(len(full_train_dataset) * config.VAL_RATIO)
    train_size = len(full_train_dataset) - val_size
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.TEST_BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.TEST_BATCH_SIZE, shuffle=False)

    return train_loader, val_loader, test_loader

# Feature Extraction
def extract_hidden_representations(model, dataloader, device, name="set"):
    model.eval()
    all_features, all_labels = [], []
    with torch.no_grad():
        for data, target in dataloader:
            data = data.to(device)
            target = target.to(device)
            features = model.extract_features(data)
            all_features.append(features.cpu())
            all_labels.append(target.cpu())
    features = torch.cat(all_features)
    labels = torch.cat(all_labels)
    print(f"[Extract] {name}: {features.shape}")
    return features, labels

def apply_pca(train_features, val_features, test_features, n_components=7):
    pca = PCA(n_components=n_components, random_state=42)
    train_pca = pca.fit_transform(train_features.numpy())
    val_pca = pca.transform(val_features.numpy())
    test_pca = pca.transform(test_features.numpy())
    print(f"PCA reduced to {n_components} dims, variance explained {pca.explained_variance_ratio_.sum():.3f}")
    return (torch.tensor(train_pca, dtype=torch.float32),
            torch.tensor(val_pca, dtype=torch.float32),
            torch.tensor(test_pca, dtype=torch.float32))

def apply_random_projection(train_features, val_features, test_features, target_dim=7, seed=42):
    torch.manual_seed(seed)
    input_dim = train_features.shape[1]
    proj_matrix = torch.randn(input_dim, target_dim)
    proj_matrix = proj_matrix / torch.norm(proj_matrix, dim=0, keepdim=True)

    train_proj = train_features @ proj_matrix #torch.matmul(A,B) Îûë A@BÎûë ÎèôÏùº
    val_proj = val_features @ proj_matrix
    test_proj = test_features @ proj_matrix
    print(f"Random Projection: {input_dim} ‚Üí {target_dim}")
    return train_proj, val_proj, test_proj

# Quantum Embedding
n_qubits = 7
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch", diff_method="backprop")
def quantum_embedding(x, weights):
    # üîπ Amplitude Embedding
    qml.AmplitudeEmbedding(x, wires=range(n_qubits), normalize=True)
    
    # üîπ PQC (RY + CNOT)
    for i in range(n_qubits):
        qml.RY(weights[i], wires=i)
    for i in range(n_qubits - 1):
        qml.CNOT(wires=[i, i+1])
    
    # üîπ Í∞Å ÌÅêÎπÑÌä∏ Z expectation Î∞òÌôò
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

class QuantumEmbeddingLayer(nn.Module):
    def __init__(self, n_qubits=7, n_features=128):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_features = n_features
        self.embed_dim = 2**n_qubits
        # ÌïôÏäµ Í∞ÄÎä•Ìïú ÌååÎùºÎØ∏ÌÑ∞ (RY Í∞ÅÎèÑ)
        self.weights = nn.Parameter(torch.randn(n_qubits) * 0.01)

    def forward(self, x):
        q_out = []
        for sample in x:
            # Í∞Å sampleÏùÑ quantum embeddingÏúºÎ°ú Î≥ÄÌôò
            q_result = quantum_embedding(sample.detach().cpu(), self.weights)
            q_result = torch.as_tensor(q_result, device=x.device).float()
            q_out.append(q_result)
        return torch.stack(q_out)


def apply_quantum_embedding(train_features, val_features, test_features, n_qubits=7):
    q_layer = QuantumEmbeddingLayer(n_qubits=n_qubits, n_features=train_features.shape[1])
    q_layer.eval()  # ÏûÑÎ≤†Îî© Ïö©ÎèÑ
    with torch.no_grad():
        train_q = q_layer(train_features)
        val_q = q_layer(val_features)
        test_q = q_layer(test_features)
    print(f"Quantum Embedding: {train_features.shape[1]} ‚Üí {n_qubits} qubits (output {train_q.shape[1]} dims)")
    return train_q, val_q, test_q


# Train/Eval
def train_cnn(model, train_loader, val_loader, config):
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    for epoch in range(config.CNN_EPOCHS):
        model.train()
        for data, target in train_loader:
            data, target = data.to(config.DEVICE), target.to(config.DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        print(f"[CNN] Epoch {epoch+1} finished")
        evaluate_model(model, val_loader, config.DEVICE, "[CNN-VAL]")
    return model

def train_mlp(model, train_feats, train_labels, val_feats, val_labels, test_feats, test_labels, config, name="MLP",epochs=5):
    train_dataset = TensorDataset(train_feats, train_labels)
    val_dataset = TensorDataset(val_feats, val_labels)
    test_dataset = TensorDataset(test_feats, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.TEST_BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.TEST_BATCH_SIZE, shuffle=False)

    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

    for epoch in range(1, epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(config.DEVICE), target.to(config.DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

            # üîπ batch Í∞ÑÍ≤©Î≥Ñ loss Ï∂úÎ†•
            if batch_idx % config.LOG_INTERVAL == 0:
                print(f"[{name}] Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] "
                      f"Loss: {loss.item():.6f}")

        # üîπ Í∞Å epochÎßàÎã§ validation/test Í≤∞Í≥º Ï∂úÎ†•
        print(f"[{name}] Epoch {epoch} finished")
        evaluate_model(model, val_loader, config.DEVICE, f"[{name}-VAL]")
        evaluate_model(model, test_loader, config.DEVICE, f"[{name}-TEST]")

    return model

def evaluate_model(model, dataloader, device, prefix=""):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    acc = 100. * correct / len(dataloader.dataset)
    print(f"{prefix} Accuracy: {acc:.2f}%")
    return acc


# Main
def main():
    config = Config()
    torch.manual_seed(config.RANDOM_SEED)

    train_loader, val_loader, test_loader = get_mnist_loaders(config)
    cnn = CNN(config.CNN_HIDDEN_DIM, config.NUM_CLASSES).to(config.DEVICE)
    cnn = train_cnn(cnn, train_loader, val_loader, config)

    # Feature extraction
    train_feats, train_labels = extract_hidden_representations(cnn, train_loader, config.DEVICE, "train")
    val_feats, val_labels = extract_hidden_representations(cnn, val_loader, config.DEVICE, "val")
    test_feats, test_labels = extract_hidden_representations(cnn, test_loader, config.DEVICE, "test")
    
    print("\n" + "="*50)
    print("PCA")
    print("="*50)
    # PCA + MLP
    train_pca, val_pca, test_pca = apply_pca(train_feats, val_feats, test_feats, config.PROJECTION_DIM)
    pca_mlp = MLP(train_pca.shape[1], config.MLP_HIDDEN_DIM, config.NUM_CLASSES).to(config.DEVICE)
    train_mlp(pca_mlp, train_pca, train_labels, val_pca, val_labels, test_pca, test_labels, config, "PCA-MLP")
    
    print("\n" + "="*50)
    print("MLP")
    print("="*50)
    # Random Projection + MLP
    train_rp, val_rp, test_rp = apply_random_projection(train_feats, val_feats, test_feats, config.PROJECTION_DIM)
    rp_mlp = MLP(train_rp.shape[1], config.MLP_HIDDEN_DIM, config.NUM_CLASSES).to(config.DEVICE)
    train_mlp(rp_mlp, train_rp, train_labels, val_rp, val_labels, test_rp, test_labels, config, "RP-MLP")
    
    print("\n" + "="*50)
    print("QEB")
    print("="*50)
    # Quantum Embedding + MLP
    train_q, val_q, test_q = apply_quantum_embedding(train_feats, val_feats, test_feats, n_qubits=config.PROJECTION_DIM)
    qeb_mlp = MLP(train_q.shape[1], config.MLP_HIDDEN_DIM, config.NUM_CLASSES).to(config.DEVICE)
    train_mlp(qeb_mlp, train_q, train_labels, val_q, val_labels, test_q, test_labels, config, "QEB-MLP",epochs=30)

if __name__ == "__main__":
    main()


[CNN] Epoch 1 finished
[CNN-VAL] Accuracy: 98.28%
[CNN] Epoch 2 finished
[CNN-VAL] Accuracy: 98.17%
[Extract] train: torch.Size([54000, 128])
[Extract] val: torch.Size([6000, 128])
[Extract] test: torch.Size([10000, 128])

PCA
PCA reduced to 7 dims, variance explained 0.797
[PCA-MLP] Epoch: 1 [0/54000] Loss: 5.926607
[PCA-MLP] Epoch: 1 [6400/54000] Loss: 0.386174
[PCA-MLP] Epoch: 1 [12800/54000] Loss: 0.264391
[PCA-MLP] Epoch: 1 [19200/54000] Loss: 0.077076
[PCA-MLP] Epoch: 1 [25600/54000] Loss: 0.022681
[PCA-MLP] Epoch: 1 [32000/54000] Loss: 0.011295
[PCA-MLP] Epoch: 1 [38400/54000] Loss: 0.175794
[PCA-MLP] Epoch: 1 [44800/54000] Loss: 0.110930
[PCA-MLP] Epoch: 1 [51200/54000] Loss: 0.141165
[PCA-MLP] Epoch 1 finished
[PCA-MLP-VAL] Accuracy: 97.17%
[PCA-MLP-TEST] Accuracy: 97.62%
[PCA-MLP] Epoch: 2 [0/54000] Loss: 0.117383
[PCA-MLP] Epoch: 2 [6400/54000] Loss: 0.032892
[PCA-MLP] Epoch: 2 [12800/54000] Loss: 0.014607
[PCA-MLP] Epoch: 2 [19200/54000] Loss: 0.079382
[PCA-MLP] Epoch: 2 [2