In [9]:
# -----------------------------------------------------------------------------
# "Thinking in balance: Framework for cognitive load distribution in human-AI collaborative teams" 
# 
# Cognitive Load Balancing Simulation (backend only)
#
# -----------------------------------------------------------------------------

import time
import itertools
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List, Optional, Dict

# ------------------------------- Configuration -------------------------------

RANDOM_SEED = 42
MAX_LOAD = 10.0                 # Soft cap for per-user cognitive load
TASK_COUNT = 20                 # How many tasks to simulate
ASSIGN_BATCH_SIZE = 5           # Assign tasks in batches, then rebalance
COMPLETE_PER_USER_PER_ROUND = 1 # How many tasks each user completes per round
PRINT_EVENTS = True             # Toggle logging

# Task types and their labels (BDI styles)
TASK_TYPES = ["analytical", "sequential", "interpersonal", "imaginative"]

# Base duration multiplier (for proxying latency); higher => tasks take longer
BASE_TIME_PER_LOAD = 1.0

# Error probabilities by current load band (behavioral proxy)
ERROR_PROB_BANDS = [
    (8.0, 0.30),  # load > 8.0 => 30% error chance
    (5.0, 0.10),  # load > 5.0 => 10% error chance
    (0.0, 0.02),  # otherwise => 2% error chance
]

# ------------------------------- Data Models ---------------------------------

@dataclass
class Task:
    id: int
    task_type: str                  # "analytical" | "sequential" | "interpersonal" | "imaginative"
    intrinsic_load: float           # baseline cognitive effort
    assigned_to: Optional["User"] = None
    status: str = "pending"         # "pending" | "assigned" | "completed"
    # for simple switching metric (how many times we move/reassign the task)
    reassignments: int = 0

    def __repr__(self):
        return f"Task#{self.id}({self.task_type}, L={self.intrinsic_load:.2f}, {self.status})"


@dataclass
class User:
    name: str
    style: str                      # 'A' | 'B' | 'C' | 'D'
    current_load: float = 0.0
    # Skill factor lowers effective load for preferred task types (good fit < 1.0)
    skill_factor: Dict[str, float] = field(default_factory=dict)
    # Active tasks assigned (not completed yet)
    active_tasks: List[Task] = field(default_factory=list)
    # For simple "task switching" proxy: how many distinct task types in a row
    recent_types: List[str] = field(default_factory=list)

    def __post_init__(self):
        # Default skill mapping if not provided
        if not self.skill_factor:
            self.skill_factor = {
                "analytical":   0.8 if self.style == 'A' else 1.0,
                "sequential":   0.8 if self.style == 'B' else 1.0,
                "interpersonal":0.8 if self.style == 'C' else 1.0,
                "imaginative":     0.8 if self.style == 'D' else 1.0,
            }

    def effective_load(self, task: Task) -> float:
        return task.intrinsic_load * self.skill_factor.get(task.task_type, 1.0)

# ------------------------------- Agent Layer ---------------------------------

class ManagerAI:
    """Per-user manager that reports local state and simple heuristics."""
    def __init__(self, user: User):
        self.user = user

    def report_load(self) -> float:
        return self.user.current_load

    def can_take_task(self, task: Task) -> bool:
        return (self.user.current_load + self.user.effective_load(task)) <= MAX_LOAD

    def accept_task(self, task: Task):
        self.user.active_tasks.append(task)
        self.user.current_load += self.user.effective_load(task)
        # Update simple "switching" proxy: track a short window of recent types
        self.user.recent_types.append(task.task_type)
        if len(self.user.recent_types) > 5:
            self.user.recent_types.pop(0)

    def release_task(self, task: Task):
        if task in self.user.active_tasks:
            self.user.active_tasks.remove(task)
            self.user.current_load -= self.user.effective_load(task)
            if self.user.current_load < 0:
                self.user.current_load = 0.0

class CoordinatorAI:
    """Global coordinator that assigns, monitors, and rebalances tasks."""
    def __init__(self, users: List[User]):
        self.users = users
        self.managers: Dict[str, ManagerAI] = {u.name: ManagerAI(u) for u in users}

    # -------------------------- Assignment & Scoring --------------------------

    def find_best_user_for(self, task: Task) -> Optional[User]:
        """Choose the user with the lowest projected load after assignment,
        favoring better style-task fit (lower effective load)."""
        best_user = None
        best_score = float("inf")
        for u in self.users:
            projected = u.current_load + u.effective_load(task)
            if projected <= MAX_LOAD:
                # score = projected load; can add tie-break on skill factor
                if projected < best_score:
                    best_score = projected
                    best_user = u
        return best_user

    def assign_task(self, task: Task, user: User):
        if task.assigned_to is user:
            return
        if task.assigned_to is not None:
            # remove from previous owner if any (reassignment)
            prev = task.assigned_to
            self.managers[prev.name].release_task(task)
            task.reassignments += 1
            if PRINT_EVENTS:
                print(f"↪ Reassigned Task#{task.id} from {prev.name} to {user.name}")

        self.managers[user.name].accept_task(task)
        task.assigned_to = user
        task.status = "assigned"
        if PRINT_EVENTS:
            eff = user.effective_load(task)
            print(f"✓ Assigned Task#{task.id} ({task.task_type}, L={task.intrinsic_load:.2f}, eff={eff:.2f}) "
                  f"to {user.name}. Load={user.current_load:.2f}")

    # ---------------------------- Monitoring ---------------------------------

    def overloaded_users(self) -> List[User]:
        return [u for u in self.users if u.current_load > MAX_LOAD]

    def high_switching_users(self, threshold_unique_types: int = 3) -> List[User]:
        """Proxy: If last few tasks include many distinct types, flag frequent switching."""
        flagged = []
        for u in self.users:
            if len(set(u.recent_types)) >= threshold_unique_types and len(u.recent_types) >= 4:
                flagged.append(u)
        return flagged

    # ----------------------------- Rebalancing --------------------------------

    def monitor_and_rebalance(self, tasks: List[Task]):
        # 1) Handle hard overloads
        for u in self.overloaded_users():
            if PRINT_EVENTS:
                print(f"⚠ {u.name} OVERLOADED (Load={u.current_load:.2f}). Attempting to offload a task...")
            # Move the most recent or heaviest task first
            if not u.active_tasks:
                continue
            task_to_move = max(u.active_tasks, key=lambda t: u.effective_load(t))
            new_owner = self.find_best_user_for(task_to_move)
            if new_owner and new_owner is not u:
                self.assign_task(task_to_move, new_owner)

        # 2) Handle high task switching (reduce extraneous load)
        for u in self.high_switching_users():
            if PRINT_EVENTS:
                print(f"✱ {u.name} shows high task switching ({u.recent_types}). "
                      f"Trying to consolidate similar tasks elsewhere.")
            # Pick one task type to keep; move the others if possible
            if not u.active_tasks:
                continue
            keep_type = u.active_tasks[-1].task_type
            for t in list(u.active_tasks):
                if t.task_type != keep_type:
                    candidate = self.find_best_user_for(t)
                    if candidate and candidate is not u:
                        self.assign_task(t, candidate)

    # ----------------------------- Completion ---------------------------------

    def complete_some_tasks(self) -> Dict[str, List[Task]]:
        """Each user completes up to N tasks, with latency & error proxies.
        Returns dict of completed tasks per user."""
        completed: Dict[str, List[Task]] = defaultdict(list)
        for u in self.users:
            # choose tasks to complete: favor earliest assigned (FIFO)
            to_complete = u.active_tasks[:COMPLETE_PER_USER_PER_ROUND]
            for t in to_complete:
                before_load = u.current_load
                # latency proxy ~ effective load * base; add small randomness
                eff = u.effective_load(t)
                latency = eff * BASE_TIME_PER_LOAD * random.uniform(0.8, 1.2)

                # error probability depends on current load
                error_prob = next((p for th, p in ERROR_PROB_BANDS if before_load > th), 0.02)
                did_error = random.random() < error_prob

                # Complete task: reduce load, mark done
                self.managers[u.name].release_task(t)
                t.assigned_to = None
                t.status = "completed"
                completed[u.name].append(t)

                if PRINT_EVENTS:
                    outcome = "ERROR" if did_error else "SUCCESS"
                    print(f"🏁 {u.name} completed Task#{t.id} [{t.task_type}] "
                          f"in {latency:.1f}s with {outcome}. New Load={u.current_load:.2f}")
        return completed

# ------------------------------- Simulation ----------------------------------

class Simulation:
    def __init__(self, users: List[User], tasks: List[Task]):
        self.users = users
        self.tasks = tasks
        self.coord = CoordinatorAI(users)

    def run(self, assign_batch: int = ASSIGN_BATCH_SIZE):
        # Assign in batches to let load accumulate, then rebalance and complete
        pending = [t for t in self.tasks if t.status == "pending"]
        batch = []
        while pending or any(u.active_tasks for u in self.users):
            # Fill a batch
            while pending and len(batch) < assign_batch:
                batch.append(pending.pop(0))
                time.sleep(0.8)   # pause for half a second To be able to see 
            # Assign current batch
            for t in batch:
                best = self.coord.find_best_user_for(t)
                if best is None:
                    # If no one has capacity, assign to the least-loaded anyway (soft overflow)
                    best = min(self.users, key=lambda u: u.current_load)
                    if PRINT_EVENTS:
                        print(f"… No capacity. Forcing assignment of Task#{t.id} to least-loaded {best.name}.")
                self.coord.assign_task(t, best)

            # Monitor & rebalance after the batch
            self.coord.monitor_and_rebalance(self.tasks)

            # Complete some tasks across users, then monitor again
            self.coord.complete_some_tasks()
            self.coord.monitor_and_rebalance(self.tasks)

            # Reset batch and refresh pending list
            batch = []
            pending = [t for t in self.tasks if t.status == "pending"]

        if PRINT_EVENTS:
            print("\n✅ Simulation finished. All tasks completed.\n")
            self.report_summary()

    def report_summary(self):
        reassign_total = sum(t.reassignments for t in self.tasks)
        print(f"Tasks completed: {len([t for t in self.tasks if t.status=='completed'])}/{len(self.tasks)}")
        print(f"Total reassignments (load balancing actions): {reassign_total}")
        # Simple distribution report by type and by user
        by_type = defaultdict(int)
        for t in self.tasks:
            by_type[t.task_type] += 1
        print("Tasks by type:", dict(by_type))
        # Note: if you want per-user completions, track in complete_some_tasks and aggregate here.

# ------------------------------- Utilities -----------------------------------

def make_users() -> List[User]:
    # Example team with balanced BDI profiles
    return [
        User("Alice", style='A'),   # Analytical
        User("Bob",   style='B'),   # Sequential
        User("Cara",  style='C'),   # Interpersonal
        User("Dan",   style='D'),   # imaginative
    ]

def make_tasks(n: int) -> List[Task]:
    rng = random.Random(RANDOM_SEED + 1)
    counter = itertools.count()
    tasks = []
    for _ in range(n):
        ttype = rng.choice(TASK_TYPES)
        intrinsic = rng.uniform(1.0, 5.0)  # intrinsic complexity between 1 and 5
        tasks.append(Task(id=next(counter), task_type=ttype, intrinsic_load=intrinsic))
    return tasks

# --------------------------------- Main --------------------------------------

if __name__ == "__main__":
    random.seed(RANDOM_SEED)

    users = make_users()
    tasks = make_tasks(TASK_COUNT)

    print("=== Cognitive Load Balancing Simulation ===")
    print(f"Users: {[f'{u.name}:{u.style}' for u in users]}")
    print(f"Generating {len(tasks)} tasks…\n")

    sim = Simulation(users, tasks)
    sim.run(assign_batch=ASSIGN_BATCH_SIZE)

=== Cognitive Load Balancing Simulation ===
Users: ['Alice:A', 'Bob:B', 'Cara:C', 'Dan:D']
Generating 20 tasks…

✓ Assigned Task#0 (analytical, L=2.14, eff=1.72) to Alice. Load=1.72
✓ Assigned Task#1 (sequential, L=4.84, eff=3.87) to Bob. Load=3.87
✓ Assigned Task#2 (interpersonal, L=3.69, eff=2.95) to Cara. Load=2.95
✓ Assigned Task#3 (analytical, L=2.81, eff=2.81) to Dan. Load=2.81
✓ Assigned Task#4 (imaginative, L=3.44, eff=3.44) to Alice. Load=5.15
🏁 Alice completed Task#0 [analytical] in 1.8s with ERROR. New Load=3.44
🏁 Bob completed Task#1 [sequential] in 3.5s with SUCCESS. New Load=0.00
🏁 Cara completed Task#2 [interpersonal] in 3.2s with SUCCESS. New Load=0.00
🏁 Dan completed Task#3 [analytical] in 3.3s with SUCCESS. New Load=0.00


KeyboardInterrupt: 