In [10]:
import base64
import time
import datetime
import requests
import json
import os, random, csv
from PIL import Image
from openai import OpenAI
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)


# CONFIGURATION
API_KEY = ""  # replace with your key
client = OpenAI(api_key=API_KEY)

base_path = "/content/drive/My Drive/CueTask"
stimuli_path = os.path.join(base_path, "stimuli")
retrocue_path = os.path.join(base_path, "retrocue_stimuli")
output_path = os.path.join(base_path, "generated_trials")
os.makedirs(output_path, exist_ok=True)

# Trial settings
N_TRIALS = 5
SET_SIZES = [2, 4, 6]

# memory grid (3 columns × 2 rows)
grid_size = (900, 600)
cell_size = (300, 300)

positions = {
    "pos1": (0, 0),
    "pos2": (300, 0),
    "pos3": (600, 0),
    "pos4": (0, 300),
    "pos5": (300, 300),
    "pos6": (600, 300)
}


# HELPER FUNCTIONS
def safe_completion_call(client, messages, max_retries=5):
    """Retry wrapper for rate limits."""
    for attempt in range(max_retries):
        try:
            return client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages
            )
        except Exception as e:
            if "rate_limit" in str(e).lower():
                wait_t = (attempt + 1) * 2
                print(f"Rate limit hit. Waiting {wait_t}s...")
                time.sleep(wait_t)
            else:
                raise
    raise Exception("Failed after multiple retries.")



# TRIAL GENERATION
def generate_trial(trial_num, run_folder):
    """Generate one trial grid with unique color-shape combos and a retrocue."""

    # --- 1) Choose set size ---
    set_size = random.choice(SET_SIZES)

    # --- 2) Unique color-shape selection ---
    all_colors = ["yellow", "purple", "pink", "orange", "green", "blue"]
    all_shapes = ["triangle", "square", "rectangle", "heart", "diamond", "circle"]

    chosen_colors = random.sample(all_colors, set_size)
    chosen_shapes = random.sample(all_shapes, set_size)
    random.shuffle(chosen_shapes)  # shuffle to de-correlate pairs

    chosen_stimuli = [f"{c}_{s}.png" for c, s in zip(chosen_colors, chosen_shapes)]

    # --- 3) Create the grid ---
    canvas = Image.new("RGB", grid_size, (255, 255, 255))
    loc_mapping = {}

    pos_names = list(positions.keys())
    random.shuffle(pos_names)

    for pos_name, file_name in zip(pos_names[:set_size], chosen_stimuli):
        stim_img = Image.open(os.path.join(stimuli_path, file_name)).resize(cell_size)
        canvas.paste(stim_img, positions[pos_name])
        loc_mapping[pos_name] = file_name

    # --- 4) Choose query item ---
    query_loc = random.choice(list(loc_mapping.keys()))
    query_stim = loc_mapping[query_loc]
    query_color, query_shape = query_stim.replace(".png", "").split("_")

    # --- 5) Retrocue logic ---
    match = random.choice([True, False])
    retrocue_files = [f for f in os.listdir(retrocue_path) if f.endswith(".png")]
    cue_type = random.choice(["object", "feature"])

    # MATCH TRIAL
    if match:
        if cue_type == "object":
            retrocue_candidates = [f"obj_{query_color}_{query_shape}.png"]
        else:
            retrocue_candidates = [f"{query_color}.png", f"{query_shape}.png"]

        retrocue_file = random.choice([f for f in retrocue_candidates if f in retrocue_files])
        cue_location = query_loc
        correct_answer = "yes"

    # MISMATCH TRIAL
    else:
        invalid_terms = {query_color, query_shape, f"obj_{query_color}_{query_shape}"}
        valid_retrocues = [f for f in retrocue_files if not any(term in f for term in invalid_terms)]
        if not valid_retrocues:
            raise ValueError("No suitable 'no' retrocue found. Check stimuli naming.")
        retrocue_file = random.choice(valid_retrocues)

        cue_location = random.choice([loc for loc in positions.keys() if loc != query_loc])
        correct_answer = "no"

    # --- 6) Save grid image ---
    trial_filename = f"trial_{trial_num}.png"
    trial_path = os.path.join(run_folder, trial_filename)
    canvas.save(trial_path)

    return {
        "trial": trial_num,
        "set_size": set_size,
        "trial_image": trial_path,
        "retrocue_file": retrocue_file,
        "query_location": query_loc,
        "query_stimulus": query_stim,
        "retrocue_stimulus": retrocue_file,
        "correct_answer": correct_answer,
        "cue_type": cue_type,
        "cue_location": cue_location
    }



# RUN GPT RESPONSE
def encode_image_base64(image_path):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


def run_trial(trial_info):
    grid_b64 = encode_image_base64(trial_info["trial_image"])
    retrocue_path_full = os.path.join(retrocue_path, trial_info["retrocue_file"])
    retrocue_b64 = encode_image_base64(retrocue_path_full)

    # Start timer
    start_time = time.time()

    response = safe_completion_call(client, [
        {"role": "system", "content": "You are performing a visual working memory task."},
        {"role": "user",
         "content": [
            # Instructions FIRST
            {"type": "text", "text":
                "You are completing a visual working memory task.\n"
                "A memory array will appear containing color–shape objects arranged in a grid.\n"
                "Then, you will see a retrocue image.\n\n"
                "A retrocue is considered a match to a location if it matches the item originally "
                "in that location in ANY of the following ways:\n"
                "  (1) exact object (same color AND same shape),\n"
                "  (2) same color only, or\n"
                "  (3) same shape only.\n\n"
                "When asked, respond ONLY 'yes' if the retrocue matches the item from that location "
                "in any of these ways. Otherwise, respond 'no'.\n\n"
                "Let’s begin."
            },

            # Show memory array
            {"type": "text", "text": "Here is the memory array:"},
            {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{grid_b64}"}},

            # Show retrocue
            {"type": "text", "text": "Here is the retrocue:"},
            {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{retrocue_b64}"}},

            # Ask final question
            {"type": "text",
             "text":
                f"Does the retrocue correspond to the item that was originally located in the "
                f"{trial_info['query_location']} position? "
                "Answer ONLY 'yes' or 'no'."
            }
         ]
        }
    ])
    # End timer
    end_time = time.time()
    reaction_time = end_time - start_time  # in seconds

    gpt_answer = response.choices[0].message.content.strip().lower()
    gpt_answer = gpt_answer.replace(".", "").replace("!", "").strip()
    return gpt_answer,reaction_time



# MAIN SCRIPT
def main():
    run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    run_folder = os.path.join(output_path, f"run_{run_id}")
    os.makedirs(run_folder, exist_ok=True)

    results_file = os.path.join(run_folder, f"results_{run_id}.csv")

    with open(results_file, "w", newline="") as csvfile:
        fieldnames = [
            "trial", "set_size", "trial_image", "retrocue_file",
            "query_location", "query_stimulus", "retrocue_stimulus",
            "correct_answer", "gpt_response", "accuracy",
            "cue_type", "cue_location", "reaction_time"
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for t in range(1, N_TRIALS + 1):
            trial_info = generate_trial(t, run_folder)
            gpt_answer, rt = run_trial(trial_info)
            accuracy = (gpt_answer == trial_info["correct_answer"])

            trial_info["gpt_response"] = gpt_answer
            trial_info["accuracy"] = accuracy
            trial_info["reaction_time"] = rt

            writer.writerow(trial_info)
            print(f"Trial {t} | Size {trial_info['set_size']} | GPT={gpt_answer} | Correct={trial_info['correct_answer']} | RT={rt:.2f}s")

    print(f"\nExperiment finished! Results saved in {run_folder}")


if __name__ == "__main__":
    main()


Mounted at /content/drive
Trial 1 | Size 2 | GPT=no | Correct=no | RT=0.80s
Trial 2 | Size 4 | GPT=yes | Correct=yes | RT=1.45s
Trial 3 | Size 6 | GPT=no | Correct=no | RT=2.42s
Trial 4 | Size 6 | GPT=yes | Correct=no | RT=1.31s
Trial 5 | Size 4 | GPT=no | Correct=no | RT=1.17s

Experiment finished! Results saved in /content/drive/My Drive/CueTask/generated_trials/run_20251205_222008
