##### Setup Environment

In [None]:
import boto3
import os
import xml.etree.ElementTree as ET
import ipywidgets as widgets
from IPython.display import display, clear_output
import json

# Unset AWS_PROFILE for both current process and child processes
os.environ.pop("AWS_PROFILE", None)
# this is sandbox environment @Andy-LZH

keys_id = ""
keys = ""
# read keys from .env file
with open(".env", "r") as f:
    lines = f.readlines()
    for line in lines:
        if line.startswith("AWS_ACCESS_KEY_ID"):
            keys_id = line.strip().split("=")[1]
            keys_id = keys_id.replace('"', '')  # remove quotes if present
        elif line.startswith("AWS_SECRET_ACCESS_KEY"):
            keys = line.strip().split("=")[1]
            keys = keys.replace('"', '')  # remove quotes if present


# @Andy-LZH's account

# mturk_environment = {
#     "endpoint": "https://mturk-requester-sandbox.us-east-1.amazonaws.com",
#     "preview": "https://workersandbox.mturk.com/mturk/preview",
#     "manage": "https://requestersandbox.mturk.com/mturk/manageHITs",
#     "reward": "0",
# }

# real environment
mturk_environment = {
    "endpoint": "https://mturk-requester.us-east-1.amazonaws.com",
    "preview": "https://www.mturk.com/mturk/preview",
    "manage": "https://requester.mturk.com/mturk/manageHITs",
}



mturk = boto3.client(
    "mturk",
    aws_access_key_id=keys_id,
    aws_secret_access_key=keys,
    region_name="us-east-1",
    endpoint_url=mturk_environment["endpoint"],
)

print("I have $" + mturk.get_account_balance()["AvailableBalance"] + " in my  account")

mturk.get_account_balance()


# === Helper Functions ===
def list_all_hits():
    hits = []
    next_token = None

    while True:
        if next_token:
            response = mturk.list_hits(MaxResults=100, NextToken=next_token)
        else:
            response = mturk.list_hits(MaxResults=100)

        hits.extend(response["HITs"])

        # Check if there's another page
        next_token = response.get("NextToken")
        if not next_token:
            break

    return hits

##### List all HITs in my account

In [None]:
mturk_hits = list_all_hits()
print("The number of HITs in my account is {}".format(len(mturk_hits)))
for hit in mturk_hits:
    print("HIT ID: {}".format(hit["HITId"]))
    print("HIT Title: {}".format(hit["Title"]))
    print("HIT Status: {}".format(hit["HITStatus"]))
    print("HIT Creation Time: {}".format(hit["CreationTime"]))
    print("HIT Reward: {}".format(hit["Reward"]))

##### Expire all HITs in my account(frequently used for testing)

In [None]:
mturk_hits = list_all_hits()
# delete all HITs in my account
for hit in mturk_hits:
    print("Deleting HIT ID: {}".format(hit["HITId"]))
    # check status before deletion
    mturk.update_expiration_for_hit(
        HITId=hit["HITId"],
        ExpireAt="1970-01-01T00:00:00Z"  # set expiration to a date in the past
    )
    if hit["HITStatus"] == "Reviewable" or hit["HITStatus"] == "Unassignable":
        print("HIT ID {} is {}, skipping deletion.".format(hit["HITId"],hit["HITStatus"]))
        continue
    mturk.delete_hit(HITId=hit["HITId"])
    print("HIT ID {} deleted.".format(hit["HITId"]))
# list all HITs again to confirm deletion
mturk_hits = mturk.list_hits(MaxResults=100)
print("The number of HITs in my account after deletion is {}".format(len(mturk_hits["HITs"])))

##### Approve and retrieve HITs

In [None]:
import os
import json
import xml.etree.ElementTree as ET
import boto3
import ipywidgets as widgets
from IPython.display import display, clear_output
from collections import defaultdict
import re

# === Config ===
DATA_DIR = "data/HITs"  # folder where all JSONs are stored

def organize_json_files_by_category():
    """Organize JSON files by category and task type"""
    json_files = [f for f in os.listdir(DATA_DIR) if f.endswith(".json")]
    
    organized_files = defaultdict(lambda: defaultdict(list))
    
    for filename in json_files:
        # Extract task type and category from filename
        # Common patterns: "spin_test_parts_Category_environment.json", etc.
        base_name = filename.replace('.json', '')
        
        # Pattern matching for different file formats
        patterns = [
            # New format: {agreeTest, main}_{subpart_category}_{sandbox, live}
            r'(agreeTest|main)_([^_]+(?:-[^_]+)*)_(live|sandbox)',  # agreeTest_Biped-Arm-Upperarm_sandbox.json, main_Vehicle-Car_live.json
            r'(spin_\w+_\w+)_(\w+)_(live|sandbox)',  # spin_test_parts_Quadruped_live.json
            r'(spin_\w+_\w+)_(\w+)',                 # spin_test_parts_Quadruped.json
            r'(\w+Test_\w+)_(\w+)_(live|sandbox)',   # QualificationTest_parts_Quadruped_live.json
            r'(\w+Test_\w+)_(\w+)',                  # QualificationTest_parts_Quadruped.json
            r'main_(\w+)_(live|sandbox)',           # main_Quadruped_live.json
        ]
        
        task_type = "other"
        category = "uncategorized"
        environment = "unknown"
        
        for pattern in patterns:
            match = re.search(pattern, base_name)
            if match:
                task_type = match.group(1)
                category = match.group(2)
                if len(match.groups()) >= 3:
                    environment = match.group(3)
                break
        
        # Create a display name that shows the organization
        display_name = f"{category} - {task_type}"
        if environment != "unknown":
            display_name += f" ({environment})"
        
        organized_files[category][task_type].append({
            'filename': filename,
            'display_name': display_name,
            'task_type': task_type,
            'category': category,
            'environment': environment
        })
    
    return organized_files

def create_sorted_file_options():
    """Create a sorted list of file options for the dropdown"""
    organized_files = organize_json_files_by_category()
    
    options = []
    
    # Sort categories alphabetically
    for category in sorted(organized_files.keys()):
        category_files = organized_files[category]
        
        # Add category header (disabled option)
        options.append((f"üìÅ {category.upper()}", ""))
        
        # Sort task types within each category
        for task_type in sorted(category_files.keys()):
            files = category_files[task_type]
            
            # Sort files within each task type
            for file_info in sorted(files, key=lambda x: x['filename']):
                display_name = f"  ‚îî‚îÄ {file_info['display_name']}"
                options.append((display_name, file_info['filename']))
    
    return options

# Get organized file options
file_options = create_sorted_file_options()

# === MTurk Client ===
# mturk = boto3.client(
#     "mturk",
#     region_name="us-east-1",
#     aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
#     aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
#     endpoint_url="https://mturk-requester.us-east-1.amazonaws.com",
# )

# === Global State ===
local_data = []
assignment_lookup = {}


# === Functions ===
def load_json_file(filename):
    path = os.path.join(DATA_DIR, filename)
    with open(path, "r") as f:
        return json.load(f)


def list_reviewable_hits():
    hits = []
    for hit in local_data:
        HITId = hit.get("HITId")
        if not HITId:
            continue
        try:
            HIT = mturk.get_hit(HITId=HITId)["HIT"]
            HITStatus = HIT.get("HITStatus", "Unknown")
            print(f"HIT {HITId} status: {HITStatus}")
            # accept if reviewable OR if it has any assignments
            assignments = mturk.list_assignments_for_hit(HITId=HITId).get(
                "Assignments", []
            )
            if assignments:
                hits.append(hit)
        except Exception as e:
            print(f"Error fetching HIT {HITId}: {e}")
    return hits


def list_assignments_for_hit(hit_id):
    for hit in local_data:
        if hit.get("HITId") == hit_id:
            return mturk.list_assignments_for_hit(HITId=hit_id).get("Assignments", [])
    return []


def parse_answer_xml(xml_string):
    """Parse XML to extract both answers and issue texts"""
    if not xml_string:
        return [], {}
    try:
        root = ET.fromstring(xml_string)
        ns = {
            "ns": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"
        }

        # Try to find the submissionData field first (new format)
        submission_field = root.find(
            ".//ns:Answer[ns:QuestionIdentifier='submissionData']/ns:FreeText", ns
        )

        if submission_field is not None:
            # New format with both answers and issueTexts
            submission_json = submission_field.text
            submission_data = json.loads(submission_json)
            answers = list(submission_data.get("answers", {}).values())
            issue_texts = submission_data.get("issueTexts", {})
            return answers, issue_texts
        else:
            # Fall back to old format (answers only)
            answers_field = root.find(
                ".//ns:Answer[ns:QuestionIdentifier='answers']/ns:FreeText", ns
            )
            if answers_field is not None:
                answers_json = answers_field.text
                answers_dict = json.loads(answers_json)
                answers = list(answers_dict.values())
                return answers, {}
            else:
                return [], {}
    except (ET.ParseError, json.JSONDecodeError) as e:
        print(f"Error parsing XML or JSON: {e}")
        return [], {}


def print_assignment_preview(assignment):
    """Print assignment information in a formatted way"""
    id_status = {
        "-1": "Something is wrong",
        "0": "One Instance",
        "1": "Multiple Instances",
    }
    hit_id = assignment.get("HITId")
    hit = next((h for h in local_data if h.get("HITId") == hit_id), None)
    if not hit:
        print("‚ùå HIT not found in local data.")
        return

    worker_id = assignment.get("WorkerId", "Unknown")
    assignment_id = assignment.get("AssignmentId", "Unknown")
    task = hit.get("task", "Unknown")
    groupIndex = hit.get("group_index", "Unknown")
    category = hit.get("categories", "Unknown")
    environment = hit.get("environment", "sandbox")

    print(f"\nüìã Assignment Preview:")
    print(f"{'='*50}")
    print(f"HIT ID: {hit_id}")
    print(f"Assignment ID: {assignment_id}")
    print(f"Worker ID: {worker_id}")
    print(f"Task: {task}")
    print(f"Category: {category}")
    print(f"Group Index: {groupIndex}")
    print(f"Environment: {environment}")

    answer_raw = assignment.get("Answer", "")
    answers, issue_texts = parse_answer_xml(answer_raw)

    if hasattr(local_data, "__getitem__") and groupIndex < len(local_data):
        annotations = local_data[groupIndex]["annotations"]
        print(f"\nSubmitted Answers:")
        print(f"{'-'*30}")
        for i, answer in enumerate(answers):
            label = id_status.get(answer, f"Unknown ({answer})")
            annotation_id = annotations[i]["id"] if i < len(annotations) else "Unknown"
            print(f"  Instance {i + 1}: {label} (ID: {annotation_id})")

            # Show issue text if this instance has "Something is wrong" selected
            if answer == "-1" and str(i) in issue_texts:
                print(f"    üí¨ Issue: {issue_texts[str(i)]}")
    else:
        print(f"\nSubmitted Answers:")
        print(f"{'-'*30}")
        for i, answer in enumerate(answers):
            label = id_status.get(answer, f"Unknown ({answer})")
            print(f"  Instance {i + 1}: {label}")

            # Show issue text if this instance has "Something is wrong" selected
            if answer == "-1" and str(i) in issue_texts:
                print(f"    üí¨ Issue: {issue_texts[str(i)]}")

    # Show summary of issues if any exist
    if issue_texts:
        print(f"\nüö® Issues Reported:")
        print(f"{'-'*30}")
        for instance_idx, issue_text in issue_texts.items():
            print(f"  Instance {int(instance_idx) + 1}: {issue_text}")

    print(f"{'='*50}\n")


def format_assignment_preview(assignment):
    id_status = {
        "-1": "Something is wrong",
        "0": "One Instance",
        "1": "Multiple Instances",
    }
    hit_id = assignment.get("HITId")
    hit = next((h for h in local_data if h.get("HITId") == hit_id), None)
    if not hit:
        return "<i>HIT not found in local data.</i>"

    s3 = boto3.client(
        "s3",
        aws_access_key_id="",
        aws_secret_access_key="",
        region_name="us-east-1",
    )
    bucket_name = "spin-instance"

    worker_id = assignment.get("WorkerId", "Unknown")
    assignment_id = assignment.get("AssignmentId", "Unknown")
    task = hit.get("task", "Unknown")
    groupIndex = hit.get("group_index", "Unknown")
    submit_time = assignment.get("SubmitTime", "Unknown")
    accept_time = assignment.get("AcceptTime", "Not submitted")
    category = hit.get("categories", "Unknown")
    environment = "live"

    # Calculate duration if both times are available
    duration_str = "Unknown"
    if submit_time != "Not submitted" and accept_time != "Unknown":
        try:
            from datetime import datetime
            submit_dt = datetime.fromisoformat(submit_time.replace('Z', '+00:00'))
            accept_dt = datetime.fromisoformat(accept_time.replace('Z', '+00:00'))
            duration = submit_dt - accept_dt
            duration_str = str(duration).split('.')[0]  # Remove microseconds
        except:
            duration_str = "Could not calculate"

    base_url = "https://andy-lzh.github.io/InstanceSpot-Frontend/"
    url = f"{base_url}?task={task}&category={category}&groupIndex={groupIndex}&sandbox=False&review=True&assignmentId={assignment_id}"

    answer_raw = assignment.get("Answer", "")
    answers, issue_texts = parse_answer_xml(answer_raw)
    annotations = local_data[groupIndex]["annotations"]

    # in answers if answers not equal to [0, 1, -1] then shrink the length
    if len(answers) != len(annotations):
        print(f"Warning: Number of answers ({len(answers)}) does not match number of annotations ({len(annotations)}). Adjusting to match.")
        if len(answers) > len(annotations):
            answers = answers[: len(annotations)]
        else:
            answers.extend(["-1"] * (len(annotations) - len(answers)))
    

    html = f"<h3>HIT ID: {hit_id}</h3>"
    html += f"<b>AssignmentId:</b> {assignment_id}<br>"
    html += f"<b>WorkerId:</b> {worker_id}<br>"
    html += f"<b>Accept Time:</b> {accept_time}<br>"
    html += f"<b>Completion Time:</b> {submit_time}<br>"
    html += f"<b>Duration:</b> {duration_str}<br>"
    html += f"<b>Access URL:</b> <a href='{url}' target='_blank'>Manage HIT</a><br>"
    html += "<b>Submitted Answers:</b><br><ul>"

    instances_results = []
    for i, answer in enumerate(answers):
        label = id_status.get(answer, f"Unknown ({answer})")
        try:
            # in mturk
            instances_results.append(int(answer))
        except (ValueError, TypeError):
            instances_results.append(-1)
        html += (
            f"<li><b>Instance {i + 1}:</b> {label}; <i>id: {annotations[i]['id']}</i>"
        )

        # Add issue text if this instance has "Something is wrong" selected
        if answer == "-1" and str(i) in issue_texts:
            html += (
                f"<br>&nbsp;&nbsp;&nbsp;&nbsp;üí¨ <i>Issue: {issue_texts[str(i)]}</i>"
            )

        html += "</li>"
    html += "</ul>"

    # Add issues summary if any exist
    if issue_texts:
        html += "<b>Issues Reported:</b><br><ul>"
        for instance_idx, issue_text in issue_texts.items():
            html += f"<li><b>Instance {int(instance_idx) + 1}:</b> {issue_text}</li>"
        html += "</ul>"

    # TODO how to make sure assignment_id is right
    file_name = (
        f"HITs/{category}/{task}/{environment}/group_{groupIndex}_{assignment_id}.json"
    )
    s3_data = local_data[groupIndex]
    s3_data["assignment_id"] = assignment_id
    s3_data["worker_id"] = worker_id

    # Store both results and issue texts
    for i, instance in enumerate(instances_results):
        s3_data["annotations"][i]["result"] = instance
        if str(i) in issue_texts:
            s3_data["annotations"][i]["issue_text"] = issue_texts[str(i)]

    s3.put_object(
        Bucket=bucket_name, Key=file_name, ACL="public-read", Body=json.dumps(s3_data)
    )
    return html


def approve_assignment(assignment_id, feedback):
    try:
        assignment = assignment_lookup[assignment_id]
        worker_id = assignment.get("WorkerId", "Unknown")

        # Print assignment preview before approving
        print_assignment_preview(assignment)

        mturk.approve_assignment(AssignmentId=assignment_id, RequesterFeedback=feedback)
        print(f"‚úÖ Approved assignment: {assignment_id}")
    except Exception as e:
        print(f"Error approving assignment {assignment_id}: {e}")


def reject_assignment(assignment_id, feedback):
    try:
        assignment = assignment_lookup[assignment_id]

        # Print assignment preview before rejecting
        print_assignment_preview(assignment)

        mturk.reject_assignment(AssignmentId=assignment_id, RequesterFeedback=feedback)
        print(f"‚ùå Rejected assignment: {assignment_id}")
    except Exception as e:
        print(f"Error rejecting assignment {assignment_id}: {e}")


# === UI Widgets ===
file_dropdown = widgets.Dropdown(
    options=file_options, 
    description="Select File:", 
    layout=widgets.Layout(width="100%")
)
hit_dropdown = widgets.Dropdown(
    description="Select HIT:", layout=widgets.Layout(width="100%")
)
assignment_dropdown = widgets.Dropdown(description="Assignment:")
feedback_box = widgets.Textarea(
    description="Feedback:",
    value="Thank you!",
    layout=widgets.Layout(width="100%", height="60px"),
)
assignment_info_box = widgets.HTML()
output_area = widgets.Output()

approve_button = widgets.Button(description="‚úÖ Approve", button_style="success")
approve_all_button = widgets.Button(
    description="‚úÖ Approve all in json", button_style="success"
)

reject_button = widgets.Button(description="‚ùå Reject", button_style="danger")


# === Event Handlers ===
def on_file_change(change):
    global local_data
    filename = change["new"]
    
    # Skip if it's a category header (empty value)
    if not filename:
        return
        
    with output_area:
        clear_output()
        print(f"üìÅ Loading file: {filename}")
    local_data = load_json_file(filename)
    hits = list_reviewable_hits()
    if hits:
        hit_dropdown.options = [
            (f"{hit.get('HITTitle', 'Untitled HIT')} ({hit['HITId']})", hit["HITId"])
            for hit in hits
        ]
        hit_dropdown.value = hits[0]["HITId"]
        on_hit_change({"new": hits[0]["HITId"]})
        approve_all_button.disabled = False
        approve_button.disabled = False
        reject_button.disabled = False
    else:
        hit_dropdown.options = [("No HITs", "")]
        assignment_dropdown.options = [("No assignments", "")]
        assignment_info_box.value = "<i>No HITs in this file.</i>"
        approve_all_button.disabled = True
        approve_button.disabled = True
        reject_button.disabled = True


def on_hit_change(change):
    hit_id = change["new"]
    assignments = list_assignments_for_hit(hit_id)
    if assignments:
        assignment_dropdown.options = [
            (a["AssignmentId"], a["AssignmentId"]) for a in assignments
        ]
        assignment_lookup.clear()
        for a in assignments:
            assignment_lookup[a["AssignmentId"]] = a
        assignment_dropdown.value = assignments[0]["AssignmentId"]
        on_assignment_change({"new": assignments[0]["AssignmentId"]})
    else:
        assignment_dropdown.options = [("No submitted assignments", "")]
        assignment_info_box.value = "<i>No assignments yet for this HIT.</i>"


def on_assignment_change(change):
    assignment_id = change["new"]
    if assignment_id in assignment_lookup:
        a = assignment_lookup[assignment_id]
        assignment_info_box.value = f"<h3>Assignment ID: {assignment_id}</h3>"
        assignment_info_box.value = format_assignment_preview(a)

        # Also print the assignment preview in the output area
        with output_area:
            clear_output()
            print_assignment_preview(a)
    else:
        assignment_info_box.value = "<i>Select a valid assignment to view details.</i>"


def on_approve_click(_):
    with output_area:
        clear_output()
        if assignment_dropdown.value:
            approve_assignment(assignment_dropdown.value, feedback_box.value)


def approve_all_hit_in_json(feedback):
    hits_in_files = list_reviewable_hits()

    if not hits_in_files:
        print("No reviewable HITs found.")
        return
    count = 0
    for hit_item in hits_in_files:
        hit_id = hit_item["HITId"]
        hit = next((h for h in local_data if h.get("HITId") == hit_id), None)
        assignments = list_assignments_for_hit(hit["HITId"])
        for a in assignments:
            try:
                # Print assignment preview before processing
                # print_assignment_preview(a)

                task = hit.get("task", "Unknown")
                groupIndex = hit.get("group_index", "Unknown")
                category = hit.get("categories", "Unknown")
                environment = "live"

                aid = a.get("AssignmentId", "Unknown")
                # get answers there
                answer_raw = a.get("Answer", "")
                answers, issue_texts = parse_answer_xml(answer_raw)
                annotations = local_data[groupIndex]["annotations"]
                if len(answers) != len(annotations):
                    print(f"Warning: Number of answers ({len(answers)}) does not match number of annotations ({len(annotations)}). Adjusting to match.")
                if len(answers) > len(annotations):
                    answers = answers[: len(annotations)]
                else:
                    answers.extend(["-1"] * (len(annotations) - len(answers)))
        
                for i, answer in enumerate(answers):
                    local_data[groupIndex]["annotations"][i]["result"] = int(answer)
                    # Store issue text if provided
                    if str(i) in issue_texts:
                        local_data[groupIndex]["annotations"][i]["issue_text"] = (
                            issue_texts[str(i)]
                        )

                # write to s3
                file_name = f"HITs/{category}/{task}/{environment}/group_{groupIndex}_{aid}.json"
                s3_bucket = "spin-instance"
                s3_client = boto3.client(
                "s3",
                aws_access_key_id="",
                aws_secret_access_key="",
                region_name="us-east-1",
            )
                s3_client.put_object(
                    Bucket=s3_bucket,
                    Key=file_name,
                    ContentType="application/json",
                    ACL="public-read",
                    Body=json.dumps(local_data[groupIndex]),
                )
                if s3_client.put_object:
                    print(f"‚úÖ Successfully uploaded {file_name} to S3.")
                    count += 1
                    # if in reviewable status then approve
                    status = mturk.get_assignment(AssignmentId=aid)["Assignment"][
                        "AssignmentStatus"
                    ]
                    print(f"Approving assignment {aid} with status {status}.")
                    if status == "Reviewable":
                        mturk.approve_assignment(AssignmentId=aid)
            except Exception as e:
                print(f"Error approving assignment {aid}: {e}")

    print(f"‚úÖ Approved {count} assignments for HIT {hit['HITId']}.")


def on_approve_all_click(_):
    with output_area:
        clear_output()
        if hit_dropdown.value:
            approve_all_hit_in_json(feedback_box.value)


def on_reject_click(_):
    with output_area:
        clear_output()
        if assignment_dropdown.value:
            reject_assignment(assignment_dropdown.value, feedback_box.value)


# Attach event handlers
file_dropdown.observe(on_file_change, names="value")
hit_dropdown.observe(on_hit_change, names="value")
assignment_dropdown.observe(on_assignment_change, names="value")
approve_button.on_click(on_approve_click)
approve_all_button.on_click(on_approve_all_click)
reject_button.on_click(on_reject_click)

# === Display the Full Panel ===
print("üìä Local MTurk HIT Review Panel")
print(f"Found {len([opt for opt in file_options if opt[1] != ''])} JSON files organized by category")

display(file_dropdown)
display(hit_dropdown)
display(assignment_dropdown)
display(assignment_info_box)
display(feedback_box)
display(widgets.HBox([approve_button, reject_button]))
display(output_area)
display(approve_all_button)

# Auto-select the first valid file
valid_files = [opt for opt in file_options if opt[1] != '']
if valid_files:
    file_dropdown.unobserve(on_file_change, names="value")  # prevent double trigger
    file_dropdown.value = valid_files[0][1]  # Select first valid file
    on_file_change({"new": valid_files[0][1]})
    file_dropdown.observe(on_file_change, names="value")