# Phase 4: Superposition-Enhanced Experience Replay

**Notebook:** `04_superposition_replay.ipynb`  
**Phase:** 4 of 9  
**Purpose:** Implement quantum superposition-inspired experience replay for improved sample efficiency  
**Author:** Saurabh Jalendra  
**Institution:** BITS Pilani (WILP Division)  
**Date:** November 2025

---

## Table of Contents

1. [Setup & Imports](#1-setup--imports)
2. [Superposition Concept](#2-superposition-concept)
3. [Amplitude-Based Weighting](#3-amplitude-based-weighting)
4. [Interference Effects](#4-interference-effects)
5. [Superposition Replay Buffer](#5-superposition-replay-buffer)
6. [Integration with World Model](#6-integration-with-world-model)
7. [Experiments](#7-experiments)
8. [Comparison with Standard Replay](#8-comparison-with-standard-replay)
9. [Visualizations](#9-visualizations)
10. [Summary](#10-summary)

---
## 1. Setup & Imports

In [None]:
"""
Cell: Imports and Configuration
Purpose: Import packages and set up environment
"""

import os
import sys
import math
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Union, Any, NamedTuple
from dataclasses import dataclass, field
from collections import deque
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from scipy import stats

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

import gymnasium as gym

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

from tqdm.notebook import tqdm

PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT / "src"))

from utils import set_seed, get_device, MetricLogger, COLORS

SEED = 42
set_seed(SEED)
DEVICE = get_device()

print(f"Device: {DEVICE}")
print(f"PyTorch: {torch.__version__}")

In [None]:
"""
Cell: World Model Components (from Phase 2)
Purpose: Import world model architecture
"""

# Core world model components (condensed from previous phases)
class RSSMState(NamedTuple):
    deter: torch.Tensor
    stoch: torch.Tensor
    @property
    def combined(self): return torch.cat([self.deter, self.stoch], dim=-1)

class MLPEncoder(nn.Module):
    def __init__(self, obs_dim, hidden_dims, latent_dim):
        super().__init__()
        layers = []
        in_dim = obs_dim
        for hd in hidden_dims:
            layers.extend([nn.Linear(in_dim, hd), nn.ELU()])
            in_dim = hd
        layers.append(nn.Linear(in_dim, latent_dim))
        self.network = nn.Sequential(*layers)
    def forward(self, obs): return self.network(obs)

class MLPDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dims, obs_dim):
        super().__init__()
        layers = []
        in_dim = latent_dim
        for hd in hidden_dims:
            layers.extend([nn.Linear(in_dim, hd), nn.ELU()])
            in_dim = hd
        layers.append(nn.Linear(in_dim, obs_dim))
        self.network = nn.Sequential(*layers)
    def forward(self, latent): return self.network(latent)

class RSSM(nn.Module):
    def __init__(self, stoch_dim, deter_dim, hidden_dim, action_dim, embed_dim, min_std=0.1):
        super().__init__()
        self.stoch_dim, self.deter_dim, self.min_std = stoch_dim, deter_dim, min_std
        self.input_proj = nn.Sequential(nn.Linear(stoch_dim + action_dim, hidden_dim), nn.ELU())
        self.gru = nn.GRUCell(hidden_dim, deter_dim)
        self.prior_net = nn.Sequential(nn.Linear(deter_dim, hidden_dim), nn.ELU(), nn.Linear(hidden_dim, 2*stoch_dim))
        self.posterior_net = nn.Sequential(nn.Linear(deter_dim + embed_dim, hidden_dim), nn.ELU(), nn.Linear(hidden_dim, 2*stoch_dim))
    
    def initial_state(self, batch_size, device):
        return RSSMState(torch.zeros(batch_size, self.deter_dim, device=device),
                        torch.zeros(batch_size, self.stoch_dim, device=device))
    
    def _get_dist(self, stats):
        mean, std = torch.chunk(stats, 2, dim=-1)
        return Normal(mean, F.softplus(std) + self.min_std)
    
    def imagine_step(self, prev_state, action):
        x = self.input_proj(torch.cat([prev_state.stoch, action], dim=-1))
        deter = self.gru(x, prev_state.deter)
        prior_dist = self._get_dist(self.prior_net(deter))
        return RSSMState(deter, prior_dist.rsample()), prior_dist
    
    def observe_step(self, prev_state, action, embed):
        prior_state, prior_dist = self.imagine_step(prev_state, action)
        posterior_dist = self._get_dist(self.posterior_net(torch.cat([prior_state.deter, embed], dim=-1)))
        return RSSMState(prior_state.deter, posterior_dist.rsample()), prior_dist, posterior_dist

class RewardPredictor(nn.Module):
    def __init__(self, state_dim, hidden_dims):
        super().__init__()
        layers = []
        in_dim = state_dim
        for hd in hidden_dims:
            layers.extend([nn.Linear(in_dim, hd), nn.ELU()])
            in_dim = hd
        layers.append(nn.Linear(in_dim, 1))
        self.network = nn.Sequential(*layers)
    def forward(self, state): return self.network(state).squeeze(-1)

class ContinuePredictor(nn.Module):
    def __init__(self, state_dim, hidden_dims):
        super().__init__()
        layers = []
        in_dim = state_dim
        for hd in hidden_dims:
            layers.extend([nn.Linear(in_dim, hd), nn.ELU()])
            in_dim = hd
        layers.append(nn.Linear(in_dim, 1))
        self.network = nn.Sequential(*layers)
    def forward(self, state): return self.network(state).squeeze(-1)

class WorldModel(nn.Module):
    def __init__(self, obs_dim, action_dim, stoch_dim=32, deter_dim=256, hidden_dim=256,
                 encoder_hidden=[256,256], decoder_hidden=[256,256], predictor_hidden=[256,256]):
        super().__init__()
        self.state_dim = stoch_dim + deter_dim
        self.encoder = MLPEncoder(obs_dim, encoder_hidden, hidden_dim)
        self.rssm = RSSM(stoch_dim, deter_dim, hidden_dim, action_dim, hidden_dim)
        self.reward_predictor = RewardPredictor(self.state_dim, predictor_hidden)
        self.continue_predictor = ContinuePredictor(self.state_dim, predictor_hidden)
        self.decoder = MLPDecoder(self.state_dim, decoder_hidden, obs_dim)
    
    def initial_state(self, batch_size):
        return self.rssm.initial_state(batch_size, next(self.parameters()).device)
    
    def observe(self, obs, action, prev_state):
        return self.rssm.observe_step(prev_state, action, self.encoder(obs))
    
    def predict(self, state):
        combined = state.combined
        return self.decoder(combined), self.reward_predictor(combined), self.continue_predictor(combined)
    
    def forward(self, obs_seq, action_seq):
        B, T, _ = obs_seq.shape
        state = self.initial_state(B)
        priors, posteriors, recon_obs, pred_rewards, pred_continues = [], [], [], [], []
        for t in range(T):
            state, prior, posterior = self.observe(obs_seq[:, t], action_seq[:, t], state)
            recon, reward, cont = self.predict(state)
            priors.append(prior); posteriors.append(posterior)
            recon_obs.append(recon); pred_rewards.append(reward); pred_continues.append(cont)
        return {'recon_obs': torch.stack(recon_obs, 1), 'pred_rewards': torch.stack(pred_rewards, 1),
                'pred_continues': torch.stack(pred_continues, 1), 'priors': priors, 'posteriors': posteriors}

print("World model components loaded.")

---
## 2. Superposition Concept

In quantum mechanics, superposition allows a system to exist in multiple states simultaneously.
For experience replay, we implement this as:

1. **Parallel sampling**: Sample multiple trajectories simultaneously
2. **Amplitude weighting**: Assign quantum-like amplitudes to each sample
3. **Interference**: Allow positive/negative interference between samples

In [None]:
"""
Cell: Superposition Concept Diagram
Purpose: Visualize the superposition-enhanced replay concept
"""

def create_superposition_diagram(figsize=(14, 8)):
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_xlim(0, 14)
    ax.set_ylim(0, 8)
    ax.axis('off')
    
    # Standard replay
    ax.text(3, 7.5, 'Standard Replay', ha='center', fontsize=12, fontweight='bold')
    for i in range(3):
        rect = mpatches.FancyBboxPatch((1+i*1.5, 6), 1.2, 0.8, boxstyle="round,pad=0.02",
                                       facecolor=COLORS['baseline'], alpha=0.6)
        ax.add_patch(rect)
        ax.text(1.6+i*1.5, 6.4, f'τ{i+1}', ha='center', va='center', fontsize=10, color='white')
    ax.annotate('', xy=(5.5, 6.4), xytext=(4.8, 6.4), arrowprops=dict(arrowstyle='->', lw=2))
    rect = mpatches.FancyBboxPatch((5.6, 6), 1.2, 0.8, boxstyle="round,pad=0.02",
                                   facecolor=COLORS['baseline'], alpha=0.9)
    ax.add_patch(rect)
    ax.text(6.2, 6.4, 'Batch', ha='center', va='center', fontsize=10, color='white', fontweight='bold')
    
    # Superposition replay
    ax.text(3, 4.5, 'Superposition Replay', ha='center', fontsize=12, fontweight='bold', color=COLORS['superposition'])
    
    # Parallel samples with amplitudes
    amplitudes = [0.5, 0.3, 0.7, 0.4, 0.6]
    for i, amp in enumerate(amplitudes):
        rect = mpatches.FancyBboxPatch((0.5+i*1.3, 3), 1.0, 0.8, boxstyle="round,pad=0.02",
                                       facecolor=COLORS['superposition'], alpha=amp)
        ax.add_patch(rect)
        ax.text(1.0+i*1.3, 3.4, f'α{i+1}', ha='center', va='center', fontsize=9, color='white')
        ax.text(1.0+i*1.3, 2.6, f'{amp}', ha='center', va='center', fontsize=8)
    
    # Interference
    ax.annotate('', xy=(8, 3.4), xytext=(7, 3.4), arrowprops=dict(arrowstyle='->', lw=2, color=COLORS['superposition']))
    ax.text(7.5, 3.8, 'Interference', ha='center', fontsize=9, style='italic')
    
    # Weighted combination
    rect = mpatches.FancyBboxPatch((8.2, 3), 2, 0.8, boxstyle="round,pad=0.02",
                                   facecolor=COLORS['superposition'], alpha=0.9)
    ax.add_patch(rect)
    ax.text(9.2, 3.4, 'Σ αᵢ|τᵢ⟩', ha='center', va='center', fontsize=11, color='white', fontweight='bold')
    
    # Benefits
    ax.text(11.5, 6.5, 'Benefits:', fontsize=11, fontweight='bold')
    benefits = ['• Better exploration', '• Sample efficiency', '• Rare event focus', '• Smooth gradients']
    for i, b in enumerate(benefits):
        ax.text(11.5, 5.8-i*0.5, b, fontsize=9)
    
    ax.set_title('Superposition-Enhanced Experience Replay', fontsize=14, fontweight='bold')
    return fig

fig = create_superposition_diagram()
plt.tight_layout()
plt.show()

---
## 3. Amplitude-Based Weighting

Assign quantum-like amplitudes to experiences based on their importance.

In [None]:
"""
Cell: Amplitude Calculator
Purpose: Compute quantum-inspired amplitudes for experiences
"""

class AmplitudeCalculator:
    """
    Calculate quantum-inspired amplitudes for experience weighting.
    
    Amplitudes are computed based on:
    - TD error (surprise/novelty)
    - Reward magnitude
    - Recency
    - State entropy
    
    Parameters
    ----------
    td_weight : float
        Weight for TD error component
    reward_weight : float
        Weight for reward magnitude
    recency_weight : float
        Weight for recency (newer = higher)
    entropy_weight : float
        Weight for state entropy
    temperature : float
        Softmax temperature for amplitude normalization
    """
    
    def __init__(
        self,
        td_weight: float = 0.4,
        reward_weight: float = 0.3,
        recency_weight: float = 0.2,
        entropy_weight: float = 0.1,
        temperature: float = 1.0
    ):
        self.td_weight = td_weight
        self.reward_weight = reward_weight
        self.recency_weight = recency_weight
        self.entropy_weight = entropy_weight
        self.temperature = temperature
    
    def compute_td_priority(
        self,
        rewards: np.ndarray,
        gamma: float = 0.99
    ) -> np.ndarray:
        """
        Compute TD-error based priority.
        Uses reward variance as proxy for TD error.
        """
        # Compute reward changes
        reward_diff = np.abs(np.diff(rewards, prepend=rewards[0]))
        return reward_diff / (reward_diff.max() + 1e-8)
    
    def compute_reward_priority(self, rewards: np.ndarray) -> np.ndarray:
        """
        Compute reward magnitude priority.
        Higher absolute rewards get higher priority.
        """
        abs_rewards = np.abs(rewards)
        return abs_rewards / (abs_rewards.max() + 1e-8)
    
    def compute_recency_priority(self, length: int, decay: float = 0.99) -> np.ndarray:
        """
        Compute recency-based priority.
        More recent experiences get higher priority.
        """
        indices = np.arange(length)
        return decay ** (length - 1 - indices)
    
    def compute_entropy_priority(self, observations: np.ndarray) -> np.ndarray:
        """
        Compute state entropy priority.
        Higher entropy (more diverse) states get higher priority.
        """
        # Use observation variance as proxy for entropy
        obs_var = np.var(observations, axis=-1)
        return obs_var / (obs_var.max() + 1e-8)
    
    def compute_amplitudes(
        self,
        observations: np.ndarray,
        rewards: np.ndarray
    ) -> np.ndarray:
        """
        Compute quantum-inspired amplitudes for a trajectory.
        
        Parameters
        ----------
        observations : np.ndarray
            Trajectory observations (T, obs_dim)
        rewards : np.ndarray
            Trajectory rewards (T,)
        
        Returns
        -------
        np.ndarray
            Amplitudes (T,) that sum to 1 (like probability amplitudes squared)
        """
        length = len(rewards)
        
        # Compute individual priorities
        td_priority = self.compute_td_priority(rewards)
        reward_priority = self.compute_reward_priority(rewards)
        recency_priority = self.compute_recency_priority(length)
        entropy_priority = self.compute_entropy_priority(observations)
        
        # Weighted combination
        combined = (
            self.td_weight * td_priority +
            self.reward_weight * reward_priority +
            self.recency_weight * recency_priority +
            self.entropy_weight * entropy_priority
        )
        
        # Softmax normalization (temperature-scaled)
        amplitudes = np.exp(combined / self.temperature)
        amplitudes = amplitudes / amplitudes.sum()
        
        return amplitudes


# Test amplitude calculator
amp_calc = AmplitudeCalculator()

test_obs = np.random.randn(50, 4)
test_rewards = np.sin(np.linspace(0, 4*np.pi, 50)) + np.random.randn(50) * 0.1

amplitudes = amp_calc.compute_amplitudes(test_obs, test_rewards)

print(f"Amplitudes shape: {amplitudes.shape}")
print(f"Amplitudes sum: {amplitudes.sum():.4f}")
print(f"Max amplitude: {amplitudes.max():.4f}")
print(f"Min amplitude: {amplitudes.min():.4f}")

In [None]:
"""
Cell: Visualize Amplitudes
Purpose: Plot amplitude distribution for a trajectory
"""

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Rewards
ax = axes[0, 0]
ax.plot(test_rewards, color=COLORS['baseline'], linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Reward')
ax.set_title('Trajectory Rewards')
ax.grid(True, alpha=0.3)

# Amplitudes
ax = axes[0, 1]
ax.bar(range(len(amplitudes)), amplitudes, color=COLORS['superposition'], alpha=0.7)
ax.set_xlabel('Step')
ax.set_ylabel('Amplitude')
ax.set_title('Quantum-Inspired Amplitudes')
ax.grid(True, alpha=0.3)

# Amplitude vs Reward
ax = axes[1, 0]
ax.scatter(test_rewards, amplitudes, c=range(len(amplitudes)), cmap='viridis', alpha=0.7)
ax.set_xlabel('Reward')
ax.set_ylabel('Amplitude')
ax.set_title('Amplitude vs Reward')
ax.grid(True, alpha=0.3)

# Cumulative amplitude
ax = axes[1, 1]
ax.plot(np.cumsum(amplitudes), color=COLORS['superposition'], linewidth=2)
ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Step')
ax.set_ylabel('Cumulative Amplitude')
ax.set_title('Cumulative Amplitude (Sampling CDF)')
ax.grid(True, alpha=0.3)

fig.suptitle('Amplitude-Based Experience Weighting', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---
## 4. Interference Effects

Implement constructive and destructive interference between sampled experiences.

In [None]:
"""
Cell: Interference Calculator
Purpose: Implement quantum interference-like effects
"""

class InterferenceCalculator:
    """
    Calculate interference effects between parallel samples.
    
    Implements:
    - Constructive interference: Similar experiences reinforce each other
    - Destructive interference: Contradictory experiences cancel out
    
    Parameters
    ----------
    interference_strength : float
        Strength of interference effects (0 = none, 1 = full)
    similarity_threshold : float
        Threshold for considering experiences similar
    """
    
    def __init__(
        self,
        interference_strength: float = 0.3,
        similarity_threshold: float = 0.7
    ):
        self.interference_strength = interference_strength
        self.similarity_threshold = similarity_threshold
    
    def compute_similarity(
        self,
        obs1: np.ndarray,
        obs2: np.ndarray
    ) -> float:
        """
        Compute similarity between two observations.
        Uses cosine similarity.
        """
        obs1_flat = obs1.flatten()
        obs2_flat = obs2.flatten()
        norm1 = np.linalg.norm(obs1_flat)
        norm2 = np.linalg.norm(obs2_flat)
        if norm1 < 1e-8 or norm2 < 1e-8:
            return 0.0
        return np.dot(obs1_flat, obs2_flat) / (norm1 * norm2)
    
    def compute_phase(
        self,
        rewards: np.ndarray,
        actions: np.ndarray
    ) -> np.ndarray:
        """
        Compute phase angles for samples.
        Phase is determined by action-reward relationship.
        """
        # Use action direction and reward sign to determine phase
        action_signal = np.mean(actions, axis=-1) if actions.ndim > 1 else actions
        reward_sign = np.sign(rewards)
        
        # Phase in [0, 2π]
        phase = np.pi * (1 + action_signal * reward_sign) / 2
        return phase
    
    def apply_interference(
        self,
        samples: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]],
        amplitudes: List[np.ndarray]
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Apply interference effects to parallel samples.
        
        Parameters
        ----------
        samples : List[Tuple]
            List of (obs, action, reward, continue) samples
        amplitudes : List[np.ndarray]
            Amplitudes for each sample
        
        Returns
        -------
        Tuple[np.ndarray, ...]
            Interfered (obs, actions, rewards, continues, weights)
        """
        num_samples = len(samples)
        
        # Compute pairwise similarities and phases
        interference_matrix = np.ones((num_samples, num_samples))
        
        for i in range(num_samples):
            for j in range(i + 1, num_samples):
                obs_i, _, rewards_i, _ = samples[i]
                obs_j, _, rewards_j, _ = samples[j]
                
                # Compute similarity
                sim = self.compute_similarity(obs_i.mean(axis=0), obs_j.mean(axis=0))
                
                # Determine interference type based on reward correlation
                reward_corr = np.corrcoef(rewards_i, rewards_j)[0, 1]
                
                if np.isnan(reward_corr):
                    reward_corr = 0.0
                
                # Constructive if similar and same reward direction
                # Destructive if similar but opposite rewards
                if sim > self.similarity_threshold:
                    if reward_corr > 0:
                        # Constructive interference
                        interference = 1.0 + self.interference_strength * sim
                    else:
                        # Destructive interference
                        interference = 1.0 - self.interference_strength * sim * abs(reward_corr)
                else:
                    interference = 1.0
                
                interference_matrix[i, j] = interference
                interference_matrix[j, i] = interference
        
        # Apply interference to amplitudes
        modified_amplitudes = []
        for i, amp in enumerate(amplitudes):
            interference_factor = interference_matrix[i].mean()
            modified_amplitudes.append(amp * interference_factor)
        
        # Combine samples with modified amplitudes
        total_amp = sum(a.sum() for a in modified_amplitudes)
        weights = [a / total_amp for a in modified_amplitudes]
        
        # Weighted combination
        obs_combined = sum(w.reshape(-1, 1) * s[0] for w, s in zip(weights, samples))
        act_combined = sum(w.reshape(-1, 1) * s[1] for w, s in zip(weights, samples))
        rew_combined = sum(w * s[2] for w, s in zip(weights, samples))
        cont_combined = sum(w * s[3] for w, s in zip(weights, samples))
        
        final_weights = sum(weights) / num_samples
        
        return obs_combined, act_combined, rew_combined, cont_combined, final_weights


print("Interference calculator defined.")

---
## 5. Superposition Replay Buffer

In [None]:
"""
Cell: Superposition Replay Buffer
Purpose: Implement the complete superposition-enhanced replay buffer
"""

@dataclass
class Episode:
    observations: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    dones: np.ndarray
    amplitudes: Optional[np.ndarray] = None
    
    def __len__(self): return len(self.observations)


class SuperpositionReplayBuffer:
    """
    Quantum superposition-inspired experience replay buffer.
    
    Features:
    - Amplitude-based sampling (prioritized by quantum-inspired weights)
    - Parallel trajectory sampling (superposition of experiences)
    - Interference effects between similar experiences
    
    Parameters
    ----------
    capacity : int
        Maximum number of episodes to store
    parallel_samples : int
        Number of parallel trajectories to sample (superposition width)
    interference_strength : float
        Strength of interference effects
    amplitude_temperature : float
        Temperature for amplitude calculation
    """
    
    def __init__(
        self,
        capacity: int = 1000,
        parallel_samples: int = 4,
        interference_strength: float = 0.2,
        amplitude_temperature: float = 1.0
    ):
        self.capacity = capacity
        self.parallel_samples = parallel_samples
        
        self.episodes: List[Episode] = []
        self.total_steps = 0
        
        self.amplitude_calc = AmplitudeCalculator(
            temperature=amplitude_temperature
        )
        self.interference_calc = InterferenceCalculator(
            interference_strength=interference_strength
        )
    
    def add_episode(self, episode: Episode) -> None:
        """
        Add an episode and compute its amplitudes.
        """
        # Compute amplitudes for the episode
        amplitudes = self.amplitude_calc.compute_amplitudes(
            episode.observations, episode.rewards
        )
        episode.amplitudes = amplitudes
        
        if len(self.episodes) >= self.capacity:
            removed = self.episodes.pop(0)
            self.total_steps -= len(removed)
        
        self.episodes.append(episode)
        self.total_steps += len(episode)
    
    def _sample_single_sequence(
        self,
        seq_len: int
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Sample a single sequence with amplitude-based prioritization.
        """
        valid_episodes = [ep for ep in self.episodes if len(ep) >= seq_len]
        
        # Sample episode weighted by total episode amplitude
        episode_weights = np.array([ep.amplitudes.sum() for ep in valid_episodes])
        episode_weights = episode_weights / episode_weights.sum()
        
        ep_idx = np.random.choice(len(valid_episodes), p=episode_weights)
        episode = valid_episodes[ep_idx]
        
        # Sample start position weighted by amplitudes
        max_start = len(episode) - seq_len
        start_weights = episode.amplitudes[:max_start + 1]
        start_weights = start_weights / start_weights.sum()
        
        start = np.random.choice(max_start + 1, p=start_weights)
        
        obs = episode.observations[start:start + seq_len]
        actions = episode.actions[start:start + seq_len]
        rewards = episode.rewards[start:start + seq_len]
        continues = 1.0 - episode.dones[start:start + seq_len]
        amplitudes = episode.amplitudes[start:start + seq_len]
        
        return obs, actions, rewards, continues, amplitudes
    
    def sample_superposition(
        self,
        batch_size: int,
        seq_len: int,
        apply_interference: bool = True
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Sample a batch with superposition-enhanced replay.
        
        Parameters
        ----------
        batch_size : int
            Number of sequences in batch
        seq_len : int
            Length of each sequence
        apply_interference : bool
            Whether to apply interference effects
        
        Returns
        -------
        Tuple[np.ndarray, ...]
            (observations, actions, rewards, continues, weights)
        """
        obs_batch, act_batch, rew_batch, cont_batch, weight_batch = [], [], [], [], []
        
        for _ in range(batch_size):
            # Sample parallel sequences (superposition)
            parallel_samples = []
            parallel_amplitudes = []
            
            for _ in range(self.parallel_samples):
                obs, act, rew, cont, amp = self._sample_single_sequence(seq_len)
                parallel_samples.append((obs, act, rew, cont))
                parallel_amplitudes.append(amp)
            
            if apply_interference and len(parallel_samples) > 1:
                # Apply interference effects
                obs, act, rew, cont, weights = self.interference_calc.apply_interference(
                    parallel_samples, parallel_amplitudes
                )
            else:
                # Simple amplitude-weighted average
                total_amp = sum(a.sum() for a in parallel_amplitudes)
                weights = [a / total_amp for a in parallel_amplitudes]
                
                obs = sum(w.reshape(-1, 1) * s[0] for w, s in zip(weights, parallel_samples))
                act = sum(w.reshape(-1, 1) * s[1] for w, s in zip(weights, parallel_samples))
                rew = sum(w * s[2] for w, s in zip(weights, parallel_samples))
                cont = sum(w * s[3] for w, s in zip(weights, parallel_samples))
                weights = sum(weights) / len(weights)
            
            obs_batch.append(obs)
            act_batch.append(act)
            rew_batch.append(rew)
            cont_batch.append(cont)
            weight_batch.append(weights)
        
        return (
            np.stack(obs_batch),
            np.stack(act_batch),
            np.stack(rew_batch),
            np.stack(cont_batch),
            np.stack(weight_batch)
        )
    
    def sample_standard(
        self,
        batch_size: int,
        seq_len: int
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Standard sampling (no superposition) for comparison.
        """
        valid = [ep for ep in self.episodes if len(ep) >= seq_len]
        obs_b, act_b, rew_b, cont_b = [], [], [], []
        
        for _ in range(batch_size):
            ep = valid[np.random.randint(len(valid))]
            start = np.random.randint(0, len(ep) - seq_len + 1)
            obs_b.append(ep.observations[start:start+seq_len])
            act_b.append(ep.actions[start:start+seq_len])
            rew_b.append(ep.rewards[start:start+seq_len])
            cont_b.append(1.0 - ep.dones[start:start+seq_len])
        
        return np.stack(obs_b), np.stack(act_b), np.stack(rew_b), np.stack(cont_b)
    
    def __len__(self) -> int:
        return len(self.episodes)


print("Superposition replay buffer defined.")

In [None]:
"""
Cell: Data Collection
Purpose: Collect episodes for testing
"""

def collect_episodes(env_name: str, num_episodes: int, seed: int = 42):
    env = gym.make(env_name)
    episodes = []
    for ep_idx in tqdm(range(num_episodes), desc="Collecting"):
        obs_l, act_l, rew_l, done_l = [], [], [], []
        obs, _ = env.reset(seed=seed+ep_idx)
        done = False
        while not done:
            action = env.action_space.sample()
            obs_l.append(obs)
            if isinstance(env.action_space, gym.spaces.Discrete):
                act_oh = np.zeros(env.action_space.n)
                act_oh[action] = 1.0
                act_l.append(act_oh)
            else:
                act_l.append(action)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            rew_l.append(reward)
            done_l.append(float(done))
        episodes.append(Episode(
            np.array(obs_l, np.float32), np.array(act_l, np.float32),
            np.array(rew_l, np.float32), np.array(done_l, np.float32)
        ))
    env.close()
    return episodes

# Collect and add to buffer
print("Collecting training data...")
episodes = collect_episodes("CartPole-v1", num_episodes=100, seed=SEED)

# Create superposition buffer
super_buffer = SuperpositionReplayBuffer(
    capacity=1000,
    parallel_samples=4,
    interference_strength=0.2,
    amplitude_temperature=1.0
)

for ep in episodes:
    super_buffer.add_episode(ep)

print(f"Superposition buffer: {len(super_buffer)} episodes, {super_buffer.total_steps} steps")

# Test sampling
obs, act, rew, cont, weights = super_buffer.sample_superposition(batch_size=8, seq_len=20)
print(f"Sample shapes: obs={obs.shape}, weights={weights.shape}")

---
## 6. Integration with World Model

In [None]:
"""
Cell: Loss Functions
Purpose: World model loss computation
"""

def compute_world_model_loss(output, obs_seq, reward_seq, continue_seq, 
                            kl_weight=1.0, free_nats=1.0, sample_weights=None):
    recon_loss = F.mse_loss(output['recon_obs'], obs_seq, reduction='none')
    reward_loss = F.mse_loss(output['pred_rewards'], reward_seq, reduction='none')
    continue_loss = F.binary_cross_entropy_with_logits(
        output['pred_continues'], continue_seq, reduction='none'
    )
    
    # Apply sample weights if provided (superposition weighting)
    if sample_weights is not None:
        weights = torch.tensor(sample_weights, device=obs_seq.device, dtype=torch.float32)
        if weights.dim() == 2:  # (batch, seq)
            recon_loss = (recon_loss.mean(-1) * weights).mean()
            reward_loss = (reward_loss * weights).mean()
            continue_loss = (continue_loss * weights).mean()
        else:
            recon_loss = recon_loss.mean()
            reward_loss = reward_loss.mean()
            continue_loss = continue_loss.mean()
    else:
        recon_loss = recon_loss.mean()
        reward_loss = reward_loss.mean()
        continue_loss = continue_loss.mean()
    
    kl_losses = []
    for prior, posterior in zip(output['priors'], output['posteriors']):
        kl = torch.distributions.kl_divergence(posterior, prior).sum(-1)
        kl = torch.clamp(kl, min=free_nats).mean()
        kl_losses.append(kl)
    kl_loss = torch.stack(kl_losses).mean()
    
    total = kl_weight * kl_loss + recon_loss + reward_loss + continue_loss
    return {'total': total, 'kl': kl_loss, 'recon': recon_loss, 
            'reward': reward_loss, 'continue': continue_loss}

In [None]:
"""
Cell: Superposition Trainer
Purpose: Training loop with superposition replay
"""

class SuperpositionTrainer:
    def __init__(self, model, buffer, lr=3e-4, device=DEVICE):
        self.model = model
        self.buffer = buffer
        self.device = device
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.logger = MetricLogger(name="superposition")
    
    def train_step(self, batch_size=32, seq_len=20, kl_weight=1.0, use_superposition=True):
        self.model.train()
        
        if use_superposition:
            obs, act, rew, cont, weights = self.buffer.sample_superposition(batch_size, seq_len)
        else:
            obs, act, rew, cont = self.buffer.sample_standard(batch_size, seq_len)
            weights = None
        
        obs = torch.tensor(obs, dtype=torch.float32, device=self.device)
        act = torch.tensor(act, dtype=torch.float32, device=self.device)
        rew = torch.tensor(rew, dtype=torch.float32, device=self.device)
        cont = torch.tensor(cont, dtype=torch.float32, device=self.device)
        
        self.optimizer.zero_grad()
        output = self.model(obs, act)
        losses = compute_world_model_loss(output, obs, rew, cont, kl_weight, sample_weights=weights)
        losses['total'].backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), 100.0)
        self.optimizer.step()
        
        result = {k: v.item() for k, v in losses.items()}
        self.logger.log(**result)
        return result
    
    def train(self, num_steps=5000, batch_size=32, seq_len=20, log_every=100, 
              kl_weight=1.0, use_superposition=True):
        desc = "Superposition Training" if use_superposition else "Standard Training"
        pbar = tqdm(range(num_steps), desc=desc)
        for step in pbar:
            self.train_step(batch_size, seq_len, kl_weight, use_superposition)
            if step % log_every == 0:
                pbar.set_postfix({
                    'total': f"{self.logger.get_mean('total', 100):.4f}",
                    'recon': f"{self.logger.get_mean('recon', 100):.4f}"
                })
        return self.logger.to_dataframe()

print("Superposition trainer defined.")

---
## 7. Experiments

In [None]:
"""
Cell: Train with Superposition Replay
Purpose: Train world model with superposition-enhanced replay
"""

# Create model
set_seed(SEED)
super_model = WorldModel(
    obs_dim=4, action_dim=2, stoch_dim=32, deter_dim=128, hidden_dim=128,
    encoder_hidden=[128, 128], decoder_hidden=[128, 128], predictor_hidden=[128, 128]
).to(DEVICE)

print(f"Model Parameters: {sum(p.numel() for p in super_model.parameters()):,}")

# Train with superposition
super_trainer = SuperpositionTrainer(super_model, super_buffer, lr=3e-4, device=DEVICE)

print("\nTraining with superposition replay...")
super_history = super_trainer.train(
    num_steps=5000, batch_size=32, seq_len=20, log_every=100,
    kl_weight=1.0, use_superposition=True
)

print(f"\nSuperposition Training complete!")
print(f"  Final loss: {super_history['total'].iloc[-100:].mean():.4f}")

In [None]:
"""
Cell: Train Standard Baseline
Purpose: Train with standard replay for comparison
"""

set_seed(SEED)
standard_model = WorldModel(
    obs_dim=4, action_dim=2, stoch_dim=32, deter_dim=128, hidden_dim=128,
    encoder_hidden=[128, 128], decoder_hidden=[128, 128], predictor_hidden=[128, 128]
).to(DEVICE)

standard_trainer = SuperpositionTrainer(standard_model, super_buffer, lr=3e-4, device=DEVICE)

print("\nTraining with standard replay...")
standard_history = standard_trainer.train(
    num_steps=5000, batch_size=32, seq_len=20, log_every=100,
    kl_weight=1.0, use_superposition=False
)

print(f"\nStandard Training complete!")
print(f"  Final loss: {standard_history['total'].iloc[-100:].mean():.4f}")

---
## 8. Comparison with Standard Replay

In [None]:
"""
Cell: Statistical Comparison
Purpose: Compare superposition vs standard replay
"""

def compare_methods(hist1, hist2, name1, name2, window=100):
    final1 = hist1['total'].iloc[-window:].values
    final2 = hist2['total'].iloc[-window:].values
    
    stat, p_value = stats.mannwhitneyu(final1, final2, alternative='two-sided')
    pooled_std = np.sqrt((final1.std()**2 + final2.std()**2) / 2)
    cohens_d = (final1.mean() - final2.mean()) / pooled_std if pooled_std > 0 else 0
    
    return {
        'method1': name1, 'method2': name2,
        'mean_loss_1': final1.mean(), 'mean_loss_2': final2.mean(),
        'std_loss_1': final1.std(), 'std_loss_2': final2.std(),
        'p_value': p_value, 'cohens_d': cohens_d,
        'significant': p_value < 0.05
    }

comparison = compare_methods(super_history, standard_history, "Superposition", "Standard")

print("\n" + "="*60)
print("COMPARISON: Superposition vs Standard Replay")
print("="*60)
print(f"\nFinal Loss:")
print(f"  Superposition: {comparison['mean_loss_1']:.4f} +/- {comparison['std_loss_1']:.4f}")
print(f"  Standard:      {comparison['mean_loss_2']:.4f} +/- {comparison['std_loss_2']:.4f}")
print(f"\nStatistics:")
print(f"  p-value: {comparison['p_value']:.4f}")
print(f"  Cohen's d: {comparison['cohens_d']:.3f}")
print(f"  Significant: {comparison['significant']}")

---
## 9. Visualizations

In [None]:
"""
Cell: Comparison Plots
Purpose: Visualize superposition vs standard
"""

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
window = 50

# Total loss
ax = axes[0, 0]
ax.plot(super_history['total'].rolling(window).mean(), 
       color=COLORS['superposition'], label='Superposition', linewidth=2)
ax.plot(standard_history['total'].rolling(window).mean(),
       color=COLORS['baseline'], label='Standard', linewidth=2)
ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Total Loss')
ax.legend(); ax.grid(True, alpha=0.3)

# Reconstruction loss
ax = axes[0, 1]
ax.plot(super_history['recon'].rolling(window).mean(),
       color=COLORS['superposition'], label='Superposition', linewidth=2)
ax.plot(standard_history['recon'].rolling(window).mean(),
       color=COLORS['baseline'], label='Standard', linewidth=2)
ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Reconstruction Loss')
ax.legend(); ax.grid(True, alpha=0.3)

# KL loss
ax = axes[1, 0]
ax.plot(super_history['kl'].rolling(window).mean(),
       color=COLORS['superposition'], label='Superposition', linewidth=2)
ax.plot(standard_history['kl'].rolling(window).mean(),
       color=COLORS['baseline'], label='Standard', linewidth=2)
ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('KL Loss')
ax.legend(); ax.grid(True, alpha=0.3)

# Final distribution
ax = axes[1, 1]
ax.hist(super_history['total'].iloc[-200:], bins=30, alpha=0.6, 
       color=COLORS['superposition'], label='Superposition')
ax.hist(standard_history['total'].iloc[-200:], bins=30, alpha=0.6,
       color=COLORS['baseline'], label='Standard')
ax.set_xlabel('Loss'); ax.set_ylabel('Frequency'); ax.set_title('Final Loss Distribution')
ax.legend(); ax.grid(True, alpha=0.3)

fig.suptitle('Superposition vs Standard Replay Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---
## 10. Summary

In [None]:
"""
Cell: Save and Summary
Purpose: Save results and display summary
"""

# Save results
results_dir = PROJECT_ROOT / "experiments" / "results" / "superposition"
results_dir.mkdir(parents=True, exist_ok=True)

torch.save({'model_state_dict': super_model.state_dict(), 'config': {'parallel_samples': 4}},
           results_dir / "cartpole_superposition.pt")
super_history.to_csv(results_dir / "superposition_history.csv", index=False)

print("="*60)
print("PHASE 4 COMPLETE: SUPERPOSITION-ENHANCED REPLAY")
print("="*60)
print("\n[1] Components Implemented")
print("    - AmplitudeCalculator: Quantum-inspired experience weighting")
print("    - InterferenceCalculator: Constructive/destructive interference")
print("    - SuperpositionReplayBuffer: Full quantum-inspired replay")
print(f"\n[2] Configuration")
print(f"    - Parallel samples: 4")
print(f"    - Interference strength: 0.2")
print(f"\n[3] Results")
print(f"    - Superposition loss: {comparison['mean_loss_1']:.4f}")
print(f"    - Standard loss: {comparison['mean_loss_2']:.4f}")
print(f"    - p-value: {comparison['p_value']:.4f}")
print("\n" + "="*60)
print("NEXT: Phase 5 - Gate-Enhanced Neural Layers")
print("="*60)