# üõ°Ô∏è KAAL Offline RL Training Notebook

## RAKSHAK - Agentic AI Cyber Guardian

This notebook trains the **KAAL (Knowledge-Augmented Autonomous Learner)** Dueling DQN agent for autonomous cyber defense.

### What is KAAL?
KAAL is a reinforcement learning agent that decides defensive actions against cyber threats:
- **MONITOR** - Continue observing the threat
- **DEPLOY_HONEYPOT** - Deploy a decoy to gather intelligence
- **ISOLATE_DEVICE** - Quarantine the compromised device
- **ENGAGE_ATTACKER** - Redirect attacker to honeypot
- **ALERT_USER** - Send notification to user

### Training Approach
- **Offline RL**: Train from stored attack events (no live environment needed)
- **Dueling DQN**: Separates value and advantage for better learning
- **Experience Replay**: Random sampling for stable training

### Output
- `kaal_policy.pth` - Trained model for Jetson deployment

---
**Author**: Team RAKSHAK  
**Runtime**: GPU recommended (T4 or better)

## 1Ô∏è‚É£ Setup & Dependencies

In [None]:
# Install dependencies
!pip install torch numpy matplotlib tqdm -q

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
import random
import json
import os
from tqdm.auto import tqdm

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Mount Google Drive (for saving model)
from google.colab import drive
drive.mount('/content/drive')

# Create output directory
OUTPUT_DIR = '/content/drive/MyDrive/RAKSHAK_Models'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"üìÅ Output directory: {OUTPUT_DIR}")

## 2Ô∏è‚É£ Dueling DQN Architecture

In [None]:
class DuelingDQN(nn.Module):
    """
    Dueling Deep Q-Network architecture.
    
    Separates the network into value and advantage streams:
    Q(s,a) = V(s) + (A(s,a) - mean(A(s,a)))
    
    This helps the agent learn which states are valuable
    without having to learn the effect of each action.
    """
    
    def __init__(self, state_size: int = 10, action_size: int = 5, hidden_size: int = 128):
        super().__init__()
        
        self.state_size = state_size
        self.action_size = action_size
        
        # Shared feature extraction
        self.feature = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        
        # Value stream - estimates V(s)
        self.value_stream = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )
        
        # Advantage stream - estimates A(s, a)
        self.advantage_stream = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, action_size)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.feature(x)
        value = self.value_stream(features)
        advantage = self.advantage_stream(features)
        
        # Combine: Q(s,a) = V(s) + (A(s,a) - mean(A))
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return q_values


@dataclass
class Transition:
    """RL transition tuple."""
    state: np.ndarray
    action: int
    reward: float
    next_state: np.ndarray
    done: bool


class ReplayBuffer:
    """Experience replay buffer."""
    
    def __init__(self, capacity: int = 100000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, transition: Transition):
        self.buffer.append(transition)
    
    def sample(self, batch_size: int) -> List[Transition]:
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))
    
    def __len__(self) -> int:
        return len(self.buffer)


# Test the architecture
model = DuelingDQN().to(device)
test_input = torch.randn(1, 10).to(device)
test_output = model(test_input)
print(f"‚úÖ DuelingDQN created successfully!")
print(f"   Input shape: {test_input.shape}")
print(f"   Output shape: {test_output.shape}")
print(f"   Q-values: {test_output.detach().cpu().numpy().flatten()}")

## 3Ô∏è‚É£ Training Data

Choose one of the options below:
- **Option A**: Generate synthetic training data (for testing)
- **Option B**: Upload real events from Jetson

In [None]:
# =============================================================================
# OPTION A: Generate Synthetic Training Data
# =============================================================================
# FIXED: Using EXCLUSIVE optimal action mapping with strong reward signals

def generate_synthetic_events(num_events: int = 5000) -> List[Dict]:
    """
    Generate synthetic attack events with EXCLUSIVE optimal actions.
    
    Key fix: Each (attack_type, severity) pair maps to exactly ONE optimal action.
    This prevents the model from learning ambiguous preferences.
    """
    # Attack types - each has specific optimal action mappings
    ATTACK_TYPES = ['port_scan', 'brute_force', 'dos_attack', 'malware', 
                    'exploit_attempt', 'data_exfiltration', 'unauthorized_access',
                    'ping_sweep', 'unknown_device']
    SEVERITIES = ['low', 'medium', 'high', 'critical']
    ACTIONS = ['MONITOR', 'DEPLOY_HONEYPOT', 'ISOLATE_DEVICE', 'ENGAGE_ATTACKER', 'ALERT_USER']
    
    # ==========================================================================
    # EXCLUSIVE OPTIMAL ACTION MAPPING
    # ==========================================================================
    # Each (attack_type, severity) -> exactly ONE optimal action
    # This creates clear, unambiguous training signals
    
    EXCLUSIVE_OPTIMAL = {
        # MONITOR scenarios (low severity reconnaissance only)
        ('port_scan', 'low'): 'MONITOR',
        ('ping_sweep', 'low'): 'MONITOR',
        
        # DEPLOY_HONEYPOT scenarios (brute force - credential capture)
        ('brute_force', 'low'): 'DEPLOY_HONEYPOT',
        ('brute_force', 'medium'): 'DEPLOY_HONEYPOT',
        
        # ISOLATE_DEVICE scenarios (critical/high severity threats)
        ('dos_attack', 'critical'): 'ISOLATE_DEVICE',
        ('dos_attack', 'high'): 'ISOLATE_DEVICE',
        ('malware', 'critical'): 'ISOLATE_DEVICE',
        ('malware', 'high'): 'ISOLATE_DEVICE',
        ('exploit_attempt', 'critical'): 'ISOLATE_DEVICE',
        ('exploit_attempt', 'high'): 'ISOLATE_DEVICE',
        
        # ENGAGE_ATTACKER scenarios (medium severity - intel gathering)
        ('dos_attack', 'medium'): 'ENGAGE_ATTACKER',
        ('exploit_attempt', 'medium'): 'ENGAGE_ATTACKER',
        ('malware', 'medium'): 'ENGAGE_ATTACKER',
        
        # ALERT_USER scenarios (user-relevant events)
        ('data_exfiltration', 'low'): 'ALERT_USER',
        ('data_exfiltration', 'medium'): 'ALERT_USER',
        ('unknown_device', 'low'): 'ALERT_USER',
        ('unknown_device', 'medium'): 'ALERT_USER',
        ('unauthorized_access', 'low'): 'ALERT_USER',
        ('unauthorized_access', 'medium'): 'ALERT_USER',
    }
    
    # Only use attack_type + severity combinations that have defined optimal actions
    VALID_COMBINATIONS = list(EXCLUSIVE_OPTIMAL.keys())
    
    events = []
    base_time = datetime.now()
    
    for i in range(num_events):
        # Pick a valid combination
        attack_type, severity = random.choice(VALID_COMBINATIONS)
        optimal_action = EXCLUSIVE_OPTIMAL[(attack_type, severity)]
        
        # Severity and attack encoding
        severity_val = SEVERITIES.index(severity) / 3.0
        attack_val = ATTACK_TYPES.index(attack_type) / (len(ATTACK_TYPES) - 1)
        
        # Generate state vector
        state = [
            attack_val,                           # [0] attack_type
            severity_val,                         # [1] severity
            random.random(),                      # [2] source_port
            random.random() * 0.2,                # [3] target_port
            random.random() * severity_val,       # [4] packets_per_sec
            random.random(),                      # [5] duration
            1.0 if random.random() > 0.7 else 0.0,  # [6] is_known_attacker
            random.random() * 0.5 + severity_val * 0.5,  # [7] device_risk
            random.random(),                      # [8] time_of_day
            random.random() * 0.3 + 0.2          # [9] protocol_risk
        ]
        
        # Choose action: 80% optimal, 20% random (for exploration)
        if random.random() > 0.2:
            action = optimal_action
        else:
            action = random.choice(ACTIONS)
        
        action_id = ACTIONS.index(action)
        is_optimal = (action == optimal_action)
        
        events.append({
            'event_id': f'evt-{i:05d}',
            'timestamp': (base_time - timedelta(minutes=i*5)).isoformat(),
            'source_ip': f'192.168.1.{random.randint(100, 200)}',
            'target_ip': '192.168.1.1',
            'attack_type': attack_type,
            'severity': severity,
            'state_vector': state,
            'action_taken': action,
            'action_id': action_id,
            'optimal_action': optimal_action,
            'outcome_success': is_optimal,
            'metadata': {'is_optimal': is_optimal}
        })
    
    return events

# Generate events
USE_SYNTHETIC = True

if USE_SYNTHETIC:
    events = generate_synthetic_events(15000)  # More data for better learning
    print(f"‚úÖ Generated {len(events)} synthetic events")
    
    # Show distributions
    from collections import Counter
    
    # Action distribution
    action_counts = Counter([e['action_taken'] for e in events])
    print(f"\nüìä Action Distribution (should be balanced):")
    for action in ['MONITOR', 'DEPLOY_HONEYPOT', 'ISOLATE_DEVICE', 'ENGAGE_ATTACKER', 'ALERT_USER']:
        count = action_counts.get(action, 0)
        print(f"   {action}: {count} ({100*count/len(events):.1f}%)")
    
    # Optimal action rate
    optimal_count = sum(1 for e in events if e['metadata']['is_optimal'])
    print(f"\nüìä Optimal Action Rate: {100*optimal_count/len(events):.1f}%")

In [None]:
# =============================================================================
# OPTION B: Upload Real Events from Jetson
# =============================================================================
# Run this cell if you have real events exported from RAKSHAK

UPLOAD_EVENTS = False  # Set to True to upload

if UPLOAD_EVENTS:
    from google.colab import files
    
    print("üì§ Upload your events JSON file(s):")
    uploaded = files.upload()
    
    events = []
    for filename, content in uploaded.items():
        data = json.loads(content.decode('utf-8'))
        if isinstance(data, dict) and 'events' in data:
            events.extend(data['events'])
        elif isinstance(data, list):
            events.extend(data)
        print(f"   Loaded {len(events)} events from {filename}")
    
    print(f"\n‚úÖ Total events loaded: {len(events)}")

## 4Ô∏è‚É£ Reward Computation & Transition Building

In [None]:
# =============================================================================
# EXCLUSIVE OPTIMAL MAPPING (must match Cell 7!)
# =============================================================================
EXCLUSIVE_OPTIMAL = {
    # MONITOR scenarios
    ('port_scan', 'low'): 'MONITOR',
    ('ping_sweep', 'low'): 'MONITOR',
    
    # DEPLOY_HONEYPOT scenarios
    ('brute_force', 'low'): 'DEPLOY_HONEYPOT',
    ('brute_force', 'medium'): 'DEPLOY_HONEYPOT',
    
    # ISOLATE_DEVICE scenarios
    ('dos_attack', 'critical'): 'ISOLATE_DEVICE',
    ('dos_attack', 'high'): 'ISOLATE_DEVICE',
    ('malware', 'critical'): 'ISOLATE_DEVICE',
    ('malware', 'high'): 'ISOLATE_DEVICE',
    ('exploit_attempt', 'critical'): 'ISOLATE_DEVICE',
    ('exploit_attempt', 'high'): 'ISOLATE_DEVICE',
    
    # ENGAGE_ATTACKER scenarios
    ('dos_attack', 'medium'): 'ENGAGE_ATTACKER',
    ('exploit_attempt', 'medium'): 'ENGAGE_ATTACKER',
    ('malware', 'medium'): 'ENGAGE_ATTACKER',
    
    # ALERT_USER scenarios
    ('data_exfiltration', 'low'): 'ALERT_USER',
    ('data_exfiltration', 'medium'): 'ALERT_USER',
    ('unknown_device', 'low'): 'ALERT_USER',
    ('unknown_device', 'medium'): 'ALERT_USER',
    ('unauthorized_access', 'low'): 'ALERT_USER',
    ('unauthorized_access', 'medium'): 'ALERT_USER',
}


def compute_reward(event: Dict) -> float:
    """
    SIMPLIFIED reward function with STRONG differentiation.
    
    +10 for correct action
    -5 for wrong action
    
    This 15-point gap creates clear Q-value separation.
    """
    attack_type = event.get('attack_type', '')
    severity = event.get('severity', 'medium')
    action = event.get('action_taken', 'MONITOR')
    
    key = (attack_type, severity)
    optimal_action = EXCLUSIVE_OPTIMAL.get(key)
    
    if optimal_action is None:
        # Unknown combination - neutral reward
        return 0.0
    
    if action == optimal_action:
        return 10.0  # Strong positive
    else:
        return -5.0  # Strong negative


def build_transitions(events: List[Dict]) -> List[Transition]:
    """Convert events to RL transitions."""
    transitions = []
    
    for i, event in enumerate(events):
        state = event.get('state_vector', [0.0] * 10)
        if len(state) != 10:
            continue
        
        state = np.array(state, dtype=np.float32)
        action = event.get('action_id', 0)
        reward = compute_reward(event)
        
        # Next state (small perturbation for offline RL)
        if i + 1 < len(events):
            next_state = events[i + 1].get('state_vector', [0.0] * 10)
            done = False
        else:
            next_state = [0.0] * 10
            done = True
        
        next_state = np.array(next_state, dtype=np.float32)
        
        transitions.append(Transition(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            done=done
        ))
    
    return transitions


# Build transitions
transitions = build_transitions(events)
print(f"‚úÖ Built {len(transitions)} transitions")

# Analyze rewards
rewards = [t.reward for t in transitions]
print(f"\nüìä Reward Statistics:")
print(f"   Positive (+10): {sum(1 for r in rewards if r > 0)}")
print(f"   Negative (-5):  {sum(1 for r in rewards if r < 0)}")
print(f"   Neutral (0):    {sum(1 for r in rewards if r == 0)}")
print(f"   Mean: {np.mean(rewards):.2f}")

# Action distribution with rewards
action_names = ['MONITOR', 'DEPLOY_HONEYPOT', 'ISOLATE_DEVICE', 'ENGAGE_ATTACKER', 'ALERT_USER']
print(f"\nüìä Average Reward by Action:")
for i, name in enumerate(action_names):
    action_rewards = [t.reward for t in transitions if t.action == i]
    if action_rewards:
        avg = np.mean(action_rewards)
        print(f"   {name}: {avg:+.2f} (n={len(action_rewards)})")

## 5Ô∏è‚É£ Training Loop

In [None]:
# =============================================================================
# Training Configuration
# =============================================================================

CONFIG = {
    'epochs': 200,           # Number of training epochs
    'batch_size': 64,        # Batch size
    'learning_rate': 0.001,  # Learning rate
    'gamma': 0.99,           # Discount factor
    'target_update': 10,     # Update target network every N epochs
    'state_size': 10,        # State vector size
    'action_size': 5,        # Number of actions
    'hidden_size': 128       # Hidden layer size
}

print("‚öôÔ∏è Training Configuration:")
for k, v in CONFIG.items():
    print(f"   {k}: {v}")

In [None]:
def train_kaal(
    transitions: List[Transition],
    config: Dict,
    device: torch.device
) -> tuple:
    """
    Train KAAL Dueling DQN using offline experience replay.
    
    Returns:
        (policy_net, losses, best_loss)
    """
    # Initialize networks
    policy_net = DuelingDQN(
        config['state_size'],
        config['action_size'],
        config['hidden_size']
    ).to(device)
    
    target_net = DuelingDQN(
        config['state_size'],
        config['action_size'],
        config['hidden_size']
    ).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    
    # Optimizer
    optimizer = optim.Adam(policy_net.parameters(), lr=config['learning_rate'])
    
    # Load transitions into replay buffer
    replay_buffer = ReplayBuffer(capacity=len(transitions) + 1000)
    for t in transitions:
        replay_buffer.push(t)
    
    print(f"\nüöÄ Starting Training...")
    print(f"   Transitions: {len(replay_buffer)}")
    print(f"   Epochs: {config['epochs']}")
    print(f"   Batch size: {config['batch_size']}")
    
    losses = []
    best_loss = float('inf')
    
    # Training loop
    progress = tqdm(range(config['epochs']), desc="Training")
    
    for epoch in progress:
        epoch_losses = []
        steps_per_epoch = max(1, len(transitions) // config['batch_size'])
        
        for step in range(steps_per_epoch):
            # Sample batch
            batch = replay_buffer.sample(config['batch_size'])
            
            # Prepare tensors
            states = torch.FloatTensor([t.state for t in batch]).to(device)
            actions = torch.LongTensor([t.action for t in batch]).to(device)
            rewards = torch.FloatTensor([t.reward for t in batch]).to(device)
            next_states = torch.FloatTensor([t.next_state for t in batch]).to(device)
            dones = torch.FloatTensor([float(t.done) for t in batch]).to(device)
            
            # Compute Q(s, a)
            current_q = policy_net(states).gather(1, actions.unsqueeze(1))
            
            # Double DQN: select actions with policy, evaluate with target
            with torch.no_grad():
                next_actions = policy_net(next_states).argmax(1, keepdim=True)
                next_q = target_net(next_states).gather(1, next_actions).squeeze(1)
                target_q = rewards + (1 - dones) * config['gamma'] * next_q
            
            # Huber loss for stability
            loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
            optimizer.step()
            
            epoch_losses.append(loss.item())
        
        # Epoch stats
        avg_loss = np.mean(epoch_losses)
        losses.append(avg_loss)
        
        # Update target network
        if (epoch + 1) % config['target_update'] == 0:
            target_net.load_state_dict(policy_net.state_dict())
        
        # Track best
        if avg_loss < best_loss:
            best_loss = avg_loss
        
        # Update progress bar
        progress.set_postfix({'loss': f'{avg_loss:.4f}', 'best': f'{best_loss:.4f}'})
    
    return policy_net, losses, best_loss


# Train!
policy_net, losses, best_loss = train_kaal(transitions, CONFIG, device)

print(f"\n‚úÖ Training Complete!")
print(f"   Final Loss: {losses[-1]:.4f}")
print(f"   Best Loss: {best_loss:.4f}")

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses, 'b-', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('KAAL Training Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Plot smoothed loss
window = 10
smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
plt.figure(figsize=(10, 4))
plt.plot(smoothed, 'g-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Smoothed Loss')
plt.title(f'KAAL Training Loss (smoothed, window={window})')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6Ô∏è‚É£ Save & Download Model

In [None]:
# Save checkpoint
checkpoint = {
    'policy_state_dict': policy_net.state_dict(),
    'state_size': CONFIG['state_size'],
    'action_size': CONFIG['action_size'],
    'hidden_size': CONFIG['hidden_size'],
    'training_info': {
        'epochs': CONFIG['epochs'],
        'batch_size': CONFIG['batch_size'],
        'learning_rate': CONFIG['learning_rate'],
        'gamma': CONFIG['gamma'],
        'transitions': len(transitions),
        'final_loss': losses[-1],
        'best_loss': best_loss,
        'timestamp': datetime.now().isoformat()
    }
}

# Save to Google Drive
model_path = f"{OUTPUT_DIR}/kaal_policy.pth"
torch.save(checkpoint, model_path)
print(f"‚úÖ Model saved to Google Drive: {model_path}")

# Also save locally for download
local_path = '/content/kaal_policy.pth'
torch.save(checkpoint, local_path)
print(f"‚úÖ Model saved locally: {local_path}")

# Save inference-only version (smaller)
inference_checkpoint = {
    'policy_state_dict': policy_net.state_dict(),
    'state_size': CONFIG['state_size'],
    'action_size': CONFIG['action_size']
}
inference_path = f"{OUTPUT_DIR}/kaal_policy_inference.pth"
torch.save(inference_checkpoint, inference_path)
print(f"‚úÖ Inference model saved: {inference_path}")

In [None]:
# Download the model
from google.colab import files

print("üì• Downloading model...")
files.download(local_path)

## 7Ô∏è‚É£ Test the Model

In [None]:
# =============================================================================
# Test the trained model
# =============================================================================
ACTIONS = ['MONITOR', 'DEPLOY_HONEYPOT', 'ISOLATE_DEVICE', 'ENGAGE_ATTACKER', 'ALERT_USER']

# Attack type encoding (index / 8 for 9 attack types)
# port_scan=0/8=0.0, brute_force=1/8=0.125, dos_attack=2/8=0.25, malware=3/8=0.375
# exploit=4/8=0.5, data_exfil=5/8=0.625, unauthorized=6/8=0.75
# ping_sweep=7/8=0.875, unknown_device=8/8=1.0

# Severity encoding: low=0/3=0.0, medium=1/3=0.33, high=2/3=0.67, critical=3/3=1.0

def test_model(model, state_vector, device):
    """Test model on a state vector."""
    model.eval()
    with torch.no_grad():
        state = torch.FloatTensor(state_vector).unsqueeze(0).to(device)
        q_values = model(state).cpu().numpy().flatten()
        action = np.argmax(q_values)
    return action, q_values


# =============================================================================
# TEST SCENARIOS - Aligned with EXCLUSIVE_OPTIMAL mapping
# =============================================================================
test_scenarios = [
    # =========================================================================
    # MONITOR scenarios: (port_scan, low), (ping_sweep, low)
    # =========================================================================
    {
        'name': 'üü¢ Port Scan (Low)',
        'state': [0.0, 0.0, 0.3, 0.02, 0.0, 0.1, 0.0, 0.1, 0.5, 0.3],  # port_scan=0.0, low=0.0
        'expected': 'MONITOR',
        'attack_type': 'port_scan',
        'severity': 'low'
    },
    {
        'name': 'üü¢ Ping Sweep (Low)',
        'state': [0.875, 0.0, 0.1, 0.01, 0.0, 0.05, 0.0, 0.05, 0.6, 0.1],  # ping_sweep=0.875, low=0.0
        'expected': 'MONITOR',
        'attack_type': 'ping_sweep',
        'severity': 'low'
    },

    # =========================================================================
    # DEPLOY_HONEYPOT scenarios: (brute_force, low), (brute_force, medium)
    # =========================================================================
    {
        'name': 'üü° Brute Force (Low)',
        'state': [0.125, 0.0, 0.2, 0.01, 0.0, 0.1, 0.0, 0.1, 0.8, 0.2],  # brute_force=0.125, low=0.0
        'expected': 'DEPLOY_HONEYPOT',
        'attack_type': 'brute_force',
        'severity': 'low'
    },
    {
        'name': 'üü° Brute Force (Medium)',
        'state': [0.125, 0.33, 0.3, 0.02, 0.1, 0.2, 0.0, 0.3, 0.5, 0.3],  # brute_force=0.125, medium=0.33
        'expected': 'DEPLOY_HONEYPOT',
        'attack_type': 'brute_force',
        'severity': 'medium'
    },

    # =========================================================================
    # ISOLATE_DEVICE scenarios: (dos/malware/exploit, critical/high)
    # =========================================================================
    {
        'name': 'üî¥ DoS Attack (Critical)',
        'state': [0.25, 1.0, 0.5, 0.01, 0.9, 0.8, 1.0, 0.9, 0.3, 0.5],  # dos=0.25, critical=1.0
        'expected': 'ISOLATE_DEVICE',
        'attack_type': 'dos_attack',
        'severity': 'critical'
    },
    {
        'name': 'üî¥ Malware (Critical)',
        'state': [0.375, 1.0, 0.2, 0.05, 0.3, 0.5, 1.0, 0.8, 0.7, 0.6],  # malware=0.375, critical=1.0
        'expected': 'ISOLATE_DEVICE',
        'attack_type': 'malware',
        'severity': 'critical'
    },
    {
        'name': 'üü† Exploit (High)',
        'state': [0.5, 0.67, 0.4, 0.03, 0.5, 0.6, 1.0, 0.7, 0.2, 0.5],  # exploit=0.5, high=0.67
        'expected': 'ISOLATE_DEVICE',
        'attack_type': 'exploit_attempt',
        'severity': 'high'
    },

    # =========================================================================
    # ENGAGE_ATTACKER scenarios: (dos/exploit/malware, medium)
    # =========================================================================
    {
        'name': 'üü° DoS Probe (Medium)',
        'state': [0.25, 0.33, 0.7, 0.02, 0.3, 0.3, 0.0, 0.4, 0.1, 0.4],  # dos=0.25, medium=0.33
        'expected': 'ENGAGE_ATTACKER',
        'attack_type': 'dos_attack',
        'severity': 'medium'
    },
    {
        'name': 'üü° Exploit Attempt (Medium)',
        'state': [0.5, 0.33, 0.6, 0.04, 0.3, 0.4, 0.0, 0.5, 0.3, 0.5],  # exploit=0.5, medium=0.33
        'expected': 'ENGAGE_ATTACKER',
        'attack_type': 'exploit_attempt',
        'severity': 'medium'
    },

    # =========================================================================
    # ALERT_USER scenarios: (data_exfil/unknown/unauthorized, low/medium)
    # =========================================================================
    {
        'name': 'üîî Data Exfiltration (Medium)',
        'state': [0.625, 0.33, 0.7, 0.06, 0.2, 0.4, 0.0, 0.5, 0.4, 0.3],  # data_exfil=0.625, medium=0.33
        'expected': 'ALERT_USER',
        'attack_type': 'data_exfiltration',
        'severity': 'medium'
    },
    {
        'name': 'üîî Unknown Device (Medium)',
        'state': [1.0, 0.33, 0.5, 0.03, 0.1, 0.2, 0.0, 0.6, 0.9, 0.4],  # unknown=1.0, medium=0.33
        'expected': 'ALERT_USER',
        'attack_type': 'unknown_device',
        'severity': 'medium'
    },
    {
        'name': 'üîî Unauthorized Access (Low)',
        'state': [0.75, 0.0, 0.4, 0.02, 0.0, 0.2, 0.0, 0.4, 0.7, 0.3],  # unauthorized=0.75, low=0.0
        'expected': 'ALERT_USER',
        'attack_type': 'unauthorized_access',
        'severity': 'low'
    }
]

# =============================================================================
# Run Tests
# =============================================================================
print("üß™ KAAL Model Test Suite (Aligned with EXCLUSIVE_OPTIMAL)")
print("=" * 70)
print(f"Testing {len(test_scenarios)} scenarios\n")

results = {'pass': 0, 'fail': 0}
action_results = {a: {'pass': 0, 'fail': 0} for a in ACTIONS}

for scenario in test_scenarios:
    action_id, q_values = test_model(policy_net, scenario['state'], device)
    predicted = ACTIONS[action_id]
    expected = scenario['expected']
    passed = predicted == expected
    
    status = "‚úÖ PASS" if passed else "‚ùå FAIL"
    results['pass' if passed else 'fail'] += 1
    action_results[expected]['pass' if passed else 'fail'] += 1
    
    print(f"\n{scenario['name']}")
    print(f"   ({scenario['attack_type']}, {scenario['severity']})")
    print(f"   Expected: {expected}")
    print(f"   Got:      {predicted} {status}")
    print(f"   Q-values:")
    for i, (name, q) in enumerate(zip(ACTIONS, q_values)):
        marker = " ‚Üê CHOSEN" if i == action_id else ""
        exp_marker = " (expected)" if name == expected else ""
        print(f"      {name:20s}: {q:8.3f}{marker}{exp_marker}")

# =============================================================================
# Summary
# =============================================================================
print("\n" + "=" * 70)
print("üìä TEST SUMMARY")
print("=" * 70)
total = results['pass'] + results['fail']
print(f"\nOverall: {results['pass']}/{total} passed ({100*results['pass']/total:.1f}%)")

print("\nBy Action:")
for action in ACTIONS:
    p = action_results[action]['pass']
    f = action_results[action]['fail']
    total_action = p + f
    if total_action > 0:
        pct = 100 * p / total_action
        status = "‚úÖ" if pct == 100 else "‚ö†Ô∏è" if pct >= 50 else "‚ùå"
        print(f"   {status} {action:20s}: {p}/{total_action} ({pct:.0f}%)")

print("\n" + "=" * 70)

## üìã Deployment Instructions

### Copy to Jetson
```bash
scp kaal_policy.pth user@jetson-ip:~/e-raksha/models/
```

### Verify on Jetson
```python
from core.agentic_defender import AgenticDefender
import yaml

with open('config/config.yaml') as f:
    config = yaml.safe_load(f)

# Update config to use new model
config['agent']['model_path'] = 'models/kaal_policy.pth'

agent = AgenticDefender(config)
print(f'Model loaded: {agent.model_loaded}')
print(f'Mode: {agent.get_statistics()["mode"]}')
```

### Run RAKSHAK
```bash
sudo python main.py
```