# IT Help Desk Scheduling - ROSIE Supercomputer Edition

This notebook consolidates the entire scheduling system for running on ROSIE.
It includes:
- Extended search parameters for supercomputer-scale optimization
- Parallel execution support for multi-core utilization
- Hyperparameter grid search
- Comprehensive visualization and comparison

## Quick Start
1. Run cells 1-4 to set up the environment and load data
2. Run cell 5 to configure hyperparameters (adjust for your compute budget)
3. Run cells 6-8 for individual algorithms OR cell 9 for parallel comparison
4. Run cell 10 for results visualization

## Cell 1: Imports and Dependencies

In [None]:
# Core imports
import numpy as np
import random
import math
import time
import json
import os
from datetime import datetime, timedelta
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

# Parallel processing
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count, Manager
import threading

# GPU Acceleration with PyTorch
import torch
import torch.nn.functional as F

# MongoDB
from pymongo import MongoClient

# Visualization - Seaborn for better plots
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

# Set Seaborn style for all plots
sns.set_theme(style="whitegrid", palette="husl")
sns.set_context("notebook", font_scale=1.1)

# ============================================================================
# GPU/CUDA SETUP
# ============================================================================
NUM_CORES = cpu_count()
print(f"Available CPU cores: {NUM_CORES}")

# Check for GPU availability
if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
    GPU_NAME = torch.cuda.get_device_name(0)
    GPU_COUNT = torch.cuda.device_count()
    GPU_MEMORY = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"GPU ACCELERATION ENABLED!")
    print(f"  Device: {GPU_NAME}")
    print(f"  GPU Count: {GPU_COUNT}")
    print(f"  Memory: {GPU_MEMORY:.1f} GB")
    USE_GPU = True
else:
    DEVICE = torch.device('cpu')
    print("No GPU available - using CPU (will be slower)")
    USE_GPU = False

print(f"PyTorch device: {DEVICE}")

# Quick GPU benchmark
if USE_GPU:
    a = torch.randn((1000, 1000), device=DEVICE)
    b = torch.randn((1000, 1000), device=DEVICE)
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(100):
        c = torch.matmul(a, b)
    torch.cuda.synchronize()
    t1 = time.time()
    print(f"GPU benchmark: 100x 1000x1000 matmul in {t1-t0:.3f}s")
    del a, b, c
    torch.cuda.empty_cache()

## Cell 2: Core Classes (Worker, ShiftSlot, SchedulingEnvironment)

In [None]:
class Worker:
    """Represents a student worker with their attributes"""
    def __init__(self, worker_id: int, name: str, tier: int, is_commuter: bool, 
                 desired_hours: float, busy_times: List[Tuple[int, int, int]]):
        self.worker_id = worker_id
        self.name = name
        self.tier = tier  # 1-4 (4 = manager, 3 = inventory tech)
        self.is_commuter = is_commuter
        self.desired_hours = desired_hours
        self.busy_times = busy_times
        
    def is_available(self, day: int, hour: int) -> bool:
        """Check if worker is available at given day and hour"""
        if self.is_commuter and hour < 9:
            return False
        for busy_day, busy_start, busy_end in self.busy_times:
            if day == busy_day and busy_start <= hour < busy_end:
                return False
        return True


class ShiftSlot:
    """Represents a time slot that needs coverage"""
    def __init__(self, day: int, hour: int, shift_type: str):
        self.day = day
        self.hour = hour
        self.shift_type = shift_type
        self.assigned_worker = None


class SchedulingEnvironment:
    """Main environment for IT scheduling problem with GPU-accelerated batch evaluation"""
    
    SHIFT_TYPES = ['Window', 'Remote']
    
    HOURS_CONFIG = {
        'finals': {
            0: (7.5, 20), 1: (7.5, 20), 2: (7.5, 20), 3: (7.5, 20),
            4: (7.5, 17), 5: (10, 18),
        },
        'regular': {
            0: (7.5, 20), 1: (7.5, 20), 2: (7.5, 20), 3: (7.5, 20),
            4: (7.5, 17), 5: (10, 18),
        }
    }

    MIN_HOURS_PER_WORKER = 14
    MAX_HOURS_PER_WORKER = 20

    COVERAGE_REQUIREMENTS = {
        'Window': {'min': 2, 'max': 2},
        'Remote': {'min': 2, 'max': 4}
    }
    
    def __init__(self, workers: List[Worker], schedule_type: str = 'finals'):
        self.workers = workers
        self.schedule_type = schedule_type
        self.hours_config = self.HOURS_CONFIG[schedule_type]
        self.shift_slots = self._generate_shift_slots()
        self.num_slots = len(self.shift_slots)
        
        # Pre-compute lookup tables for GPU acceleration
        self._build_gpu_lookup_tables()
        
    def _generate_shift_slots(self) -> List[ShiftSlot]:
        """Generate all shift slots"""
        slots = []
        for day, (start_hour, end_hour) in self.hours_config.items():
            start_int = int(np.ceil(start_hour))
            end_int = int(end_hour)
            for hour in range(start_int, end_int):
                for _ in range(self.COVERAGE_REQUIREMENTS['Window']['max']):
                    slots.append(ShiftSlot(day, hour, 'Window'))
                for _ in range(self.COVERAGE_REQUIREMENTS['Remote']['max']):
                    slots.append(ShiftSlot(day, hour, 'Remote'))
        return slots
    
    def _build_gpu_lookup_tables(self):
        """Pre-compute lookup tables for fast GPU evaluation"""
        # Worker ID to index mapping
        self.worker_id_to_idx = {w.worker_id: i for i, w in enumerate(self.workers)}
        self.idx_to_worker_id = {i: w.worker_id for i, w in enumerate(self.workers)}
        self.num_workers = len(self.workers)
        
        # Slot properties as tensors
        self.slot_days = torch.tensor([s.day for s in self.shift_slots], device=DEVICE, dtype=torch.int32)
        self.slot_hours = torch.tensor([s.hour for s in self.shift_slots], device=DEVICE, dtype=torch.int32)
        self.slot_is_window = torch.tensor([1 if s.shift_type == 'Window' else 0 
                                            for s in self.shift_slots], device=DEVICE, dtype=torch.int32)
        
        # Worker properties as tensors
        self.worker_tiers = torch.tensor([w.tier for w in self.workers], device=DEVICE, dtype=torch.int32)
        self.worker_is_commuter = torch.tensor([1 if w.is_commuter else 0 
                                                 for w in self.workers], device=DEVICE, dtype=torch.int32)
        
        # Availability matrix: (num_workers, num_slots) - 1 if available, 0 if not
        availability = np.zeros((self.num_workers, self.num_slots), dtype=np.float32)
        for w_idx, worker in enumerate(self.workers):
            for s_idx, slot in enumerate(self.shift_slots):
                if worker.is_available(slot.day, slot.hour):
                    availability[w_idx, s_idx] = 1.0
        self.availability_matrix = torch.tensor(availability, device=DEVICE)
        
        # Coverage group indices - which slots belong to same (day, hour) group
        self.coverage_groups = {}
        for i, slot in enumerate(self.shift_slots):
            key = (slot.day, slot.hour)
            if key not in self.coverage_groups:
                self.coverage_groups[key] = {'Window': [], 'Remote': []}
            self.coverage_groups[key][slot.shift_type].append(i)
        
        print(f"GPU lookup tables built: {self.num_slots} slots, {self.num_workers} workers")
    
    def evaluate_schedule(self, schedule: np.ndarray) -> Tuple[float, Dict]:
        """Evaluate a single schedule (CPU version for compatibility)"""
        penalty = 0
        details = {
            'coverage_violations': 0, 'tier_mismatches': 0, 'worker_conflicts': 0,
            'hour_violations': 0, 'min_hour_violations': 0, 
            'morning_shift_violations': 0, 'shift_length_violations': 0
        }

        # Coverage check
        for key, shifts in self.coverage_groups.items():
            window_count = sum(1 for i in shifts['Window'] if schedule[i] != -1)
            remote_count = sum(1 for i in shifts['Remote'] if schedule[i] != -1)
            
            if window_count < 2:
                penalty += 100 * (2 - window_count)
                details['coverage_violations'] += 1
            if remote_count < 2:
                penalty += 100 * (2 - remote_count)
                details['coverage_violations'] += 1

        # Worker stats
        worker_hours = {w.worker_id: 0 for w in self.workers}
        worker_morning = {w.worker_id: 0 for w in self.workers}
        worker_assignments = {w.worker_id: [] for w in self.workers}
        
        for i, worker_id in enumerate(schedule):
            if worker_id == -1:
                continue
            slot = self.shift_slots[i]
            worker_hours[worker_id] += 1
            worker_assignments[worker_id].append(i)
            if slot.hour < 12:
                worker_morning[worker_id] += 1
            
            worker = next((w for w in self.workers if w.worker_id == worker_id), None)
            if worker and not worker.is_available(slot.day, slot.hour):
                penalty += 200
                details['worker_conflicts'] += 1
            if worker and worker.tier >= 3 and slot.shift_type == 'Window':
                penalty += 10
                details['tier_mismatches'] += 1

        # Hour constraints
        for worker in self.workers:
            hours = worker_hours[worker.worker_id]
            if hours < self.MIN_HOURS_PER_WORKER:
                penalty += (self.MIN_HOURS_PER_WORKER - hours) * 75
                details['min_hour_violations'] += 1
            if hours > self.MAX_HOURS_PER_WORKER:
                penalty += (hours - self.MAX_HOURS_PER_WORKER) * 50
                details['hour_violations'] += 1
            if worker_morning[worker.worker_id] > 2:
                penalty += (worker_morning[worker.worker_id] - 2) * 30
                details['morning_shift_violations'] += 1

        # Shift length constraints
        for worker_id, assignments in worker_assignments.items():
            if not assignments:
                continue
            day_hours = {}
            for idx in assignments:
                slot = self.shift_slots[idx]
                if slot.day not in day_hours:
                    day_hours[slot.day] = set()
                day_hours[slot.day].add(slot.hour)
            
            for day, hours in day_hours.items():
                sorted_hours = sorted(hours)
                blocks = [[sorted_hours[0]]]
                for h in sorted_hours[1:]:
                    if h == blocks[-1][-1] + 1:
                        blocks[-1].append(h)
                    else:
                        blocks.append([h])
                for block in blocks:
                    if len(block) < 2:
                        penalty += 500
                        details['shift_length_violations'] += 1
                    elif len(block) > 6:
                        penalty += (len(block) - 6) * 100
                        details['shift_length_violations'] += 1

        return penalty, details
    
    def batch_evaluate_gpu(self, population: List[np.ndarray]) -> torch.Tensor:
        """
        GPU-accelerated batch evaluation of multiple schedules.
        Returns tensor of penalties for each schedule.
        """
        batch_size = len(population)
        
        # Convert population to GPU tensor
        # Shape: (batch_size, num_slots)
        pop_np = np.array(population, dtype=np.int64)
        pop_tensor = torch.tensor(pop_np, device=DEVICE, dtype=torch.int64)
        
        penalties = torch.zeros(batch_size, device=DEVICE)
        
        # 1. Coverage violations (vectorized)
        for key, shifts in self.coverage_groups.items():
            window_indices = torch.tensor(shifts['Window'], device=DEVICE, dtype=torch.int64)
            remote_indices = torch.tensor(shifts['Remote'], device=DEVICE, dtype=torch.int64)
            
            # Count assigned workers per slot type
            window_assigned = pop_tensor[:, window_indices]  # (batch, num_window_slots)
            remote_assigned = pop_tensor[:, remote_indices]
            
            window_count = (window_assigned != -1).sum(dim=1).float()
            remote_count = (remote_assigned != -1).sum(dim=1).float()
            
            # Penalty for under-coverage
            penalties += torch.clamp(2 - window_count, min=0) * 100
            penalties += torch.clamp(2 - remote_count, min=0) * 100
        
        # 2. Worker hour violations (need to count per worker)
        # Create one-hot encoding of assignments
        for b in range(batch_size):
            schedule = pop_tensor[b]
            worker_hours = torch.zeros(self.num_workers, device=DEVICE)
            worker_morning = torch.zeros(self.num_workers, device=DEVICE)
            
            for s_idx in range(self.num_slots):
                worker_id = schedule[s_idx].item()
                if worker_id != -1 and worker_id in self.worker_id_to_idx:
                    w_idx = self.worker_id_to_idx[worker_id]
                    worker_hours[w_idx] += 1
                    if self.slot_hours[s_idx] < 12:
                        worker_morning[w_idx] += 1
                    
                    # Availability violation
                    if self.availability_matrix[w_idx, s_idx] == 0:
                        penalties[b] += 200
                    
                    # Tier mismatch
                    if self.worker_tiers[w_idx] >= 3 and self.slot_is_window[s_idx] == 1:
                        penalties[b] += 10
            
            # Min/max hour violations
            under_hours = torch.clamp(self.MIN_HOURS_PER_WORKER - worker_hours, min=0)
            over_hours = torch.clamp(worker_hours - self.MAX_HOURS_PER_WORKER, min=0)
            penalties[b] += under_hours.sum() * 75
            penalties[b] += over_hours.sum() * 50
            
            # Morning shift violations
            over_morning = torch.clamp(worker_morning - 2, min=0)
            penalties[b] += over_morning.sum() * 30
        
        return penalties
    
    def get_available_workers(self, day: int, hour: int) -> List[int]:
        """Get list of worker IDs available for a given day and hour"""
        return [w.worker_id for w in self.workers if w.is_available(day, hour)]

print("GPU-accelerated SchedulingEnvironment defined!")
print(f"Using device: {DEVICE}")

## Cell 3: MongoDB Data Loader

In [None]:
class MongoDBLoader:
    """Load scheduling data from MongoDB"""
    
    def __init__(self, connection_string: str = "mongodb://localhost:27017/", 
                 database: str = "finals_scheduler"):
        self.client = MongoClient(connection_string)
        self.db = self.client[database]
        self.users_collection = self.db['Users']
        self.finals_collection = self.db['Finals']
    
    def parse_tier(self, position: str) -> int:
        """Convert position string to tier number"""
        tier_map = {'Tier 1': 1, 'Tier 2': 2, 'Tier 3': 3, 'Tier 4': 4}
        return tier_map.get(position, 1)
    
    def get_day_from_date(self, date_str: str) -> tuple:
        """Convert date string to day of week"""
        date = datetime.fromisoformat(date_str.replace('Z', '+00:00'))
        if date.year == 2024:
            date = date.replace(year=2025)
        return date.weekday(), date
    
    def parse_time(self, time_str: str) -> int:
        """Convert time string (HH:MM) to hour integer"""
        hours, minutes = time_str.split(':')
        hour = int(hours)
        if int(minutes) >= 30:
            return hour + 0.5
        return hour
    
    def load_workers(self) -> List[Worker]:
        """Load all active workers from MongoDB"""
        users = list(self.users_collection.find({'isActive': True}))
        workers = []
        
        for user in users:
            user_id = user['userId']
            finals = list(self.finals_collection.find({'userId': str(user_id)}))
            
            busy_times = []
            for final in finals:
                day, date_obj = self.get_day_from_date(final['date'])
                start_hour = self.parse_time(final['startTime'])
                end_hour = self.parse_time(final['endTime'])

                if day == 6:  # Skip Sunday
                    continue
                busy_times.append((day, int(start_hour), int(end_hour)))
            
            worker = Worker(
                worker_id=user_id,
                name=user['name'],
                tier=self.parse_tier(user.get('position', 'Tier 1')),
                is_commuter=user.get('isCommuter', False),
                desired_hours=user.get('desiredHours', 15),
                busy_times=busy_times
            )
            workers.append(worker)
        
        return workers
    
    def print_loaded_data(self, workers: List[Worker]):
        """Print summary of loaded data"""
        day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
        print(f"\n{'='*60}")
        print(f"LOADED {len(workers)} WORKERS FROM MONGODB")
        print(f"{'='*60}")
        
        for worker in workers:
            print(f"\n{worker.name} (ID: {worker.worker_id})")
            print(f"  Tier: {worker.tier}, Commuter: {'Yes' if worker.is_commuter else 'No'}")
            print(f"  Desired Hours: {worker.desired_hours}, Busy Times: {len(worker.busy_times)}")
    
    def close(self):
        """Close MongoDB connection"""
        self.client.close()

print("MongoDB loader defined successfully!")

## Cell 4: Load Data (Local Files, MongoDB, or JSON Export)

**Data Source Options:**

1. **`local_files`** (Default): Loads directly from `Users.json` and `Finals.json` in the Data directory
   - Set `DATA_DIRECTORY` to the path containing your JSON files
   - Default path is `../../Data` (relative to notebook location)

2. **`json_export`**: Loads from a pre-exported `workers_data.json` file
   - Run `export_workers_for_rosie.py` locally to create this file
   - Useful for ROSIE where direct file access may be limited

3. **`mongodb`**: Direct MongoDB connection
   - Requires network access to MongoDB Atlas
   - May have SSL/firewall issues on HPC clusters

In [None]:
# ============================================================================
# CHOOSE DATA SOURCE: "mongodb", "json_export", or "local_files"
# ============================================================================
DATA_SOURCE = "local_files"  # Change to "mongodb" or "json_export" if needed

# MongoDB settings (only used if DATA_SOURCE = "mongodb")
MONGODB_CONNECTION = "mongodb+srv://vamsi123:d32rm2786@cluster1.lnpslid.mongodb.net/"
DATABASE_NAME = "Scheduler"

# JSON export file settings (only used if DATA_SOURCE = "json_export")
# Generate this file by running export_workers_for_rosie.py locally
JSON_EXPORT_FILE = "workers_data.json"

# Local files settings (only used if DATA_SOURCE = "local_files")
# Path to the Data directory containing Users.json and Finals.json
DATA_DIRECTORY = ""

def parse_tier(position: str) -> int:
    """Convert position string to tier number"""
    tier_map = {'Tier 1': 1, 'Tier 2': 2, 'Tier 3': 3, 'Tier 4': 4}
    return tier_map.get(position, 1)

def get_day_from_date(date_str: str) -> tuple:
    """Convert date string to day of week"""
    date = datetime.fromisoformat(date_str.replace('Z', '+00:00'))
    if date.year == 2024:
        date = date.replace(year=2025)
    return date.weekday(), date

def parse_time(time_str: str) -> int:
    """Convert time string (HH:MM) to hour integer"""
    hours, minutes = time_str.split(':')
    hour = int(hours)
    if int(minutes) >= 30:
        return hour + 0.5
    return hour

def load_workers_from_local_files(data_dir: str) -> List[Worker]:
    """Load workers from local Users.json and Finals.json files"""
    users_path = os.path.join(data_dir, 'Users.json')
    finals_path = os.path.join(data_dir, 'Finals.json')
    
    # Load Users.json
    with open(users_path, 'r', encoding='utf-8') as f:
        users = json.load(f)
    
    # Load Finals.json
    with open(finals_path, 'r', encoding='utf-8') as f:
        finals = json.load(f)
    
    # Create a mapping of userId to their finals (busy times)
    finals_by_user = {}
    for final in finals:
        user_id = str(final['userId'])
        if user_id not in finals_by_user:
            finals_by_user[user_id] = []
        finals_by_user[user_id].append(final)
    
    workers = []
    for user in users:
        # Skip inactive users
        if not user.get('isActive', True):
            continue
        
        user_id = user['userId']
        user_finals = finals_by_user.get(str(user_id), [])
        
        busy_times = []
        for final in user_finals:
            day, date_obj = get_day_from_date(final['date'])
            start_hour = parse_time(final['startTime'])
            end_hour = parse_time(final['endTime'])
            
            if day == 6:  # Skip Sunday
                continue
            busy_times.append((day, int(start_hour), int(end_hour)))
        
        worker = Worker(
            worker_id=user_id,
            name=user['name'],
            tier=parse_tier(user.get('position', 'Tier 1')),
            is_commuter=user.get('isCommuter', False),
            desired_hours=user.get('desiredHours', 15),
            busy_times=busy_times
        )
        workers.append(worker)
    
    return workers

# ============================================================================
# LOAD DATA
# ============================================================================

if DATA_SOURCE == "local_files":
    print(f"Loading workers from local files in: {DATA_DIRECTORY}")
    
    workers = load_workers_from_local_files(DATA_DIRECTORY)
    
    print(f"\n{'='*60}")
    print(f"LOADED {len(workers)} ACTIVE WORKERS FROM LOCAL FILES")
    print(f"{'='*60}")
    
    for worker in workers:
        print(f"\n{worker.name} (ID: {worker.worker_id})")
        print(f"  Tier: {worker.tier}, Commuter: {'Yes' if worker.is_commuter else 'No'}")
        print(f"  Desired Hours: {worker.desired_hours}, Busy Times: {len(worker.busy_times)}")

elif DATA_SOURCE == "json_export":
    print(f"Loading workers from exported JSON file: {JSON_EXPORT_FILE}")
    
    with open(JSON_EXPORT_FILE, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    workers = []
    for w in data['workers']:
        worker = Worker(
            worker_id=w['worker_id'],
            name=w['name'],
            tier=w['tier'],
            is_commuter=w['is_commuter'],
            desired_hours=w['desired_hours'],
            busy_times=[tuple(bt) for bt in w['busy_times']]
        )
        workers.append(worker)
    
    print(f"\n{'='*60}")
    print(f"LOADED {len(workers)} WORKERS FROM JSON EXPORT")
    print(f"{'='*60}")
    print(f"Exported at: {data.get('exported_at', 'Unknown')}")
    
    for worker in workers:
        print(f"\n{worker.name} (ID: {worker.worker_id})")
        print(f"  Tier: {worker.tier}, Commuter: {'Yes' if worker.is_commuter else 'No'}")
        print(f"  Desired Hours: {worker.desired_hours}, Busy Times: {len(worker.busy_times)}")

elif DATA_SOURCE == "mongodb":
    print("Connecting to MongoDB...")
    loader = MongoDBLoader(MONGODB_CONNECTION, DATABASE_NAME)
    workers = loader.load_workers()
    loader.print_loaded_data(workers)
    loader.close()

else:
    raise ValueError(f"Invalid DATA_SOURCE: {DATA_SOURCE}. Use 'mongodb', 'json_export', or 'local_files'")

# Create scheduling environment
print(f"\n{'='*60}")
print("CREATING SCHEDULING ENVIRONMENT")
print(f"{'='*60}")
env = SchedulingEnvironment(workers, schedule_type='finals')
print(f"Total shift slots: {env.num_slots}")
print(f"Number of workers: {len(env.workers)}")
print(f"Schedule type: {env.schedule_type}")

## Cell 5: Extended Hyperparameter Configuration for ROSIE

These parameters are significantly extended compared to local execution.
Adjust based on your compute time allocation.

In [None]:
# ============================================================================
# ROSIE HYPERPARAMETER CONFIGURATION - Extended for ~2 Hour Runs
# ============================================================================

COMPUTE_PROFILE = "rosie_2hr"  # Options: "quick", "standard", "rosie_2hr"

PROFILES = {
    "quick": {  # ~15-30 minutes total
        "GA": {
            "population_size": 200,
            "generations": 5000,
            "crossover_rate": 0.85,
            "mutation_rate": 0.35,
            "elitism_count": 10,
            "max_time": 600.0  # 10 min max
        },
        "SA": {
            "initial_temp": 5000.0,
            "final_temp": 0.1,
            "cooling_rate": 0.9995,
            "iterations_per_temp": 150,
            "max_iterations": 100000,
            "max_time": 600.0,  # 10 min max
            "max_reheats": 10
        },
        "CSP": {
            "max_time": 600.0,  # 10 min max
            "local_search_iterations": 100000,
            "num_restarts": 5
        }
    },
    "standard": {  # ~30-60 minutes total
        "GA": {
            "population_size": 400,
            "generations": 15000,
            "crossover_rate": 0.87,
            "mutation_rate": 0.40,
            "elitism_count": 20,
            "max_time": 1800.0  # 30 min max
        },
        "SA": {
            "initial_temp": 8000.0,
            "final_temp": 0.01,
            "cooling_rate": 0.99985,
            "iterations_per_temp": 200,
            "max_iterations": 300000,
            "max_time": 1800.0,  # 30 min max
            "max_reheats": 20
        },
        "CSP": {
            "max_time": 1800.0,  # 30 min max
            "local_search_iterations": 300000,
            "num_restarts": 10
        }
    },
    "rosie_2hr": {  # ~2 hours per algorithm - RECOMMENDED FOR ROSIE
        "GA": {
            "population_size": 800,
            "generations": 50000,          # Large generation count
            "crossover_rate": 0.88,
            "mutation_rate": 0.42,
            "elitism_count": 40,
            "max_time": 7200.0,            # 2 hour hard limit
            "checkpoint_interval": 300     # Checkpoint every 5 min
        },
        "SA": {
            # Properly tuned cooling schedule for 2 hours:
            # With cooling_rate=0.999995 and iterations_per_temp=300,
            # temp goes from 10000 to 0.01 in ~3M iterations
            # At ~2000 iter/sec, that's ~25 min of cooling cycles
            # Plus reheats extends to full 2 hours
            "initial_temp": 10000.0,
            "final_temp": 0.01,
            "cooling_rate": 0.999995,      # Very slow cooling
            "iterations_per_temp": 300,
            "max_iterations": 15000000,    # 15M iterations max
            "max_time": 7200.0,            # 2 hour hard limit
            "max_reheats": 50,             # Allow many reheats
            "checkpoint_interval": 300     # Checkpoint every 5 min
        },
        "CSP": {
            "max_time": 7200.0,            # 2 hour hard limit
            "local_search_iterations": 5000000,  # 5M iterations
            "num_restarts": 30,            # Many restarts for diversity
            "checkpoint_interval": 300     # Checkpoint every 5 min
        }
    }
}

# Select active configuration
GA_CONFIG = PROFILES[COMPUTE_PROFILE]["GA"]
SA_CONFIG = PROFILES[COMPUTE_PROFILE]["SA"]
CSP_CONFIG = PROFILES[COMPUTE_PROFILE]["CSP"]

# Checkpointing settings
CHECKPOINT_DIR = "checkpoints"
ENABLE_CHECKPOINTING = True

# Create checkpoint directory
import os
if ENABLE_CHECKPOINTING and not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

print(f"{'='*70}")
print(f"PROFILE: {COMPUTE_PROFILE.upper()}")
print(f"{'='*70}")
print(f"
Genetic Algorithm:")
print(f"  Population: {GA_CONFIG['population_size']}, Generations: {GA_CONFIG['generations']}")
print(f"  Max time: {GA_CONFIG['max_time']/3600:.1f} hours")
print(f"
Simulated Annealing:")
print(f"  Initial temp: {SA_CONFIG['initial_temp']}, Cooling rate: {SA_CONFIG['cooling_rate']}")
print(f"  Max iterations: {SA_CONFIG['max_iterations']:,}, Max time: {SA_CONFIG['max_time']/3600:.1f} hours")
print(f"  Max reheats: {SA_CONFIG['max_reheats']}")
print(f"
CSP Solver:")
print(f"  Max time: {CSP_CONFIG['max_time']/3600:.1f} hours")
print(f"  Local search iterations: {CSP_CONFIG['local_search_iterations']:,}")
print(f"  Num restarts: {CSP_CONFIG['num_restarts']}")
print(f"
Checkpointing: {'ENABLED' if ENABLE_CHECKPOINTING else 'DISABLED'}")
print(f"
Expected runtime: ~2 hours (parallel) or ~6 hours (sequential)")


## Cell 6: Enhanced Genetic Algorithm

In [None]:
class GeneticAlgorithm:
    """GPU-accelerated Genetic Algorithm with batch fitness evaluation"""

    def __init__(self, environment: SchedulingEnvironment,
                 population_size: int = 500,
                 generations: int = 10000,
                 crossover_rate: float = 0.85,
                 mutation_rate: float = 0.40,
                 elitism_count: int = 25,
                 batch_size: int = 500,
                 adaptive_mutation: bool = True):
        
        self.env = environment
        self.population_size = population_size
        self.generations = generations
        self.crossover_rate = crossover_rate
        self.mutation_rate = mutation_rate
        self.base_mutation_rate = mutation_rate
        self.elitism_count = elitism_count
        self.batch_size = batch_size
        self.adaptive_mutation = adaptive_mutation
        
        self.chromosome_length = self.env.num_slots
        self.worker_ids = [w.worker_id for w in self.env.workers]
        
        self.best_solution = None
        self.best_fitness = float('inf')
        self.fitness_history = []
        self.stagnation_counter = 0
        
    def initialize_population(self) -> List[np.ndarray]:
        """Create diverse initial population"""
        population = []
        min_hours = self.env.MIN_HOURS_PER_WORKER
        min_shift, max_shift = 2, 6

        for pop_idx in range(self.population_size):
            chromosome = np.full(self.chromosome_length, -1, dtype=int)
            worker_hours = {w.worker_id: 0 for w in self.env.workers}

            slot_groups = {}
            for i, slot in enumerate(self.env.shift_slots):
                key = (slot.day, slot.shift_type)
                if key not in slot_groups:
                    slot_groups[key] = []
                slot_groups[key].append((i, slot.hour))

            for key in slot_groups:
                slot_groups[key].sort(key=lambda x: x[1])

            keys = list(slot_groups.keys())
            random.shuffle(keys)

            for key in keys:
                slots = slot_groups[key]
                day, shift_type = key
                i = 0
                while i < len(slots):
                    slot_idx, hour = slots[i]
                    available = self.env.get_available_workers(day, hour)
                    if not available:
                        i += 1
                        continue

                    if pop_idx % 3 == 0:
                        under_min = [w for w in available if worker_hours[w] < min_hours]
                        candidates = under_min if under_min else available
                    elif pop_idx % 3 == 1:
                        candidates = available
                    else:
                        candidates = sorted(available, key=lambda w: worker_hours[w])[:max(1, len(available)//2)]

                    chosen = random.choice(candidates)
                    block_length = random.randint(min_shift, max_shift)

                    assigned = 0
                    for j in range(i, min(i + block_length, len(slots))):
                        next_idx, next_hour = slots[j]
                        if next_hour == hour + (j - i):
                            worker = next((w for w in self.env.workers if w.worker_id == chosen), None)
                            if worker and worker.is_available(day, next_hour):
                                chromosome[next_idx] = chosen
                                assigned += 1
                            else:
                                break
                        else:
                            break

                    if assigned > 0:
                        worker_hours[chosen] += assigned
                    i += max(1, assigned)

            population.append(chromosome)
        return population
    
    def batch_fitness_gpu(self, population: List[np.ndarray]) -> List[float]:
        """GPU-accelerated batch fitness evaluation"""
        if USE_GPU and len(population) >= 10:
            penalties = self.env.batch_evaluate_gpu(population)
            return penalties.cpu().numpy().tolist()
        else:
            return [self.env.evaluate_schedule(ind)[0] for ind in population]
    
    def select_parents(self, population: List[np.ndarray], 
                      fitnesses: List[float]) -> Tuple[np.ndarray, np.ndarray]:
        """Tournament selection"""
        tournament_size = 3 if self.stagnation_counter < 100 else 5
        
        def tournament_select(k):
            indices = random.sample(range(len(population)), k)
            best_idx = min(indices, key=lambda i: fitnesses[i])
            return population[best_idx].copy()
        
        return tournament_select(tournament_size), tournament_select(tournament_size)
    
    def crossover(self, parent1: np.ndarray, parent2: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Multi-point crossover"""
        if random.random() > self.crossover_rate:
            return parent1.copy(), parent2.copy()
        
        num_points = random.randint(2, 4)
        points = sorted(random.sample(range(1, self.chromosome_length), num_points))
        points = [0] + points + [self.chromosome_length]
        
        offspring1 = np.empty(self.chromosome_length, dtype=int)
        offspring2 = np.empty(self.chromosome_length, dtype=int)
        
        for i in range(len(points) - 1):
            start, end = points[i], points[i+1]
            if i % 2 == 0:
                offspring1[start:end] = parent1[start:end]
                offspring2[start:end] = parent2[start:end]
            else:
                offspring1[start:end] = parent2[start:end]
                offspring2[start:end] = parent1[start:end]
        
        return offspring1, offspring2
    
    def mutate(self, chromosome: np.ndarray) -> np.ndarray:
        """Mutation with multiple strategies"""
        if random.random() > self.mutation_rate:
            return chromosome

        min_shift, max_shift = 2, 6
        min_hours = self.env.MIN_HOURS_PER_WORKER
        mutation_types = ['extend_block', 'swap_blocks', 'fill_gap', 'reassign']
        if self.stagnation_counter > 50:
            mutation_types.append('shuffle_day')
        
        mutation_type = random.choice(mutation_types)

        if mutation_type == 'extend_block':
            assigned = [i for i, w in enumerate(chromosome) if w != -1]
            if assigned:
                idx = random.choice(assigned)
                slot = self.env.shift_slots[idx]
                worker_id = chromosome[idx]
                for i, s in enumerate(self.env.shift_slots):
                    if (s.day == slot.day and s.shift_type == slot.shift_type and
                        abs(s.hour - slot.hour) == 1 and chromosome[i] == -1):
                        worker = next((w for w in self.env.workers if w.worker_id == worker_id), None)
                        if worker and worker.is_available(s.day, s.hour):
                            hours = sum(1 for w in chromosome if w == worker_id)
                            if hours < 20:
                                chromosome[i] = worker_id
                                break

        elif mutation_type == 'swap_blocks':
            assigned = [i for i, w in enumerate(chromosome) if w != -1]
            if len(assigned) >= 2:
                idx1, idx2 = random.sample(assigned, 2)
                slot1, slot2 = self.env.shift_slots[idx1], self.env.shift_slots[idx2]
                w1, w2 = chromosome[idx1], chromosome[idx2]
                worker1 = next((w for w in self.env.workers if w.worker_id == w1), None)
                worker2 = next((w for w in self.env.workers if w.worker_id == w2), None)
                if (worker1 and worker2 and
                    worker1.is_available(slot2.day, slot2.hour) and
                    worker2.is_available(slot1.day, slot1.hour)):
                    chromosome[idx1], chromosome[idx2] = w2, w1

        elif mutation_type == 'fill_gap':
            empty = [i for i, w in enumerate(chromosome) if w == -1]
            if empty:
                idx = random.choice(empty)
                slot = self.env.shift_slots[idx]
                available = self.env.get_available_workers(slot.day, slot.hour)
                if available:
                    worker_hours = {w.worker_id: sum(1 for x in chromosome if x == w.worker_id)
                                   for w in self.env.workers}
                    under_min = [w for w in available if worker_hours.get(w, 0) < min_hours]
                    chosen = random.choice(under_min) if under_min else random.choice(available)
                    block_size = random.randint(min_shift, 4)
                    for i, s in enumerate(self.env.shift_slots):
                        if (s.day == slot.day and s.shift_type == slot.shift_type and
                            slot.hour <= s.hour < slot.hour + block_size and chromosome[i] == -1):
                            worker = next((w for w in self.env.workers if w.worker_id == chosen), None)
                            if worker and worker.is_available(s.day, s.hour):
                                hours = sum(1 for w in chromosome if w == chosen)
                                if hours < 20:
                                    chromosome[i] = chosen

        elif mutation_type == 'reassign':
            idx = random.randint(0, self.chromosome_length - 1)
            slot = self.env.shift_slots[idx]
            available = self.env.get_available_workers(slot.day, slot.hour)
            if available:
                current = chromosome[idx]
                if current in available and len(available) > 1:
                    available = [w for w in available if w != current]
                chromosome[idx] = random.choice(available)
                
        elif mutation_type == 'shuffle_day':
            day = random.randint(0, 5)
            day_slots = [(i, slot) for i, slot in enumerate(self.env.shift_slots) if slot.day == day]
            if day_slots:
                assignments = [(i, chromosome[i]) for i, _ in day_slots if chromosome[i] != -1]
                if len(assignments) >= 2:
                    random.shuffle(assignments)
                    orig_slots = [(i, chromosome[i]) for i, _ in day_slots if chromosome[i] != -1]
                    for (orig_idx, _), (_, new_worker) in zip(orig_slots, assignments):
                        slot = self.env.shift_slots[orig_idx]
                        worker = next((w for w in self.env.workers if w.worker_id == new_worker), None)
                        if worker and worker.is_available(slot.day, slot.hour):
                            chromosome[orig_idx] = new_worker

        return chromosome
    
    def repair_chromosome(self, chromosome: np.ndarray) -> np.ndarray:
        """Repair chromosome to fix availability violations"""
        for i, worker_id in enumerate(chromosome):
            if worker_id == -1:
                continue
            slot = self.env.shift_slots[i]
            worker = next((w for w in self.env.workers if w.worker_id == worker_id), None)
            if worker and not worker.is_available(slot.day, slot.hour):
                available = self.env.get_available_workers(slot.day, slot.hour)
                chromosome[i] = random.choice(available) if available else -1
        return chromosome
    
    def solve(self, verbose: bool = True) -> Tuple[np.ndarray, float, List[float]]:
        """Run GA with GPU-accelerated batch fitness evaluation"""
        start_time = time.time()
        population = self.initialize_population()
        
        if verbose:
            print(f"GA initialized: pop={self.population_size}, gens={self.generations}")
            print(f"GPU batch evaluation: {'ENABLED' if USE_GPU else 'DISABLED'}")
        
        for generation in range(self.generations):
            # GPU batch fitness evaluation
            fitnesses = self.batch_fitness_gpu(population)
            
            min_fitness_idx = np.argmin(fitnesses)
            current_best = fitnesses[min_fitness_idx]
            
            if current_best < self.best_fitness:
                self.best_fitness = current_best
                self.best_solution = population[min_fitness_idx].copy()
                self.stagnation_counter = 0
            else:
                self.stagnation_counter += 1
            
            self.fitness_history.append(self.best_fitness)
            
            # Adaptive mutation
            if self.adaptive_mutation:
                if self.stagnation_counter > 100:
                    self.mutation_rate = min(0.8, self.base_mutation_rate * 2)
                elif self.stagnation_counter > 50:
                    self.mutation_rate = min(0.6, self.base_mutation_rate * 1.5)
                else:
                    self.mutation_rate = self.base_mutation_rate
            
            if verbose and generation % 500 == 0:
                elapsed = time.time() - start_time
                gens_per_sec = generation / elapsed if elapsed > 0 else 0
                eta = (self.generations - generation) / gens_per_sec if gens_per_sec > 0 else 0
                print(f"Gen {generation}: Best={self.best_fitness:.1f}, "
                      f"Avg={np.mean(fitnesses):.1f}, Rate={gens_per_sec:.1f} gen/s, "
                      f"ETA={eta/60:.1f}min")
            
            if self.best_fitness == 0:
                if verbose:
                    print(f"Perfect solution at generation {generation}!")
                break
            
            # Create new population
            new_population = []
            elite_indices = np.argsort(fitnesses)[:self.elitism_count]
            for idx in elite_indices:
                new_population.append(population[idx].copy())
            
            while len(new_population) < self.population_size:
                parent1, parent2 = self.select_parents(population, fitnesses)
                offspring1, offspring2 = self.crossover(parent1, parent2)
                offspring1 = self.repair_chromosome(self.mutate(offspring1))
                offspring2 = self.repair_chromosome(self.mutate(offspring2))
                new_population.append(offspring1)
                if len(new_population) < self.population_size:
                    new_population.append(offspring2)
            
            population = new_population
        
        total_time = time.time() - start_time
        if verbose:
            print(f"\nGA completed in {total_time:.1f}s ({total_time/60:.1f} min)")
            print(f"Best fitness: {self.best_fitness:.2f}")
            _, details = self.env.evaluate_schedule(self.best_solution)
            print(f"Violations: {details}")
        
        return self.best_solution, self.best_fitness, self.fitness_history

print("GPU-accelerated Genetic Algorithm defined!")

## Cell 7: Enhanced Simulated Annealing

In [None]:
class SimulatedAnnealing:
    """
    Simulated Annealing with properly tuned cooling schedule and checkpointing.

    Cooling Schedule Math:
    - With cooling_rate=0.999995 and initial_temp=10000, final_temp=0.01:
    - Number of cooling steps: log(0.01/10000) / log(0.999995) ~ 2.76M
    - At iterations_per_temp=300, total iterations from cooling ~ 830M
    - But we use max_iterations and max_time as hard limits
    - Reheats provide escape from local minima
    """

    def __init__(self, environment,
                 initial_temp: float = 10000.0,
                 final_temp: float = 0.01,
                 cooling_rate: float = 0.999995,
                 iterations_per_temp: int = 300,
                 max_iterations: int = 15000000,
                 max_time: float = 7200.0,
                 max_reheats: int = 50,
                 checkpoint_interval: float = 300.0,
                 checkpoint_dir: str = "checkpoints"):

        self.env = environment
        self.initial_temp = initial_temp
        self.final_temp = final_temp
        self.cooling_rate = cooling_rate
        self.iterations_per_temp = iterations_per_temp
        self.max_iterations = max_iterations
        self.max_time = max_time
        self.max_reheats = max_reheats
        self.checkpoint_interval = checkpoint_interval
        self.checkpoint_dir = checkpoint_dir

        self.worker_ids = [w.worker_id for w in self.env.workers]
        self.best_solution = None
        self.best_cost = float('inf')
        self.cost_history = []
        self.checkpoint_history = []  # Store periodic snapshots for graphing

    def generate_initial_solution(self) -> np.ndarray:
        """Generate initial feasible solution"""
        solution = np.full(self.env.num_slots, -1, dtype=int)
        worker_hours = {w.worker_id: 0 for w in self.env.workers}
        min_hours = self.env.MIN_HOURS_PER_WORKER
        min_shift, max_shift = 2, 6

        slot_groups = {}
        for i, slot in enumerate(self.env.shift_slots):
            key = (slot.day, slot.shift_type)
            if key not in slot_groups:
                slot_groups[key] = []
            slot_groups[key].append((i, slot.hour))

        for key in slot_groups:
            slot_groups[key].sort(key=lambda x: x[1])

        for key in slot_groups:
            slots = slot_groups[key]
            day, _ = key
            i = 0
            while i < len(slots):
                slot_idx, hour = slots[i]
                available = self.env.get_available_workers(day, hour)
                if not available:
                    i += 1
                    continue

                available_sorted = sorted(available,
                    key=lambda w: (0 if worker_hours[w] < min_hours else 1, worker_hours[w]))
                chosen = available_sorted[0]
                block_length = random.randint(min_shift, max_shift)

                assigned = 0
                for j in range(i, min(i + block_length, len(slots))):
                    next_idx, next_hour = slots[j]
                    if next_hour == hour + (j - i):
                        worker = next((w for w in self.env.workers if w.worker_id == chosen), None)
                        if worker and worker.is_available(day, next_hour):
                            solution[next_idx] = chosen
                            assigned += 1
                        else:
                            break
                    else:
                        break

                if assigned > 0:
                    worker_hours[chosen] += assigned
                i += max(1, assigned)

        return solution

    def generate_neighbor(self, solution: np.ndarray, temperature: float) -> np.ndarray:
        """Generate neighbor solution with temperature-adaptive moves"""
        neighbor = solution.copy()
        min_shift, max_shift = 2, 6
        temp_ratio = temperature / self.initial_temp

        # More aggressive moves at higher temps, refined moves at lower temps
        if temp_ratio > 0.7:
            strategies = ['swap_blocks', 'extend_block', 'shrink_block', 'reassign_block', 'fill_block', 'shuffle_day']
        elif temp_ratio > 0.3:
            strategies = ['swap_blocks', 'extend_block', 'reassign_block', 'fill_block']
        else:
            strategies = ['swap_blocks', 'extend_block', 'fill_block', 'fine_tune']

        strategy = random.choice(strategies)

        if strategy == 'swap_blocks':
            assigned = [i for i, w in enumerate(solution) if w != -1]
            if len(assigned) >= 2:
                idx1, idx2 = random.sample(assigned, 2)
                slot1, slot2 = self.env.shift_slots[idx1], self.env.shift_slots[idx2]
                w1, w2 = neighbor[idx1], neighbor[idx2]
                worker1 = next((w for w in self.env.workers if w.worker_id == w1), None)
                worker2 = next((w for w in self.env.workers if w.worker_id == w2), None)
                if (worker1 and worker2 and
                    worker1.is_available(slot2.day, slot2.hour) and
                    worker2.is_available(slot1.day, slot1.hour)):
                    neighbor[idx1], neighbor[idx2] = w2, w1

        elif strategy == 'extend_block':
            assigned = [i for i, w in enumerate(solution) if w != -1]
            if assigned:
                idx = random.choice(assigned)
                slot = self.env.shift_slots[idx]
                worker_id = neighbor[idx]
                for i, s in enumerate(self.env.shift_slots):
                    if (s.day == slot.day and s.shift_type == slot.shift_type and
                        abs(s.hour - slot.hour) == 1 and neighbor[i] == -1):
                        worker = next((w for w in self.env.workers if w.worker_id == worker_id), None)
                        if worker and worker.is_available(s.day, s.hour):
                            hours = sum(1 for x in neighbor if x == worker_id)
                            if hours < 20:
                                neighbor[i] = worker_id
                                break

        elif strategy == 'shrink_block':
            assigned = [i for i, w in enumerate(solution) if w != -1]
            if assigned:
                idx = random.choice(assigned)
                slot = self.env.shift_slots[idx]
                worker_id = neighbor[idx]
                block_size = sum(1 for i, s in enumerate(self.env.shift_slots)
                               if s.day == slot.day and s.shift_type == slot.shift_type
                               and neighbor[i] == worker_id)
                if block_size > min_shift:
                    neighbor[idx] = -1

        elif strategy == 'reassign_block':
            idx = random.randint(0, len(solution) - 1)
            slot = self.env.shift_slots[idx]
            available = self.env.get_available_workers(slot.day, slot.hour)
            if available:
                current = neighbor[idx]
                if current in available and len(available) > 1:
                    available = [w for w in available if w != current]
                neighbor[idx] = random.choice(available)
            else:
                neighbor[idx] = -1

        elif strategy == 'fill_block':
            empty = [i for i, w in enumerate(solution) if w == -1]
            if empty:
                idx = random.choice(empty)
                slot = self.env.shift_slots[idx]
                available = self.env.get_available_workers(slot.day, slot.hour)
                if available:
                    worker_hours = {w.worker_id: sum(1 for x in neighbor if x == w.worker_id)
                                   for w in self.env.workers}
                    under_min = [w for w in available if worker_hours.get(w, 0) < 14]
                    chosen = random.choice(under_min) if under_min else random.choice(available)
                    block_length = random.randint(min_shift, max_shift)
                    for i, s in enumerate(self.env.shift_slots):
                        if (s.day == slot.day and s.shift_type == slot.shift_type and
                            slot.hour <= s.hour < slot.hour + block_length and neighbor[i] == -1):
                            worker = next((w for w in self.env.workers if w.worker_id == chosen), None)
                            if worker and worker.is_available(s.day, s.hour):
                                hours = sum(1 for x in neighbor if x == chosen)
                                if hours < 20:
                                    neighbor[i] = chosen

        elif strategy == 'shuffle_day':
            day = random.randint(0, 5)
            day_indices = [i for i, slot in enumerate(self.env.shift_slots)
                          if slot.day == day and neighbor[i] != -1]
            if len(day_indices) >= 2:
                idx1, idx2 = random.sample(day_indices, 2)
                slot1, slot2 = self.env.shift_slots[idx1], self.env.shift_slots[idx2]
                w1, w2 = neighbor[idx1], neighbor[idx2]
                worker1 = next((w for w in self.env.workers if w.worker_id == w1), None)
                worker2 = next((w for w in self.env.workers if w.worker_id == w2), None)
                if (worker1 and worker2 and
                    worker1.is_available(slot2.day, slot2.hour) and
                    worker2.is_available(slot1.day, slot1.hour)):
                    neighbor[idx1], neighbor[idx2] = w2, w1

        elif strategy == 'fine_tune':
            idx = random.randint(0, len(solution) - 1)
            slot = self.env.shift_slots[idx]
            current = neighbor[idx]
            available = self.env.get_available_workers(slot.day, slot.hour)
            if available and current != -1:
                worker_hours = {w.worker_id: sum(1 for x in neighbor if x == w.worker_id)
                               for w in self.env.workers}
                current_hours = worker_hours.get(current, 0)
                better = [w for w in available if abs(worker_hours.get(w, 0) - 17) < abs(current_hours - 17)]
                if better:
                    neighbor[idx] = random.choice(better)

        return neighbor

    def save_checkpoint(self, iteration: int, elapsed: float, temperature: float):
        """Save checkpoint to file"""
        if not hasattr(self, 'checkpoint_dir') or not self.checkpoint_dir:
            return

        checkpoint = {
            'algorithm': 'SA',
            'iteration': iteration,
            'elapsed_seconds': elapsed,
            'elapsed_minutes': elapsed / 60,
            'temperature': temperature,
            'best_cost': self.best_cost,
            'cost_history_length': len(self.cost_history),
            'timestamp': datetime.now().isoformat()
        }

        self.checkpoint_history.append(checkpoint)

        filename = os.path.join(self.checkpoint_dir, f'sa_checkpoint_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json')
        try:
            with open(filename, 'w', encoding='utf-8') as f:
                json.dump({
                    'checkpoint': checkpoint,
                    'best_solution': self.best_solution.tolist() if self.best_solution is not None else None,
                    'cost_history_sample': self.cost_history[::1000] if len(self.cost_history) > 1000 else self.cost_history
                }, f, indent=2)
        except Exception as e:
            print(f"  Warning: Could not save checkpoint: {e}")

    def solve(self, verbose: bool = True) -> Tuple[np.ndarray, float, List[float]]:
        """Run SA with proper cooling schedule and checkpointing"""
        start_time = time.time()
        last_checkpoint_time = start_time

        current_solution = self.generate_initial_solution()
        current_cost = self.env.evaluate_schedule(current_solution)[0]

        self.best_solution = current_solution.copy()
        self.best_cost = current_cost

        temperature = self.initial_temp
        iteration = 0
        iterations_since_improvement = 0
        reheat_count = 0

        if verbose:
            print(f"SA Configuration:")
            print(f"  Initial temp: {self.initial_temp}, Final temp: {self.final_temp}")
            print(f"  Cooling rate: {self.cooling_rate}")
            print(f"  Iterations per temp: {self.iterations_per_temp}")
            print(f"  Max iterations: {self.max_iterations:,}")
            print(f"  Max time: {self.max_time/60:.0f} min ({self.max_time/3600:.1f} hr)")
            print(f"  Max reheats: {self.max_reheats}")
            print(f"  Checkpoint interval: {self.checkpoint_interval}s")
            print(f"\nInitial cost: {current_cost:.2f}")

        while temperature > self.final_temp:
            elapsed = time.time() - start_time

            if iteration >= self.max_iterations:
                if verbose:
                    print(f"\n  STOPPING: Max iterations ({self.max_iterations:,}) reached")
                break
            if elapsed >= self.max_time:
                if verbose:
                    print(f"\n  STOPPING: Max time ({self.max_time/60:.0f} min) reached")
                break
            if reheat_count >= self.max_reheats:
                if verbose:
                    print(f"\n  STOPPING: Max reheats ({self.max_reheats}) reached")
                break

            if elapsed - (last_checkpoint_time - start_time) >= self.checkpoint_interval:
                if ENABLE_CHECKPOINTING:
                    self.save_checkpoint(iteration, elapsed, temperature)
                last_checkpoint_time = time.time()
                if verbose:
                    print(f"  [Checkpoint saved at {elapsed/60:.1f} min]")

            for _ in range(self.iterations_per_temp):
                iteration += 1
                iterations_since_improvement += 1

                if iteration >= self.max_iterations or (time.time() - start_time) >= self.max_time:
                    break

                new_solution = self.generate_neighbor(current_solution, temperature)
                new_cost = self.env.evaluate_schedule(new_solution)[0]

                delta = new_cost - current_cost

                if delta < 0:
                    accept = True
                else:
                    accept_prob = math.exp(-delta / max(temperature, 0.001))
                    accept = random.random() < accept_prob

                if accept:
                    current_solution = new_solution
                    current_cost = new_cost

                    if current_cost < self.best_cost:
                        self.best_solution = current_solution.copy()
                        self.best_cost = current_cost
                        iterations_since_improvement = 0

                if iteration % 100 == 0:
                    self.cost_history.append(self.best_cost)

                if self.best_cost == 0:
                    if verbose:
                        print(f"\nPerfect solution found at iteration {iteration}!")
                    return self.best_solution, self.best_cost, self.cost_history

            if iterations_since_improvement > 50000 and reheat_count < self.max_reheats:
                reheat_count += 1
                old_temp = temperature
                temperature = self.initial_temp * (0.5 ** (reheat_count / 10))
                iterations_since_improvement = 0
                if verbose:
                    print(f"  Reheat #{reheat_count}: {old_temp:.1f} -> {temperature:.1f} at iter {iteration:,}")

            temperature *= self.cooling_rate

            if verbose and iteration % 100000 == 0:
                elapsed = time.time() - start_time
                rate = iteration / elapsed if elapsed > 0 else 0
                print(f"  Iter {iteration:,}: Temp={temperature:.2f}, Best={self.best_cost:.1f}, "
                      f"Time={elapsed/60:.1f}min, Rate={rate:.0f}/s")

        if ENABLE_CHECKPOINTING:
            self.save_checkpoint(iteration, time.time() - start_time, temperature)

        total_time = time.time() - start_time
        if verbose:
            print(f"\n{'='*60}")
            print(f"SA COMPLETED")
            print(f"{'='*60}")
            print(f"Total time: {total_time:.1f}s ({total_time/60:.1f} min)")
            print(f"Total iterations: {iteration:,}")
            print(f"Reheats used: {reheat_count}")
            print(f"Best cost: {self.best_cost:.2f}")
            _, details = self.env.evaluate_schedule(self.best_solution)
            print(f"Violations: {details}")

        return self.best_solution, self.best_cost, self.cost_history

print("Enhanced Simulated Annealing with checkpointing defined!")


## Cell 8: Enhanced CSP Solver

In [None]:
class CSPSolver:
    """Optimized CSP solver with time limits"""

    def __init__(self, environment: SchedulingEnvironment,
                 max_time: float = 1800.0,
                 local_search_iterations: int = 100000,
                 batch_size: int = 200,
                 num_restarts: int = 3):
        
        self.env = environment
        self.max_time = max_time
        self.local_search_iterations = local_search_iterations
        self.batch_size = batch_size
        self.num_restarts = num_restarts

        self.worker_ids = [w.worker_id for w in self.env.workers]
        self.nodes_explored = 0
        self.improvements = 0
        self.start_time = None

        self.best_solution = None
        self.best_penalty = float('inf')
        self.min_shift = 2
        self.max_shift = 6
        self.min_hours = self.env.MIN_HOURS_PER_WORKER

    def _get_worker_hours(self, assignment: np.ndarray) -> Dict[int, int]:
        """Count hours assigned to each worker"""
        worker_hours = {w.worker_id: 0 for w in self.env.workers}
        for worker_id in assignment:
            if worker_id != -1 and worker_id in worker_hours:
                worker_hours[worker_id] += 1
        return worker_hours

    def _build_greedy_solution(self, randomize: bool = False) -> np.ndarray:
        """Build initial solution using greedy construction"""
        solution = np.full(self.env.num_slots, -1, dtype=int)
        worker_hours = {w.worker_id: 0 for w in self.env.workers}

        slot_groups = {}
        for i, slot in enumerate(self.env.shift_slots):
            key = (slot.day, slot.shift_type)
            if key not in slot_groups:
                slot_groups[key] = []
            slot_groups[key].append((i, slot.hour))

        for key in slot_groups:
            slot_groups[key].sort(key=lambda x: x[1])

        keys = list(slot_groups.keys())
        if randomize:
            random.shuffle(keys)

        for (day, shift_type) in keys:
            slots = slot_groups[(day, shift_type)]
            i = 0
            while i < len(slots):
                slot_idx, hour = slots[i]
                available = self.env.get_available_workers(day, hour)
                if not available:
                    i += 1
                    continue

                def worker_score(w_id):
                    hours = worker_hours[w_id]
                    base = 0 if hours < self.min_hours else 1000
                    noise = random.random() * 10 if randomize else 0
                    return base + hours + noise

                available_sorted = sorted(available, key=worker_score)

                for chosen in available_sorted:
                    if worker_hours[chosen] >= 20:
                        continue
                    max_block = min(self.max_shift, 20 - worker_hours[chosen])
                    if max_block < self.min_shift:
                        continue

                    block_length = 0
                    for j in range(i, len(slots)):
                        next_idx, next_hour = slots[j]
                        if next_hour != hour + (j - i):
                            break
                        worker = next((w for w in self.env.workers if w.worker_id == chosen), None)
                        if worker and worker.is_available(day, next_hour):
                            block_length += 1
                            if block_length >= max_block:
                                break
                        else:
                            break

                    if block_length >= self.min_shift:
                        for j in range(block_length):
                            next_idx, _ = slots[i + j]
                            solution[next_idx] = chosen
                        worker_hours[chosen] += block_length
                        i += block_length
                        break
                else:
                    i += 1

        return solution

    def _local_search(self, solution: np.ndarray, iterations: int, verbose: bool = True) -> np.ndarray:
        """Improve solution using local search"""
        current = solution.copy()
        current_penalty = self.env.evaluate_schedule(current)[0]

        best = current.copy()
        best_penalty = current_penalty

        no_improvement_count = 0
        max_no_improvement = 300

        for iteration in range(iterations):
            if time.time() - self.start_time > self.max_time:
                break

            move_type = random.choice(['swap', 'reassign_block', 'extend', 'fill_gap'])
            neighbor = current.copy()

            if move_type == 'swap':
                assigned = [i for i, w in enumerate(current) if w != -1]
                if len(assigned) >= 2:
                    idx1, idx2 = random.sample(assigned, 2)
                    slot1, slot2 = self.env.shift_slots[idx1], self.env.shift_slots[idx2]
                    w1, w2 = neighbor[idx1], neighbor[idx2]
                    worker1 = next((w for w in self.env.workers if w.worker_id == w1), None)
                    worker2 = next((w for w in self.env.workers if w.worker_id == w2), None)
                    if (worker1 and worker2 and
                        worker1.is_available(slot2.day, slot2.hour) and
                        worker2.is_available(slot1.day, slot1.hour)):
                        neighbor[idx1], neighbor[idx2] = w2, w1

            elif move_type == 'reassign_block':
                assigned = [i for i, w in enumerate(current) if w != -1]
                if assigned:
                    idx = random.choice(assigned)
                    slot = self.env.shift_slots[idx]
                    old_worker = neighbor[idx]
                    available = [w for w in self.env.get_available_workers(slot.day, slot.hour) if w != old_worker]
                    if available:
                        new_worker = random.choice(available)
                        for i, s in enumerate(self.env.shift_slots):
                            if (s.day == slot.day and s.shift_type == slot.shift_type and neighbor[i] == old_worker):
                                worker = next((w for w in self.env.workers if w.worker_id == new_worker), None)
                                if worker and worker.is_available(s.day, s.hour):
                                    neighbor[i] = new_worker

            elif move_type == 'extend':
                assigned = [i for i, w in enumerate(current) if w != -1]
                if assigned:
                    idx = random.choice(assigned)
                    slot = self.env.shift_slots[idx]
                    worker_id = neighbor[idx]
                    for i, s in enumerate(self.env.shift_slots):
                        if (s.day == slot.day and s.shift_type == slot.shift_type and
                            abs(s.hour - slot.hour) == 1 and neighbor[i] == -1):
                            worker = next((w for w in self.env.workers if w.worker_id == worker_id), None)
                            if worker and worker.is_available(s.day, s.hour):
                                hours = sum(1 for w in neighbor if w == worker_id)
                                if hours < 20:
                                    neighbor[i] = worker_id
                                    break

            elif move_type == 'fill_gap':
                empty = [i for i, w in enumerate(current) if w == -1]
                if empty:
                    idx = random.choice(empty)
                    slot = self.env.shift_slots[idx]
                    available = self.env.get_available_workers(slot.day, slot.hour)
                    worker_hours = self._get_worker_hours(neighbor)
                    under_min = [w for w in available if worker_hours[w] < self.min_hours]
                    if under_min:
                        chosen = random.choice(under_min)
                    elif available:
                        chosen = random.choice(available)
                    else:
                        continue
                    block_size = random.randint(2, 4)
                    for i, s in enumerate(self.env.shift_slots):
                        if (s.day == slot.day and s.shift_type == slot.shift_type and
                            slot.hour <= s.hour < slot.hour + block_size and neighbor[i] == -1):
                            worker = next((w for w in self.env.workers if w.worker_id == chosen), None)
                            if worker and worker.is_available(s.day, s.hour):
                                hours = sum(1 for w in neighbor if w == chosen)
                                if hours < 20:
                                    neighbor[i] = chosen

            neighbor_penalty = self.env.evaluate_schedule(neighbor)[0]
            self.nodes_explored += 1

            if neighbor_penalty < current_penalty:
                current = neighbor
                current_penalty = neighbor_penalty
                no_improvement_count = 0
                if current_penalty < best_penalty:
                    best = current.copy()
                    best_penalty = current_penalty
                    self.improvements += 1
            else:
                no_improvement_count += 1
                if random.random() < 0.02:
                    current = neighbor
                    current_penalty = neighbor_penalty

            if no_improvement_count > max_no_improvement:
                current = self._build_greedy_solution(randomize=True)
                current_penalty = self.env.evaluate_schedule(current)[0]
                no_improvement_count = 0

        return best

    def solve(self, verbose: bool = True) -> Tuple[Optional[np.ndarray], float, Dict]:
        """Solve with multi-restart strategy"""
        self.start_time = time.time()
        self.nodes_explored = 0
        self.improvements = 0

        if verbose:
            print(f"CSP: max_time={self.max_time/60:.0f}min, iterations={self.local_search_iterations:,}")

        iterations_per_restart = self.local_search_iterations // self.num_restarts

        for restart in range(self.num_restarts):
            if time.time() - self.start_time > self.max_time:
                break
                
            if verbose:
                print(f"  Restart {restart + 1}/{self.num_restarts}")

            initial_solution = self._build_greedy_solution(randomize=(restart > 0))
            initial_penalty = self.env.evaluate_schedule(initial_solution)[0]

            best_solution = self._local_search(initial_solution, iterations_per_restart, verbose)
            best_penalty = self.env.evaluate_schedule(best_solution)[0]

            if best_penalty < self.best_penalty:
                self.best_penalty = best_penalty
                self.best_solution = best_solution.copy()
                if verbose:
                    print(f"    New best: {self.best_penalty:.2f}")

        elapsed_time = time.time() - self.start_time
        stats = {'nodes_explored': self.nodes_explored, 'improvements': self.improvements, 
                 'time': elapsed_time, 'success': True}

        if verbose:
            print(f"\nCSP completed in {elapsed_time:.1f}s ({elapsed_time/60:.1f} min)")
            print(f"Best penalty: {self.best_penalty:.2f}")
            _, details = self.env.evaluate_schedule(self.best_solution)
            print(f"Violations: {details}")

        return self.best_solution, self.best_penalty, stats

print("Optimized CSP Solver defined!")

## Cell 9: Run Algorithms (Choose Individual or Parallel)

In [None]:
# ============================================================================
# PARALLEL EXECUTION WITH GPU ACCELERATION
# ============================================================================
# Runs all algorithms simultaneously - total time = max(GA, SA, CSP) time

from concurrent.futures import ThreadPoolExecutor, as_completed

RUN_GA = True
RUN_SA = True  
RUN_CSP = True

results = {}

# ============================================================================
# ALGORITHM RUNNER FUNCTIONS
# ============================================================================

def run_ga(workers_data, schedule_type, config):
    """Run GA in separate thread"""
    local_env = SchedulingEnvironment(workers_data, schedule_type=schedule_type)
    print(f"\n{'='*60}\nSTARTING GENETIC ALGORITHM\n{'='*60}")
    
    ga = GeneticAlgorithm(
        local_env,
        population_size=config['population_size'],
        generations=config['generations'],
        crossover_rate=config['crossover_rate'],
        mutation_rate=config['mutation_rate'],
        elitism_count=config['elitism_count'],
        batch_size=config['batch_size']
    )
    
    start = time.time()
    solution, fitness, history = ga.solve(verbose=True)
    elapsed = time.time() - start
    
    _, details = local_env.evaluate_schedule(solution)
    print(f"\n*** GA COMPLETE: {fitness:.1f} penalty in {elapsed/60:.1f} min ***")
    
    return {'algo': 'GA', 'solution': solution, 'penalty': fitness, 
            'history': history, 'time': elapsed, 'details': details}

def run_sa(workers_data, schedule_type, config):
    """Run SA in separate thread"""
    local_env = SchedulingEnvironment(workers_data, schedule_type=schedule_type)
    print(f"\n{'='*60}\nSTARTING SIMULATED ANNEALING\n{'='*60}")
    
    sa = SimulatedAnnealing(
        local_env,
        initial_temp=config['initial_temp'],
        final_temp=config['final_temp'],
        cooling_rate=config['cooling_rate'],
        iterations_per_temp=config['iterations_per_temp'],
        batch_eval=config['batch_eval']
    )
    
    start = time.time()
    solution, cost, history = sa.solve(verbose=True)
    elapsed = time.time() - start
    
    _, details = local_env.evaluate_schedule(solution)
    print(f"\n*** SA COMPLETE: {cost:.1f} penalty in {elapsed/60:.1f} min ***")
    
    return {'algo': 'SA', 'solution': solution, 'penalty': cost,
            'history': history, 'time': elapsed, 'details': details}

def run_csp(workers_data, schedule_type, config):
    """Run CSP in separate thread"""
    local_env = SchedulingEnvironment(workers_data, schedule_type=schedule_type)
    print(f"\n{'='*60}\nSTARTING CSP SOLVER\n{'='*60}")
    
    csp = CSPSolver(
        local_env,
        max_time=config['max_time'],
        local_search_iterations=config['local_search_iterations'],
        batch_size=config['batch_size']
    )
    
    start = time.time()
    solution, penalty, stats = csp.solve(verbose=True)
    elapsed = time.time() - start
    
    _, details = local_env.evaluate_schedule(solution)
    print(f"\n*** CSP COMPLETE: {penalty:.1f} penalty in {elapsed/60:.1f} min ***")
    
    return {'algo': 'CSP', 'solution': solution, 'penalty': penalty,
            'history': [], 'time': elapsed, 'details': details}

# ============================================================================
# EXECUTE IN PARALLEL
# ============================================================================

print("="*70)
print(f"PARALLEL EXECUTION - GPU: {'ENABLED' if USE_GPU else 'DISABLED'}")
print(f"Profile: {COMPUTE_PROFILE.upper()}")
print("="*70)

parallel_start = time.time()

tasks = []
if RUN_GA:
    tasks.append(('GA', run_ga, GA_CONFIG))
if RUN_SA:
    tasks.append(('SA', run_sa, SA_CONFIG))
if RUN_CSP:
    tasks.append(('CSP', run_csp, CSP_CONFIG))

print(f"Running {len(tasks)} algorithms in parallel...")

with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
    futures = {executor.submit(func, workers, env.schedule_type, cfg): name 
               for name, func, cfg in tasks}
    
    for future in as_completed(futures):
        algo = futures[future]
        try:
            result = future.result()
            results[result['algo']] = {
                'solution': result['solution'],
                'penalty': result['penalty'],
                'history': result['history'],
                'time': result['time'],
                'details': result['details']
            }
        except Exception as e:
            print(f"ERROR in {algo}: {e}")
            import traceback
            traceback.print_exc()

parallel_time = time.time() - parallel_start

# ============================================================================
# SUMMARY
# ============================================================================
print("\n" + "="*70)
print("EXECUTION SUMMARY")
print("="*70)
print(f"Total wall-clock time: {parallel_time/60:.1f} minutes")

if results:
    seq_time = sum(r['time'] for r in results.values())
    speedup = seq_time / parallel_time if parallel_time > 0 else 1
    print(f"Sequential time would be: {seq_time/60:.1f} minutes")
    print(f"Parallel speedup: {speedup:.2f}x")
    
    print(f"\nResults:")
    for algo in ['GA', 'SA', 'CSP']:
        if algo in results:
            r = results[algo]
            print(f"  {algo}: Penalty={r['penalty']:.1f}, Time={r['time']/60:.1f}min")

## Cell 10: Results Summary and Comparison

In [None]:
print("="*80)
print("ALGORITHM COMPARISON SUMMARY")
print("="*80)

# Summary table
print(f"\n{'Algorithm':<15} {'Penalty':<15} {'Runtime':<15} {'Status'}")
print("-"*60)

best_algo = min(results.keys(), key=lambda a: results[a]['penalty'])

for algo in ['GA', 'SA', 'CSP']:
    if algo in results:
        r = results[algo]
        runtime_str = f"{r['time']:.1f}s" if r['time'] < 60 else f"{r['time']/60:.1f}min"
        status = "<-- BEST" if algo == best_algo else ""
        print(f"{algo:<15} {r['penalty']:<15.2f} {runtime_str:<15} {status}")

# Detailed violations
print(f"\n{'='*80}")
print("CONSTRAINT VIOLATIONS BREAKDOWN")
print("="*80)

constraints = [
    'coverage_violations', 'worker_conflicts', 'hour_violations',
    'min_hour_violations', 'shift_length_violations', 'tier_mismatches',
    'fairness_violations', 'morning_shift_violations'
]

print(f"\n{'Constraint':<30}", end="")
for algo in ['GA', 'SA', 'CSP']:
    if algo in results:
        print(f"{algo:<10}", end="")
print()
print("-"*60)

for constraint in constraints:
    print(f"{constraint:<30}", end="")
    for algo in ['GA', 'SA', 'CSP']:
        if algo in results:
            val = results[algo]['details'].get(constraint, 0)
            print(f"{val:<10}", end="")
    print()

# Best solution details
print(f"\n{'='*80}")
print(f"BEST SOLUTION: {best_algo} (Penalty: {results[best_algo]['penalty']:.2f})")
print("="*80)

## Cell 11: Visualization

In [None]:
# ============================================================================
# PROFESSIONAL VISUALIZATIONS WITH SEABORN
# ============================================================================

import pandas as pd

# Determine best algorithm
best_algo = min(results.keys(), key=lambda a: results[a]['penalty'])

# Create figure with publication-quality settings
fig = plt.figure(figsize=(16, 12))

# Color palette
palette = sns.color_palette("husl", 3)
algo_colors = {'GA': palette[0], 'SA': palette[1], 'CSP': palette[2]}

# ============================================================================
# Plot 1: Algorithm Performance Comparison (Bar Chart)
# ============================================================================
ax1 = fig.add_subplot(2, 2, 1)

algos = list(results.keys())
penalties = [results[a]['penalty'] for a in algos]
colors = [algo_colors[a] for a in algos]

bars = ax1.bar(algos, penalties, color=colors, edgecolor='black', linewidth=1.2)

# Highlight best algorithm
for i, (algo, bar) in enumerate(zip(algos, bars)):
    if algo == best_algo:
        bar.set_edgecolor('gold')
        bar.set_linewidth(3)
        ax1.annotate('BEST', xy=(i, penalties[i]), xytext=(0, 10),
                    textcoords='offset points', ha='center', fontweight='bold',
                    color='darkgreen', fontsize=11)

# Add value labels
for i, (algo, penalty) in enumerate(zip(algos, penalties)):
    ax1.text(i, penalty + max(penalties)*0.02, f'{penalty:.0f}', 
             ha='center', va='bottom', fontweight='bold', fontsize=12)

ax1.set_ylabel('Penalty Score (lower is better)', fontsize=12)
ax1.set_title('Algorithm Performance Comparison', fontsize=14, fontweight='bold')
ax1.set_ylim(0, max(penalties) * 1.15)
sns.despine(ax=ax1)

# ============================================================================
# Plot 2: Runtime Comparison
# ============================================================================
ax2 = fig.add_subplot(2, 2, 2)

runtimes_min = [results[a]['time'] / 60 for a in algos]

bars2 = ax2.barh(algos, runtimes_min, color=colors, edgecolor='black', linewidth=1.2)

for i, (algo, rt) in enumerate(zip(algos, runtimes_min)):
    ax2.text(rt + max(runtimes_min)*0.02, i, f'{rt:.1f} min', 
             va='center', fontweight='bold', fontsize=11)

ax2.set_xlabel('Runtime (minutes)', fontsize=12)
ax2.set_title('Algorithm Runtime Comparison', fontsize=14, fontweight='bold')
ax2.set_xlim(0, max(runtimes_min) * 1.2)
sns.despine(ax=ax2)

# ============================================================================
# Plot 3: Convergence History
# ============================================================================
ax3 = fig.add_subplot(2, 2, 3)

for algo in algos:
    history = results[algo]['history']
    if history:
        # Subsample for cleaner visualization
        if len(history) > 2000:
            step = len(history) // 2000
            history = history[::step]
            x = list(range(0, len(results[algo]['history']), step))
        else:
            x = list(range(len(history)))
        
        ax3.plot(x, history, label=f'{algo} (final: {results[algo]["penalty"]:.0f})', 
                 color=algo_colors[algo], linewidth=2, alpha=0.9)

ax3.set_xlabel('Iteration', fontsize=12)
ax3.set_ylabel('Best Penalty (log scale)', fontsize=12)
ax3.set_title('Optimization Convergence', fontsize=14, fontweight='bold')
ax3.set_yscale('log')
ax3.legend(loc='upper right', frameon=True, fancybox=True, shadow=True)
ax3.grid(True, alpha=0.3)
sns.despine(ax=ax3)

# ============================================================================
# Plot 4: Constraint Violations Heatmap
# ============================================================================
ax4 = fig.add_subplot(2, 2, 4)

constraints = ['coverage_violations', 'worker_conflicts', 'hour_violations',
               'min_hour_violations', 'shift_length_violations', 'tier_mismatches',
               'morning_shift_violations']

# Build data matrix
violation_data = []
for algo in algos:
    row = [results[algo]['details'].get(c, 0) for c in constraints]
    violation_data.append(row)

# Clean constraint names for display
clean_names = [c.replace('_', ' ').title() for c in constraints]

# Create DataFrame
df_violations = pd.DataFrame(violation_data, index=algos, columns=clean_names)

# Create heatmap
sns.heatmap(df_violations, annot=True, fmt='d', cmap='RdYlGn_r', 
            linewidths=0.5, ax=ax4, cbar_kws={'label': 'Violation Count'},
            annot_kws={'size': 11, 'weight': 'bold'})

ax4.set_title('Constraint Violations by Algorithm', fontsize=14, fontweight='bold')
ax4.set_xticklabels(ax4.get_xticklabels(), rotation=45, ha='right', fontsize=10)
ax4.set_yticklabels(ax4.get_yticklabels(), rotation=0, fontsize=11)

plt.tight_layout()
plt.savefig('algorithm_comparison.png', dpi=200, bbox_inches='tight', 
            facecolor='white', edgecolor='none')
plt.show()

print(f"\nVisualization saved to: algorithm_comparison.png")

# ============================================================================
# Additional Summary Statistics
# ============================================================================
print("\n" + "="*70)
print("DETAILED RESULTS SUMMARY")
print("="*70)

print(f"\n{'Algorithm':<10} {'Penalty':<12} {'Runtime':<12} {'Violations':<15} {'Status'}")
print("-"*60)

for algo in algos:
    r = results[algo]
    total_violations = sum(r['details'].values())
    status = "*** BEST ***" if algo == best_algo else ""
    print(f"{algo:<10} {r['penalty']:<12.1f} {r['time']/60:<12.1f}min {total_violations:<15} {status}")

print(f"\nBest Solution: {best_algo} with penalty {results[best_algo]['penalty']:.1f}")

## Cell 12: Export Best Schedule

In [None]:
def export_schedule(solution, env, filename_prefix='best_schedule'):
    """Export schedule to JSON format"""
    day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday']
    
    # Build schedule structure
    schedule_by_day = {day: [] for day in day_names}
    
    for i, worker_id in enumerate(solution):
        if worker_id == -1:
            continue
        
        slot = env.shift_slots[i]
        worker = next((w for w in env.workers if w.worker_id == worker_id), None)
        
        if worker:
            schedule_by_day[day_names[slot.day]].append({
                'hour': slot.hour,
                'shift_type': slot.shift_type,
                'worker_id': worker_id,
                'worker_name': worker.name,
                'tier': worker.tier
            })
    
    # Sort by hour
    for day in schedule_by_day:
        schedule_by_day[day].sort(key=lambda x: (x['hour'], x['shift_type']))
    
    # Calculate worker hours
    worker_hours = {}
    for worker in env.workers:
        hours = sum(1 for w in solution if w == worker.worker_id)
        worker_hours[worker.name] = {
            'assigned': hours,
            'desired': worker.desired_hours,
            'tier': worker.tier
        }
    
    # Get penalty details
    penalty, details = env.evaluate_schedule(solution)
    
    output = {
        'generated': datetime.now().isoformat(),
        'penalty': penalty,
        'violations': details,
        'schedule': schedule_by_day,
        'worker_hours': worker_hours
    }
    
    filename = f"{filename_prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(output, f, indent=2)
    
    print(f"Schedule exported to: {filename}")
    return filename

# Export best solution
best_solution = results[best_algo]['solution']
export_schedule(best_solution, env, f'best_schedule_{best_algo}')

## Cell 13: Display Human-Readable Schedule

In [None]:
def print_schedule(solution, env):
    """Print human-readable schedule"""
    day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday']
    
    # Group by day, hour, shift_type
    schedule = {}
    for i, worker_id in enumerate(solution):
        if worker_id == -1:
            continue
        
        slot = env.shift_slots[i]
        key = (slot.day, slot.hour, slot.shift_type)
        if key not in schedule:
            schedule[key] = []
        
        worker = next((w for w in env.workers if w.worker_id == worker_id), None)
        if worker:
            schedule[key].append(f"{worker.name} (T{worker.tier})")
    
    print("\n" + "="*80)
    print("BEST SCHEDULE")
    print("="*80)
    
    for day_idx, day_name in enumerate(day_names):
        print(f"\n{day_name}:")
        print("-"*70)
        
        day_schedule = {k: v for k, v in schedule.items() if k[0] == day_idx}
        hours = sorted(set(k[1] for k in day_schedule.keys()))
        
        for hour in hours:
            window = schedule.get((day_idx, hour, 'Window'), [])
            remote = schedule.get((day_idx, hour, 'Remote'), [])
            
            print(f"  {hour:02d}:00-{hour+1:02d}:00 | Window: {', '.join(window) if window else '---':<40}")
            print(f"               | Remote: {', '.join(remote) if remote else '---'}")
    
    # Worker summary
    print("\n" + "="*80)
    print("WORKER HOURS SUMMARY")
    print("="*80)
    print(f"{'Worker':<25} {'Assigned':<10} {'Desired':<10} {'Diff':<10}")
    print("-"*55)
    
    for worker in env.workers:
        assigned = sum(1 for w in solution if w == worker.worker_id)
        diff = assigned - worker.desired_hours
        diff_str = f"+{diff:.0f}" if diff > 0 else f"{diff:.0f}"
        print(f"{worker.name:<25} {assigned:<10} {worker.desired_hours:<10.0f} {diff_str:<10}")

print_schedule(best_solution, env)

## Cell 14: Hyperparameter Grid Search (Optional - Extended Run)

Use this for exhaustive hyperparameter tuning. Warning: This can take a very long time!

In [None]:
# ============================================================================
# OPTIONAL: GPU-Accelerated Hyperparameter Grid Search
# ============================================================================
# Set to True only if you want to tune hyperparameters (adds significant time)

RUN_GRID_SEARCH = False

if RUN_GRID_SEARCH:
    print("="*60)
    print("GPU-ACCELERATED HYPERPARAMETER GRID SEARCH")
    print("="*60)
    print(f"GPU: {GPU_NAME if USE_GPU else 'Not available'}")
    
    # Quick grid search - test key parameters
    ga_grid = {
        'population_size': [500, 1000],
        'generations': [5000, 10000],
        'mutation_rate': [0.35, 0.45]
    }
    
    grid_results = []
    total_configs = len(ga_grid['population_size']) * len(ga_grid['generations']) * len(ga_grid['mutation_rate'])
    config_num = 0
    
    grid_start = time.time()
    
    for pop in ga_grid['population_size']:
        for gens in ga_grid['generations']:
            for mut in ga_grid['mutation_rate']:
                config_num += 1
                print(f"\n[{config_num}/{total_configs}] pop={pop}, gens={gens}, mut={mut}")
                
                test_env = SchedulingEnvironment(workers, schedule_type='finals')
                ga = GeneticAlgorithm(
                    test_env,
                    population_size=pop,
                    generations=gens,
                    mutation_rate=mut,
                    elitism_count=max(10, pop//20),
                    batch_size=pop
                )
                
                start = time.time()
                solution, penalty, _ = ga.solve(verbose=False)
                runtime = time.time() - start
                
                grid_results.append({
                    'population_size': pop,
                    'generations': gens,
                    'mutation_rate': mut,
                    'penalty': penalty,
                    'runtime': runtime
                })
                print(f"  -> Penalty: {penalty:.1f}, Time: {runtime:.1f}s")
    
    grid_time = time.time() - grid_start
    
    # Find best configuration
    best_config = min(grid_results, key=lambda x: x['penalty'])
    
    print(f"\n{'='*60}")
    print("GRID SEARCH RESULTS")
    print(f"{'='*60}")
    print(f"Total time: {grid_time/60:.1f} minutes")
    print(f"\nBest Configuration:")
    print(f"  Population Size: {best_config['population_size']}")
    print(f"  Generations: {best_config['generations']}")
    print(f"  Mutation Rate: {best_config['mutation_rate']}")
    print(f"  Best Penalty: {best_config['penalty']:.1f}")
    print(f"  Runtime: {best_config['runtime']:.1f}s")
    
    # Visualize results
    fig, ax = plt.subplots(figsize=(10, 6))
    
    df_grid = pd.DataFrame(grid_results)
    df_grid['config'] = df_grid.apply(
        lambda r: f"p{r['population_size']}_g{r['generations']}_m{r['mutation_rate']}", axis=1)
    
    colors = ['green' if r['penalty'] == best_config['penalty'] else 'steelblue' 
              for _, r in df_grid.iterrows()]
    
    bars = ax.bar(range(len(df_grid)), df_grid['penalty'], color=colors)
    ax.set_xticks(range(len(df_grid)))
    ax.set_xticklabels(df_grid['config'], rotation=45, ha='right')
    ax.set_ylabel('Penalty')
    ax.set_title('Hyperparameter Grid Search Results')
    sns.despine()
    plt.tight_layout()
    plt.savefig('grid_search_results.png', dpi=150)
    plt.show()
    
else:
    print("Grid search disabled. Set RUN_GRID_SEARCH = True to enable.")