In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import pennylane as qml
import math
import matplotlib.pyplot as plt
from torchsummary import summary

# KANLinear definition Soure: https://github.com/Blealtan/efficient-kan/blob/f39e5146af34299ad3a581d2106eb667ba0fa6fa/src/efficient_kan/kan.py#L6
class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

# Quantum layer using PennyLane
class QuantumLayer(nn.Module):
    def __init__(self, n_qubits, n_features):
        super(QuantumLayer, self).__init__()
        self.n_qubits = n_qubits
        self.n_features = n_features

        # Define a Pennylane device
        self.device = qml.device("default.qubit", wires=n_qubits)

        # Define a simple quantum circuit
        def circuit(inputs, weights):
            for i in range(len(inputs)):  # Use len(inputs) to avoid IndexError
                qml.RX(inputs[i], wires=i)
                qml.RY(inputs[i], wires=i)
        
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]


        # Create a quantum node
        weight_shapes = {"weights": (3, n_qubits)}  # Example: 3 layers of entanglement
        self.qnode = qml.QNode(circuit, self.device, interface="torch")

        # Convert the quantum node to a torch layer
        self.qlayer = qml.qnn.TorchLayer(self.qnode, weight_shapes)

    def forward(self, x):
        # Ensure input size matches the number of qubits
        if x.size(1) < self.n_qubits:
            raise ValueError(
                f"Input features ({x.size(1)}) must be at least the number of qubits ({self.n_qubits})."
            )
        x = x[:, :self.n_qubits]  # Use only the required features
        return self.qlayer(x)


# CNN model with KAN and Quantum Layer
class CNNKANQuantum(nn.Module):
    def __init__(self):
        super(CNNKANQuantum, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2)
        self.kan1 = KANLinear(64 * 8 * 8, 256)
        self.quantum_layer = QuantumLayer(n_qubits=8, n_features=256)
        self.kan2 = KANLinear(8, 10)  # Adjust output size to match quantum layer

    def forward(self, x):
        x = F.selu(self.conv1(x))
        x = self.pool1(x)
        x = F.selu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.kan1(x)
        x = self.quantum_layer(x)
        x = self.kan2(x)
        return x

# Model instantiation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNKANQuantum().to(device)
print(model)

# Print parameter details
def print_parameter_details(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            params = parameter.numel()
            total_params += params
            print(f"{name}: {params}")
    print(f"Total trainable parameters: {total_params}")

print_parameter_details(model)
summary(model, input_size=(3, 32, 32))

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=500, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Training and testing functions
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-3)

# Training loop
for epoch in range(1, 11):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import math
import matplotlib.pyplot as plt
from torchsummary import summary
import pennylane as qml
from pennylane import numpy as np
from torch.nn import Module

# KANLinear definition Soure: https://github.com/Blealtan/efficient-kan/blob/f39e5146af34299ad3a581d2106eb667ba0fa6fa/src/efficient_kan/kan.py#L6
class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )



import torch
from torch import nn
import pennylane as qml
import numpy as np
from torch.nn import functional as F
from typing import List, Optional, Tuple

# class QuantumLayer(nn.Module):
#     def __init__(self, n_qubits, n_layers, seed=None):
#         super(QuantumLayer, self).__init__()
        
#         # Random seed for reproducibility if provided
#         if seed is not None:
#             torch.manual_seed(seed)
#             np.random.seed(seed)

#         self.n_qubits = n_qubits
#         self.n_layers = n_layers
        
#         # Quantum device initialization
#         self.dev = qml.device("default.qubit", wires=self.n_qubits)

#         # Define the quantum circuit
#         @qml.qnode(self.dev, interface="torch")
#         def circuit(inputs, weights):
#             # Encode classical data
#             for i in range(self.n_qubits):
#                 qml.RX(inputs[i], wires=i)

#             # Variational layer with entangling gates
#             qml.StronglyEntanglingLayers(weights, wires=range(self.n_qubits))

#             # Measurement of all qubits using PauliZ
#             return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]

#         self.circuit = circuit

#         # Initialize quantum weights with He initialization for stability
#         self.weights = nn.Parameter(torch.randn(self.n_layers, self.n_qubits, 3, dtype=torch.float32) * np.sqrt(2. / self.n_qubits))

#     def forward(self, x):
#         # Flatten the input for quantum encoding
#         batch_size = x.size(0)
#         x = x.view(batch_size, -1)
        
#         results = []
    
#         for i in range(batch_size):
#             # Normalize inputs to [0, π], ensuring that the data fits the quantum circuit expectations
#             inputs = x[i, :self.n_qubits]  
#             inputs = torch.tanh(inputs) * np.pi  # Normalizing inputs within [0, π]
            
#             # Convert quantum circuit output to torch.Tensor and match device
#             circuit_output = torch.tensor(
#                 self.circuit(inputs, self.weights), dtype=torch.float32
#             ).to(x.device)
            
#             # Post-processing: apply activation function (like tanh) to quantum outputs
#             results.append(torch.tanh(circuit_output))  # Activation to regularize output
    
#         return torch.stack(results)

class QuantumLayer(nn.Module):
    """Enhanced Quantum Layer with improved architecture and error handling"""
    
    def __init__(
        self, 
        n_qubits: int, 
        n_layers: int,
        activation: str = 'tanh',
        device_type: str = 'default.qubit',
        seed: Optional[int] = None,
        noise_strength: float = 0.0
    ):
        super(QuantumLayer, self).__init__()
        
        if n_qubits <= 0 or n_layers <= 0:
            raise ValueError("n_qubits and n_layers must be positive integers")
            
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        self.noise_strength = noise_strength
        
        # Set random seeds if provided
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            
        # Activation function selection
        self.activation = {
            'tanh': torch.tanh,
            'relu': F.relu,
            'sigmoid': torch.sigmoid
        }.get(activation, torch.tanh)
        
        # Initialize quantum device with error handling
        try:
            self.dev = qml.device(device_type, wires=self.n_qubits)
        except Exception as e:
            raise RuntimeError(f"Failed to initialize quantum device: {str(e)}")
            
        # Define quantum circuit with improved architecture
        @qml.qnode(self.dev, interface="torch", diff_method="parameter-shift")
        def circuit(inputs: torch.Tensor, weights: torch.Tensor) -> List[float]:
            # Input encoding layer
            self._encode_inputs(inputs)
            
            # Variational layers
            self._apply_variational_layers(weights)
            
            # Add noise if specified (for robustness training)
            if self.noise_strength > 0:
                self._apply_noise()
            
            # Measurement layer
            return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
            
        self.circuit = circuit
        
        # Initialize weights with improved initialization scheme
        weight_shape = (self.n_layers, self.n_qubits, 3)
        self.weights = nn.Parameter(
            self._initialize_weights(weight_shape)
        )
        
    def _initialize_weights(self, shape: Tuple) -> torch.Tensor:
        """Improved weight initialization with scale adjustment"""
        scale = np.sqrt(2.0 / (self.n_qubits * self.n_layers))
        return torch.randn(shape, dtype=torch.float32) * scale
        
    def _encode_inputs(self, inputs: torch.Tensor) -> None:
        """Enhanced input encoding with amplitude encoding"""
        for i in range(self.n_qubits):
            qml.RY(inputs[i], wires=i)
            qml.RZ(inputs[i], wires=i)
            
    def _apply_variational_layers(self, weights: torch.Tensor) -> None:
        """Apply variational layers with improved entanglement"""
        for layer in range(self.n_layers):
            # Custom entangling layers
            qml.StronglyEntanglingLayers(
                weights[layer].reshape(1, self.n_qubits, 3), 
                wires=range(self.n_qubits)
            )
            # Add CNOT ladder for better entanglement
            for i in range(self.n_qubits - 1):
                qml.CNOT(wires=[i, i + 1])
                
    def _apply_noise(self) -> None:
        """Apply controlled noise for robustness"""
        for i in range(self.n_qubits):
            qml.DepolarizingChannel(self.noise_strength, wires=i)
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        
        results = []
        for i in range(batch_size):
            try:
                # Improved input preprocessing
                inputs = x[i, :self.n_qubits]
                inputs = self._preprocess_inputs(inputs)
                
                # Execute quantum circuit
                circuit_output = torch.tensor(
                    self.circuit(inputs, self.weights),
                    dtype=torch.float32
                ).to(x.device)
                
                # Post-process output
                processed_output = self._postprocess_output(circuit_output)
                results.append(processed_output)
                
            except Exception as e:
                raise RuntimeError(f"Error in quantum circuit execution: {str(e)}")
                
        return torch.stack(results)
        
    def _preprocess_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
        """Enhanced input preprocessing with normalization"""
        # Scale inputs to [-π/2, π/2] for better quantum encoding
        return torch.arctan(inputs) * 2
        
    def _postprocess_output(self, output: torch.Tensor) -> torch.Tensor:
        """Enhanced output post-processing"""
        return self.activation(output)

# class CNNKANWithQuantum(nn.Module):
#     def __init__(self, n_qubits=8, n_layers=1):
#         super(CNNKANWithQuantum, self).__init__()
#         self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
#         self.pool1 = nn.MaxPool2d(2)
#         self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
#         self.pool2 = nn.MaxPool2d(2)
#         self.kan1 = KANLinear(64 * 8 * 8, 256)
#         self.quantum = QuantumLayer(n_qubits=n_qubits, n_layers=n_layers)
#         self.fc = nn.Linear(n_qubits, 10)  # Map quantum outputs to class predictions
        
#     def forward(self, x):
#         x = F.selu(self.conv1(x))
#         x = self.pool1(x)
#         x = F.selu(self.conv2(x))
#         x = self.pool2(x)
#         x = x.view(x.size(0), -1)
#         x = self.kan1(x)
#         x = self.quantum(x)
#         x = self.fc(x)
#         return x

class CNNKANWithQuantum(nn.Module):
    def __init__(self, n_qubits=4, n_layers=2):
        super(CNNKANWithQuantum, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # For 3-channel input
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2)
        self.kan1 = KANLinear(64 * 8 * 8, 256)  # Update input size for KANLinear
        self.fc = nn.Linear(256, n_qubits)
        # self.quantum = QuantumLayer(n_qubits=n_qubits, n_layers=n_layers)
        self.quantum_layers = nn.ModuleList([QuantumLayer(n_qubits=n_qubits, n_layers=n_layers) for _ in range(n_layers)])

        self.output = nn.Linear(n_qubits, 10)  # Map quantum outputs to class predictions

    def forward(self, x):
        x = F.selu(self.conv1(x))
        x = self.pool1(x)
        x = F.selu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)  # Flatten feature map
        x = self.kan1(x)
        x = F.selu(self.fc(x))  # Apply fully connected layer
        # x = self.quantum(x
        for quantum_layer in self.quantum_layers:
            x = quantum_layer(x)
        x = self.output(x)  # Final classification layer
        return x

import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 2)  # Final output layer

    def forward(self, x):
        # Convolutional layers
        x = F.selu(self.conv1(x))
        x = self.pool1(x)
        x = F.selu(self.conv2(x))
        x = self.pool2(x)

        # Flattening the layer for the fully connected layer
        x = x.view(x.size(0), -1)

        # Fully connected layers
        x = F.selu(self.fc1(x))
        x = self.fc2(x)

        return x

def print_parameter_details(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            params = parameter.numel()  # Number of elements in the tensor
            total_params += params
            print(f"{name}: {params}")
    print(f"Total trainable parameters: {total_params}")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = CNN().to(device)

# Uncommnet this line for CNN KAN.
model = CNNKANWithQuantum().to(device)
print(model)
print_parameter_details(model)
summary(model,  input_size=(3, 32, 32))

# Note the this is just a rough demo for Visualization. Need modifcation.
def visualize_kan_parameters(kan_layer, layer_name):
    base_weights = kan_layer.base_weight.data.cpu().numpy()
    plt.hist(base_weights.ravel(), bins=50)
    plt.title(f"Distribution of Base Weights - {layer_name}")
    plt.xlabel("Weight Value")
    plt.ylabel("Frequency")
    plt.show()
    if hasattr(kan_layer, 'spline_weight'):
        spline_weights = kan_layer.spline_weight.data.cpu().numpy()
        plt.hist(spline_weights.ravel(), bins=50)
        plt.title(f"Distribution of Spline Weights - {layer_name}")
        plt.xlabel("Weight Value")
        plt.ylabel("Frequency")
        plt.show()

for name, param in model.named_parameters():
    print(f"{name}: {param.size()} {'requires_grad' if param.requires_grad else 'frozen'}")

# TODO: Need to explore various Optimizer and optimize the Learning Rate.
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
# ])
# train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform)
# train_loader = DataLoader(train_dataset, batch_size=500, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# # Transformations to reshape MNIST to (3, 32, 32)
# transform = transforms.Compose([
#     transforms.Resize((32, 32)),  # Resize to 32x32
#     transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for 3 channels
# ])

# from torch.utils.data import DataLoader, Dataset
# from torchvision import datasets, transforms

# # Custom Dataset Wrapper to Filter Specific Labels
# class FilteredMNIST(Dataset):
#     def __init__(self, dataset, labels_to_include):
#         """
#         Filters the given dataset to include only specified labels.
        
#         Args:
#             dataset (Dataset): The dataset to filter.
#             labels_to_include (list): List of labels to include in the filtered dataset.
#         """
#         self.dataset = dataset
#         self.labels_to_include = labels_to_include
#         self.filtered_indices = [
#             idx for idx, (_, label) in enumerate(dataset) if label in labels_to_include
#         ]

#     def __len__(self):
#         return len(self.filtered_indices)

#     def __getitem__(self, idx):
#         original_idx = self.filtered_indices[idx]
#         return self.dataset[original_idx]

# # # Updated Dataset and Dataloader
# # train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# # test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# # train_loader = DataLoader(train_dataset, batch_size=500, shuffle=True)
# # test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# # Load the MNIST dataset
# original_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# original_test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# # Filter the datasets for labels 3 and 6
# filtered_train_dataset = FilteredMNIST(original_train_dataset, labels_to_include=[3, 6])
# filtered_test_dataset = FilteredMNIST(original_test_dataset, labels_to_include=[3, 6])

# # Create DataLoaders for the filtered datasets
# train_loader = DataLoader(filtered_train_dataset, batch_size=500, shuffle=True)
# test_loader = DataLoader(filtered_test_dataset, batch_size=256, shuffle=False)

###############
# Quantum Data Augmentation Class
class QuantumDataAugmentation(nn.Module):
    def __init__(self, n_qubits):
        super(QuantumDataAugmentation, self).__init__()
        self.n_qubits = n_qubits

    def apply_random_quantum_rotation(self, image):
        """Apply random quantum rotations to the input image."""
        # Random angle for quantum rotation
        theta = random.uniform(0, 2 * np.pi)
        
        # We will use the quantum rotation matrix for a single qubit (simplified)
        # The rotation matrix in 1D (using a simple rotation in the quantum space)
        rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], 
                                    [np.sin(theta), np.cos(theta)]])
        
        # Apply the rotation to the image by multiplying each pixel by a random angle (simulating rotation)
        augmented_image = image * torch.tensor(rotation_matrix[0, 0], dtype=torch.float32)
        
        return augmented_image
    
    def forward(self, x):
        augmented_images = []
        for image in x:
            augmented_images.append(self.apply_random_quantum_rotation(image))
        
        return torch.stack(augmented_images)

# Define Transform for Preprocessing
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to 32x32
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for 3 channels
])

# Custom Dataset Wrapper to Filter Specific Labels
class FilteredMNIST(Dataset):
    def __init__(self, dataset, labels_to_include, quantum_aug):
        """
        Filters the given dataset to include only specified labels and applies quantum data augmentation.
        
        Args:
            dataset (Dataset): The dataset to filter.
            labels_to_include (list): List of labels to include in the filtered dataset.
            quantum_aug (QuantumDataAugmentation): Quantum data augmentation module to apply to the data.
        """
        self.dataset = dataset
        self.labels_to_include = labels_to_include
        self.quantum_aug = quantum_aug
        self.filtered_indices = [
            idx for idx, (_, label) in enumerate(dataset) if label in labels_to_include
        ]

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        original_idx = self.filtered_indices[idx]
        image, label = self.dataset[original_idx]
        
        # Apply Quantum Data Augmentation
        image = self.quantum_aug(image)  # Apply quantum rotation augmentation
        
        return image, label

# Initialize Quantum Data Augmentation
quantum_aug = QuantumDataAugmentation(n_qubits=8)

# Load the MNIST dataset
original_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
original_test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# Filter the datasets for labels 3 and 6
filtered_train_dataset = FilteredMNIST(original_train_dataset, labels_to_include=[3, 6], quantum_aug=quantum_aug)
filtered_test_dataset = FilteredMNIST(original_test_dataset, labels_to_include=[3, 6], quantum_aug=quantum_aug)

# Create DataLoaders for the filtered datasets
train_loader = DataLoader(filtered_train_dataset, batch_size=500, shuffle=True)
test_loader = DataLoader(filtered_test_dataset, batch_size=256, shuffle=False)

###############

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def evaluate(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss()(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

for epoch in range(6):
    train(model, device, train_loader, optimizer, epoch)
    evaluate(model, device, test_loader)
torch.save(model.state_dict(), 'model_weights_KAN.pth')

CNNKANWithQuantum(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (kan1): KANLinear(
    (base_activation): SiLU()
  )
  (fc): Linear(in_features=256, out_features=4, bias=True)
  (quantum_layers): ModuleList(
    (0-1): 2 x QuantumLayer()
  )
  (output): Linear(in_features=4, out_features=10, bias=True)
)
conv1.weight: 864
conv1.bias: 32
conv2.weight: 18432
conv2.bias: 64
kan1.base_weight: 1048576
kan1.spline_weight: 8388608
kan1.spline_scaler: 1048576
fc.weight: 1024
fc.bias: 4
quantum_layers.0.weights: 24
quantum_layers.1.weights: 24
output.weight: 40
output.bias: 10
Total trainable parameters: 10506278
----------------------------------------------------------------
        Layer (type)               Output