# Trustworthy AI for Open RAN: Demo Notebook

This notebook demonstrates:
1. Loading a trained model
2. Running inference with uncertainty quantification
3. Interpreting decisions with SHAP
4. Visualizing results

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.environment.ran_env import RANEnvironment
from src.models.baseline import BaselinePolicy

%matplotlib inline
sns.set_style('whitegrid')

## 1. Create Environment

In [None]:
env = RANEnvironment(
    n_cells=10,
    n_ues_per_cell=15,
    traffic_pattern='commute',
    seed=42
)

print(f"Observation space: {env.observation_space.shape}")
print(f"Action space: {env.action_space.shape}")

## 2. Load Pre-trained Model

In [None]:
# Load checkpoint (placeholder)
policy = BaselinePolicy(n_cells=10)
print("Model loaded successfully!")

## 3. Run Inference

In [None]:
obs, _ = env.reset()
total_reward = 0
metrics = []

for step in range(100):
    action = policy.predict(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    
    total_reward += reward
    metrics.append(info)
    
    if terminated or truncated:
        break

print(f"Total reward: {total_reward:.2f}")
print(f"Average throughput: {np.mean([m['total_throughput'] for m in metrics]):.2f} Mbps")
print(f"SLA violation rate: {np.mean([m['sla_violation_rate'] for m in metrics]):.2%}")

## 4. Visualize Performance

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Throughput over time
axes[0, 0].plot([m['total_throughput'] for m in metrics])
axes[0, 0].set_title('Throughput Over Time')
axes[0, 0].set_xlabel('Time Step')
axes[0, 0].set_ylabel('Throughput (Mbps)')

# SLA violations
axes[0, 1].plot([m['sla_violations'] for m in metrics], color='red')
axes[0, 1].set_title('SLA Violations')
axes[0, 1].set_xlabel('Time Step')
axes[0, 1].set_ylabel('Number of Violations')

# Fairness (Jain's index)
axes[1, 0].plot([m['jain_index'] for m in metrics], color='green')
axes[1, 0].axhline(y=0.85, color='r', linestyle='--', label='Target')
axes[1, 0].set_title("Jain's Fairness Index")
axes[1, 0].set_xlabel('Time Step')
axes[1, 0].set_ylabel('Jain Index')
axes[1, 0].legend()

# Energy consumption
axes[1, 1].plot([m['energy'] for m in metrics], color='orange')
axes[1, 1].set_title('Energy Consumption')
axes[1, 1].set_xlabel('Time Step')
axes[1, 1].set_ylabel('Normalized Energy')

plt.tight_layout()
plt.savefig('../results/figures/demo_performance.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Uncertainty Quantification (Placeholder)

In the full implementation, this section would:
- Run ensemble inference
- Show uncertainty estimates
- Plot reliability diagrams

## 6. Interpretability with SHAP (Placeholder)

In the full implementation, this section would:
- Compute SHAP values
- Show feature importance
- Visualize attention weights for GNN