# Phase 5: Gate-Enhanced Neural Layers

**Quantum-Enhanced Simulation Learning for Reinforcement Learning**

Author: Saurabh Jalendra  
Institution: BITS Pilani (WILP Division)  
Date: November 2025

---

## Overview

This notebook implements **quantum gate-inspired neural network layers** that transform
classical neural operations using principles from quantum computing gates:

### Quantum Gates Implemented

1. **Hadamard Gate (H)**: Creates superposition-like feature mixing
2. **Rotation Gates (Rx, Ry, Rz)**: Parameterized rotations in feature space
3. **CNOT Gate**: Controlled operations creating entanglement-like correlations
4. **Phase Gate (S, T)**: Phase shifts for feature modulation

### Key Concepts

- **Unitary-inspired transformations**: Preserve information (approximately)
- **Parameterized rotations**: Learnable angles for flexible transformations
- **Entanglement-like correlations**: Feature dependencies through controlled ops

---

## 5.1 Setup and Imports

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import math
from typing import Dict, List, Tuple, Optional, Union, Callable
from dataclasses import dataclass, field
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import matplotlib.pyplot as plt

from src.utils import set_seed, get_device, MetricLogger, Timer, COLORS

# Set seed for reproducibility
set_seed(42)
device = get_device()
print(f"Using device: {device}")

## 5.2 Quantum Gate Mathematical Foundations

### Classical Quantum Gates

**Hadamard Gate:**
$$H = \frac{1}{\sqrt{2}} \begin{pmatrix} 1 & 1 \\ 1 & -1 \end{pmatrix}$$

**Rotation Gates:**
$$R_x(\theta) = \begin{pmatrix} \cos(\theta/2) & -i\sin(\theta/2) \\ -i\sin(\theta/2) & \cos(\theta/2) \end{pmatrix}$$

$$R_y(\theta) = \begin{pmatrix} \cos(\theta/2) & -\sin(\theta/2) \\ \sin(\theta/2) & \cos(\theta/2) \end{pmatrix}$$

$$R_z(\theta) = \begin{pmatrix} e^{-i\theta/2} & 0 \\ 0 & e^{i\theta/2} \end{pmatrix}$$

### Classical Adaptations

We adapt these to real-valued neural network operations:
- Complex exponentials become sinusoidal transformations
- 2x2 matrices generalize to arbitrary dimensions
- Learnable parameters replace fixed angles

## 5.3 Hadamard-Inspired Layer

The Hadamard gate creates equal superposition. We adapt this to neural networks
by mixing features through orthogonal-like transformations.

In [None]:
class HadamardLayer(nn.Module):
    """
    Hadamard-inspired neural network layer.
    
    Creates superposition-like mixing of features using Hadamard-like
    transformations extended to arbitrary dimensions.
    
    Parameters
    ----------
    dim : int
        Input/output dimension (must be power of 2 for true Hadamard)
    learnable_scale : bool
        Whether to learn scaling factors
    normalize : bool
        Whether to normalize output
    """
    
    def __init__(
        self,
        dim: int,
        learnable_scale: bool = True,
        normalize: bool = True
    ):
        super().__init__()
        self.dim = dim
        self.normalize = normalize
        
        # Create Hadamard-like matrix
        H = self._create_hadamard_matrix(dim)
        self.register_buffer('hadamard', H)
        
        # Learnable scaling
        if learnable_scale:
            self.scale = nn.Parameter(torch.ones(dim))
        else:
            self.register_buffer('scale', torch.ones(dim))
        
        # Learnable bias
        self.bias = nn.Parameter(torch.zeros(dim))
    
    def _create_hadamard_matrix(self, n: int) -> Tensor:
        """
        Create a Hadamard-like orthogonal matrix.
        
        For dimensions that are powers of 2, uses true Hadamard construction.
        For other dimensions, uses an approximation via QR decomposition.
        """
        # Check if n is power of 2
        if n > 0 and (n & (n - 1)) == 0:
            # True Hadamard construction via Sylvester's method
            H = torch.tensor([[1.0]])
            while H.shape[0] < n:
                H = torch.cat([
                    torch.cat([H, H], dim=1),
                    torch.cat([H, -H], dim=1)
                ], dim=0)
            H = H / math.sqrt(n)
        else:
            # Approximate with random orthogonal matrix
            random_matrix = torch.randn(n, n)
            Q, _ = torch.linalg.qr(random_matrix)
            H = Q
        
        return H
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Apply Hadamard-like transformation.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, ..., dim)
        
        Returns
        -------
        Tensor
            Transformed tensor of same shape
        """
        # Apply Hadamard transformation
        y = F.linear(x, self.hadamard)
        
        # Scale and bias
        y = y * self.scale + self.bias
        
        # Optional normalization
        if self.normalize:
            y = F.layer_norm(y, (self.dim,))
        
        return y
    
    def extra_repr(self) -> str:
        return f"dim={self.dim}, normalize={self.normalize}"

In [None]:
# Test Hadamard layer
print("Testing HadamardLayer...")

# Power of 2 dimension
hadamard_64 = HadamardLayer(64).to(device)
x = torch.randn(32, 64, device=device)
y = hadamard_64(x)
print(f"Input shape: {x.shape}, Output shape: {y.shape}")

# Verify Hadamard matrix properties
H = hadamard_64.hadamard
HHT = H @ H.T
identity_error = torch.norm(HHT - torch.eye(64, device=device)).item()
print(f"Orthogonality error (should be ~0): {identity_error:.6f}")

# Non-power of 2 dimension
hadamard_100 = HadamardLayer(100).to(device)
x2 = torch.randn(32, 100, device=device)
y2 = hadamard_100(x2)
print(f"Non-power-of-2: Input shape: {x2.shape}, Output shape: {y2.shape}")

## 5.4 Rotation Gate Layers

Rotation gates perform parameterized rotations in feature space.
We implement Rx, Ry, Rz-inspired layers with learnable angles.

In [None]:
class RotationLayer(nn.Module):
    """
    Rotation gate-inspired neural network layer.
    
    Implements learnable rotations in feature space inspired by
    quantum rotation gates (Rx, Ry, Rz).
    
    Parameters
    ----------
    dim : int
        Feature dimension
    num_rotations : int
        Number of rotation pairs (rotations applied to pairs of features)
    rotation_type : str
        Type of rotation: 'xy', 'xz', 'yz', or 'all'
    """
    
    def __init__(
        self,
        dim: int,
        num_rotations: Optional[int] = None,
        rotation_type: str = 'all'
    ):
        super().__init__()
        self.dim = dim
        self.num_rotations = num_rotations or (dim // 2)
        self.rotation_type = rotation_type
        
        # Learnable rotation angles
        if rotation_type == 'all':
            # Three angles per rotation (Rx, Ry, Rz)
            self.angles = nn.Parameter(
                torch.randn(self.num_rotations, 3) * 0.1
            )
        else:
            # Single angle per rotation
            self.angles = nn.Parameter(
                torch.randn(self.num_rotations) * 0.1
            )
        
        # Indices for rotation pairs
        indices = torch.randperm(dim)[:self.num_rotations * 2]
        self.register_buffer('idx1', indices[:self.num_rotations])
        self.register_buffer('idx2', indices[self.num_rotations:])
    
    def _apply_rotation_2d(
        self,
        x1: Tensor,
        x2: Tensor,
        theta: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """
        Apply 2D rotation to feature pairs.
        
        Parameters
        ----------
        x1, x2 : Tensor
            Feature pairs to rotate (batch, num_rotations)
        theta : Tensor
            Rotation angles (num_rotations,)
        
        Returns
        -------
        Tuple[Tensor, Tensor]
            Rotated feature pairs
        """
        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)
        
        y1 = cos_theta * x1 - sin_theta * x2
        y2 = sin_theta * x1 + cos_theta * x2
        
        return y1, y2
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Apply rotation transformations.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, dim)
        
        Returns
        -------
        Tensor
            Rotated tensor of same shape
        """
        # Clone to avoid in-place modification
        y = x.clone()
        
        # Get feature pairs
        x1 = x[:, self.idx1]  # (batch, num_rotations)
        x2 = x[:, self.idx2]  # (batch, num_rotations)
        
        if self.rotation_type == 'all':
            # Apply Rz, Ry, Rx in sequence
            for i in range(3):
                x1, x2 = self._apply_rotation_2d(x1, x2, self.angles[:, i])
        else:
            # Single rotation
            x1, x2 = self._apply_rotation_2d(x1, x2, self.angles)
        
        # Update features
        y = y.scatter(1, self.idx1.unsqueeze(0).expand(x.shape[0], -1), x1)
        y = y.scatter(1, self.idx2.unsqueeze(0).expand(x.shape[0], -1), x2)
        
        return y
    
    def extra_repr(self) -> str:
        return f"dim={self.dim}, num_rotations={self.num_rotations}, type={self.rotation_type}"

In [None]:
# Test Rotation layer
print("Testing RotationLayer...")

rotation_layer = RotationLayer(64, rotation_type='all').to(device)
x = torch.randn(32, 64, device=device)
y = rotation_layer(x)

print(f"Input shape: {x.shape}, Output shape: {y.shape}")
print(f"Number of rotation angles: {rotation_layer.angles.shape}")

# Verify approximate norm preservation (rotations should preserve norms)
x_norms = torch.norm(x, dim=1)
y_norms = torch.norm(y, dim=1)
norm_diff = torch.abs(x_norms - y_norms).mean().item()
print(f"Average norm difference: {norm_diff:.6f}")

## 5.5 CNOT-Inspired Layer

The CNOT (Controlled-NOT) gate creates entanglement between qubits.
We adapt this to create controlled dependencies between features.

In [None]:
class CNOTLayer(nn.Module):
    """
    CNOT-inspired neural network layer.
    
    Creates entanglement-like correlations between features through
    controlled operations where one feature controls the transformation
    of another.
    
    Parameters
    ----------
    dim : int
        Feature dimension
    num_controls : int
        Number of control-target pairs
    threshold : float
        Activation threshold for control feature
    learnable_threshold : bool
        Whether threshold is learnable
    """
    
    def __init__(
        self,
        dim: int,
        num_controls: Optional[int] = None,
        threshold: float = 0.0,
        learnable_threshold: bool = True
    ):
        super().__init__()
        self.dim = dim
        self.num_controls = num_controls or (dim // 2)
        
        # Learnable thresholds
        if learnable_threshold:
            self.threshold = nn.Parameter(
                torch.full((self.num_controls,), threshold)
            )
        else:
            self.register_buffer(
                'threshold',
                torch.full((self.num_controls,), threshold)
            )
        
        # Learnable transformation weights for target
        self.transform_weight = nn.Parameter(
            torch.randn(self.num_controls) * 0.1
        )
        
        # Temperature for soft thresholding
        self.temperature = nn.Parameter(torch.tensor(1.0))
        
        # Control and target indices
        indices = torch.randperm(dim)[:self.num_controls * 2]
        self.register_buffer('control_idx', indices[:self.num_controls])
        self.register_buffer('target_idx', indices[self.num_controls:])
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Apply CNOT-like controlled transformations.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, dim)
        
        Returns
        -------
        Tensor
            Transformed tensor
        """
        y = x.clone()
        
        # Get control and target features
        control = x[:, self.control_idx]  # (batch, num_controls)
        target = x[:, self.target_idx]    # (batch, num_controls)
        
        # Soft control activation using sigmoid
        control_activation = torch.sigmoid(
            (control - self.threshold) * self.temperature
        )
        
        # Apply controlled transformation (like XOR in quantum, we use negation-like transform)
        # When control is active, transform target
        transformed_target = target + control_activation * self.transform_weight * target
        
        # Update targets
        y = y.scatter(
            1,
            self.target_idx.unsqueeze(0).expand(x.shape[0], -1),
            transformed_target
        )
        
        return y
    
    def get_entanglement_strength(self) -> Tensor:
        """
        Compute a measure of entanglement strength.
        
        Returns
        -------
        Tensor
            Average absolute transformation weight
        """
        return torch.abs(self.transform_weight).mean()
    
    def extra_repr(self) -> str:
        return f"dim={self.dim}, num_controls={self.num_controls}"

In [None]:
# Test CNOT layer
print("Testing CNOTLayer...")

cnot_layer = CNOTLayer(64).to(device)
x = torch.randn(32, 64, device=device)
y = cnot_layer(x)

print(f"Input shape: {x.shape}, Output shape: {y.shape}")
print(f"Entanglement strength: {cnot_layer.get_entanglement_strength().item():.6f}")

# Show that targets change based on controls
# When control is high, target should change more
x_high_control = torch.randn(32, 64, device=device)
x_high_control[:, cnot_layer.control_idx] = 5.0  # High control values
y_high = cnot_layer(x_high_control)

x_low_control = x_high_control.clone()
x_low_control[:, cnot_layer.control_idx] = -5.0  # Low control values
y_low = cnot_layer(x_low_control)

# Target difference should be larger than control difference
target_diff = (y_high[:, cnot_layer.target_idx] - y_low[:, cnot_layer.target_idx]).abs().mean()
print(f"Target difference with high vs low control: {target_diff.item():.4f}")

## 5.6 Phase Gate Layer

Phase gates apply phase shifts to quantum states. We adapt this
to apply learnable modulations to features.

In [None]:
class PhaseLayer(nn.Module):
    """
    Phase gate-inspired neural network layer.
    
    Applies learnable phase-like modulations to features using
    sinusoidal transformations.
    
    Parameters
    ----------
    dim : int
        Feature dimension
    phase_type : str
        Type of phase gate: 'S' (pi/2), 'T' (pi/4), or 'learnable'
    """
    
    def __init__(
        self,
        dim: int,
        phase_type: str = 'learnable'
    ):
        super().__init__()
        self.dim = dim
        self.phase_type = phase_type
        
        if phase_type == 'S':
            # S gate: pi/2 phase
            self.register_buffer('phases', torch.full((dim,), math.pi / 2))
        elif phase_type == 'T':
            # T gate: pi/4 phase
            self.register_buffer('phases', torch.full((dim,), math.pi / 4))
        else:
            # Learnable phases
            self.phases = nn.Parameter(torch.randn(dim) * 0.1)
        
        # Learnable amplitude modulation
        self.amplitude = nn.Parameter(torch.ones(dim))
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Apply phase modulation.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, dim)
        
        Returns
        -------
        Tensor
            Phase-modulated tensor
        """
        # Apply phase shift through sinusoidal modulation
        # y = amplitude * (x * cos(phase) + |x| * sin(phase))
        cos_phase = torch.cos(self.phases)
        sin_phase = torch.sin(self.phases)
        
        y = self.amplitude * (x * cos_phase + torch.abs(x) * sin_phase)
        
        return y
    
    def extra_repr(self) -> str:
        return f"dim={self.dim}, type={self.phase_type}"

In [None]:
# Test Phase layer
print("Testing PhaseLayer...")

for phase_type in ['S', 'T', 'learnable']:
    phase_layer = PhaseLayer(64, phase_type=phase_type).to(device)
    x = torch.randn(32, 64, device=device)
    y = phase_layer(x)
    print(f"Phase type '{phase_type}': Input shape {x.shape}, Output shape {y.shape}")

## 5.7 Composite Quantum Gate Block

Combine multiple gate layers into a single quantum-inspired block.

In [None]:
class QuantumGateBlock(nn.Module):
    """
    Composite quantum gate-inspired neural network block.
    
    Combines Hadamard, Rotation, CNOT, and Phase layers in a
    configurable sequence similar to quantum circuits.
    
    Parameters
    ----------
    dim : int
        Feature dimension
    num_layers : int
        Number of gate layers in sequence
    use_hadamard : bool
        Include Hadamard layers
    use_rotation : bool
        Include Rotation layers
    use_cnot : bool
        Include CNOT layers
    use_phase : bool
        Include Phase layers
    residual : bool
        Use residual connections
    """
    
    def __init__(
        self,
        dim: int,
        num_layers: int = 2,
        use_hadamard: bool = True,
        use_rotation: bool = True,
        use_cnot: bool = True,
        use_phase: bool = True,
        residual: bool = True
    ):
        super().__init__()
        self.dim = dim
        self.residual = residual
        
        layers = []
        for i in range(num_layers):
            layer_gates = []
            
            if use_hadamard:
                layer_gates.append(HadamardLayer(dim, normalize=True))
            
            if use_rotation:
                layer_gates.append(RotationLayer(dim, rotation_type='all'))
            
            if use_cnot:
                layer_gates.append(CNOTLayer(dim))
            
            if use_phase:
                layer_gates.append(PhaseLayer(dim, phase_type='learnable'))
            
            layers.append(nn.Sequential(*layer_gates))
        
        self.layers = nn.ModuleList(layers)
        
        # Layer norm for residual
        if residual:
            self.layer_norms = nn.ModuleList([
                nn.LayerNorm(dim) for _ in range(num_layers)
            ])
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Apply quantum gate block.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, dim)
        
        Returns
        -------
        Tensor
            Transformed tensor
        """
        for i, layer in enumerate(self.layers):
            if self.residual:
                x = x + self.layer_norms[i](layer(x))
            else:
                x = layer(x)
        
        return x
    
    def extra_repr(self) -> str:
        return f"dim={self.dim}, num_layers={len(self.layers)}, residual={self.residual}"

In [None]:
# Test Quantum Gate Block
print("Testing QuantumGateBlock...")

gate_block = QuantumGateBlock(
    dim=64,
    num_layers=3,
    use_hadamard=True,
    use_rotation=True,
    use_cnot=True,
    use_phase=True,
    residual=True
).to(device)

x = torch.randn(32, 64, device=device)
y = gate_block(x)

print(f"Input shape: {x.shape}, Output shape: {y.shape}")
print(f"Number of parameters: {sum(p.numel() for p in gate_block.parameters())}")
print(f"\nBlock structure:")
print(gate_block)

## 5.8 Gate-Enhanced World Model Components

Now we integrate quantum gate layers into the world model architecture.

In [None]:
class GateEnhancedEncoder(nn.Module):
    """
    Encoder with quantum gate-enhanced layers.
    
    Parameters
    ----------
    obs_dim : int
        Observation dimension
    hidden_dim : int
        Hidden layer dimension
    embed_dim : int
        Output embedding dimension
    num_gate_layers : int
        Number of quantum gate blocks
    """
    
    def __init__(
        self,
        obs_dim: int,
        hidden_dim: int = 256,
        embed_dim: int = 64,
        num_gate_layers: int = 2
    ):
        super().__init__()
        
        # Input projection
        self.input_proj = nn.Linear(obs_dim, hidden_dim)
        
        # Quantum gate block
        self.gate_block = QuantumGateBlock(
            dim=hidden_dim,
            num_layers=num_gate_layers,
            residual=True
        )
        
        # Output projection
        self.output_proj = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, embed_dim),
            nn.ELU()
        )
    
    def forward(self, obs: Tensor) -> Tensor:
        """
        Encode observations.
        
        Parameters
        ----------
        obs : Tensor
            Observations of shape (batch, obs_dim)
        
        Returns
        -------
        Tensor
            Embeddings of shape (batch, embed_dim)
        """
        x = F.elu(self.input_proj(obs))
        x = self.gate_block(x)
        return self.output_proj(x)

In [None]:
class GateEnhancedDecoder(nn.Module):
    """
    Decoder with quantum gate-enhanced layers.
    
    Parameters
    ----------
    state_dim : int
        State dimension (deterministic + stochastic)
    hidden_dim : int
        Hidden layer dimension
    obs_dim : int
        Output observation dimension
    num_gate_layers : int
        Number of quantum gate blocks
    """
    
    def __init__(
        self,
        state_dim: int,
        hidden_dim: int = 256,
        obs_dim: int = 4,
        num_gate_layers: int = 2
    ):
        super().__init__()
        
        # Input projection
        self.input_proj = nn.Linear(state_dim, hidden_dim)
        
        # Quantum gate block
        self.gate_block = QuantumGateBlock(
            dim=hidden_dim,
            num_layers=num_gate_layers,
            residual=True
        )
        
        # Output layers (mean and log_std)
        self.output_layer = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ELU()
        )
        self.mean = nn.Linear(hidden_dim // 2, obs_dim)
        self.log_std = nn.Linear(hidden_dim // 2, obs_dim)
    
    def forward(self, state: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Decode state to observation distribution.
        
        Parameters
        ----------
        state : Tensor
            State of shape (batch, state_dim)
        
        Returns
        -------
        Tuple[Tensor, Tensor]
            Mean and log_std of predicted observation distribution
        """
        x = F.elu(self.input_proj(state))
        x = self.gate_block(x)
        x = self.output_layer(x)
        
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-10, 2)
        
        return mean, log_std

In [None]:
class GateEnhancedRSSM(nn.Module):
    """
    RSSM with quantum gate-enhanced transition model.
    
    Parameters
    ----------
    embed_dim : int
        Embedding dimension from encoder
    action_dim : int
        Action dimension
    deter_dim : int
        Deterministic state dimension
    stoch_dim : int
        Stochastic state dimension
    hidden_dim : int
        Hidden layer dimension
    num_gate_layers : int
        Number of quantum gate layers in transition
    """
    
    def __init__(
        self,
        embed_dim: int = 64,
        action_dim: int = 1,
        deter_dim: int = 128,
        stoch_dim: int = 32,
        hidden_dim: int = 256,
        num_gate_layers: int = 2
    ):
        super().__init__()
        self.deter_dim = deter_dim
        self.stoch_dim = stoch_dim
        
        # GRU cell for deterministic state
        self.gru = nn.GRUCell(stoch_dim + action_dim, deter_dim)
        
        # Gate-enhanced prior (imagination)
        self.prior_gate = QuantumGateBlock(
            dim=hidden_dim,
            num_layers=num_gate_layers,
            residual=True
        )
        self.prior_input = nn.Linear(deter_dim, hidden_dim)
        self.prior_output = nn.Linear(hidden_dim, stoch_dim * 2)
        
        # Gate-enhanced posterior (with observation)
        self.posterior_gate = QuantumGateBlock(
            dim=hidden_dim,
            num_layers=num_gate_layers,
            residual=True
        )
        self.posterior_input = nn.Linear(deter_dim + embed_dim, hidden_dim)
        self.posterior_output = nn.Linear(hidden_dim, stoch_dim * 2)
    
    def initial_state(self, batch_size: int, device: torch.device) -> Dict[str, Tensor]:
        """Get initial state."""
        return {
            'deter': torch.zeros(batch_size, self.deter_dim, device=device),
            'stoch': torch.zeros(batch_size, self.stoch_dim, device=device)
        }
    
    def get_full_state(self, state: Dict[str, Tensor]) -> Tensor:
        """Concatenate deterministic and stochastic states."""
        return torch.cat([state['deter'], state['stoch']], dim=-1)
    
    def prior(self, deter: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Compute prior distribution (for imagination).
        
        Parameters
        ----------
        deter : Tensor
            Deterministic state
        
        Returns
        -------
        Tuple[Tensor, Tensor]
            Mean and std of prior distribution
        """
        x = F.elu(self.prior_input(deter))
        x = self.prior_gate(x)
        stats = self.prior_output(x)
        mean, log_std = torch.chunk(stats, 2, dim=-1)
        std = F.softplus(log_std) + 0.1
        return mean, std
    
    def posterior(self, deter: Tensor, embed: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Compute posterior distribution (with observation).
        
        Parameters
        ----------
        deter : Tensor
            Deterministic state
        embed : Tensor
            Observation embedding
        
        Returns
        -------
        Tuple[Tensor, Tensor]
            Mean and std of posterior distribution
        """
        x = torch.cat([deter, embed], dim=-1)
        x = F.elu(self.posterior_input(x))
        x = self.posterior_gate(x)
        stats = self.posterior_output(x)
        mean, log_std = torch.chunk(stats, 2, dim=-1)
        std = F.softplus(log_std) + 0.1
        return mean, std
    
    def step(
        self,
        prev_state: Dict[str, Tensor],
        action: Tensor,
        embed: Optional[Tensor] = None
    ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
        """
        Single step of RSSM.
        
        Parameters
        ----------
        prev_state : Dict[str, Tensor]
            Previous state with 'deter' and 'stoch'
        action : Tensor
            Action taken
        embed : Optional[Tensor]
            Observation embedding (None for imagination)
        
        Returns
        -------
        Tuple[Dict[str, Tensor], Dict[str, Tensor]]
            New state and distribution stats
        """
        # Update deterministic state
        gru_input = torch.cat([prev_state['stoch'], action], dim=-1)
        deter = self.gru(gru_input, prev_state['deter'])
        
        # Get prior
        prior_mean, prior_std = self.prior(deter)
        
        # Get posterior if embed available, otherwise use prior
        if embed is not None:
            post_mean, post_std = self.posterior(deter, embed)
            # Sample from posterior
            stoch = post_mean + post_std * torch.randn_like(post_std)
        else:
            post_mean, post_std = prior_mean, prior_std
            # Sample from prior
            stoch = prior_mean + prior_std * torch.randn_like(prior_std)
        
        new_state = {'deter': deter, 'stoch': stoch}
        stats = {
            'prior_mean': prior_mean,
            'prior_std': prior_std,
            'post_mean': post_mean,
            'post_std': post_std
        }
        
        return new_state, stats

In [None]:
class GateEnhancedWorldModel(nn.Module):
    """
    Complete world model with quantum gate-enhanced components.
    
    Parameters
    ----------
    obs_dim : int
        Observation dimension
    action_dim : int
        Action dimension
    config : Optional[Dict]
        Configuration dictionary
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        config: Optional[Dict] = None
    ):
        super().__init__()
        config = config or {}
        
        # Dimensions
        hidden_dim = config.get('hidden_dim', 256)
        embed_dim = config.get('embed_dim', 64)
        deter_dim = config.get('deter_dim', 128)
        stoch_dim = config.get('stoch_dim', 32)
        num_gate_layers = config.get('num_gate_layers', 2)
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.state_dim = deter_dim + stoch_dim
        
        # Components
        self.encoder = GateEnhancedEncoder(
            obs_dim, hidden_dim, embed_dim, num_gate_layers
        )
        self.decoder = GateEnhancedDecoder(
            self.state_dim, hidden_dim, obs_dim, num_gate_layers
        )
        self.rssm = GateEnhancedRSSM(
            embed_dim, action_dim, deter_dim, stoch_dim, hidden_dim, num_gate_layers
        )
        
        # Reward and continue predictors
        self.reward_pred = nn.Sequential(
            nn.Linear(self.state_dim, hidden_dim // 2),
            nn.ELU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.continue_pred = nn.Sequential(
            nn.Linear(self.state_dim, hidden_dim // 2),
            nn.ELU(),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def initial_state(self, batch_size: int, device: torch.device) -> Dict[str, Tensor]:
        """Get initial RSSM state."""
        return self.rssm.initial_state(batch_size, device)
    
    def encode(self, obs: Tensor) -> Tensor:
        """Encode observation."""
        return self.encoder(obs)
    
    def decode(self, state: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
        """Decode state to observation distribution."""
        full_state = self.rssm.get_full_state(state)
        return self.decoder(full_state)
    
    def predict_reward(self, state: Dict[str, Tensor]) -> Tensor:
        """Predict reward from state."""
        full_state = self.rssm.get_full_state(state)
        return self.reward_pred(full_state).squeeze(-1)
    
    def predict_continue(self, state: Dict[str, Tensor]) -> Tensor:
        """Predict continue probability from state."""
        full_state = self.rssm.get_full_state(state)
        return torch.sigmoid(self.continue_pred(full_state)).squeeze(-1)
    
    def step(
        self,
        prev_state: Dict[str, Tensor],
        action: Tensor,
        obs: Optional[Tensor] = None
    ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
        """
        Single step of world model.
        
        Parameters
        ----------
        prev_state : Dict[str, Tensor]
            Previous RSSM state
        action : Tensor
            Action taken
        obs : Optional[Tensor]
            Observation (None for imagination)
        
        Returns
        -------
        Tuple[Dict[str, Tensor], Dict[str, Tensor]]
            New state and distribution stats
        """
        embed = self.encode(obs) if obs is not None else None
        return self.rssm.step(prev_state, action, embed)
    
    def forward(
        self,
        obs_seq: Tensor,
        action_seq: Tensor
    ) -> Dict[str, Tensor]:
        """
        Process a sequence through the world model.
        
        Parameters
        ----------
        obs_seq : Tensor
            Observations (batch, seq_len, obs_dim)
        action_seq : Tensor
            Actions (batch, seq_len, action_dim)
        
        Returns
        -------
        Dict[str, Tensor]
            Dictionary with states, predictions, and distribution stats
        """
        batch_size, seq_len = obs_seq.shape[:2]
        device = obs_seq.device
        
        # Initialize
        state = self.initial_state(batch_size, device)
        
        # Collect outputs
        states = []
        prior_means, prior_stds = [], []
        post_means, post_stds = [], []
        
        for t in range(seq_len):
            state, stats = self.step(state, action_seq[:, t], obs_seq[:, t])
            
            states.append(self.rssm.get_full_state(state))
            prior_means.append(stats['prior_mean'])
            prior_stds.append(stats['prior_std'])
            post_means.append(stats['post_mean'])
            post_stds.append(stats['post_std'])
        
        # Stack
        states = torch.stack(states, dim=1)
        
        # Decode all states
        flat_states = states.reshape(-1, states.shape[-1])
        obs_mean, obs_log_std = self.decoder(flat_states)
        obs_mean = obs_mean.reshape(batch_size, seq_len, -1)
        obs_log_std = obs_log_std.reshape(batch_size, seq_len, -1)
        
        # Predict rewards
        reward_pred = self.reward_pred(flat_states).reshape(batch_size, seq_len)
        
        return {
            'states': states,
            'obs_mean': obs_mean,
            'obs_log_std': obs_log_std,
            'reward_pred': reward_pred,
            'prior_mean': torch.stack(prior_means, dim=1),
            'prior_std': torch.stack(prior_stds, dim=1),
            'post_mean': torch.stack(post_means, dim=1),
            'post_std': torch.stack(post_stds, dim=1)
        }

In [None]:
# Test Gate-Enhanced World Model
print("Testing GateEnhancedWorldModel...")

model = GateEnhancedWorldModel(
    obs_dim=4,
    action_dim=1,
    config={'num_gate_layers': 2}
).to(device)

# Test forward pass
batch_size, seq_len = 16, 20
obs_seq = torch.randn(batch_size, seq_len, 4, device=device)
action_seq = torch.randn(batch_size, seq_len, 1, device=device)

outputs = model(obs_seq, action_seq)

print(f"States shape: {outputs['states'].shape}")
print(f"Obs mean shape: {outputs['obs_mean'].shape}")
print(f"Reward pred shape: {outputs['reward_pred'].shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## 5.9 Gate-Enhanced Training

Training loop for the gate-enhanced world model.

In [None]:
class GateEnhancedTrainer:
    """
    Trainer for gate-enhanced world model.
    
    Parameters
    ----------
    model : GateEnhancedWorldModel
        The world model to train
    learning_rate : float
        Learning rate
    kl_weight : float
        Weight for KL divergence loss
    free_nats : float
        Free nats for KL loss
    """
    
    def __init__(
        self,
        model: GateEnhancedWorldModel,
        learning_rate: float = 1e-4,
        kl_weight: float = 1.0,
        free_nats: float = 3.0
    ):
        self.model = model
        self.kl_weight = kl_weight
        self.free_nats = free_nats
        
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=1e-5
        )
        
        self.logger = MetricLogger(name='gate_enhanced')
    
    def compute_loss(
        self,
        obs_seq: Tensor,
        action_seq: Tensor,
        reward_seq: Tensor
    ) -> Tuple[Tensor, Dict[str, float]]:
        """
        Compute training loss.
        
        Parameters
        ----------
        obs_seq : Tensor
            Observations (batch, seq_len, obs_dim)
        action_seq : Tensor
            Actions (batch, seq_len, action_dim)
        reward_seq : Tensor
            Rewards (batch, seq_len)
        
        Returns
        -------
        Tuple[Tensor, Dict[str, float]]
            Total loss and individual loss components
        """
        outputs = self.model(obs_seq, action_seq)
        
        # Reconstruction loss
        obs_dist = torch.distributions.Normal(
            outputs['obs_mean'],
            torch.exp(outputs['obs_log_std'])
        )
        recon_loss = -obs_dist.log_prob(obs_seq).mean()
        
        # KL divergence loss with free nats
        prior_dist = torch.distributions.Normal(
            outputs['prior_mean'],
            outputs['prior_std']
        )
        post_dist = torch.distributions.Normal(
            outputs['post_mean'],
            outputs['post_std']
        )
        kl_div = torch.distributions.kl_divergence(post_dist, prior_dist)
        kl_loss = torch.maximum(
            kl_div.mean(),
            torch.tensor(self.free_nats, device=kl_div.device)
        )
        
        # Reward loss
        reward_loss = F.mse_loss(outputs['reward_pred'], reward_seq)
        
        # Total loss
        total_loss = recon_loss + self.kl_weight * kl_loss + reward_loss
        
        metrics = {
            'recon_loss': recon_loss.item(),
            'kl_loss': kl_loss.item(),
            'reward_loss': reward_loss.item(),
            'total_loss': total_loss.item()
        }
        
        return total_loss, metrics
    
    def train_step(
        self,
        obs_seq: Tensor,
        action_seq: Tensor,
        reward_seq: Tensor
    ) -> Dict[str, float]:
        """
        Single training step.
        
        Returns
        -------
        Dict[str, float]
            Loss metrics
        """
        self.model.train()
        self.optimizer.zero_grad()
        
        loss, metrics = self.compute_loss(obs_seq, action_seq, reward_seq)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 100.0)
        self.optimizer.step()
        
        # Log metrics
        for key, value in metrics.items():
            self.logger.log(**{key: value})
        
        return metrics
    
    def evaluate(
        self,
        obs_seq: Tensor,
        action_seq: Tensor,
        reward_seq: Tensor
    ) -> Dict[str, float]:
        """
        Evaluate model.
        
        Returns
        -------
        Dict[str, float]
            Loss metrics
        """
        self.model.eval()
        with torch.no_grad():
            _, metrics = self.compute_loss(obs_seq, action_seq, reward_seq)
        return metrics

In [None]:
# Test Gate-Enhanced Trainer
print("Testing GateEnhancedTrainer...")

model = GateEnhancedWorldModel(obs_dim=4, action_dim=1).to(device)
trainer = GateEnhancedTrainer(model)

# Generate synthetic data
obs_seq = torch.randn(16, 20, 4, device=device)
action_seq = torch.randn(16, 20, 1, device=device)
reward_seq = torch.randn(16, 20, device=device)

# Run training step
metrics = trainer.train_step(obs_seq, action_seq, reward_seq)
print(f"Training metrics:")
for key, value in metrics.items():
    print(f"  {key}: {value:.4f}")

## 5.10 Comparison: Gate-Enhanced vs Standard

Compare the gate-enhanced world model against the standard baseline.

In [None]:
import gymnasium as gym
from collections import deque

def collect_episodes(
    env_name: str,
    num_episodes: int = 10,
    max_steps: int = 200
) -> List[Dict[str, np.ndarray]]:
    """
    Collect episodes from environment.
    
    Parameters
    ----------
    env_name : str
        Environment name
    num_episodes : int
        Number of episodes to collect
    max_steps : int
        Maximum steps per episode
    
    Returns
    -------
    List[Dict]
        List of episode dictionaries
    """
    env = gym.make(env_name)
    episodes = []
    
    for _ in range(num_episodes):
        obs_list, action_list, reward_list = [], [], []
        
        obs, _ = env.reset()
        obs_list.append(obs)
        
        for _ in range(max_steps):
            action = env.action_space.sample()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            
            action_list.append([action] if isinstance(action, (int, float)) else action)
            reward_list.append(reward)
            obs_list.append(next_obs)
            
            if terminated or truncated:
                break
            obs = next_obs
        
        # Remove last obs (not paired with action)
        obs_list = obs_list[:-1]
        
        if len(obs_list) > 10:  # Minimum episode length
            episodes.append({
                'obs': np.array(obs_list, dtype=np.float32),
                'actions': np.array(action_list, dtype=np.float32),
                'rewards': np.array(reward_list, dtype=np.float32)
            })
    
    env.close()
    return episodes


def create_batches(
    episodes: List[Dict],
    batch_size: int = 16,
    seq_len: int = 20,
    device: torch.device = device
) -> List[Tuple[Tensor, Tensor, Tensor]]:
    """
    Create training batches from episodes.
    
    Parameters
    ----------
    episodes : List[Dict]
        List of episode dictionaries
    batch_size : int
        Batch size
    seq_len : int
        Sequence length
    device : torch.device
        Device to put tensors on
    
    Returns
    -------
    List[Tuple[Tensor, Tensor, Tensor]]
        List of (obs, actions, rewards) batches
    """
    # Collect all valid sequences
    sequences = []
    for ep in episodes:
        ep_len = len(ep['obs'])
        for start in range(0, ep_len - seq_len, seq_len // 2):
            sequences.append({
                'obs': ep['obs'][start:start+seq_len],
                'actions': ep['actions'][start:start+seq_len],
                'rewards': ep['rewards'][start:start+seq_len]
            })
    
    # Shuffle and batch
    np.random.shuffle(sequences)
    batches = []
    
    for i in range(0, len(sequences) - batch_size, batch_size):
        batch_seqs = sequences[i:i+batch_size]
        
        obs = torch.tensor(
            np.stack([s['obs'] for s in batch_seqs]),
            dtype=torch.float32, device=device
        )
        actions = torch.tensor(
            np.stack([s['actions'] for s in batch_seqs]),
            dtype=torch.float32, device=device
        )
        rewards = torch.tensor(
            np.stack([s['rewards'] for s in batch_seqs]),
            dtype=torch.float32, device=device
        )
        
        batches.append((obs, actions, rewards))
    
    return batches

In [None]:
# Standard World Model for comparison (from Phase 2)
class StandardWorldModel(nn.Module):
    """
    Standard world model without quantum gate enhancements.
    """
    
    def __init__(self, obs_dim: int, action_dim: int, config: Optional[Dict] = None):
        super().__init__()
        config = config or {}
        
        hidden_dim = config.get('hidden_dim', 256)
        embed_dim = config.get('embed_dim', 64)
        deter_dim = config.get('deter_dim', 128)
        stoch_dim = config.get('stoch_dim', 32)
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.state_dim = deter_dim + stoch_dim
        self.deter_dim = deter_dim
        self.stoch_dim = stoch_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, embed_dim),
            nn.ELU()
        )
        
        # Decoder
        self.decoder_net = nn.Sequential(
            nn.Linear(self.state_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ELU()
        )
        self.decoder_mean = nn.Linear(hidden_dim // 2, obs_dim)
        self.decoder_log_std = nn.Linear(hidden_dim // 2, obs_dim)
        
        # RSSM
        self.gru = nn.GRUCell(stoch_dim + action_dim, deter_dim)
        
        self.prior_net = nn.Sequential(
            nn.Linear(deter_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, stoch_dim * 2)
        )
        
        self.posterior_net = nn.Sequential(
            nn.Linear(deter_dim + embed_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, stoch_dim * 2)
        )
        
        # Reward predictor
        self.reward_pred = nn.Sequential(
            nn.Linear(self.state_dim, hidden_dim // 2),
            nn.ELU(),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def initial_state(self, batch_size: int, device: torch.device) -> Dict[str, Tensor]:
        return {
            'deter': torch.zeros(batch_size, self.deter_dim, device=device),
            'stoch': torch.zeros(batch_size, self.stoch_dim, device=device)
        }
    
    def get_full_state(self, state: Dict[str, Tensor]) -> Tensor:
        return torch.cat([state['deter'], state['stoch']], dim=-1)
    
    def forward(self, obs_seq: Tensor, action_seq: Tensor) -> Dict[str, Tensor]:
        batch_size, seq_len = obs_seq.shape[:2]
        device = obs_seq.device
        
        state = self.initial_state(batch_size, device)
        
        states = []
        prior_means, prior_stds = [], []
        post_means, post_stds = [], []
        
        for t in range(seq_len):
            # Encode
            embed = self.encoder(obs_seq[:, t])
            
            # Update deterministic
            gru_input = torch.cat([state['stoch'], action_seq[:, t]], dim=-1)
            deter = self.gru(gru_input, state['deter'])
            
            # Prior
            prior_stats = self.prior_net(deter)
            prior_mean, prior_log_std = torch.chunk(prior_stats, 2, dim=-1)
            prior_std = F.softplus(prior_log_std) + 0.1
            
            # Posterior
            post_input = torch.cat([deter, embed], dim=-1)
            post_stats = self.posterior_net(post_input)
            post_mean, post_log_std = torch.chunk(post_stats, 2, dim=-1)
            post_std = F.softplus(post_log_std) + 0.1
            
            # Sample
            stoch = post_mean + post_std * torch.randn_like(post_std)
            
            state = {'deter': deter, 'stoch': stoch}
            
            states.append(self.get_full_state(state))
            prior_means.append(prior_mean)
            prior_stds.append(prior_std)
            post_means.append(post_mean)
            post_stds.append(post_std)
        
        states = torch.stack(states, dim=1)
        
        # Decode
        flat_states = states.reshape(-1, states.shape[-1])
        dec_hidden = self.decoder_net(flat_states)
        obs_mean = self.decoder_mean(dec_hidden).reshape(batch_size, seq_len, -1)
        obs_log_std = self.decoder_log_std(dec_hidden).clamp(-10, 2).reshape(batch_size, seq_len, -1)
        
        reward_pred = self.reward_pred(flat_states).reshape(batch_size, seq_len)
        
        return {
            'states': states,
            'obs_mean': obs_mean,
            'obs_log_std': obs_log_std,
            'reward_pred': reward_pred,
            'prior_mean': torch.stack(prior_means, dim=1),
            'prior_std': torch.stack(prior_stds, dim=1),
            'post_mean': torch.stack(post_means, dim=1),
            'post_std': torch.stack(post_stds, dim=1)
        }

In [None]:
class StandardTrainer:
    """Standard trainer for comparison."""
    
    def __init__(self, model, learning_rate=1e-4, kl_weight=1.0, free_nats=3.0):
        self.model = model
        self.kl_weight = kl_weight
        self.free_nats = free_nats
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
        self.logger = MetricLogger(name='standard')
    
    def train_step(self, obs_seq, action_seq, reward_seq):
        self.model.train()
        self.optimizer.zero_grad()
        
        outputs = self.model(obs_seq, action_seq)
        
        # Reconstruction loss
        obs_dist = torch.distributions.Normal(outputs['obs_mean'], torch.exp(outputs['obs_log_std']))
        recon_loss = -obs_dist.log_prob(obs_seq).mean()
        
        # KL loss
        prior_dist = torch.distributions.Normal(outputs['prior_mean'], outputs['prior_std'])
        post_dist = torch.distributions.Normal(outputs['post_mean'], outputs['post_std'])
        kl_div = torch.distributions.kl_divergence(post_dist, prior_dist)
        kl_loss = torch.maximum(kl_div.mean(), torch.tensor(self.free_nats, device=kl_div.device))
        
        # Reward loss
        reward_loss = F.mse_loss(outputs['reward_pred'], reward_seq)
        
        total_loss = recon_loss + self.kl_weight * kl_loss + reward_loss
        
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 100.0)
        self.optimizer.step()
        
        metrics = {
            'recon_loss': recon_loss.item(),
            'kl_loss': kl_loss.item(),
            'reward_loss': reward_loss.item(),
            'total_loss': total_loss.item()
        }
        
        for key, value in metrics.items():
            self.logger.log(**{key: value})
        
        return metrics

In [None]:
def run_comparison(
    env_name: str = 'CartPole-v1',
    num_episodes: int = 20,
    num_epochs: int = 50,
    batch_size: int = 16,
    seq_len: int = 20,
    seed: int = 42
) -> Dict[str, List[float]]:
    """
    Run comparison between gate-enhanced and standard world models.
    
    Parameters
    ----------
    env_name : str
        Environment name
    num_episodes : int
        Number of episodes to collect
    num_epochs : int
        Number of training epochs
    batch_size : int
        Batch size
    seq_len : int
        Sequence length
    seed : int
        Random seed
    
    Returns
    -------
    Dict[str, List[float]]
        Training histories for both models
    """
    set_seed(seed)
    
    print(f"Collecting {num_episodes} episodes from {env_name}...")
    episodes = collect_episodes(env_name, num_episodes)
    
    # Get dimensions
    obs_dim = episodes[0]['obs'].shape[1]
    action_dim = episodes[0]['actions'].shape[1]
    
    print(f"Observation dim: {obs_dim}, Action dim: {action_dim}")
    print(f"Creating training batches...")
    
    batches = create_batches(episodes, batch_size, seq_len)
    print(f"Created {len(batches)} batches")
    
    # Create models
    print("\nInitializing models...")
    
    gate_model = GateEnhancedWorldModel(obs_dim, action_dim).to(device)
    standard_model = StandardWorldModel(obs_dim, action_dim).to(device)
    
    gate_params = sum(p.numel() for p in gate_model.parameters())
    standard_params = sum(p.numel() for p in standard_model.parameters())
    print(f"Gate-enhanced parameters: {gate_params:,}")
    print(f"Standard parameters: {standard_params:,}")
    
    # Create trainers
    gate_trainer = GateEnhancedTrainer(gate_model)
    standard_trainer = StandardTrainer(standard_model)
    
    # Training histories
    histories = {
        'gate_loss': [],
        'standard_loss': [],
        'gate_recon': [],
        'standard_recon': []
    }
    
    print(f"\nTraining for {num_epochs} epochs...")
    timer = Timer().start()
    
    for epoch in range(num_epochs):
        gate_losses, standard_losses = [], []
        gate_recons, standard_recons = [], []
        
        for obs, actions, rewards in batches:
            # Train gate-enhanced
            gate_metrics = gate_trainer.train_step(obs, actions, rewards)
            gate_losses.append(gate_metrics['total_loss'])
            gate_recons.append(gate_metrics['recon_loss'])
            
            # Train standard
            standard_metrics = standard_trainer.train_step(obs, actions, rewards)
            standard_losses.append(standard_metrics['total_loss'])
            standard_recons.append(standard_metrics['recon_loss'])
        
        # Record epoch averages
        histories['gate_loss'].append(np.mean(gate_losses))
        histories['standard_loss'].append(np.mean(standard_losses))
        histories['gate_recon'].append(np.mean(gate_recons))
        histories['standard_recon'].append(np.mean(standard_recons))
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}:")
            print(f"  Gate-enhanced loss: {histories['gate_loss'][-1]:.4f}")
            print(f"  Standard loss: {histories['standard_loss'][-1]:.4f}")
    
    elapsed = timer.stop()
    print(f"\nTraining completed in {elapsed:.2f}s")
    
    return histories

In [None]:
# Run comparison
histories = run_comparison(
    env_name='CartPole-v1',
    num_episodes=20,
    num_epochs=50,
    seed=42
)

In [None]:
# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Total loss
ax = axes[0]
ax.plot(histories['gate_loss'], label='Gate-Enhanced', color=COLORS['gates'], linewidth=2)
ax.plot(histories['standard_loss'], label='Standard', color=COLORS['baseline'], linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Total Loss')
ax.set_title('Total Loss Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

# Reconstruction loss
ax = axes[1]
ax.plot(histories['gate_recon'], label='Gate-Enhanced', color=COLORS['gates'], linewidth=2)
ax.plot(histories['standard_recon'], label='Standard', color=COLORS['baseline'], linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Reconstruction Loss')
ax.set_title('Reconstruction Loss Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/figures/gate_enhanced_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# Print statistics
print("\n" + "="*60)
print("Final Comparison (last 10 epochs average):")
print("="*60)
print(f"Gate-enhanced total loss: {np.mean(histories['gate_loss'][-10:]):.4f}")
print(f"Standard total loss: {np.mean(histories['standard_loss'][-10:]):.4f}")
print(f"Gate-enhanced recon loss: {np.mean(histories['gate_recon'][-10:]):.4f}")
print(f"Standard recon loss: {np.mean(histories['standard_recon'][-10:]):.4f}")

## 5.11 Gate Layer Analysis

Analyze the learned gate parameters to understand what the model has learned.

In [None]:
def analyze_gate_layers(model: GateEnhancedWorldModel):
    """
    Analyze the learned parameters in quantum gate layers.
    
    Parameters
    ----------
    model : GateEnhancedWorldModel
        The trained model
    """
    print("Gate Layer Analysis")
    print("=" * 60)
    
    # Analyze encoder gate block
    print("\nEncoder Gate Block:")
    for i, layer in enumerate(model.encoder.gate_block.layers):
        print(f"\n  Layer {i+1}:")
        for sublayer in layer:
            if isinstance(sublayer, RotationLayer):
                angles = sublayer.angles.detach().cpu().numpy()
                print(f"    Rotation angles - mean: {angles.mean():.4f}, std: {angles.std():.4f}")
            elif isinstance(sublayer, CNOTLayer):
                entanglement = sublayer.get_entanglement_strength().item()
                print(f"    CNOT entanglement strength: {entanglement:.4f}")
            elif isinstance(sublayer, PhaseLayer):
                phases = sublayer.phases.detach().cpu().numpy()
                print(f"    Phase angles - mean: {phases.mean():.4f}, std: {phases.std():.4f}")
            elif isinstance(sublayer, HadamardLayer):
                scale = sublayer.scale.detach().cpu().numpy()
                print(f"    Hadamard scale - mean: {scale.mean():.4f}, std: {scale.std():.4f}")
    
    # Analyze RSSM gate blocks
    print("\nRSSM Prior Gate Block:")
    for i, layer in enumerate(model.rssm.prior_gate.layers):
        print(f"  Layer {i+1}:")
        for sublayer in layer:
            if isinstance(sublayer, CNOTLayer):
                entanglement = sublayer.get_entanglement_strength().item()
                print(f"    CNOT entanglement: {entanglement:.4f}")


# Run analysis on a trained model
print("Creating and training model for analysis...")
analysis_model = GateEnhancedWorldModel(obs_dim=4, action_dim=1).to(device)

# Quick training
trainer = GateEnhancedTrainer(analysis_model)
for _ in range(10):
    obs = torch.randn(16, 20, 4, device=device)
    actions = torch.randn(16, 20, 1, device=device)
    rewards = torch.randn(16, 20, device=device)
    trainer.train_step(obs, actions, rewards)

analyze_gate_layers(analysis_model)

## 5.12 Summary

### Key Findings

1. **Hadamard Layers**: Create orthogonal feature mixing, preserving information while enabling complex transformations

2. **Rotation Layers**: Learn task-specific rotations in feature space, providing flexible parameterized transformations

3. **CNOT Layers**: Create controlled dependencies between features, enabling entanglement-like correlations

4. **Phase Layers**: Apply learnable phase modulations for additional expressivity

### Implementation Notes

- All layers preserve dimension and can be stacked
- Residual connections improve training stability
- The gate-enhanced model has more parameters but offers richer representations

### Next Steps

- Phase 6: Error Correction Ensemble
- Phase 7: Comprehensive Comparison
- Phase 8: Ablation Studies

In [None]:
print("\n" + "="*60)
print("Phase 5: Gate-Enhanced Neural Layers - COMPLETE")
print("="*60)
print("\nImplemented:")
print("  - HadamardLayer: Orthogonal feature mixing")
print("  - RotationLayer: Parameterized rotations (Rx, Ry, Rz)")
print("  - CNOTLayer: Controlled operations for correlations")
print("  - PhaseLayer: Phase modulations")
print("  - QuantumGateBlock: Composite quantum circuit-like block")
print("  - GateEnhancedWorldModel: Full world model integration")
print("  - Comparison with standard world model")
print("\nReady for Phase 6: Error Correction Ensemble")