In [1]:
# What we do ---quantum-classical Hybrid ResNet-18 + Q-CNN as for O-RAN xApp
""
# This notebook demonstrates a quantum-classical hybrid model for spectrogram IQ classification in intelligent O-RAN.

# Highlights:
# - Classical ResNet-18 as a feature extractor
# - Quantum Convolutional Layers with 6-qubit circuits (PennyLane)
# - Attention-based quantum-classical fusion
# - Custom loss & quantum-inspired data augmentation

# - Updated for PennyLane, PyTorch, and xApp-oriented data pipeline
# Author: Jared A.Ergu
# License: CIS Lab--R301B, CCU, Taiwan
""

''

 # Part 2: Environment Setup and Imports

In [2]:
# my imports & Configuration
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import pennylane as qml
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, List
import logging

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Part 3: Model Configuration

In [3]:
# to configure model
@dataclass
class ModelConfig:
    resnet_pretrained: bool = True
    resnet_freeze_layers: int = 4
    feature_dim: int = 512
    n_qubits: int = 6
    n_layers: int = 4
    quantum_backend: str = "lightning.qubit"
    quantum_shots: int = 1024
    num_classes: int = 2
    learning_rate: float = 1e-4
    batch_size: int = 32
    dropout_rate: float = 0.3
    use_quantum_data_encoding: bool = True
    angle_embedding_type: str = "amplitude"
    input_channels: int = 3
    image_size: int = 224


# QuantumLayer

In [4]:
# QuantumLayer - parameterized 6-qubit circuit
class QuantumLayer(nn.Module):
    def __init__(self, n_qubits: int = 6, n_layers: int = 4, backend: str = "lightning.qubit"):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        self.dev = qml.device(backend, wires=n_qubits)

        self.params = nn.Parameter(torch.randn(n_layers, n_qubits, 3) * 0.1)
        self.phase_params = nn.Parameter(torch.randn(n_layers, n_qubits) * 0.05)

        self.qnode = qml.QNode(self._circuit, self.dev, interface="torch")

    def _circuit(self, inputs, params, phase_params):
        # Input encoding and initial phase injection
        for i in range(self.n_qubits):
            qml.RY(inputs[i % len(inputs)], wires=i)
            qml.RZ(phase_params[0, i], wires=i)

        # Variational circuit with entanglement and phase layers
        for layer in range(self.n_layers):
            for q in range(self.n_qubits):
                qml.RX(params[layer, q, 0], wires=q)
                qml.RY(params[layer, q, 1], wires=q)
                qml.RZ(params[layer, q, 2], wires=q)

            for q in range(self.n_qubits - 1):
                qml.CNOT(wires=[q, q + 1])
            qml.CNOT(wires=[self.n_qubits - 1, 0])

            if layer % 2 == 1:
                for q in range(1, self.n_qubits):
                    qml.CZ(wires=[0, q])
                if self.n_qubits >= 4:
                    qml.CZ(wires=[1, 3])
                    qml.CZ(wires=[2, 4])
                if self.n_qubits == 6:
                    qml.CZ(wires=[1, 5])

            for q in range(self.n_qubits):
                qml.PhaseShift(phase_params[layer, q], wires=q)

        # Measurement in Pauli basis
        return [
            qml.expval(qml.PauliZ(wires=0)),
            qml.expval(qml.PauliX(wires=1)),
            qml.expval(qml.PauliY(wires=2)),
            qml.expval(qml.PauliZ(wires=3)),
            qml.expval(qml.PauliX(wires=4)),
            qml.expval(qml.PauliY(wires=5)),
        ]
        return torch.stack(outputs)
    def forward(self, x):
        batch_size = x.shape[0]
        x_norm = torch.tanh(x) * np.pi
        outputs = []

        for i in range(batch_size):
            xi = x_norm[i]
            xi = xi[:self.n_qubits] if xi.shape[0] > self.n_qubits else F.pad(xi, (0, self.n_qubits - xi.shape[0]))

            out = self.qnode(xi, self.params, self.phase_params)
            if isinstance(out, list):
                out_tensor = torch.stack(out)
            else:
                out_tensor = out

            outputs.append(out_tensor)

        return torch.stack(outputs).float()  # âœ… FIX: ensure float32 for torch.Linear compatibility


# QuantumConvolutionalLayer

In [5]:
# QuantumConvolutionalLayer - Spectrogram Feature to Quantum Patch Encoder
class QuantumConvolutionalLayer(nn.Module):
    def __init__(self, n_qubits: int = 6, n_layers: int = 4, kernel_size: int = 3, stride: int = 1):
        super().__init__()
        self.n_qubits = n_qubits
        self.kernel_size = kernel_size
        self.stride = stride

        # Core quantum processor
        self.quantum_layer = QuantumLayer(n_qubits, n_layers)

        # Classical post-processing after quantum circuit
        self.post_process = nn.Sequential(
            nn.Linear(n_qubits, n_qubits * 2),
            nn.LayerNorm(n_qubits * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(n_qubits * 2, n_qubits),
            nn.Tanh()
        )

        # Project high-dim input to match quantum input size if needed
        self.adaptive_pool = nn.AdaptiveAvgPool1d(n_qubits)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Spectrogram tensor (B, C, H, W)
        Returns:
            Quantum-processed features (B, n_qubits)
        """
        B, C, H, W = x.shape

        # Extract temporal features: Mean over frequency (height)
        temporal = torch.mean(x, dim=2).view(B, -1)

        # Extract spectral features: Mean over time (width)
        spectral = torch.mean(x, dim=3).view(B, -1)

        # Combine frequency-time features
        combined = torch.cat([temporal, spectral], dim=1)

        # Pool or pad to match quantum input size
        if combined.shape[1] > self.n_qubits:
            combined = self.adaptive_pool(combined.unsqueeze(1)).squeeze(1)
        elif combined.shape[1] < self.n_qubits:
            combined = F.pad(combined, (0, self.n_qubits - combined.shape[1]))

        # Run through quantum circuit
        quantum_out = self.quantum_layer(combined)

        # Classical post-quantum processing
        return self.post_process(quantum_out)
        


# ResNetQuantumHybrid --> the full model pipeline that ties ResNet + Quantum layers + attention + classifier?

In [6]:
# ResNetQuantumHybrid - Full xApp-Oriented Model with Dual Quantum Paths
class ResNetQuantumHybrid(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # Backbone Feature Extractor (ResNet-18)
        self.resnet_backbone = self._build_resnet_backbone()

        #  Dimensionality Reduction Head for ResNet Output
        self.feature_reducer = nn.Sequential(
            nn.Linear(512, config.feature_dim),
            nn.BatchNorm1d(config.feature_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.feature_dim, config.n_qubits * 4),
            nn.ReLU(),
            nn.Dropout(config.dropout_rate * 0.5)
        )

        # Dual Quantum Convolutional Pipelines
        self.quantum_conv1 = QuantumConvolutionalLayer(config.n_qubits, config.n_layers)
        self.quantum_conv2 = QuantumConvolutionalLayer(config.n_qubits, max(2, config.n_layers // 2))

        # Final Quantum Fusion Layer
        self.quantum_fusion = QuantumLayer(config.n_qubits, config.n_layers)

        # Multihead Attention Between Classical and Quantum
        self.attention = nn.MultiheadAttention(
            embed_dim=config.n_qubits,
            num_heads=2,
            dropout=config.dropout_rate,
            batch_first=True
        )

        # Learnable Feature Fusion Weights
        self.fusion_weights = nn.Parameter(torch.tensor([0.7, 0.3]))

        # Final Classification Head
        self.classifier_head = nn.Sequential(
            nn.Linear(config.n_qubits, config.n_qubits * 3),
            nn.LayerNorm(config.n_qubits * 3),
            nn.ReLU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.n_qubits * 3, config.n_qubits * 2),
            nn.ReLU(),
            nn.Dropout(config.dropout_rate * 0.5),
            nn.Linear(config.n_qubits * 2, config.num_classes)
        )

    def _build_resnet_backbone(self) -> nn.Module:
        resnet = models.resnet18(pretrained=self.config.resnet_pretrained)

        # Modify input channel if spectrograms have different format
        if self.config.input_channels != 3:
            resnet.conv1 = nn.Conv2d(
                self.config.input_channels, 64, kernel_size=7, stride=2,
                padding=3, bias=False
            )

        # Remove the final FC classifier
        resnet = nn.Sequential(*list(resnet.children())[:-1])

        # Freeze early layers
        for i, child in enumerate(resnet.children()):
            if i < self.config.resnet_freeze_layers:
                for param in child.parameters():
                    param.requires_grad = False

        return resnet

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        B = x.shape[0]

        # Step 1: Classical feature extraction
        classical_feat = self.resnet_backbone(x).view(B, -1)
        classical_feat = self.feature_reducer(classical_feat)

        # Step 2: Quantum pathway 1 and 2
        qfeat1 = self.quantum_conv1(x)
        qfeat2 = self.quantum_conv2(x)

        # Step 3: Learnable quantum combination
        alpha = torch.sigmoid(self.fusion_weights[0])
        beta = torch.sigmoid(self.fusion_weights[1])
        combined_qfeat = (alpha * qfeat1 + beta * qfeat2) / (alpha + beta)

        # Step 4: Quantum fusion
        fused_qfeat = self.quantum_fusion(combined_qfeat)

        # Step 5: Attention between classical and quantum
        q_input = fused_qfeat.unsqueeze(1)
        c_input = classical_feat[:, :self.config.n_qubits].unsqueeze(1)
        attended_q, attn_weights = self.attention(q_input, q_input, q_input)
        attended_q = attended_q.squeeze(1)

        # Step 6: Final fusion and classification
        final_feat = attended_q + 0.1 * c_input.squeeze(1)
        logits = self.classifier_head(final_feat)

        return {
            "logits": logits,
            "classical_features": classical_feat,
            "quantum_features": fused_qfeat,
            "attention_weights": attn_weights.squeeze(1),
            "combined_features": final_feat,
            "quantum_pathway1": qfeat1,
            "quantum_pathway2": qfeat2
        }


# QuantumLoss â€” combining CE loss + quantum regularization + pathway alignment

In [7]:
# QuantumLoss - Total Loss for Quantum-Classical Hybrid Model
class QuantumLoss(nn.Module):
    def __init__(self, alpha: float = 0.1, beta: float = 0.05, gamma: float = 0.02):
        super().__init__()
        self.alpha = alpha  # Regularization
        self.beta = beta    # Coherence penalty
        self.gamma = gamma  # Pathway alignment
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor,
                model: ResNetQuantumHybrid) -> Dict[str, torch.Tensor]:
        """
        Compute loss components and aggregate
        """
        # ðŸ”¹ 1. Classical cross-entropy
        ce = self.ce_loss(outputs['logits'], targets)

        # ðŸ”¹ 2. Quantum regularization (L2 norm on parameters)
        reg = 0
        n = 0
        for module in model.modules():
            if isinstance(module, QuantumLayer):
                reg += torch.norm(module.params, p=2)
                if hasattr(module, 'phase_params'):
                    reg += torch.norm(module.phase_params, p=2)
                n += 1
        if n > 0:
            reg /= n

        #  3. Quantum coherence loss (maximize feature std dev)
        q_feat = outputs['quantum_features']
        coherence = -torch.mean(torch.std(q_feat, dim=1))

        #  4. Pathway consistency (MSE between two Q-features)
        if 'quantum_pathway1' in outputs and 'quantum_pathway2' in outputs:
            align = F.mse_loss(outputs['quantum_pathway1'], outputs['quantum_pathway2'])
        else:
            align = torch.tensor(0.0, device=q_feat.device)

        #  5. Combine all
        total = ce + self.alpha * reg + self.beta * coherence + self.gamma * align

        return {
            "total_loss": total,
            "ce_loss": ce,
            "quantum_reg": reg,
            "coherence_loss": coherence,
            "pathway_consistency": align
        }


# QuantumDataAugmentation --> quantum-inspired transformations for spectrograms

In [8]:
# QuantumDataAugmentation - Quantum-inspired spectrogram augmenting
class QuantumDataAugmentation:
    @staticmethod
    def quantum_noise_injection(x: torch.Tensor, noise_level: float = 0.1) -> torch.Tensor:
        """
        Simulating quantum measurement noise and bit-flip errors
        """
        noise = torch.randn_like(x) * noise_level
        pauli_noise = torch.rand_like(x)
        mask = pauli_noise < 0.06  # 6% bit-flip probability
        noisy_x = x + noise
        noisy_x[mask] *= -1  # simulate Pauli-X flip
        return torch.clamp(noisy_x, 0, 1)

    @staticmethod
    def quantum_rotation_augmentation(x: torch.Tensor, max_angle: float = 0.1) -> torch.Tensor:
        """
        Simulates phase shifts by applying sine-cosine mixing with random angles
        """
        angles = (torch.rand_like(x) * 2 - 1) * max_angle  # Uniform in [-max_angle, +max_angle]
        rotated_x = x * torch.cos(angles) + torch.randn_like(x) * torch.sin(angles) * 0.1
        return torch.clamp(rotated_x, 0, 1)

    @staticmethod
    def spectrogram_quantum_transform(x: torch.Tensor, n_qubits: int = 6) -> torch.Tensor:
        """
        Applies frequency-specific quantum-like transforms to simulate entanglement variance
        """
        B, C, H, W = x.shape
        freq_weights = torch.randn(n_qubits, device=x.device) * 0.1
        freq_indices = torch.linspace(0, H - 1, n_qubits).long()

        for i, freq_idx in enumerate(freq_indices):
            if freq_idx < H:
                x[:, :, freq_idx, :] *= (1 + freq_weights[i])

        return torch.clamp(x, 0, 1)


# Utility Functions (create_hybrid_model, count_parameters)

In [9]:
# Utility: Model Factory + Parameter Counter
def create_hybrid_model(config: ModelConfig) -> ResNetQuantumHybrid:
    """
    Factory function for initializing the hybrid model with proper quantum parameter init
    """
    logger.info(f" Initializing hybrid model with {config.n_qubits} qubits and {config.n_layers} quantum layers...")
    model = ResNetQuantumHybrid(config)

    # Xavier (Glorot) initialization for trainable quantum parameters
    for module in model.modules():
        if isinstance(module, QuantumLayer):
            if module.params.data.ndim >= 2:
                nn.init.xavier_uniform_(module.params.data)
            else:
                module.params.data.uniform_(-0.1, 0.1)

            if hasattr(module, 'phase_params'):
                module.phase_params.data.uniform_(-0.05, 0.05)

    return model


def count_parameters(model: nn.Module) -> Dict[str, int]:
    """
    Counts classical and quantum parameters for reporting
    """
    classical = 0
    quantum = 0

    for name, param in model.named_parameters():
        if 'quantum' in name.lower() or 'params' in name:
            quantum += param.numel()
        else:
            classical += param.numel()

    return {
        'classical_parameters': classical,
        'quantum_parameters': quantum,
        'total_parameters': classical + quantum
    }


# Forward Pass checkingg

In [10]:
# Quick Check - Forward Pass
if __name__ == "__main__":
    config = ModelConfig(
        n_qubits=6,
        n_layers=4,
        num_classes=2,
        feature_dim=256,
        input_channels=3  # RGB spectrograms
    )

    model = create_hybrid_model(config)
    param_count = count_parameters(model)

    print("Hybrid Model Created Successfully!")
    print(f" Classical Params: {param_count['classical_parameters']:,}")
    print(f"Quantum Params: {param_count['quantum_parameters']:,}")
    print(f"Total Params:    {param_count['total_parameters']:,}")

    # Dummy spectrogram input (batch size 2)
    test_input = torch.randn(2, 3, 224, 224)

    with torch.no_grad():
        outputs = model(test_input)
        print(" Logits shape:", outputs['logits'].shape)
        print(" Quantum feature shape:", outputs['quantum_features'].shape)




Hybrid Model Created Successfully!
 Classical Params: 11,315,106
Quantum Params: 612
Total Params:    11,315,718
 Logits shape: torch.Size([2, 2])
 Quantum feature shape: torch.Size([2, 6])
