In [237]:
import os
import shutil
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
import imageio
import pandas as pd
import glob

device = "cuda"

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

In [238]:
def get_images(prefix, folder_path="stimuli_labeled/"):
    search_pattern = os.path.join(folder_path, prefix)
    image_files = glob.glob(search_pattern)
    return image_files

In [239]:
all_images = [
    get_images("album*.jpg"),
    get_images("album*.jpg"),
    get_images("extra*.jpg"),
    get_images("extra*.jpg"),
    get_images("coco_shared*.jpg"), 
    get_images("coco_shared*.jpg"),
    get_images("coco_nonshared*.jpg")[:-30], # need to remove 20 of these to fit design
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
    get_images("target*.jpg"),
]

In [240]:
# flatten
all_images_shuffled = [item for sublist in all_images for item in sublist]
print(len(all_images_shuffled))

1485


In [241]:
num_sessions = 3
num_runs_total = 15 * num_sessions
trials_per_run = 36
total_trials = (trials_per_run * num_runs_total)
print("num_runs_total", num_runs_total)
print("total_trials", total_trials)
print("total_trials_no_blanks", total_trials - (3 * num_runs_total))

assert len(all_images_shuffled) == total_trials - (3 * num_runs_total)

def blank_generator(seed):
    random.seed(seed)
    random_numbers = [
        random.randint(8, 13),
        random.randint(18, 23),
        trials_per_run-1]
    return random_numbers

run_to_blanks = [blank_generator(r) for r in range(num_runs_total)]
# run_to_blanks

num_runs_total 45
total_trials 1620
total_trials_no_blanks 1485


In [242]:
p_id=0
random.seed(p_id)

# shuffle image order
random.seed(p_id)
random.shuffle(all_images_shuffled)

participant_path = "conditions_files/participant" + str(p_id)
os.makedirs("conditions_files",exist_ok=True)

current_image_list = []
is_repeat_list = []
run_num_list = []
is_new_run_list = []
is_blank_trial_list = []
trial_index_list = []
image_index = 0
all_blanks_list_list = []
previous_image_list = []
for run_num in range(num_runs_total):
    blank_trial_indices = run_to_blanks[run_num]
    for trial_index in range(trials_per_run):
        run_num_list.append(run_num)
        all_blanks_list_list.append(blank_trial_indices)
        trial_index_list.append(trial_index)
        
        if trial_index == 0:
            previous_image_list.append("blank")
        else:
            previous_image_list.append(previous_image)
            
        if trial_index == (trials_per_run - 1) and run_num != (num_runs_total - 1):
            is_new_run_list.append(1)
        else:
            is_new_run_list.append(0)

        if trial_index in blank_trial_indices:
            current_image_list.append("blank.jpg")
            previous_image = "blank"
            is_blank_trial_list.append(1)
            is_repeat_list.append(0)
        else:
            image_path = all_images_shuffled[image_index]
            if "images/" + image_path in current_image_list:
                is_repeat_list.append(1)
            else:
                is_repeat_list.append(0)
            current_image_list.append("images/" + image_path)
            is_blank_trial_list.append(0)
            previous_image = image_path.split("__")[-1].split(".jpg")[0].replace("_","\n")

            image_index += 1
            
# output study and test
output_dict = {"current_image": current_image_list,
               "is_repeat": is_repeat_list,
               "trial_index": trial_index_list,
               "is_blank_trial":is_blank_trial_list,
               "is_new_run": is_new_run_list,
               "run_num": run_num_list,
               "previous_image": previous_image_list,
               "all_blanks_list": all_blanks_list_list}
output_df = pd.DataFrame(output_dict)
study_test_file_path = participant_path + ".csv"
output_df.to_csv(study_test_file_path, index = False)