In [None]:
from psychopy import visual, core, event
import random
import csv
import os
from datetime import datetime

# =========================
# Config & Globals
# =========================
WIN_SIZE = (1920, 1080)
FULLSCR = True
UNITS = "pix"
BG_COLOR = (0,0,0)

# Stimulus layout
FIELD_W, FIELD_H = 900, 600
DOT_RADIUS = 14
DOT_MIN_DIST = 2*DOT_RADIUS + 6

# Counts (keep single-digit recall)
GREEN_RANGE = (3, 9)
YELLOW_RANGE = (3, 9)

# Design
PRACTICE_SET_SIZES = [1]
CSPAN_SET_SIZES = [3]
SETS_PER_SIZE = 1

# Timing
FEEDBACK_DURATION = 2.0  # sec

# Filenames (absolute paths)
name = 'test'
triallog_path = os.path.abspath(f"{name}_trials.csv")
summary_path = os.path.abspath(f"{name}_summary.csv")

# =========================
# Window & IO
# =========================
win = visual.Window(size=WIN_SIZE, units=UNITS, fullscr=FULLSCR, color=BG_COLOR)
mouse = event.Mouse(win=win, visible=True)

# Text styles
text_kwargs = dict(color='white', height=32, alignText='center', wrapWidth=1500)
small_text_kwargs = dict(color='white', height=24, alignText='center', wrapWidth=1600)

msg = visual.TextStim(win, text='', **text_kwargs)
msg_small = visual.TextStim(win, text='', **small_text_kwargs)

# =========================
# Utilities
# =========================
def check_escape():
    if 'escape' in event.getKeys(['escape']):
        win.close(); core.quit()

def show_instructions(text):
    mouse.clickReset()
    while True:
        check_escape()
        msg.text = text
        msg.draw()
        msg_small.text = "Click to continue"
        msg_small.pos = (0, -win.size[1]//2 + 100)
        msg_small.draw()
        win.flip()
        buttons, times = mouse.getPressed(getTime=True)
        if any(t > 0 for t in times):
            core.wait(0.12) 
            return
        core.wait(0.01)

class Button:
    def __init__(self, win, text, pos, size=(220,70)):
        self.rect = visual.Rect(win, width=size[0], height=size[1], pos=pos,
                                fillColor=(-0.2,-0.2,-0.2), lineColor=(0.7,0.7,0.7))
        self.label = visual.TextStim(win, text=text, pos=pos, color='white', height=28)
    def draw(self):
        self.rect.draw(); self.label.draw()
    def contains(self, mouse):
        return self.rect.contains(mouse)

# =========================
# Dot field generation & drawing
# =========================
def _random_position(bounds_w, bounds_h):
    return (random.uniform(-bounds_w/2.0, bounds_w/2.0),
            random.uniform(-bounds_h/2.0, bounds_h/2.0))

def _far_enough(pt, pts, min_dist):
    for (x,y) in pts:
        if (pt[0]-x)**2 + (pt[1]-y)**2 < (min_dist**2):
            return False
    return True

def generate_dot_positions(n_total, radius=DOT_RADIUS, field_w=FIELD_W, field_h=FIELD_H,
                           min_dist=DOT_MIN_DIST, max_tries=5000):
    positions, tries = [], 0
    while len(positions) < n_total and tries < max_tries:
        tries += 1
        p = _random_position(field_w, field_h)
        if _far_enough(p, positions, min_dist):
            positions.append(p)
    if len(positions) < n_total:
        while len(positions) < n_total:
            p = _random_position(field_w, field_h)
            if _far_enough(p, positions, max(min_dist*0.7, radius*2)):
                positions.append(p)
    return positions

def draw_dot_field(win, green_positions, yellow_positions, radius=DOT_RADIUS):
    for (x,y) in yellow_positions:
        visual.Circle(win, radius=radius, pos=(x,y), fillColor='yellow', lineColor='yellow').draw()
    for (x,y) in green_positions:
        visual.Circle(win, radius=radius, pos=(x,y), fillColor='green', lineColor='green').draw()

# =========================
# CSPAN card display & recall UI
# =========================
def show_card_and_wait(green_pos, yellow_pos):
    mouse.clickReset()
    while True:
        check_escape()
        msg.text = "Count the GREEN dots. Ignore YELLOW.\n\nClick to continue when you finish counting."
        msg.pos = (0, 500); msg.draw(); msg.pos = (0, 0)
        draw_dot_field(win, green_pos, yellow_pos)
        win.flip()
        _, times = mouse.getPressed(getTime=True)
        if any(t > 0 for t in times):
            rt = min(t for t in times if t > 0)
            core.wait(0.12)  # debouncing
            return rt
        core.wait(0.01)

def recall_counts_ui(set_size):
    nums = [str(i) for i in range(10)]
    cols = 5; row_h = 90; col_w = 120
    origin_x = - (cols-1)/2 * col_w
    origin_y = + 40
    num_buttons = []
    for idx, t in enumerate(nums):
        r = idx//cols; c = idx%cols
        pos = (origin_x + c*col_w, origin_y - r*row_h)
        num_buttons.append(Button(win, t, pos, size=(100,60)))
    blank_btn = Button(win, "blank", pos=(-200, -240), size=(160,60))
    clear_btn = Button(win, "clear", pos=(0, -240), size=(160,60))
    exit_btn  = Button(win, "Exit",  pos=(200, -240), size=(160,60))

    recalled = []
    mouse.clickReset()
    was_down = False
    while True:
        check_escape()
        msg_small.text = f"Recall the green-dot counts in order (remaining: {set_size - len(recalled)})"
        msg_small.pos = (0, 300); msg_small.draw()
        for b in num_buttons: b.draw()
        for b in [blank_btn, clear_btn, exit_btn]: b.draw()
        prev = " ".join([str(x) if x is not None else "_" for x in recalled])
        msg.text = prev
        msg.pos = (0, 200); msg.draw(); msg.pos = (0, 0)
        win.flip()

        down = mouse.getPressed()[0]
        if down and not was_down:  # rising edge
            if blank_btn.contains(mouse):
                if len(recalled) < set_size: recalled.append(None)
            elif clear_btn.contains(mouse):
                recalled = []
            elif exit_btn.contains(mouse):
                break
            else:
                for i, b in enumerate(num_buttons):
                    if b.contains(mouse) and len(recalled) < set_size:
                        recalled.append(int(nums[i]))
                        break
            # wait for release
            while mouse.getPressed()[0]:
                core.wait(0.005)
            core.wait(0.05)
        was_down = down

        if len(recalled) >= set_size:
            break
        core.wait(0.005)

    while len(recalled) < set_size:
        recalled.append(None)
    return recalled

def feedback_screen(n_correct_in_set, set_size):
    t_end = core.getTime() + FEEDBACK_DURATION
    while core.getTime() < t_end:
        check_escape()
        msg.text = f"You recalled {n_correct_in_set} / {set_size} positions correctly."
        msg.draw(); win.flip()

# =========================
# Scoring helpers
# =========================
def compute_span_scores(perfect_by_size):
    """perfect_by_size: dict {size: [True/False,…] (length=SETS_PER_SIZE)}"""
    span_2of3 = 0
    for s in sorted(perfect_by_size.keys()):
        if sum(1 for v in perfect_by_size[s] if v) >= 2:
            span_2of3 = s
        else:
            break
    span_partial = float(span_2of3)
    next_size = span_2of3 + 1
    if next_size in perfect_by_size and sum(1 for v in perfect_by_size[next_size] if v) >= 1:
        span_partial += 0.3
    return span_2of3, span_partial

# =========================
# CSV writers
# =========================
def log_set_rows(trial_rows, block_type, set_index, set_size,
                 green_counts, yellow_counts, card_rts, recalled):
    for i in range(set_size):
        trial_rows.append(dict(
            block_type=block_type,            # 'practice' or 'cspan_main'
            set_index=set_index,
            set_size=set_size,
            trial_in_set=i+1,
            stim_green=green_counts[i],
            stim_yellow=yellow_counts[i],
            rt_display=card_rts[i],
            recall_value=("" if recalled[i] is None else recalled[i]),
            recall_index=i+1
        ))

def save_triallog(trial_rows, path):
    fieldnames = [
        'block_type','set_index','set_size','trial_in_set',
        'stim_green','stim_yellow','rt_display','recall_value','recall_index'
    ]
    with open(path, 'w', newline='', encoding='utf-8-sig') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in trial_rows:
            writer.writerow(row)

def save_summary(summary_dict, path):
    fieldnames = list(summary_dict.keys())
    with open(path, 'w', newline='', encoding='utf-8-sig') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader(); writer.writerow(summary_dict)

# =========================
# Instructions
# =========================
show_instructions(
    "Welcome!\n\nYou'll count the GREEN dots on each display (ignore YELLOW). "
    "After a short sequence of displays, you'll recall the number of GREEN dots from each display in order.\n\n"
    "Work steadily and accurately. Click to continue."
)

show_instructions(
    "Practice\n\nWe'll start with two short practice sets so you can get used to the procedure and the recall screen."
)

# =========================
# Practice block (two sets)
# =========================
trial_rows = []
set_counter = 0
for set_size in PRACTICE_SET_SIZES:
    set_counter += 1
    green_counts = [random.randint(GREEN_RANGE[0], GREEN_RANGE[1]) for _ in range(set_size)]
    yellow_counts = [random.randint(YELLOW_RANGE[0], YELLOW_RANGE[1]) for _ in range(set_size)]
    green_positions_seq, yellow_positions_seq = [], []
    for i in range(set_size):
        n_g = green_counts[i]; n_y = yellow_counts[i]
        pos = generate_dot_positions(n_g + n_y)
        green_positions_seq.append(pos[:n_g]); yellow_positions_seq.append(pos[n_g:])

    # collect RTs per card
    card_rts = []
    for i in range(set_size):
        rt = show_card_and_wait(green_positions_seq[i], yellow_positions_seq[i])
        card_rts.append(rt)

    # recall & feedback
    recalled = recall_counts_ui(set_size)
    correct_positions = sum(1 for i in range(set_size) if (recalled[i] is not None and recalled[i] == green_counts[i]))
    feedback_screen(correct_positions, set_size)

    # merged logging
    log_set_rows(trial_rows, 'practice', set_counter, set_size,
                 green_counts, yellow_counts, card_rts, recalled)

# =========================
# CSPAN main (ascending sizes, 3 sets each)
# =========================
show_instructions(
    "Main task\n\nNow for the real sets. You'll complete sets of size 3 through 7, three sets at each size. "
    "Remember to recall the counts in order at the end of each set."
)

set_sizes_order = []
for s in CSPAN_SET_SIZES:
    set_sizes_order.extend([s]*SETS_PER_SIZE)

perfect_by_size = {s: [] for s in CSPAN_SET_SIZES}

for set_size in set_sizes_order:
    set_counter += 1
    green_counts = [random.randint(GREEN_RANGE[0], GREEN_RANGE[1]) for _ in range(set_size)]
    yellow_counts = [random.randint(YELLOW_RANGE[0], YELLOW_RANGE[1]) for _ in range(set_size)]
    green_positions_seq, yellow_positions_seq = [], []
    for i in range(set_size):
        n_g = green_counts[i]; n_y = yellow_counts[i]
        pos = generate_dot_positions(n_g + n_y)
        green_positions_seq.append(pos[:n_g]); yellow_positions_seq.append(pos[n_g:])

    # collect RTs
    card_rts = []
    for i in range(set_size):
        rt = show_card_and_wait(green_positions_seq[i], yellow_positions_seq[i])
        card_rts.append(rt)

    # recall & score
    recalled = recall_counts_ui(set_size)
    perfect = True
    correct_positions = 0
    for i in range(set_size):
        rv = recalled[i]; tv = green_counts[i]
        if rv is None or rv != tv: perfect = False
        if rv == tv: correct_positions += 1
    perfect_by_size[set_size].append(perfect)
    feedback_screen(correct_positions, set_size)

    # merged logging
    log_set_rows(trial_rows, 'cspan_main', set_counter, set_size,
                 green_counts, yellow_counts, card_rts, recalled)

# =========================
# Compute span scores & save CSVs
# =========================
span_2of3, span_partial = compute_span_scores(perfect_by_size)
abs_perfect_sum = sum(s for s in CSPAN_SET_SIZES for ok in perfect_by_size[s] if ok)

summary = dict(
    span_max_2of3=span_2of3,
    span_partial=span_partial,
    abs_perfect_sum=abs_perfect_sum
)

save_triallog(trial_rows, triallog_path)
save_summary(summary, summary_path)

# Final message
show_instructions(
    f"Task complete.\n\nCounting Span (max ≥2/3): {span_2of3}\n"
    f"Partial credit span: {span_partial:.1f}\nAbsolute perfect-sum: {abs_perfect_sum}\n\n"
)

win.close(); core.quit()
