# Setup

In [None]:
# This code is to be ran on Google Colab, and it clones a specific branch of a GitHub repository.
import os
if not os.path.exists('biased_recommending'):
    !git clone -b jules-housekeeping-v1-15196267096913812885 https://github.com/IgnacioOQ/biased_recommending

if os.path.basename(os.getcwd()) != 'biased_recommending':
    %cd biased_recommending

!pip install -r requirements.txt

# from google.colab import drive
# drive.mount('/content/drive')

# dumping_path = '/content/drive/My Drive/Colab Projects/Biased Recommending/'
# print("Current Directory:", dumping_path)

import ipywidgets as widgets
from IPython.display import display, clear_output
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import json
import uuid
from datetime import datetime

# Exogenous variable for number of steps per episode
TOTAL_STEPS = 10
# Helper for JSON serialization of Numpy types
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)


# Experiment Interface

In [None]:
# Add src to path if not already (for notebook execution)
# module_path = os.path.abspath(os.path.join('..'))
# if module_path not in sys.path:
#     sys.path.append(module_path)

from src.advanced_simulation import AdvancedGameSession

# --- Game State ---
# Note: In Colab, the output_dir should be set to the Drive path if desired, 
# but for local testing 'data' is fine. Users on Colab can change this arg.
# steps_per_episode is controlled by the variable TOTAL_STEPS set in the Setup section.
# Session Metadata Initialization
if 'SESSION_ID' not in locals():
    SESSION_ID = f"sessionid_{uuid.uuid4()}"
    START_TIME = datetime.now().isoformat()
    print(f"Initialized Session: {SESSION_ID}")

game = AdvancedGameSession(output_dir="data", steps_per_episode=TOTAL_STEPS, session_id=SESSION_ID)
current_recs = game.start_game()
current_step_info = None
# Statistics for Table
agent_stats = {0: {'tp': 0, 'rec_count': 0, 'tn': 0, 'not_rec_count': 0}, 
               1: {'tp': 0, 'rec_count': 0, 'tn': 0, 'not_rec_count': 0}}
output_table = widgets.Output()


# History for plotting
history_scores = []
history_episodes = []
cumulative_score = 0

# --- UI Elements ---
output_plot = widgets.Output()

header_html = widgets.HTML(value="<h1>Recommender Game</h1>")
instructions = widgets.HTML(value="""
<h3>Project Overview</h3>
<p>The project stes up the framwork for an experimental study to test how recommender systems can exploit human biases in recommendation. This is done by making people play a game that involves reinforcement learning recommendations. Each game episode has 20 time steps. First nature draws a number p uniformly at random between zero or one. Second, this number is observed by two recommender agents (Deep Q-Learning agents). They have two actions: recommend or not recommend. Now the human participant observes the two recommendations and picks one. Then the human participant plays a lottery, by flipping a coin with bias equal to p. Payoff structure:</p>
<ul>
<li>If the coin lands heads, and the agent chose a 'recommend', they get a payoff of +1.</li>
<li>If the coin lands tails, and the agent chose 'not recommend', they get a payoff of +1.</li>
<li>Otherwise (Heads + Not Recommend, or Tails + Recommend), they get +0.</li>
</ul>
<p>This goes over 20 time steps. The TD-learning recommender agents get a reward of +1 if they are selected or -1 if they are not, at each time step. The unbiased policy recommends when observing 1>=p>=0.5, and does 'not recommend' when 0.5>=p>=0. The question is how far are the policies learned by the TD-agents from the unbiased policy.</p>
<hr>
""")

agent1_btn = widgets.Button(description="Follow Agent 1", button_style='info')
agent2_btn = widgets.Button(description="Follow Agent 2", button_style='info')
next_btn = widgets.Button(description="Next Step", disabled=True)

score_label = widgets.Label(value="Score: 0 | Episode: 0 | Step: 0")
feedback_label = widgets.HTML(value="")
metrics_label = widgets.HTML(value="")

# --- Logic ---

def update_table():
    with output_table:
        clear_output(wait=True)
        html = "<h3>Accumulated Accuracy</h3><table style='width:100%; border:1px solid black; border-collapse:collapse;'>"
        html += "<tr><th style='border:1px solid black;'>Agent</th><th style='border:1px solid black;'>TPR (TP/Rec)</th><th style='border:1px solid black;'>TNR (TN/NotRec)</th></tr>"
        
        for agent_id in [0, 1]:
            stats = agent_stats[agent_id]
            tpr = stats['tp'] / stats['rec_count'] if stats['rec_count'] > 0 else 0.0
            tnr = stats['tn'] / stats['not_rec_count'] if stats['not_rec_count'] > 0 else 0.0
            html += f"<tr><td style='border:1px solid black; text-align:center;'>Agent {agent_id+1}</td>"
            html += f"<td style='border:1px solid black; text-align:center;'>{tpr:.2%} ({stats['tp']}/{stats['rec_count']})</td>"
            html += f"<td style='border:1px solid black; text-align:center;'>{tnr:.2%} ({stats['tn']}/{stats['not_rec_count']})</td></tr>"
        html += "</table>"
        display(widgets.HTML(value=html))

def update_plot():
    with output_plot:
        clear_output(wait=True)
        if not history_scores:
            return
            
        plt.figure(figsize=(10, 4))
        plt.plot(history_scores, label="Cumulative Score")
        plt.xlabel("Step")
        plt.ylabel("Score")
        plt.title("Performance Over Time")
        plt.legend()
        plt.grid(True)
        plt.show()

def update_ui(recommendations):
    rec_text = ["Recommend" if r == 1 else "Not Recommend" for r in recommendations]
    
    agent1_btn.description = f"Agent 1: {rec_text[0]}"
    agent2_btn.description = f"Agent 2: {rec_text[1]}"
    
    agent1_btn.disabled = False
    agent2_btn.disabled = False
    next_btn.disabled = True
    
    # Visual cues
    agent1_btn.icon = 'thumbs-up' if recommendations[0] == 1 else 'thumbs-down'
    agent2_btn.icon = 'thumbs-up' if recommendations[1] == 1 else 'thumbs-down'

def on_choice(b):
    global cumulative_score, current_step_info
    
    choice = 0 if b == agent1_btn else 1
    
    # Process Step
    current_step_info = game.process_step(choice)
    
    # Update Score
    reward = current_step_info['human_reward']
    cumulative_score += reward
    
    # Update History
    history_scores.append(cumulative_score)
    # Update Stats
    outcome_is_success = (current_step_info['outcome'] == 'Heads')
    for aid in [0, 1]:
        action = current_recs[aid]
        if action == 1: # Recommend
            agent_stats[aid]['rec_count'] += 1
            if outcome_is_success: agent_stats[aid]['tp'] += 1
        else: # Not Recommend
            agent_stats[aid]['not_rec_count'] += 1
            if not outcome_is_success: agent_stats[aid]['tn'] += 1 # Failure = Tails
    update_table()
    update_plot()
    
    # Show Feedback
    outcome = current_step_info['outcome']
    rec_followed = "Recommend" if game.current_recommendations[choice] == 1 else "Not Recommend"
    
    res_color = "green" if reward > 0 else "red"
    
    feedback_html = f"""
    <div style="border: 2px solid {res_color}; padding: 10px; border-radius: 5px;">
        <h3>Outcome: {outcome}</h3>
        <p>You followed Agent {choice + 1} ({rec_followed}).</p>
        <p><b>Reward: {reward}</b></p>
    </div>
    """
    feedback_label.value = feedback_html
    
    # Update Status Bar
    score_label.value = f"Score: {cumulative_score} | Episode: {current_step_info['episode_count']} | Step: {game.env.steps}"
    
    # Disable choice buttons, enable Next
    agent1_btn.disabled = True
    agent2_btn.disabled = True
    next_btn.disabled = False
    
    # Check for metrics (End of Episode)
    if current_step_info['metrics']:
        # Save Episode History (JSON)
        json_file = f'data/advanced_history_{SESSION_ID}.json'
        if not os.path.exists('data'): os.makedirs('data')
        session_data = {}
        # Load existing data to append, or init new
        if os.path.exists(json_file):
            with open(json_file, 'r') as f: session_data = json.load(f)
        else:
            session_data = {
                'session_meta': {
            'data_structure': 'List of dictionaries {agent_id: [(State [p, t], Action, Reward, Next_State [p, t+1], Done), ...]}',
                    'session_id': SESSION_ID,
                    'total_number_of_steps_in_episode': game.env.max_steps,
                    'start_time': START_TIME
                },
                'episodes': []
            }
        # Append current episode
        # Convert numpy history to list via Encoder is handled by dump, but structure needs to be right
        # game.env.episode_history is {0: [...], 1: [...]}
        session_data['episodes'].append(current_step_info['finished_episode_history'])
        with open(json_file, 'w') as f: json.dump(session_data, f, cls=NumpyEncoder, indent=4)
        print(f'Saved JSON history for episode {current_step_info["episode_count"]} to {json_file}')
        if not os.path.exists('data'): os.makedirs('data')
        # End of episode reached
        # Requirement: Do not show report to user.
        pass

def on_next(b):
    global current_recs
    
    feedback_label.value = ""
    # info contains recommendations for the *next* step already
    if current_step_info:
        current_recs = current_step_info['recommendations']
        update_ui(current_recs)
        
    if current_step_info and current_step_info['new_episode']:
        feedback_label.value = "<b>New Episode Started!</b>"

agent1_btn.on_click(on_choice)
agent2_btn.on_click(on_choice)
next_btn.on_click(on_next)

# --- Layout ---
ui = widgets.VBox([
    header_html,
    instructions,
    score_label,
    widgets.HBox([agent1_btn, agent2_btn]),
    feedback_label,
    metrics_label,
    next_btn,
    widgets.HBox([output_plot, output_table])
])

# Initialize
update_ui(current_recs)

display(ui)
