In [1]:
import pandas as pd
import numpy as np
import random
import re
import time
import os
import pickle
import sys
from tqdm import tqdm
from mistralai import Mistral
from scipy.special import softmax

# --- CONFIGURATION ---
API_KEY = "4aSIsAS2QW3dzgQuckWLz4hC7bN9ZBfw" 
MODEL = "magistral-small-2509"

# Experiment Parameters
N_AGENTS = 30
N_ARMS = 50
N_ROUNDS = 20
MOCK_MODE = False

# RL Parameters
RL_LEARNING_RATE = 0.5
RL_TEMPERATURE = .3

# --- CLUSTER PATHS ---
BASE_DIR = os.getcwd() 
CHECKPOINT_DF = os.path.join(BASE_DIR, "checkpoint_experiment_df.pkl")
CHECKPOINT_LOG = os.path.join(BASE_DIR, "checkpoint_message_log.pkl")
CHECKPOINT_WEIGHTS = os.path.join(BASE_DIR, "checkpoint_weights_log.pkl") # <--- NEW FILE

# Initialize Client
if not MOCK_MODE:
    client = Mistral(api_key=API_KEY)
else:
    client = None

# --- 1. HELPER FUNCTIONS ---
def get_empty_structure():
    return {'assigned_arms': [], 'chosen_arm': [], 'payoff': []}

def initialize_experiment_log(num_agents, num_iterations):
    data = []
    for _ in range(num_iterations):
        row_data = {f"Agent_{i}": get_empty_structure() for i in range(num_agents)}
        row_data['Total_Payoff'] = 0.0
        data.append(row_data)
    return pd.DataFrame(data)

def initialize_message_log():
    return []

def initialize_weights_log():
    return []

# --- 2. CLUSTER-SAFE CHECKPOINTING ---
def save_checkpoint(df, msg_log, weights_log):
    """
    Saves DF, Messages, AND the Weight Matrix History.
    Atomic writes to prevent corruption.
    """
    # Define Temps
    tmp_df = CHECKPOINT_DF + ".tmp"
    tmp_msg = CHECKPOINT_LOG + ".tmp"
    tmp_w = CHECKPOINT_WEIGHTS + ".tmp"
    
    # Write Temps
    df.to_pickle(tmp_df)
    with open(tmp_msg, "wb") as f: pickle.dump(msg_log, f)
    with open(tmp_w, "wb") as f: pickle.dump(weights_log, f)
        
    # Atomic Rename
    os.replace(tmp_df, CHECKPOINT_DF)
    os.replace(tmp_msg, CHECKPOINT_LOG)
    os.replace(tmp_w, CHECKPOINT_WEIGHTS)

def load_checkpoint(n_agents, n_rounds):
    if os.path.exists(CHECKPOINT_DF) and os.path.exists(CHECKPOINT_LOG) and os.path.exists(CHECKPOINT_WEIGHTS):
        print(f"Found checkpoint. Loading...")
        df = pd.read_pickle(CHECKPOINT_DF)
        with open(CHECKPOINT_LOG, "rb") as f: msg_log = pickle.load(f)
        with open(CHECKPOINT_WEIGHTS, "rb") as f: weights_log = pickle.load(f)
            
        start_round = 0
        for i in range(n_rounds):
            if df.at[i, "Agent_0"]['chosen_arm']: 
                start_round = i + 1
            else:
                break
        print(f"Resuming experiment from Round {start_round}...")
        return start_round, df, msg_log, weights_log
    else:
        print("No checkpoint found. Starting fresh experiment.")
        df = initialize_experiment_log(n_agents, n_rounds)
        msg_log = initialize_message_log()
        weights_log = initialize_weights_log()
        return 0, df, msg_log, weights_log

# --- 3. ENVIRONMENT & API ---
np.random.seed(42) 

# Generate Evenly Spaced Means (0-100) & Shuffle
_sorted_means = np.linspace(0, 100, N_ARMS)
TRUE_ARM_MEANS = _sorted_means.copy()
np.random.shuffle(TRUE_ARM_MEANS)

print(f"Environment Initialized (0-100). Best Arm Mean: {np.max(TRUE_ARM_MEANS):.2f}")

def get_arm_reward(arm_index):
    mean = TRUE_ARM_MEANS[arm_index]
    # Scale=5.0 for 0-100 range
    reward = np.clip(np.random.normal(loc=mean, scale=5.0), 0.0, 100.0)
    return round(reward, 2)

def call_mistral(system_prompt, user_prompt):
    if MOCK_MODE: return "MOCK_RESPONSE"
    max_retries = 5
    base_wait = 5
    
    for attempt in range(max_retries):
        try:
            response = client.chat.complete(
                model=MODEL,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ]
            )
            content = response.choices[0].message.content
            if isinstance(content, list):
                return "".join([c.text for c in content if c.type == 'text']).strip()
            return content
        except Exception as e:
            err = str(e).lower()
            wait = base_wait * (2 ** attempt)
            print(f"\n[API Warning] Attempt {attempt+1} failed: {e}")
            if "429" in err or "rate limit" in err:
                print(f"Rate Limit Hit. Sleeping {wait}s...")
            else:
                print(f"Retrying in {wait}s...")
            time.sleep(wait)
    return "ERROR"

# --- 4. CLASSES ---
class Agent:
    def __init__(self, agent_id, total_agents, total_arms, total_rounds):
        self.agent_id = f"Agent_{agent_id}"
        self.total_agents = total_agents
        self.total_arms = total_arms
        self.total_rounds = total_rounds
        self.history = []
        self.inbox = []
        
    def get_system_prompt(self):
        return (f"You are {self.agent_id}, in a {self.total_agents}-agent bandit game. "
                f"Rewards range from 0 to 100. Goal: Maximize team reward.")

    def _format_history(self):
        if not self.history: return "No history yet."
        # Returns FULL history
        return "\n".join([f"Round {r['round']}: Pulled Arm {r['arm']}, Reward: {r['payoff']:.2f}" for r in self.history])

    def generate_message(self, current_round, assigned_arms, assignment_map):
        hist_str = self._format_history()
        user_prompt = (
            f"--- ROUND {current_round} ---\n"
            f"My Assigned Arms: {assigned_arms}\n"
            f"All Agents' Assignments: {assignment_map}\n"
            f"My History:\n{hist_str}\n\n"
            "Task: Based on history and map, message ONE agent. "
            "Tell them something useful. Keep the message **short**.\n"
            "Format: 'TO: Agent_X | MSG: <content>'"
        )
        resp = call_mistral(self.get_system_prompt(), user_prompt)
        
        if MOCK_MODE: return f"Agent_{random.randint(0, self.total_agents-1)}", "Short msg"
        
        match = re.search(r"TO:\s*(Agent_\d+).*?MSG:\s*(.*)", resp, re.DOTALL | re.IGNORECASE)
        if match: return match.group(1).strip(), match.group(2).strip()
        return None, None

    def receive_message(self, sender, content):
        self.inbox.append(f"From {sender}: {content}")

    def make_choice(self, assigned_arms):
        inbox_txt = "\n".join(self.inbox) if self.inbox else "No messages."
        user_prompt = (f"Msgs:\n{inbox_txt}\n\nPick one arm from {assigned_arms}. Return ONLY the integer.")
        resp = call_mistral(self.get_system_prompt(), user_prompt)
        
        if MOCK_MODE: return random.choice(assigned_arms)
        
        nums = re.findall(r'\d+', resp)
        if nums:
            choice = int(nums[0])
            if choice in assigned_arms: return choice
        return random.choice(assigned_arms)

    def update_history(self, r, arm, payoff):
        self.history.append({'round': r, 'arm': arm, 'payoff': payoff})

class RLManager:
    # FIXED: Arguments now match the call (learning_rate, temperature)
    def __init__(self, num_arms, num_agents, learning_rate=0.5, temperature=0.3): 
        self.num_arms = num_arms
        self.num_agents = num_agents
        self.lr = learning_rate       # Store internally as self.lr
        self.temp = temperature       # Store internally as self.temp
        self.weights = np.zeros((num_agents, num_arms))
        
        # STRICT LIMIT: Forces sharing
        self.max_arms_per_agent = 5 

    def assign_arms(self):
        assignment = {f"Agent_{i}": [] for i in range(self.num_agents)}
        agent_counts = np.zeros(self.num_agents, dtype=int)
        
        # 1. Softmax
        probs_matrix = softmax(self.weights / self.temp, axis=0)
        
        # 2. Shuffle Arm Order
        arm_order = list(range(self.num_arms))
        random.shuffle(arm_order)
        
        for arm_idx in arm_order:
            col_probs = probs_matrix[:, arm_idx].copy()
            
            # --- THE FIX: MASK FULL AGENTS ---
            full_indices = np.where(agent_counts >= self.max_arms_per_agent)[0]
            col_probs[full_indices] = 0.0 
            
            # Re-normalize
            total_p = np.sum(col_probs)
            if total_p > 0:
                col_probs /= total_p
                chosen = np.random.choice(self.num_agents, p=col_probs)
            else:
                # Fallback if everyone is full
                avail = [i for i in range(self.num_agents) if agent_counts[i] < self.max_arms_per_agent]
                if not avail: avail = list(range(self.num_agents))
                chosen = random.choice(avail)

            assignment[f"Agent_{chosen}"].append(arm_idx)
            agent_counts[chosen] += 1
        
        # Safety Net (Minimal usage now)
        empty_agents = [i for i, c in enumerate(agent_counts) if c == 0]
        if empty_agents:
            donors = [i for i, c in enumerate(agent_counts) if c > 1]
            random.shuffle(empty_agents)
            for poor in empty_agents:
                if not donors: break
                rich_guy = donors[0]
                steal_idx = random.randint(0, len(assignment[f"Agent_{rich_guy}"]) - 1)
                arm_val = assignment[f"Agent_{rich_guy}"].pop(steal_idx)
                agent_counts[rich_guy] -= 1
                assignment[f"Agent_{poor}"].append(arm_val)
                agent_counts[poor] += 1
                if agent_counts[rich_guy] <= 1: donors.pop(0)

        return assignment

    def update_weights(self, agent_id_str, arm_idx, reward):
        idx = int(agent_id_str.split("_")[1])
        norm_reward = reward / 100.0
        self.weights[idx, arm_idx] += self.lr * norm_reward
        
# --- 5. MAIN EXECUTION ---
if __name__ == "__main__":
    # 1. Load with Weights Log
    start_round, experiment_df, message_log, weights_log = load_checkpoint(N_AGENTS, N_ROUNDS)
    
    # 2. Setup with RL
    agents = [Agent(i, N_AGENTS, N_ARMS, N_ROUNDS) for i in range(N_AGENTS)]
    manager = RLManager(N_ARMS, N_AGENTS, learning_rate=RL_LEARNING_RATE, temperature=RL_TEMPERATURE)
    
    # 3. Restore Agent Memory
    if start_round > 0:
        print("Restoring Memory...")
        for r in range(start_round):
            for agent in agents:
                rec = experiment_df.at[r, agent.agent_id]
                if rec['chosen_arm']:
                    agent.update_history(r, rec['chosen_arm'][0], rec['payoff'][0])
        
        # Restore Manager Weights (Crucial for RL consistency)
        if weights_log:
            print("Restoring Manager Weights...")
            # Set weights to the state at the end of the last completed round
            manager.weights = weights_log[-1].copy()
            
    print(f"Starting RL Simulation. Alpha={RL_LEARNING_RATE}, Temp={RL_TEMPERATURE}")
    
    # 4. Loop
    for t in tqdm(range(start_round, N_ROUNDS), desc="RL Progress"):
        
        assignments = manager.assign_arms()
        
        # Phase 1: Message
        msgs_this_round = []
        for agent in agents:
            tid, content = agent.generate_message(t, assignments[agent.agent_id], assignments)
            if tid and tid in assignments and tid != agent.agent_id:
                rec = {'iteration': t, 'sender': agent.agent_id, 'receiver': tid, 'message': content}
                msgs_this_round.append(rec)
                message_log.append(rec)
        
        # Deliver
        agent_map = {a.agent_id: a for a in agents}
        for m in msgs_this_round: agent_map[m['receiver']].receive_message(m['sender'], m['message'])
            
        # Phase 2: Action
        round_payoffs = []
        for agent in agents:
            my_arms = assignments[agent.agent_id]
            choice = agent.make_choice(my_arms)
            payoff = get_arm_reward(choice)
            
            agent.update_history(t, choice, payoff)
            
            # Update Manager
            manager.update_weights(agent.agent_id, choice, payoff)
            
            experiment_df.at[t, agent.agent_id] = {
                'assigned_arms': list(my_arms),
                'chosen_arm': [choice],
                'payoff': [payoff]
            }
            round_payoffs.append(payoff)
            
        # Stats
        experiment_df.at[t, 'Total_Payoff'] = np.average(round_payoffs)
        
        # SNAPSHOT THE MATRIX
        weights_log.append(manager.weights.copy())
        
        # Atomic Save
        save_checkpoint(experiment_df, message_log, weights_log)
        
    print("Done! Exporting...")
    experiment_df.to_csv("final_rl_results.csv")
    pd.DataFrame(message_log).to_csv("final_rl_messages.csv")
    
    # Save final weights history for visualization
    with open("final_weights_history.pkl", "wb") as f:
        pickle.dump(weights_log, f)
        
    print("All files saved.")

  from pandas.core import (


Environment Initialized (0-100). Best Arm Mean: 100.00
No checkpoint found. Starting fresh experiment.
Starting RL Simulation. Alpha=0.5, Temp=0.3


RL Progress:   0%|                                                                              | 0/20 [00:00<?, ?it/s]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:   5%|███▎                                                               | 1/20 [08:45<2:46:19, 525.24s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:  10%|██████▋                                                            | 2/20 [16:31<2:27:06, 490.34s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:  15%|██████████                                                         | 3/20 [23:59<2:13:31, 471.25s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:  40%|██████████████████████████                                       | 8/20 [1:07:53<1:47:24, 537.01s/it]


Rate Limit Hit. Sleeping 5s...


RL Progress:  50%|████████████████████████████████                                | 10/20 [1:27:29<1:33:52, 563.21s/it]


Rate Limit Hit. Sleeping 5s...


RL Progress:  60%|██████████████████████████████████████▍                         | 12/20 [1:45:27<1:13:30, 551.33s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:  70%|██████████████████████████████████████████████▏                   | 14/20 [2:04:08<55:46, 557.83s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:  75%|█████████████████████████████████████████████████▌                | 15/20 [2:13:04<45:55, 551.15s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:  90%|███████████████████████████████████████████████████████████▍      | 18/20 [2:40:22<18:17, 548.90s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress:  95%|██████████████████████████████████████████████████████████████▋   | 19/20 [2:48:43<08:54, 534.56s/it]


Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...

Rate Limit Hit. Sleeping 5s...


RL Progress: 100%|██████████████████████████████████████████████████████████████████| 20/20 [2:57:41<00:00, 533.10s/it]

Done! Exporting...
All files saved.





In [None]:
class RLManager:
    def __init__(self, num_arms, num_agents, lr=0.5, temp=0.3): # <--- LOWER TEMP (was 2.0)
        self.num_arms = num_arms
        self.num_agents = num_agents
        self.lr = lr
        self.temp = temp
        self.weights = np.zeros((num_agents, num_arms))
        
        # Calculate strict capacity limits to force distribution
        # If we have 50 arms and 30 agents, avg is 1.6 arms/agent.
        # We allow a buffer. Max = 5 ensures no one gets overwhelmed.
        self.max_arms_per_agent = 5 

    def assign_arms(self):
        assignment = {f"Agent_{i}": [] for i in range(self.num_agents)}
        
        # Track how many arms each agent currently has
        agent_counts = np.zeros(self.num_agents, dtype=int)
        
        # 1. Create Probability Matrix (Softmax)
        # We add a tiny epsilon to avoid division by zero errors later
        probs_matrix = softmax(self.weights / self.temp, axis=0)
        
        # 2. Shuffle Arms (Assign in random order so arm #0 isn't always first)
        arm_order = list(range(self.num_arms))
        random.shuffle(arm_order)
        
        for arm_idx in arm_order:
            # Get probs for this arm across all agents
            col_probs = probs_matrix[:, arm_idx].copy()
            
            # --- CONSTRAINT: MASK FULL AGENTS ---
            # Identify agents who hit the cap
            full_agents_indices = np.where(agent_counts >= self.max_arms_per_agent)[0]
            
            # Set their prob to 0
            col_probs[full_agents_indices] = 0.0
            
            # Renormalize so probabilities sum to 1 again
            total_p = np.sum(col_probs)
            
            if total_p > 0:
                col_probs /= total_p
                # Sample an agent
                chosen_agent = np.random.choice(self.num_agents, p=col_probs)
            else:
                # Edge Case: Everyone is full (shouldn't happen if math works)
                # or numeric instability. Fallback to random non-full agent.
                available = [i for i in range(self.num_agents) if agent_counts[i] < self.max_arms_per_agent]
                if not available: 
                    # If truly everyone is full (e.g. N_ARMS > N_AGENTS * MAX), pick random
                    available = list(range(self.num_agents))
                chosen_agent = random.choice(available)

            # Assign
            assignment[f"Agent_{chosen_agent}"].append(arm_idx)
            agent_counts[chosen_agent] += 1
        
        # 3. Safety Net (Much lighter now)
        # Since we enforced a Max Cap, the arms naturally spread out.
        # But we still check for empty agents just in case N_ARMS is small.
        empty_agents = [i for i, c in enumerate(agent_counts) if c == 0]
        
        if empty_agents:
            # Identify agents who have > 1 arm (can afford to give one up)
            donors = [i for i, c in enumerate(agent_counts) if c > 1]
            random.shuffle(empty_agents)
            
            for poor in empty_agents:
                if not donors: break
                
                # Find a donor
                rich_guy = donors[0]
                
                # Pick a random arm to steal (not necessarily the last one)
                steal_idx = random.randint(0, len(assignment[f"Agent_{rich_guy}"]) - 1)
                arm_val = assignment[f"Agent_{rich_guy}"].pop(steal_idx)
                agent_counts[rich_guy] -= 1
                
                # Give to poor
                assignment[f"Agent_{poor}"].append(arm_val)
                agent_counts[poor] += 1
                
                if agent_counts[rich_guy] <= 1:
                    donors.pop(0)

        return assignment

    def update_weights(self, agent_id_str, arm_idx, reward):
        idx = int(agent_id_str.split("_")[1])
        norm_reward = reward / 100.0
        self.weights[idx, arm_idx] += self.lr * norm_reward