In [4]:
import os
import re
import pandas as pd
import numpy as np
import random
from math import ceil
import copy
# import cv2
import glob
import shutil
import json
from itertools import count
from collections import OrderedDict

import constants as consts

import warnings
warnings.filterwarnings('ignore')

In [5]:
class MakeFiles:
    """
    This class makes run and target files
        Args:
            block_names (list of str): options are 'visual_search', 'n_back', 'social_prediction', 'semantic_prediction', 'action_observation'
            run_name_prefix (str): prefix of run name
            tile_run (int): determines number of block repeats within a run
            instruct_dur (int): length of instruct for block_names (sec)
            block_dur_secs (int): length of block_name (sec)
            rest_dur_secs (int): length of rest (sec), 0 if no rest
            num_runs (int): number of runs
            counterbalance_runs (bool): counterbalance block order across runs
    """

    def __init__(self, **kwargs):
        f = open(file=os.path.join(Defaults.CONFIG_DIR, 'run_config.json'))
        config = json.load(f)
        self.config = copy.deepcopy(config)
        self.config.update(**kwargs)
    
    def _create_run_dataframe(self, target_files):
        for iter, target_file in enumerate(target_files):
            
            # load target file
            dataframe = pd.read_csv(target_file)

            start_time = dataframe.iloc[0]['start_time'] + self.cum_time 
            end_time = dataframe.iloc[-1]['start_time'] + dataframe.iloc[-1]['trial_dur'] + self.config['instruct_dur'] + self.cum_time

            target_file_name = Path(target_file).name
            num_sec = re.findall(r'\d+(?=sec)', target_file)[0]
            target_num = re.findall(r'\d+(?=.csv)', target_file)[0]
            num_trials = len(dataframe)

            data = {'block_name': self.block_name, 'block_iter': iter+1, 'block_num': self.block_num+1, # 'block_iter': iter+1
                    'num_trials': num_trials, 'target_num': target_num, 'num_sec': num_sec,
                    'target_file': target_file_name, 'start_time': start_time, 'end_time': end_time,
                    'instruct_dur': self.config['instruct_dur'], 'display_trial_feedback': self.display_trial_feedback,
                    'replace_stimuli': self.replace_stimuli, 'feedback_type': self.feedback_type, 'target_score': self.target_score}

            self.all_data.append(data)
            self.cum_time = end_time
    
    def _save_run_file(self, dataframe, run_name):
        # save out to file
        dataframe.to_csv(os.path.join(Defaults.RUN_DIR, run_name), index=False, header=True)
    
    def _add_rest(self):
        run_name_prefix = self.config['run_name_prefix']
        run_files = sorted(glob.glob(os.path.join(Defaults.RUN_DIR, f'*{run_name_prefix}*.csv')))

        # make target file
        BlockClass = TASK_MAP['rest']
        config = self._load_config(fpath=os.path.join(Defaults.CONFIG_DIR, f'rest_config.json'))
        block = BlockClass(target_config=config)
        self.target_name = block.make_targetfile()

        for run_file in run_files:
            dataframe = pd.read_csv(run_file)

            dataframe = self._add_rest_rows(dataframe)

            dataframe.to_csv(run_file, index = False, header = True)
    
    def _counterbalance_runs(self):
        while self._test_counterbalance() > 0:
            print('not balanced ...')
            self._create_run()
        
        print('these runs are perfectly balanced')
    
    def _check_task_run(self):
        # check if task exists in dict
        exists_in_dict = [True for key in self.target_dict.keys() if self.block_name==key]
        if not exists_in_dict: 
            self.target_dict.update({self.block_name: self.fpaths})

        # create run dataframe
        random.seed(self.block_num+1)
        target_files_sample = [self.target_dict[self.block_name].pop(random.randrange(len(self.target_dict[self.block_name]))) for _ in np.arange(self.config['tile_run'])]

        return target_files_sample
   
    def _insert_row(self, row_number, dataframe, row_value): 
        # Slice the upper half of the dataframe 
        df1 = dataframe[0:row_number] 
    
        # Store the result of lower half of the dataframe 
        df2 = dataframe[row_number:] 
    
        # Insert the row in the upper half dataframe 
        df1.loc[row_number]=row_value 
    
        # Concat the two dataframes 
        df_result = pd.concat([df1, df2]) 
    
        # Reassign the index labels 
        df_result.index = [*range(df_result.shape[0])] 
    
        # Return the updated dataframe 
        return df_result 
    
    def _correct_start_end_times(self, dataframe):

        timestamps = (np.cumsum(dataframe['num_sec'] + dataframe['instruct_dur'])).to_list()

        dataframe['end_time'] = timestamps

        timestamps.insert(0, 0) 
        dataframe['start_time'] = timestamps[:-1]

        return dataframe 
    
    def _add_rest_rows(self, dataframe):
        self.num_rest = (len(self.config['block_names']) * self.config['tile_run']) - 1
        
        trials_before_rest = np.tile(np.round((len(dataframe) + self.num_rest) /(self.num_rest)), self.num_rest)
        rest = np.cumsum(trials_before_rest).astype(int) - 1

        # row values
        row_dict = {'block_name': 'rest', 'block_iter': np.float('NaN') , 'block_num': len(self.config['block_names']) + 1,
                     'num_trials': 1, 'target_num': np.float('NaN') , 'num_sec': self.config['rest_dur_secs'],
                    'target_file': self.target_name, 'start_time': np.float('NaN') , 'end_time': np.float('NaN') ,
                    'instruct_dur': 0, 'display_trial_feedback': np.float('NaN') ,
                    'replace_stimuli': np.float('NaN') , 'feedback_type': np.float('NaN') , 'target_score': np.float('NaN') }

        rest_blocks = np.arange(1, self.num_rest+1)

        # Let's create a row which we want to insert 
        for idx, row_number in enumerate(rest):
            # row_value = np.tile('rest', len(dataframe.columns))
            row_dict.update({'block_iter': rest_blocks[idx]})
            row_value = list(row_dict.values())
            if row_number > dataframe.index.max()+1: 
                print("Invalid row_number") 
            else: 
                dataframe = self._insert_row(row_number, dataframe, row_value)

        # update start and end times
        dataframe = self._correct_start_end_times(dataframe)

        return dataframe
    
    def _load_config(self, fpath):
        """ loads JSON file as dict
            Args: 
                fpath (str): full path to .json file
            Returns
                loads JSON as dict
        """
        f = open(fpath) 
    
        # returns JSON object as a dict 
        return json.load(f) 
    
    def _save_target_files(self, df_target):
        """ saves out target files
            Args:
                df_target (pandas dataframe)
            Returns:
                modified pandas dataframes `df_target`
        """
        # # shuffle and set a seed (to ensure reproducibility)
        # df_target = df_target.sample(n=len(df_target), random_state=self.random_state, replace=False).reset_index(drop=True)

        start_time = np.round(np.arange(0, self.num_trials*(self.config['trial_dur']+self.config['iti_dur']), self.config['trial_dur']+self.config['iti_dur']), 1)
        data = {"trial_dur":self.config['trial_dur'], "iti_dur":self.config['iti_dur'], "start_time":start_time, "hand": self.config['hand']}

        df_target = pd.concat([df_target, pd.DataFrame.from_records(data)], axis=1, ignore_index=False, sort=False)

        # get targetfile name
        tf_name = f"{self.config['block_name']}_{self.config['block_dur_secs']}sec" # was {self.num_trials}trials
        tf_name = self._get_target_file_name(tf_name)

        # save out dataframe to a csv file in the target directory (TARGET_DIR)
        df_target.to_csv(os.path.join(self.TARGET_DIR, tf_name), index=False, header=True)

        print(f'saving out {tf_name}')
    
    def _get_target_file_name(self, targetfile_name):
        # figure out naming convention for target files
        target_num = []

        if not os.path.exists(self.TARGET_DIR):
            os.makedirs(self.TARGET_DIR)
            
        for f in os.listdir(self.TARGET_DIR):
            if re.search(targetfile_name, f):
                regex = r"_(\d+).csv"
                target_num.append(int(re.findall(regex, f)[0]))
                
        if target_num==[]:
            outfile_name = f"{targetfile_name}_01.csv" # first target file
        else:
            num = np.max(target_num)+1
            outfile_name = f"{targetfile_name}_{num:02d}.csv" # second or more
        
        return outfile_name
    
    def _sample_evenly_from_col(self, dataframe, num_stim, column='condition_name', **kwargs):
        if kwargs.get('random_state'):
            random_state = kwargs['random_state']
        else:
            random_state = 2

        if kwargs.get('replace'):
            replace = kwargs['replace']
        else:
            replace = False
        num_values = len(dataframe[column].unique())
        group_size = int(np.ceil(num_stim / num_values))
        group_data = dataframe.groupby(column).apply(lambda x: x.sample(group_size, random_state=random_state, replace=replace))
        group_data = group_data.sample(num_stim, random_state=random_state, replace=replace).reset_index(drop=True).sort_values(column)
        return group_data.reset_index(drop=True)
    
    def _correct_block_iter(self, dataframe):
        dataframe['block_iter'] = dataframe.groupby('block_name').cumcount() + 1 

        return dataframe
    
    def _create_run(self):
        # delete any run files that exist in the folder
        # files = glob.glob(os.path.join(Defaults.RUN_DIR, '*run*.csv'))
        # for f in files:
        #     os.remove(f)

        # create run files
        self.target_dict = {}
        for run in np.arange(self.config['num_runs']):
            self.cum_time = 0.0
            self.all_data = []

            for self.block_num, self.block_name in enumerate(self.config['block_names']):

                # get target files for `block_name`
                self.TARGET_DIR = os.path.join(Defaults.TARGET_DIR, self.block_name)
                self.fpaths = sorted(glob.glob(os.path.join(self.TARGET_DIR, f'*{self.block_name}*.csv')))

                # sample tasks
                target_files_sample = self._check_task_run()

                # get tf info
                df = pd.read_csv(os.path.join(self.TARGET_DIR, target_files_sample[0]))
                self.display_trial_feedback = np.unique(df['display_trial_feedback'])[0]
                self.replace_stimuli = np.unique(df['replace_stimuli'])[0]
                self.feedback_type = np.unique(df['feedback_type'])[0]
                self.target_score = np.unique(df['target_score'])[0]

                # create run dataframe
                self._create_run_dataframe(target_files=target_files_sample)

            # shuffle order of tasks within run
            df_run = pd.DataFrame.from_dict(self.all_data)
            df_run = df_run.sample(n=len(df_run), replace=False)

            # correct `block_iter`, `start_time`, `run_time`
            df_run = self._correct_block_iter(dataframe=df_run)
            df_run['start_time'] = sorted(df_run['start_time']) 
            df_run['end_time'] = sorted(df_run['end_time']) 

            # save run file
            run_name = self.config['run_name_prefix'] + '_' +  f'{run+1:02d}' + '.csv'
            self._save_run_file(dataframe=df_run, run_name=run_name)
            # print(f'saving out {run_name}')
    
    def _test_counterbalance(self):
        filenames = sorted(glob.glob(os.path.join(Defaults.RUN_DIR, '*run_*')))

        dataframe_all = pd.DataFrame()
        for i, file in enumerate(filenames):
            dataframe = pd.read_csv(file)
            dataframe['run'] = i + 1
            dataframe['block_num_unique'] = np.arange(len(dataframe)) + 1
            dataframe_all = pd.concat([dataframe_all, dataframe])

        # create new column
        dataframe_all['block_name_unique'] = dataframe_all['block_name'] + '_' + dataframe_all['block_iter'].astype(str)

        task = np.array(list(map({}.setdefault, dataframe_all['block_name_unique'], count()))) + 1
        last_task = list(task[0:-1])
        last_task.insert(0,0)
        last_task = np.array(last_task)
        last_task[dataframe_all['block_num_unique']==1] = 0

        dataframe_all['last_task'] = last_task
        dataframe_all['task'] = task
        dataframe_all['task_num'] = task

        # get pivot table
        f = pd.pivot_table(dataframe_all, index=['task'], columns=['last_task'], values=['task_num'], aggfunc=len)

        return sum([sum(f['task_num'][col]>5) for col in f['task_num'].columns]) 
    
    def check_videos(self):
        for block_name in ['action_observation', 'social_prediction']:

            TARGET_DIR = os.path.join(Defaults.TARGET_DIR, block_name)

            files = glob.glob(os.path.join(Defaults.TARGET_DIR, block_name, '*.csv'))

            # loop over files
            video_count = []
            for file in files:
                dataframe = pd.read_csv(file)

                # loop over videos and check that they exist
                videos = dataframe['stim']
                for stim in videos:
                    video_fpath = os.path.join(Defaults.STIM_DIR, block_name, "modified_clips", stim)
                    if not os.path.exists(video_fpath):
                        video_count.append(stim)
                        print(f'{block_name}: {stim} is missing from videos')

            if not video_count:
                print(f'there are no videos missing for {block_name}')
    
    def make_targetfiles(self, **kwargs):
        for self.block_name in self.config['block_names']:

            TARGET_DIR = os.path.join(Defaults.TARGET_DIR, self.block_name)

            # delete any target files that exist in the folder
            files = glob.glob(os.path.join(Defaults.TARGET_DIR, self.block_name, '*.csv'))
            for f in files:
                os.remove(f)

            # make target files
            BlockClass = TASK_MAP[self.block_name]
            config = self._load_config(fpath=os.path.join(Defaults.CONFIG_DIR, f'{self.block_name}_config.json'))
            block = BlockClass(target_config=config, **kwargs)
            block.make_targetfile()
    
    def make_runfiles(self, **kwargs):
        
        # make run files
        self._create_run()

        # OPTION TO COUNTERBALANCE RUNS
        if self.config['counterbalance_runs']:
            self._counterbalance_runs()

        # OPTION TO ADD REST
        if self.config['rest_dur_secs']>0:
            self._add_rest()

        # check if videos for action observation and social prediction exist
        self.check_videos()
    
    def make_all(self, **kwargs):
        # create target files
        self.make_targetfiles(**kwargs)

        # create run files
        self.make_runfiles(**kwargs)

In [None]:
class VisualSearch(MakeFiles):
    """
        This class makes target files for Visual Search using parameters from config file
        Args:
            target_config (dict): dictionary loaded from `visual_search_config.json`
        Kwargs:
            block_name (str): 'visual_search'
            orientations (int): orientations of target/distractor stims
            balance_blocks (dict): keys are 'condition_name', 'trial_type'
            trial_dur (int): length of trial (sec)
            iti_dur (iti): length of iti (sec)
            instruct_dur (int): length of instruct for block_names (sec)
            hand (str): response hand
            replace (bool): sample stim with or without replacement
            display_trial_feedback (bool): display trial-by-trial feedback
    """
    
    def __init__(self, target_config, **kwargs):
        super().__init__()
        self.config.update(target_config)
        self.config.update(**kwargs)
    
    def _get_block_info(self):
        # num of blocks (i.e. target files) to make
        self.num_blocks = self.config['num_runs'] * self.config['tile_run']

        # get overall number of trials
        self.num_trials = int(self.config['block_dur_secs'] / (self.config['trial_dur'] + self.config['iti_dur']))  

        # get `num_stims` - lowest denominator across `balance_blocks`
        denominator = np.prod([len(stim) for stim in [*self.config['balance_blocks'].values()]])
        self.num_stims = ceil(self.num_trials / denominator) # round up to nearest int
    
    def _create_columns(self):

        def _get_condition(x):
            for key in self.config['balance_blocks']['condition_name'].keys():
                cond = self.config['balance_blocks']['condition_name'][key]
                if x==cond:
                    value = key
            return value

        dataframe = pd.DataFrame()
        # make `condition_name` column
        conds = [self.config['balance_blocks']['condition_name'][key] for key in self.config['balance_blocks']['condition_name'].keys()]
        dataframe['stim'] = self.num_trials*conds
        dataframe['condition_name'] = dataframe['stim'].apply(lambda x: _get_condition(x))
        dataframe['stim'] = dataframe['stim'].astype(int)

        # make `trial_type` column
        dataframe['trial_type'] = self.num_trials*self.config['balance_blocks']['trial_type']
        dataframe['trial_type'] = dataframe['trial_type'].sort_values().reset_index(drop=True)

        dataframe['display_trial_feedback'] = self.config['display_trial_feedback']
        dataframe['replace_stimuli'] = self.config['replace']
        dataframe['feedback_type'] = self.config['feedback_type']
        dataframe['target_score'] = self.config['target_score']

        return dataframe

    def _balance_design(self, dataframe):
        
        # this assumes that there is a `condition_name` key in all tasks (which there should be)
        # dataframe = dataframe.groupby([*self.config['balance_blocks']], as_index=False).apply(lambda x: self._sample_evenly_from_col(x, num_stim=self.num_stims, column='condition_name', random_state=self.random_state, replace=self.config['replace'])).reset_index(drop=True)

        dataframe =  dataframe.groupby([*self.config['balance_blocks']], as_index=False).apply(lambda x: x.sample(n=self.num_stims, random_state=self.random_state, replace=self.config['replace'])).reset_index(drop=True)
        
        # ensure that only `num_trials` are sampled
        num_stims = int(self.num_trials / len(self.config['balance_blocks']['condition_name']))
        dataframe = dataframe.groupby('condition_name', as_index=False).apply(lambda x: x.sample(n=num_stims, random_state=self.random_state, replace=False)).reset_index(drop=True)
        
        # shuffle the order of the trials
        dataframe = dataframe.apply(lambda x: x.sample(n=self.num_trials, random_state=self.random_state, replace=False)).reset_index(drop=True)

        # ensure that only `num_trials` are sampled
        return dataframe

    def _save_visual_display(self, dataframe):
        # add visual display cols
        display_pos, orientations_correct = zip(*[self._make_search_display(cond, self.config['orientations'], trial_type) for (cond, trial_type) in zip(dataframe["stim"], dataframe["trial_type"])])

        data_dicts = []
        for trial_idx, trial_conditions in enumerate(display_pos):
            for condition, point in trial_conditions.items():
                data_dicts.append({'trial': trial_idx, 'stim': condition, 'xpos': point[0], 'ypos': point[1], 'orientation': orientations_correct[trial_idx][condition]})  
        
        # save out to dataframe
        df_display = pd.DataFrame.from_records(data_dicts)

        # save out visual display
        visual_display_name = self._get_visual_display_name()
        df_display.to_csv(os.path.join(self.TARGET_DIR, visual_display_name))

    def _get_visual_display_name(self):
        block_name = self.config['block_name']
        block_dur_secs = self.config['block_dur_secs']
        tf_name = f"{block_name}_{block_dur_secs}sec"
        tf_name = self._get_target_file_name(tf_name)

        str_part = tf_name.partition(self.config['block_name'])
        visual_display_name = 'display_pos' + str_part[2] 

        return visual_display_name
    
    def _make_search_display(self, display_size, orientations, trial_type):
        # make location and orientations lists (for target and distractor items)

        # STIM POSITIONS
        grid_h_dva = 8.4
        grid_v_dva = 11.7

        n_h_items = 6
        n_v_items = 8

        item_h_pos = np.linspace(-grid_h_dva / 2.0, +grid_h_dva/ 2.0, n_h_items)
        item_v_pos = np.linspace(-grid_v_dva / 2.0, +grid_v_dva / 2.0, n_v_items)

        grid_pos = []
        for curr_h_pos in item_h_pos:
            for curr_v_pos in item_v_pos:
                grid_pos.append([curr_h_pos, curr_v_pos])

        locations = random.sample(grid_pos, display_size)

        ## STIM ORIENTATIONS
        orientations_list = orientations*int(display_size/4)
        
        # if trial type is false - randomly replace target stim (90)
        # with a distractor
        if not trial_type:
            orientations_list = [random.sample(orientations[1:],1)[0] if x==90 else x for x in orientations_list]
        
        # if trial is true and larger than 4, leave one target stim (90) in list 
        # and randomly replace the others with distractor stims
        if display_size >4 and trial_type:
            indices = [i for i, x in enumerate(orientations_list) if x == 90]
            indices.pop(0)
            new_num = random.sample(orientations[1:],2) # always assumes that orientations_list is as follows: [90,180,270,360]
            for i, n in zip(*(indices, new_num)): 
                orientations_list[i] = n

        return dict(enumerate(locations)), dict(enumerate(orientations_list))
        
    def make_targetfile(self):
        """
        makes target file(s) for action observation
        """
        # get info about block
        self._get_block_info()

        seeds = np.arange(self.num_blocks)+1

        for self.block in np.arange(self.num_blocks):

            # randomly sample so that conditions (2Back- and 2Back+) are equally represented
            self.random_state = seeds[self.block]

            # create the dataframe
            df_target = self._create_columns()

            # balance the dataframe
            df_target = self._balance_design(dataframe=df_target)

            self.TARGET_DIR = os.path.join(Defaults.TARGET_DIR, self.config['block_name'])

            # save visual display dataframe
            self._save_visual_display(dataframe=df_target)

            # save target file
            self._save_target_files(df_target)

In [6]:
class Target():

    def __init__(self, study_name, task_name, hand, trial_dur, iti_dur,
                 run_number, display_trial_feedback = True, task_dur = 30, tr = 1):

        """
        variables and information shared across all tasks
        """
        self.study_name             = study_name             # name of the study: 'fmri' or 'behavioral'
        self.task_name              = task_name              # name of the task
        self.task_dur               = task_dur               # duration of the task (default: 30 sec)
        self.hand                   = hand                   # string representing the hand: "right", "left", or "none"
        self.trial_dur              = trial_dur              # duration of trial
        self.iti_dur                = iti_dur                # duration of the inter trial interval
        self.display_trial_feedback = display_trial_feedback # display feedback after trial (default: True)
        self.tr                     = tr                     # the TR of the scanner
        self.run_number             = run_number             # the number of run
        self.target_dict            = {}                     # a dicttionary that will be saved as target file for the task
         
    def make_trials(self):
        """
        making trials (rows) with columns (variables) shared across tasks
        """
        self.num_trials = int(self.task_dur/(self.trial_dur + self.iti_dur)) # total number of trials
        self.target_dict['start_time'] = [(self.trial_dur + self.iti_dur)*trial_number for trial_number in range(self.num_trials)]
        self.target_dict['end_time']   = [(trial_number+1)*self.trial_dur + trial_number*self.iti_dur for trial_number in range(self.num_trials)]
        self.target_dict['hand']       = np.tile(self.hand, self.num_trials).T.flatten() 
        self.target_dict['trial_dur']  = [self.trial_dur for trial_number in range(self.num_trials)]
        self.target_dict['iti_dur']    = [self.iti_dur for trial_number in range(self.num_trials)]

    def balance_design(self, dataframe):
        """
        balance task design
        """
        pass
    
    def save_target_file(self):
        """
        save the target file in the corresponding directory
        """

        self.df = pd.DataFrame(self.target_dict)
        # path to save the target files
        path2task_target = consts.target_dir / self.study_name / self.task_name
        consts.dircheck(path2task_target)

        target_filename = path2task_target / f"{self.task_name}_{self.task_dur}sec_{self.run_number+1:02d}.csv"
        self.df.to_csv(target_filename)

In [None]:
class VisualSearch(Target):
    def __init__(self, study_name = 'behavioral', hand = 'right', 
                 trial_dur = 2, iti_dur = 0.5, run_number = 1, display_trial_feedback = True, 
                 task_dur = 30, tr = 1, tile_block = 1, block_dur_secs = 15, 
                 num_blocks = 5, replace = False):

        super(VisualSearch, self).__init__(study_name = study_name, task_name = 'visual_search', hand = hand, 
                                           trial_dur = trial_dur, iti_dur = iti_dur, run_number = run_number, 
                                           display_trial_feedback = display_trial_feedback, task_dur = task_dur,
                                           tr = tr)

        self.block_dur_secs = block_dur_secs
        self.num_blocks     = num_blocks
        self.tile_block     = tile_block
        # self.instruct_dur = 5
        # self.replace = False
        self.orientations   = list([90, 180, 270, 360])
        self.trials_info = {'condition_name': {'easy': '4', 'hard': '8'}, 'trial_type': [True, False]}

        self.num_blocks = self.num_blocks * self.tile_block
        
    def _add_task_info(self):
        super.make_trials() # first fill in the common fields

        # get overall number of trials
        self.num_trials = int(self.block_dur_secs / (self.trial_dur + self.iti_dur))  

        # get `num_stims` - lowest denominator across `balance_blocks`
        denominator    = np.prod([len(stim) for stim in [*self.trials_info.values()]])
        self.num_stims = np.ceil(self.num_trials / denominator) # round up to nearest int

        conds = [self.trials_info['condition_name'][key] for key in self.trials_info['condition_name'].keys()]

        self.target_dict['stim']    = (self.self.num_trials*conds).astype(int)
        self.target_dict['condition_name'] = self.target_dict['stim'].apply(lambda x: _get_condition(x)) # get condifion??????

        # make `trial_type` column
        self.target_dict['trial_type'] = self.num_trials*self.balance_blocks['trial_type']
        self.target_dict['trial_type'] = self.target_dict['trial_type'].sort_values().reset_index(drop=True)


        seeds = np.arange(self.num_blocks)+1

        for self.block in np.arange(self.num_blocks):
            # randomly sample so that conditions are equally represented
            self.random_state = seeds[self.block]
       
    def _save_visual_display(self, dataframe):
        # add visual display cols
        display_pos, orientations_correct = zip(*[self._make_search_display(cond, self.orientations, trial_type) for (cond, trial_type) in zip(dataframe["stim"], dataframe["trial_type"])])

        data_dicts = []
        for trial_idx, trial_conditions in enumerate(display_pos):
            for condition, point in trial_conditions.items():
                data_dicts.append({'trial': trial_idx, 'stim': condition, 'xpos': point[0], 'ypos': point[1], 'orientation': orientations_correct[trial_idx][condition]})  
        
        # save out to dataframe
        df_display = pd.DataFrame.from_records(data_dicts)

        # save out visual display
        visual_display_name = self._get_visual_display_name()
        df_display.to_csv(os.path.join(self.target_dir, visual_display_name))

    def _make_search_display(self, display_size, orientations, trial_type):
        # make location and orientations lists (for target and distractor items)

        # STIM POSITIONS
        grid_h_dva = 8.4
        grid_v_dva = 11.7

        n_h_items = 6
        n_v_items = 8

        item_h_pos = np.linspace(-grid_h_dva / 2.0, +grid_h_dva/ 2.0, n_h_items)
        item_v_pos = np.linspace(-grid_v_dva / 2.0, +grid_v_dva / 2.0, n_v_items)

        grid_pos = []
        for curr_h_pos in item_h_pos:
            for curr_v_pos in item_v_pos:
                grid_pos.append([curr_h_pos, curr_v_pos])

        locations = np.random.sample(grid_pos, display_size)

        ## STIM ORIENTATIONS
        orientations_list = orientations*int(display_size/4)
        
        # if trial type is false - randomly replace target stim (90)
        # with a distractor
        if not trial_type:
            orientations_list = [np.random.sample(orientations[1:],1)[0] if x==90 else x for x in orientations_list]
        
        # if trial is true and larger than 4, leave one target stim (90) in list 
        # and randomly replace the others with distractor stims
        if display_size >4 and trial_type:
            indices = [i for i, x in enumerate(orientations_list) if x == 90]
            indices.pop(0)
            new_num = np.random.sample(orientations[1:],2) # always assumes that orientations_list is as follows: [90,180,270,360]
            for i, n in zip(*(indices, new_num)): 
                orientations_list[i] = n

        return dict(enumerate(locations)), dict(enumerate(orientations_list))

    def _balance_design(self, dataframe):
        
        # this assumes that there is a `condition_name` key in all tasks (which there should be)
        # dataframe = dataframe.groupby([*self.config['balance_blocks']], as_index=False).apply(lambda x: self._sample_evenly_from_col(x, num_stim=self.num_stims, column='condition_name', random_state=self.random_state, replace=self.config['replace'])).reset_index(drop=True)
        self.df = pd.DataFrame(self.target_dict)
        self.df =  self.df.groupby([*self.config['balance_blocks']], as_index=False).apply(lambda x: x.sample(n=self.num_stims, random_state=self.random_state, replace=self.config['replace'])).reset_index(drop=True)
        
        # ensure that only `num_trials` are sampled
        num_stims = int(self.num_trials / len(self.config['balance_blocks']['condition_name']))
        self.df = self.df.groupby('condition_name', as_index=False).apply(lambda x: x.sample(n=num_stims, random_state=self.random_state, replace=False)).reset_index(drop=True)
        
        # shuffle the order of the trials
        self.df = self.df.apply(lambda x: x.sample(n=self.num_trials, random_state=self.random_state, replace=False)).reset_index(drop=True)

        # ensure that only `num_trials` are sampled
        # return dataframe