In [None]:
# ============================================================
# Environment Setup (with kernel caching)
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

exec(open("/content/drive/MyDrive/MambaCompression/setup_colab.py").read())
!pip install compressai --quiet 2>/dev/null

Mounted at /content/drive
=== 1. Core Dependencies ===

=== 2. VMamba CUDA Kernel (ss2d) ===
Current GPU: Tesla T4 (sm_75)
Cache arch matches current GPU (sm_75) ✓
Cache found! Restoring 1 kernel files...
  Restored: selective_scan_cuda_oflex.cpython-312-x86_64-linux-gnu.so -> /usr/local/lib/python3.12/dist-packages/selective_scan_cuda_oflex.cpython-312-x86_64-linux-gnu.so
selective_scan_cuda_oflex imported OK (sm_75)

=== Setup Complete ===
Project: /content/drive/MyDrive/MambaCompression
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m444.5/444.5 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m80.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K  

In [None]:
# ============================================================
# Cell 2: Full RP-MPQ Analysis (INT8 + Quantization)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

!python train_ae.py \
  --checkpoint "saved_models/mamba_transnet_L2_dim512_baseline/best.pth" \
  --encoder mamba \
  --decoder transnet \
  --encoded_dim 512 \
  --train_path data/DATA_Htrainout.mat \
  --test_path data/DATA_Htestout.mat \
  --epochs 0 \
  --learning-rate 1e-3 \
  --decoder_layers 2 \
  --encoder_layers 2 \
  --batch-size 200 \
  --num-workers 4 \
  --quant_type asym \
  --pq INT8 \
  --aq 8 \
  --act_quant 16 \
  --analyze_all

In [None]:
# ============================================================
# Cell 3: FP32 Baseline Inference (No Quantization)
# Expected NMSE: -15.34 dB (outdoor, mamba+transnet, dim=512)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

!python train_ae.py \
  --checkpoint "saved_models/mamba_transnet_L2_dim512_baseline/best.pth" \
  --encoder mamba \
  --decoder transnet \
  --encoded_dim 512 \
  --train_path data/DATA_Htrainout.mat \
  --test_path data/DATA_Htestout.mat \
  --epochs 0 \
  --decoder_layers 2 \
  --encoder_layers 2 \
  --batch-size 200 \
  --test-batch-size 200 \
  --num-workers 4 \
  --pq FP32 \
  --aq 0 \
  --act_quant 32

In [None]:
# ============================================================
# Cell 4: Mamba-Transformer AE CR=1/4 Uniform Quantization Sweep
# Weights: INT16 / INT8 / INT4 / INT2,  Activations: INT16 (fixed)
# Latent: 8-bit (--aq 8, CsiNet/CLNet과 동일 조건)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

BASE = (
    "python train_ae.py"
    " --checkpoint saved_models/mamba_transnet_L2_dim512_baseline/best.pth"
    " --encoder mamba --decoder transnet --encoded_dim 512"
    " --train_path data/DATA_Htrainout.mat --test_path data/DATA_Htestout.mat"
    " --epochs 0 --decoder_layers 2 --encoder_layers 2"
    " --batch-size 200 --test-batch-size 200 --num-workers 2"
    " --aq 8 --act_quant 16"
)

for prec in ["INT16", "INT8", "INT4", "INT2"]:
    print(f"\n{'#'*60}")
    print(f"# Running: W={prec}  A=INT16  Latent=8bit")
    print(f"{'#'*60}")
    !{BASE} --pq {prec}

In [None]:
# ============================================================
# Cell 5: RP-MPQ Offline Policy Search (ILP + KL Refinement)
# Figure 2: ILP vs KL-refined policy evaluation
# Range: 75% ~ 95% BOPs Saving | Step: 0.5% | KL Candidates: 10
# Output: results/csv/mp_policy_lut_mamba_raw.csv      ← raw (새로 추가)
#         results/csv/mp_policy_lut_mamba_pruned.csv   ← monotonic smoothed
#         results/csv/fitting_raw_data_mamba.csv
#         results/plots/exp1_pareto_accuracy_mamba_raw.png       ← raw 그림 (별도)
#         results/plots/exp1_pareto_accuracy_mamba_monotonic.png ← monotonic 그림 (별도)
#         ../figures/kl_vs_ilp_raw.pdf
#         ../figures/kl_vs_ilp_monotonic.pdf
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

import os
csv_dir = "results/csv"
for f in ["mp_policy_lut_mamba_pruned.csv",
          "mp_policy_lut_mamba_raw.csv",
          "fitting_raw_data_mamba.csv"]:
    p = os.path.join(csv_dir, f)
    if os.path.exists(p):
        os.remove(p)
        print(f"Removed: {p}")

!pip install pulp -q

!python train_ae.py \
  --checkpoint "saved_models/mamba_transnet_L2_dim512_baseline/best.pth" \
  --encoder mamba \
  --decoder transnet \
  --encoded_dim 512 \
  --train_path data/DATA_Htrainout.mat \
  --test_path data/DATA_Htestout.mat \
  --epochs 0 \
  --decoder_layers 2 \
  --encoder_layers 2 \
  --batch-size 200 \
  --test-batch-size 200 \
  --num-workers 2 \
  --pq INT8 \
  --aq 8 \
  --act_quant 16 \
  --analyze_all

In [None]:
# ============================================================
# Cell 6: Mamba-Transformer AE  CR=1/16  Training  (1000 epochs, resumable)
#
# Resume logic:
#   - best.pth가 존재하면 저장된 절대 epoch을 읽어 남은 epoch만 학습
#   - Colab이 꺼져도 다시 이 셀을 실행하면 정확히 이어서 학습
#   - 모델 저장 경로: saved_models/mamba_transnet_dim128_cr16/best.pth
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

import os, torch

# ── Config ──────────────────────────────────────────────────────────────
SAVE_DIR      = "saved_models/mamba_transnet_dim128_cr16"
TARGET_EPOCHS = 1000
ENCODED_DIM   = 128          # CR = 1/16  (2048 / 16)
LR            = 1e-3
# ────────────────────────────────────────────────────────────────────────

os.makedirs(SAVE_DIR, exist_ok=True)
ckpt = f"{SAVE_DIR}/best.pth"

# ── Resume 포인트 결정 (절대 epoch 기준) ─────────────────────────────────
if os.path.isfile(ckpt):
    state = torch.load(ckpt, map_location="cpu")
    done      = int(state.get("epoch", -1)) + 1   # 절대 epoch (누적)
    remaining = TARGET_EPOCHS - done
    print(f"[Resume] Checkpoint found  →  epoch {done}/{TARGET_EPOCHS} done, {remaining} remaining")
else:
    done      = 0
    remaining = TARGET_EPOCHS
    print(f"[Fresh ] No checkpoint     →  training from scratch for {remaining} epochs")

if remaining <= 0:
    print(f"[Done  ] Already reached {TARGET_EPOCHS} epochs. Run the inference cell below.")
else:
    print(f"[Train ] Running {remaining} epoch(s)  (best.pth → {SAVE_DIR})\n")
    !python train_ae.py \
      --encoder mamba \
      --decoder transnet \
      --encoded_dim {ENCODED_DIM} \
      --save_dir    {SAVE_DIR} \
      --checkpoint  {ckpt} \
      --start_epoch {done} \
      --train_path  data/DATA_Htrainout.mat \
      --test_path   data/DATA_Htestout.mat \
      --epochs      {remaining} \
      --learning-rate {LR} \
      --decoder_layers 2 \
      --encoder_layers 2 \
      --batch-size      200 \
      --test-batch-size 200 \
      --num-workers 4 \
      --pq FP32 --aq 0 --act_quant 32

In [None]:
# ============================================================
# Cell 7: Mamba-Transformer AE  CR=1/16  FP32 Inference
# Expected: NMSE (dB), Enc FLOPs, Total FLOPs
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

!python train_ae.py \
  --checkpoint "saved_models/mamba_transnet_dim128_cr16/best.pth" \
  --encoder mamba \
  --decoder transnet \
  --encoded_dim 128 \
  --train_path data/DATA_Htrainout.mat \
  --test_path  data/DATA_Htestout.mat \
  --epochs 0 \
  --decoder_layers 2 \
  --encoder_layers 2 \
  --batch-size      200 \
  --test-batch-size 200 \
  --num-workers 4 \
  --pq FP32 --aq 0 --act_quant 32

In [None]:
# ============================================================
# Cell 8: Budget Consistency Validation  ← Table III in paper
# "Budget Consistency under Online RP-MPQ"
#
# Prerequisites:
#   - Cell 5 완료 (mp_policy_lut_mamba_pruned.csv 존재)
#   - best.pth 체크포인트 존재
#
# 소요시간: ~20-30min (GPU)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC
import os

LOG = "results/csv/exp4_budget_run.log"

# ── Step 1: Run (stdout+stderr → log file so nothing is truncated) ──────
!python train_ae.py \
  --checkpoint "saved_models/mamba_transnet_L2_dim512_baseline/best.pth" \
  --encoder mamba \
  --decoder transnet \
  --encoded_dim 512 \
  --train_path data/DATA_Htrainout.mat \
  --test_path  data/DATA_Htestout.mat \
  --epochs 0 \
  --decoder_layers 2 --encoder_layers 2 \
  --batch-size 200 --test-batch-size 200 \
  --num-workers 2 \
  --aq 8 --act_quant 16 \
  --analyze_all \
  > {LOG} 2>&1 && echo "✅ Done" || echo "❌ Script failed — see log"

# ── Step 2: Check last 60 lines of log (errors are usually at the end) ──
print("\n─── Last 60 lines of log ────────────────────────────────────────")
with open(LOG) as f:
    lines = f.readlines()
for l in lines[-60:]:
    print(l, end="")

# ── Step 3: Display Budget Consistency Table (if CSV updated) ────────────
print("\n\n─── Budget Consistency Table ────────────────────────────────────")
import pandas as pd
import numpy as np

csv_path = "results/csv/ranc_simulation_results_mamba.csv"
df = pd.read_csv(csv_path)
print("Columns:", df.columns.tolist())

if 'Target_Saving' not in df.columns:
    print("\n[WARNING] Old CSV format — exp4 failed to overwrite.")
    print("Scroll up or open the log file for the traceback:")
    print(f"  → {os.path.abspath(LOG)}")
else:
    sub = (df[(df['SNR_Context'] == 20) & (df['QoS_Target'] == 0.99)]
             [['Target_Saving', 'Realized_Saving', 'Lambda']]
             .drop_duplicates('Target_Saving')
             .sort_values('Target_Saving')
             .reset_index(drop=True))

    sub['Deviation (%)'] = (
        (sub['Realized_Saving'] - sub['Target_Saving']).abs()
        / sub['Target_Saving'] * 100
    ).round(3)

    rep = [87.5, 90.0, 92.5]
    mask = sub['Target_Saving'].round(1).isin(rep)
    table = sub[mask].copy() if mask.any() else sub.iloc[len(sub)//4 : 3*len(sub)//4 : max(1, len(sub)//4)]

    print(table[['Target_Saving', 'Realized_Saving', 'Deviation (%)']].to_string(index=False))
    print(f"\n  Max deviation: {table['Deviation (%)'].max():.3f}%")

    print("\n─── LaTeX rows ─────────────────────────────────────────────────")
    for _, row in table.iterrows():
        t, r, d = row['Target_Saving'], row['Realized_Saving'], row['Deviation (%)']
        print(f"  ${t:.1f}\\%$ & ${r:.2f}\\%$ & ${d:.3f}$ \\\\")


In [None]:
# ============================================================
# Cell 9: RP-MPQ Offline Wide Sweep (85–98%, 0.05% Step)
#
# 목적: ILP prediction vs KL refinement 비교 (raw + monotonic smoothed)
# Prerequisites:
#   - Cell 1 (환경 설정) 완료
#   - HAWQ results 존재 (Cell 5 한 번 이상 실행 필요)
#
# Output:
#   results/csv/mp_policy_lut_mamba_wide_raw.csv      ← raw (smoothing 전)
#   results/csv/mp_policy_lut_mamba_wide_pruned.csv   ← monotonic smoothed
#   results/plots/exp1_pareto_accuracy_mamba_raw.png
#   results/plots/exp1_pareto_accuracy_mamba_monotonic.png
#
# 소요시간: ~4-6h (GPU, 261 policies)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

import os
csv_dir = "results/csv"
for f in ["mp_policy_lut_mamba_wide_raw.csv",
          "mp_policy_lut_mamba_wide_pruned.csv"]:
    p = os.path.join(csv_dir, f)
    if os.path.exists(p):
        os.remove(p)
        print(f"Removed: {p}")

!pip install pulp -q

!python train_ae.py \
  --checkpoint "saved_models/mamba_transnet_L2_dim512_baseline/best.pth" \
  --encoder mamba \
  --decoder transnet \
  --encoded_dim 512 \
  --train_path data/DATA_Htrainout.mat \
  --test_path data/DATA_Htestout.mat \
  --epochs 0 \
  --decoder_layers 2 \
  --encoder_layers 2 \
  --batch-size 200 \
  --test-batch-size 200 \
  --num-workers 2 \
  --pq INT8 \
  --aq 8 \
  --act_quant 16 \
  --analyze_all \
  --wide_sweep \
  --wide_step 0.05

/content/drive/MyDrive/MambaCompression/MambaIC
--- Start: 2026-02-24 07:23:23.689991 ---
[INFO] Config: W:[INT8] A:[INT16] FB:[8-bit]
    Hybrid: False | Chunking: False
[INFO] Device: CUDA
  @torch.cuda.amp.custom_fwd
  @torch.cuda.amp.custom_bwd
  @torch.cuda.amp.custom_fwd
  @torch.cuda.amp.custom_bwd
  @torch.cuda.amp.custom_fwd
  @torch.cuda.amp.custom_bwd
[INFO] Building: UE Encoder [mamba-L2] + BS Decoder [transnet-L2]
  scaler = GradScaler(enabled=(device == 'cuda'))
[INFO] Loaded existing HAWQ results: /content/drive/MyDrive/MambaCompression/MambaIC/results/csv/hawq_importance_split.csv

[INFO] Starting unified scan... (85-98% | 0.05% step)

[INFO] Offline Policy Search: Range 85-95% | Step 0.1% | Points: 261
Scanning Pareto & Calibration: 100% 261/261 [42:41<00:00,  9.82s/it]

[INFO] Offline policy set + calibration data saved.
[INFO] Saved: /content/drive/MyDrive/MambaCompression/MambaIC/results/plots/exp1_pareto_accuracy_mamba_raw.png
Figure(700x450)
[INFO] Saved: /content

In [None]:
# ============================================================
# Cell 10: Fig. 2 — Offline Policy Refinement Ablation (CR=1/4, outdoor)
#
# (a) ILP-predicted vs KL-refined NMSE (monotonic-smoothed Pareto frontier)
# (b) Discrepancy |NMSE_ILP - NMSE_KL-Ref|
#
# Prerequisites: Cell 9 완료 (wide CSV 존재)
# ============================================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

%cd /content/drive/MyDrive/MambaCompression/MambaIC

CSV_DIR = "results/csv"
pruned_path = os.path.join(CSV_DIR, "mp_policy_lut_mamba_wide_pruned.csv")
assert os.path.exists(pruned_path), f"Pruned CSV not found: {pruned_path}  <- Cell 9 먼저 실행"

df = pd.read_csv(pruned_path).sort_values("Actual_Saving").reset_index(drop=True)
x = df["Actual_Saving"].values
nmse_ilp = df["NMSE_ILP"].values
nmse_kl  = df["NMSE_KL"].values
disc     = np.abs(nmse_ilp - nmse_kl)

# ── Figure 2: (a) + (b)  — solid lines + fill_between style ─────────────
plt.rcParams.update({
    'font.size': 11, 'font.family': 'serif',
    'mathtext.fontset': 'stix',
    'axes.labelsize': 12, 'axes.titlesize': 13,
})
fig, (ax_a, ax_b) = plt.subplots(1, 2, figsize=(12, 4.5))

# ── (a) ILP vs KL-Refined — solid lines ──────────────────────────────────
ax_a.plot(x, nmse_ilp, 'b-s', label='ILP-predicted', markersize=3, linewidth=1.8, alpha=0.85)
ax_a.plot(x, nmse_kl,  'r-o', label='KL-refined',    markersize=3, linewidth=1.8, alpha=0.85)
ax_a.set_xlabel('BOPs Saving vs. FP32 (%)')
ax_a.set_ylabel('NMSE (dB)')
ax_a.set_title('(a) ILP vs. KL-Refined Policy')
ax_a.legend(fontsize=10, loc='upper left')
ax_a.grid(True, linestyle='--', alpha=0.35)

# ── (b) Discrepancy — fill_between + line ─────────────────────────────────
ax_b.fill_between(x, 0, disc, color='steelblue', alpha=0.3)
ax_b.plot(x, disc, 'b-', linewidth=1.8, alpha=0.85)
ax_b.set_xlabel('BOPs Saving vs. FP32 (%)')
ax_b.set_ylabel('|NMSE_ILP - NMSE_KL-Ref| (dB)')
ax_b.set_title('(b) ILP vs. KL-Refined Discrepancy')
ax_b.grid(True, linestyle='--', alpha=0.35)

fig.tight_layout()

out_png = "results/plots/fig2_offline_ablation_mamba.png"
out_pdf = "../figures/fig2_offline_ablation_mamba.pdf"
os.makedirs(os.path.dirname(out_png), exist_ok=True)
os.makedirs(os.path.dirname(out_pdf), exist_ok=True)
fig.savefig(out_png, dpi=300, bbox_inches='tight')
fig.savefig(out_pdf, dpi=300, bbox_inches='tight')
print(f"Saved: {out_png}")
print(f"Saved: {out_pdf}")
plt.show()

# ── Summary stats ─────────────────────────────────────────────────────────
idx_max = np.argmax(disc)
print(f"\n[Summary] {len(df)} policies (step ~{np.median(np.diff(x)):.2f}%)")
print(f"  Max discrepancy: {disc.max():.2f} dB  @ {x[idx_max]:.1f}% saving")
print(f"  Mean discrepancy: {disc.mean():.3f} dB")
print(f"  Discrepancy > 0.5 dB: {(disc > 0.5).sum()} points")
print(f"  Range where KL improves: {x[disc > 0.01][0]:.1f}% ~ {x[disc > 0.01][-1]:.1f}%")

In [None]:
# ============================================================
# Cell 11: Baselines (CRNet + CLNet) 0.05% Wide Sweep
#
# Prerequisites: Cell 1 완료
# Output:
#   MambaIC/results/csv/mp_policy_lut_crnet_cr4_out.csv  (덮어쓰기)
#   MambaIC/results/csv/mp_policy_lut_clnet_cr4_out.csv  (덮어쓰기)
#
# 소요시간: ~1-2h (GPU, 2 models × ~261 policies)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression

!pip install pulp -q

!python rpmpq_baselines.py --wide_step 0.05

In [None]:
# ============================================================
# Cell 12: CsiNet 0.05% Wide Sweep (Keras/TF)
#
# Prerequisites: Cell 1 완료
# Output:
#   MambaIC/results/csv/mp_policy_lut_csinet_cr4_out.csv  (덮어쓰기)
#
# 소요시간: ~30-60min (GPU, 1 model × ~261 policies)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/Python_CsiNet-master

!python csinet_onlytest.py --env outdoor --analyze_all --wide_step 0.05

In [None]:
# ============================================================
# Cell 13: Ablation — ILP Granularity (num_chunks=8)
#
# FC layer를 8개 chunk로 분할 → ILP 변수 감소 → 계단식 staircase
# KL refinement 효과가 극대화되는지 확인
#
# 기존 파일 덮어쓰지 않음 (별도 suffix: _nc8)
# Output:
#   results/csv/hawq_importance_split_nc8.csv
#   results/csv/mp_policy_lut_mamba_wide_nc8_raw.csv
#   results/csv/mp_policy_lut_mamba_wide_nc8_pruned.csv
#
# 소요시간: ~20-30min (GPU, ~131 policies at 0.1% step)
# ============================================================
%cd /content/drive/MyDrive/MambaCompression/MambaIC

import os
csv_dir = "results/csv"
for f in ["mp_policy_lut_mamba_wide_nc8_raw.csv",
          "mp_policy_lut_mamba_wide_nc8_pruned.csv"]:
    p = os.path.join(csv_dir, f)
    if os.path.exists(p):
        os.remove(p)
        print(f"Removed: {p}")

!pip install pulp -q

!python train_ae.py \
  --checkpoint "saved_models/mamba_transnet_L2_dim512_baseline/best.pth" \
  --encoder mamba \
  --decoder transnet \
  --encoded_dim 512 \
  --train_path data/DATA_Htrainout.mat \
  --test_path data/DATA_Htestout.mat \
  --epochs 0 \
  --decoder_layers 2 \
  --encoder_layers 2 \
  --batch-size 200 \
  --test-batch-size 200 \
  --num-workers 2 \
  --pq INT8 \
  --aq 8 \
  --act_quant 16 \
  --analyze_all \
  --wide_sweep \
  --wide_step 0.1 \
  --num_chunks 8

In [None]:
# ============================================================
# Cell 14: Plot — num_chunks=8 vs 32 Ablation Comparison
#
# Prerequisites: Cell 9 (nc=32) + Cell 13 (nc=8) 완료
# ============================================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

%cd /content/drive/MyDrive/MambaCompression/MambaIC

CSV_DIR = "results/csv"
configs = {
    "nc=32 (default)": os.path.join(CSV_DIR, "mp_policy_lut_mamba_wide_pruned.csv"),
    "nc=8  (coarse)":  os.path.join(CSV_DIR, "mp_policy_lut_mamba_wide_nc8_pruned.csv"),
}

plt.rcParams.update({
    'font.size': 11, 'font.family': 'serif',
    'mathtext.fontset': 'stix',
    'axes.labelsize': 12, 'axes.titlesize': 13,
})
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for label, path in configs.items():
    if not os.path.exists(path):
        print(f"[SKIP] {label}: {path} not found")
        continue
    df = pd.read_csv(path).sort_values("Actual_Saving").reset_index(drop=True)
    x = df["Actual_Saving"].values
    disc = np.abs(df["NMSE_ILP"].values - df["NMSE_KL"].values)

    # (a) ILP vs KL
    ax = axes[0]
    style = '-' if 'default' in label else '--'
    ax.plot(x, df["NMSE_ILP"].values, f'b{style}', linewidth=1.5, alpha=0.7,
            label=f'ILP ({label})')
    ax.plot(x, df["NMSE_KL"].values, f'r{style}', linewidth=1.5, alpha=0.7,
            label=f'KL ({label})')

    # (b) Discrepancy
    ax2 = axes[1]
    ax2.plot(x, disc, style, linewidth=1.8, label=label)
    ax2.fill_between(x, 0, disc, alpha=0.15)

    print(f"[{label}] {len(df)} pts | Max disc: {disc.max():.2f} dB | Mean: {disc.mean():.3f} dB")

axes[0].set_xlabel('BOPs Saving vs. FP32 (%)')
axes[0].set_ylabel('NMSE (dB)')
axes[0].set_title('(a) ILP vs. KL-Refined Policy')
axes[0].legend(fontsize=8, loc='upper left')
axes[0].grid(True, linestyle='--', alpha=0.35)

axes[1].set_xlabel('BOPs Saving vs. FP32 (%)')
axes[1].set_ylabel('|NMSE_ILP - NMSE_KL-Ref| (dB)')
axes[1].set_title('(b) Discrepancy: nc=32 vs nc=8')
axes[1].legend(fontsize=10)
axes[1].grid(True, linestyle='--', alpha=0.35)

fig.tight_layout()
out = "results/plots/fig2_ablation_nc8_vs_nc32.png"
fig.savefig(out, dpi=300, bbox_inches='tight')
print(f"\nSaved: {out}")
plt.show()

In [None]:
# ============================================================
# Cell 15: Empirical Validation of Lemma 1
# "Contractive SSM bounds quantization-induced state error"
#
# Validates that per-token SSM state error ||e_t|| saturates
# as token position t increases, confirming the bounded-error
# guarantee under contractive state recursion (rho < 1).
#
# Uses the TRAINED model weights (A_logs, dt_projs, x_proj)
# with random input — the contractivity is a weight property.
#
# Output: results/plots/lemma1_ssm_state_error.png
#         ../figures/lemma1_ssm_state_error.pdf
#
# Prerequisites: Cell 1 (setup), trained model checkpoint
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os, sys, types

%cd /content/drive/MyDrive/MambaCompression/MambaIC
if '.' not in sys.path:
    sys.path.insert(0, '.')

# ── Workaround: Pre-load modules, bypassing models/__init__.py ────────
# Problem: models/__init__.py imports MambaIC → compressai → torch_geometric → crash
# Solution: Pre-load VSS_module and MambaAE directly into sys.modules
#           so Python never needs to execute models/__init__.py

# 1. Clear ALL cached models.* and related modules from previous runs
for _k in list(sys.modules.keys()):
    if _k == 'models' or _k.startswith('models.'):
        del sys.modules[_k]
for _k in ['ModularModels', 'MambaAE', 'VSS_module', 'csm_triton']:
    sys.modules.pop(_k, None)

# 2. Create minimal stub for models package
_models_dir = os.path.join(os.getcwd(), 'models')
_stub = types.ModuleType('models')
_stub.__path__ = [_models_dir]
_stub.__package__ = 'models'
sys.modules['models'] = _stub

# 3. Pre-load VSS_module from models/ directory (bypasses __init__.py)
_orig_syspath = sys.path[:]
sys.path.insert(0, _models_dir)
try:
    import VSS_module
    sys.modules['models.VSS_module'] = VSS_module
    _stub.VSS_module = VSS_module
finally:
    sys.path[:] = _orig_syspath

# 4. Pre-load MambaAE (its 'from models.VSS_module' finds pre-loaded module)
import MambaAE
sys.modules['models.MambaAE'] = MambaAE

# ── 1. Load model ──────────────────────────────────────────────────────
from ModularModels import ModularAE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ModularAE(encoder_type='mamba', decoder_type='transnet',
                  encoded_dim=512, encoder_layers=2, decoder_layers=2)
ckpt = torch.load('saved_models/mamba_transnet_L2_dim512_baseline/best.pth',
                   map_location='cpu')
state = ckpt.get('model_state_dict', ckpt.get('state_dict', ckpt))
model.load_state_dict(state, strict=False)
model = model.to(device).eval()
print(f"[OK] Model loaded on {device}")

# ── 2. Random input (Lemma depends on weights, not data) ──────────────
N_SAMPLES = 10
torch.manual_seed(42)
x_test = torch.rand(N_SAMPLES, 2, 32, 32, device=device)
print(f"[OK] Random input: {x_test.shape}")

# ── 3. Forward through encoder stem to get SS2D input ─────────────────
from einops import rearrange

mamba_blk = model.encoder.layers[0]   # first ChunkedResidualMambaBlock
ss2d      = mamba_blk.vss[1]          # SS2D module (index 1 in Sequential)

with torch.no_grad():
    x_stem = model.encoder.stem(x_test)
    x_norm = mamba_blk.norm(x_stem)
    x_act  = mamba_blk.act(x_norm)
    cs = mamba_blk.chunk_size
    x_chunked = rearrange(x_act,
        'b c (h cs_h) (w cs_w) -> (b h w) c cs_h cs_w',
        cs_h=cs, cs_w=cs)
    x_bhwc = x_chunked.permute(0, 2, 3, 1).contiguous()
    xz = ss2d.in_proj(x_bhwc)
    d_inner = xz.shape[-1] // 2
    x_proj = xz[..., :d_inner]
    z_gate = ss2d.act(xz[..., d_inner:])
    x_conv = x_proj.permute(0, 3, 1, 2).contiguous()
    x_conv = ss2d.act(ss2d.conv2d(x_conv))

B_sz, D_dim, H, W = x_conv.shape
L      = H * W
K      = 4
N_s    = ss2d.A_logs.shape[1]
R      = ss2d.dt_projs_weight.shape[2]
print(f"[OK] SS2D input: B={B_sz}, D={D_dim}, H={H}, W={W}, L={L}, K={K}, N={N_s}, R={R}")

# ── 4. Helper: symmetric quantization ─────────────────────────────────
def quantize_sym(x, bits):
    if bits >= 32: return x.clone()
    abs_max = x.abs().max()
    if abs_max == 0: return x.clone()
    qmax = 2**(bits - 1) - 1
    scale = abs_max / qmax
    return torch.round(x / scale).clamp(-qmax, qmax) * scale

# ── 5. Compute projected SSM tensors (FP32 or quantized weights) ──────
from models.VSS_module import CrossScan

def compute_ssm_tensors(ss2d, x_conv, w_bits=32):
    B, D, H, W = x_conv.shape
    L, K, N, R = H*W, 4, ss2d.A_logs.shape[1], ss2d.dt_projs_weight.shape[2]

    xpw  = ss2d.x_proj_weight.data.float()
    dtpw = ss2d.dt_projs_weight.data.float()
    dtpb = ss2d.dt_projs_bias.data.float()
    alg  = ss2d.A_logs.data.float()
    ds   = ss2d.Ds.data.float()

    if w_bits < 32:
        xpw  = quantize_sym(xpw, w_bits)
        dtpw = quantize_sym(dtpw, w_bits)
        alg  = quantize_sym(alg, w_bits)

    xs = CrossScan.apply(x_conv).view(B, K, D, L).float()
    x_dbl = torch.einsum("bkdl, kcd -> bkcl", xs, xpw)
    dts_r, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
    dts = torch.einsum("bkrl, kdr -> bkdl", dts_r, dtpw)

    As = -torch.exp(alg)
    db = dtpb.view(-1)

    return (xs.reshape(B, K*D, L), dts.reshape(B, K*D, L),
            As, Bs.contiguous(), Cs.contiguous(), ds, db)

# ── 6. Manual selective scan (records all intermediate states) ─────────
def manual_ssm_dir(u, delta, A, B, C, D_skip, db, k, D):
    B_sz, _, L = u.shape
    N = A.shape[1]
    sl = slice(k*D, (k+1)*D)

    u_k, dt_k = u[:, sl, :], delta[:, sl, :]
    A_k, B_k, C_k = A[sl, :], B[:, k, :, :], C[:, k, :, :]
    D_k, db_k = D_skip[sl], db[sl]

    states = torch.zeros(B_sz, D, L, N, device=u.device)
    y      = torch.zeros(B_sz, D, L, device=u.device)
    s      = torch.zeros(B_sz, D, N, device=u.device)

    for t in range(L):
        dt  = F.softplus(dt_k[:, :, t] + db_k)
        dA  = torch.exp(dt.unsqueeze(-1) * A_k)
        dBu = (dt.unsqueeze(-1)
               * B_k[:, :, t].unsqueeze(1)
               * u_k[:, :, t].unsqueeze(-1))
        s = dA * s + dBu
        states[:, :, t, :] = s
        y[:, :, t] = ((C_k[:, :, t].unsqueeze(1) * s).sum(-1)
                       + D_k * u_k[:, :, t])
    return y, states

# ── 7. Run FP32 baseline ──────────────────────────────────────────────
print("\nRunning FP32 baseline SSM loop (4 dirs x 64 tokens)...")
with torch.no_grad():
    xs_fp, dts_fp, As_fp, Bs_fp, Cs_fp, Ds_fp, db_fp = \
        compute_ssm_tensors(ss2d, x_conv, w_bits=32)

fp32_states, fp32_y = {}, {}
for k in range(K):
    y_k, s_k = manual_ssm_dir(xs_fp, dts_fp, As_fp, Bs_fp, Cs_fp, Ds_fp, db_fp, k, D_dim)
    fp32_states[k] = s_k
    fp32_y[k] = y_k
print("  Done.")

# ── 8. Run quantized versions ─────────────────────────────────────────
bit_configs = [16, 8, 4, 2]
results = {}

for bits in bit_configs:
    print(f"  W{bits}...", end="", flush=True)
    with torch.no_grad():
        xs_q, dts_q, As_q, Bs_q, Cs_q, Ds_q, db_q = \
            compute_ssm_tensors(ss2d, x_conv, w_bits=bits)

    state_errs, output_errs = [], []
    for k in range(K):
        y_q, s_q = manual_ssm_dir(xs_q, dts_q, As_q, Bs_q, Cs_q, Ds_q, db_q, k, D_dim)
        se = (fp32_states[k] - s_q).norm(dim=-1).mean(dim=(0, 1))
        oe = (fp32_y[k] - y_q).abs().mean(dim=(0, 1))
        state_errs.append(se)
        output_errs.append(oe)

    results[bits] = {
        'state_err':  torch.stack(state_errs).mean(0).cpu().numpy(),
        'output_err': torch.stack(output_errs).mean(0).cpu().numpy(),
    }
    se_plat = results[bits]['state_err'][-5:].mean()
    oe_plat = results[bits]['output_err'][-5:].mean()
    print(f" state_plateau={se_plat:.4f}, output_plateau={oe_plat:.4f}")

# ── 9. Contractivity check ────────────────────────────────────────────
with torch.no_grad():
    dt_vals = F.softplus(dts_fp[:, :D_dim, :] + db_fp[:D_dim][None, :, None])
    dA_vals = torch.exp(dt_vals.unsqueeze(-1) * As_fp[:D_dim][None, :, None, :])
    rho_max  = dA_vals.max().item()
    rho_99   = dA_vals.quantile(0.99).item()
    rho_mean = dA_vals.mean().item()

print(f"\n{'='*55}")
print(f" Contractivity Verification  (direction k=0)")
print(f"{'='*55}")
print(f"  max  exp(Delta*A) = {rho_max:.6f}   {'< 1 OK' if rho_max < 1 else '>= 1 WARNING'}")
print(f"  99th exp(Delta*A) = {rho_99:.6f}")
print(f"  mean exp(Delta*A) = {rho_mean:.6f}")

print(f"\n{'='*55}")
print(f" Error Saturation Analysis")
print(f"{'='*55}")
for bits in bit_configs:
    se = results[bits]['state_err']
    plateau = se[-5:].mean()
    thresh = 0.95 * plateau
    t_stable = np.argmax(se >= thresh) if plateau > 1e-10 else 0
    print(f"  W{bits:2d}: plateau = {plateau:.4f} | 95% reached at t = {t_stable}")

# ── 10. Plot ───────────────────────────────────────────────────────────
plt.rcParams.update({
    'font.family': 'serif', 'font.size': 11,
    'mathtext.fontset': 'stix',
    'axes.labelsize': 12, 'axes.titlesize': 13,
})

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))
colors = {16: '#2196F3', 8: '#FF9800', 4: '#E91E63', 2: '#9C27B0'}
tokens = np.arange(L)

for bits in bit_configs:
    ax1.plot(tokens, results[bits]['state_err'], color=colors[bits],
             linewidth=1.8, label=f'W{bits}', alpha=0.85)
    ax2.plot(tokens, results[bits]['output_err'], color=colors[bits],
             linewidth=1.8, label=f'W{bits}', alpha=0.85)

ax1.set_xlabel('Token position $t$')
ax1.set_ylabel(r'$\Vert \mathbf{e}_t \Vert_2 = '
               r'\Vert \mathbf{s}_t^{\mathrm{FP32}} '
               r'- \mathbf{s}_t^{\mathrm{Q}} \Vert_2$')
ax1.set_title('(a) SSM State Error')
ax1.legend(fontsize=10)
ax1.grid(True, ls='--', alpha=0.3)
ax1.text(0.97, 0.05,
         f'$\\rho_{{\\max}}={rho_max:.4f}$\n$\\bar{{\\rho}}={rho_mean:.4f}$',
         transform=ax1.transAxes, ha='right', va='bottom', fontsize=9,
         bbox=dict(boxstyle='round,pad=0.3', fc='wheat', alpha=0.5))

ax2.set_xlabel('Token position $t$')
ax2.set_ylabel(r'$| y_t^{\mathrm{FP32}} - y_t^{\mathrm{Q}} |$')
ax2.set_title('(b) SSM Output Error')
ax2.legend(fontsize=10)
ax2.grid(True, ls='--', alpha=0.3)

fig.suptitle('Lemma 1 Validation: Contractive SSM bounds '
             'quantization error propagation',
             fontsize=12, y=1.02)
fig.tight_layout()

out_png = "results/plots/lemma1_ssm_state_error.png"
out_pdf = "../figures/lemma1_ssm_state_error.pdf"
os.makedirs(os.path.dirname(out_png), exist_ok=True)
os.makedirs(os.path.dirname(out_pdf), exist_ok=True)
fig.savefig(out_png, dpi=300, bbox_inches='tight')
fig.savefig(out_pdf, dpi=300, bbox_inches='tight')
print(f"\nSaved: {out_png}")
print(f"Saved: {out_pdf}")
plt.show()