In [1]:
from PIL import Image
import math
import random
from pathlib import Path
import os
import matplotlib.pyplot as plt
import numpy as np
from retina_transform_v2 import foveat_img
from tqdm import tqdm

CATEGORIES = ['cat','dog', 'bird', 'turtle', 'insect']
 
SCREEN_SIZE = (1680, 1050)
BACKGROUND_COLOR= 128
DISTANCE_FROM_CENTER =369 #  devided by 30 pixels = 12degree
N_POSITIONS=4

DIR_SOURCE = Path('../data/imagenet_val_segmented_resized')
foveated_object_save_path = Path(str(DIR_SOURCE)+'_foveated')

DIR_SAVE = Path('../exp-source')

target2index = {'dog':0, 'cat':1, 'frog':2, 'turtle':3, 'bird':4,
                'primate':5, 'fish':6, 'crab':7, 'insect':8}


In [2]:

def generate_search_display(objectpathlist):
    
    # create base array
    # base_array = Image.new("RGB", (min(SCREEN_SIZE), min(SCREEN_SIZE)), (BACKGROUND_COLOR, BACKGROUND_COLOR, BACKGROUND_COLOR))
    base_array = np.full((min(SCREEN_SIZE), min(SCREEN_SIZE), 3), BACKGROUND_COLOR, dtype=np.uint8)
    base_center = int(base_array.shape[1] / 2)

    random_float = random.random() * 0.25 # random float between 0 and 0.25 to add some randomness to the start positions
    bboxlist = []

    for n, objectpath in enumerate(objectpathlist):
        # load object image
        img_object = Image.open(objectpath)
        h_object, w_object = img_object.size

        # assign object locations, start at 12 o'clock and move counter-clockwise.
        # object with position 0 will be the first one.
        t = math.pi + n * (2 * math.pi / N_POSITIONS) + random_float * 2 * math.pi
        pos = (base_center + DISTANCE_FROM_CENTER * math.cos(t), base_center + DISTANCE_FROM_CENTER * math.sin(t))

        # get object bbox in 'xyxy format'
        x_center, y_center = int(pos[1]), int(pos[0])
        assert y_center >= (h_object / 2) and x_center >= (w_object / 2)
        xmin, xmax = x_center - int(w_object / 2), x_center + int(h_object / 2)
        ymin, ymax = y_center - int(h_object / 2), y_center + int(h_object / 2)

        # insert object to base array
        base_array[ymin:ymax, xmin:xmax] = img_object
        bboxlist.append([xmin, ymin, xmax, ymax])

    # # show base array
    # plt.imshow(base_array)
    # plt.show()
    
    return base_array, bboxlist

In [3]:

turtle_TA_targets = [['cowboy_hat', 'n03124170'],
['Dutch_oven', 'n03259280'],
['gong', 'n03447721'],
['honeycomb', 'n03530642'],
['mailbag', 'n03709823'],
['mortar', 'n03786901'],
['broccoli', 'n07714990'],
['hay', 'n07802026'],
['platypus','n01873310'],
['conch', 'n01943899'],
['Arabian_camel', 'n02437312'],
['shield', 'n04192698'],
['shopping_basket', 'n04204238'],
['tank', 'n04389033'],
['stone_wall', 'n04326547'],
['wok', 'n04596742'],
['wooden_spoon', 'n04597913'],
['guacamole', 'n07583066']]

dog_TA_targets = [['chocolate_sauce', 'n07836838'],
['dough', 'n07860988'],
['meat_loaf', 'n07871810'],
['potpie', 'n07875152'],
['burrito', 'n07880968'],
['armadillo', 'n02454379'],
['cliff', 'n09246464'],
['promontory', 'n09399592'],
['gyromitra', 'n13037406'],
['French_loaf', 'n07684084'],
['bagel', 'n07693725'],
['mashed_potato', 'n07711569'],
['swab', 'n04367480'],
['ice_cream', 'n07614500'],
['velvet' , 'n04525038'],]

bird_TA_targets = [['cup', 'n07930864'],    
['coral_reef', 'n09256479'],
['missile', 'n03773504'],
['mitten', 'n03775071'],
['pitcher', 'n03950228'],
['poncho', 'n03980874'],
['warplane', 'n04552348']]


insect_TA_targets = [['hip', 'n12620546'],
['buckeye', 'n12768682'],
['crash_helmet', 'n03127747'],
['knot', 'n03627232'],
['screw', 'n04153751'],
['crayfish', 'n01985128'],
['tripod', 'n04485082'],
['screwdriver', 'n04154565'],
['spindle', 'n04277352'],
['sleeping_bag', 'n04235860'],
['whistle', 'n04579432']]


turtle_TA_targets = [item[1] for item in turtle_TA_targets]
dog_TA_targets = [item[1] for item in dog_TA_targets]
bird_TA_targets = [item[1] for item in bird_TA_targets]
insect_TA_targets = [item[1] for item in insect_TA_targets]

confusing_targets = {
    'turtle': turtle_TA_targets,
    'dog': dog_TA_targets,
    'bird': bird_TA_targets,
    'insect': insect_TA_targets
}


In [7]:

N_DISPLAY_PER_CATEGORY = 250

# create paths to save foveated objects 
for cat in CATEGORIES+['nontarget']:
    (foveated_object_save_path/cat).mkdir(parents=True, exist_ok=False)

# create paths to save search displays
clean_display_path = DIR_SAVE/f'display_clean'
clean_display_path.mkdir(parents=True, exist_ok=False)
foveated_display_path = DIR_SAVE/f'display_foveated'
foveated_display_path.mkdir(parents=True, exist_ok=False)

# generate displays
display_info = []
display_idx = 0
for cond in ['TP', 'TA']:
    print(f'\n==== condition: {cond} ====')


    for target in CATEGORIES:
        print(f"start generating for target {target}")
        targetpathlist = [DIR_SOURCE/target/file for file in os.listdir(DIR_SOURCE/target)]
        nontargetpathlist = [DIR_SOURCE/'nontarget'/file for file in os.listdir(DIR_SOURCE/'nontarget')]
        
        if target == 'cat':
            num_displays = 50
        else:
            num_displays = N_DISPLAY_PER_CATEGORY

        for si in tqdm(range(num_displays)):
            display_filename = f'{cond}_{target}_d{display_idx}'
            
            random_number = random.random()  # Generate a random number between 0 and 1
            if random_number <= 0.2 and target != 'cat':  # 20% of the time
                # sample a target object from target list, and non-target objects from nontarget list
                if cond == 'TP':                
                    selected_targetpathlist = random.sample(targetpathlist, 1)
                    selected_nontargetpathlist = random.sample(nontargetpathlist, N_POSITIONS-1)
                elif cond == 'TA':
                    confusingtargetpathlist = [path for path in nontargetpathlist if str(path).split('/')[-1].split('_')[0] in confusing_targets[target]]
                    selected_targetpathlist = random.sample(confusingtargetpathlist, 1)
                    nontargetminusconfusinglist = list(set(nontargetpathlist) - set(confusingtargetpathlist))
                    selected_nontargetpathlist = random.sample(nontargetminusconfusinglist, N_POSITIONS-1)

            else:
                # sample a target object from target list, and non-target objects from nontarget list
                if cond == 'TP':                
                    selected_targetpathlist = random.sample(targetpathlist, 1)
                    selected_nontargetpathlist = random.sample(nontargetpathlist, N_POSITIONS-1)
                elif cond == 'TA':
                    selected_targetpathlist = []
                    selected_nontargetpathlist = random.sample(nontargetpathlist, N_POSITIONS)

            # full object list and shuffle
            objectpathlist = selected_targetpathlist + selected_nontargetpathlist 
            assert len(objectpathlist) == N_POSITIONS
            # random.shuffle(objectpathlist)
            objectorderlist = list(range(N_POSITIONS))
            random.shuffle(objectorderlist)
            objectpathlist = list(np.array(objectpathlist)[objectorderlist])

            # generate search display
            display, bboxlist = generate_search_display(objectpathlist)
            # display = cv2.cvtColor(display, cv2.COLOR_RGB2BGR)

            # foveate search display
            xc, yc = int(display.shape[1]/2), int(display.shape[0]/2)
            display_foveated = foveat_img(display, [(xc, yc)], k=31, pixels_per_degree=31) # this yields around 997 center pixels are high resolution

            # show output
            # print(bbox)
            # plt.imshow(display)
            # plt.axis('off')
            # plt.show()    
            # plt.imshow(display_foveated)
            # plt.axis('off')
            # plt.show()  

            # crop the areas of display using bbox and save each object image
            foveated_objectpathlist = []
            for idx_bbox, bbox in enumerate(bboxlist):
                xmin, ymin, xmax, ymax = bbox
                object = display_foveated[ymin:ymax, xmin:xmax]
                object = Image.fromarray(object)
                object_filename = f'{display_filename}-b{idx_bbox}-{str(objectpathlist[idx_bbox]).split("/")[-1]}'
                foveated_objectpath = foveated_object_save_path/str(objectpathlist[idx_bbox]).split("/")[-2]/object_filename
                object.save(foveated_objectpath)
                foveated_objectpathlist.append(foveated_objectpath)

            # save search displays
            display = Image.fromarray(display)
            display.save(clean_display_path/f'{display_filename}.png') 
            display_foveated = Image.fromarray(display_foveated)
            display_foveated.save(foveated_display_path/f'{display_filename}_foveated.png') 

            # save display info
            display_info.append({
                'display_idx': display_idx,
                'display_filename': display_filename,
                'condition': cond,
                'target_label': target,
                'target_index': target2index[target],
                
                'clean_objlist': [str(path) for path in objectpathlist],
                'foveated_objlist': [str(path) for path in foveated_objectpathlist],
                'target_idx_in_objlist': objectorderlist.index(0) if cond == 'TP' else None,

                'bboxlist': bboxlist,
            })

            # next index
            display_idx += 1


np.save(DIR_SAVE/'display_info.npy', display_info, allow_pickle=True)
print('display_info saved!')



==== condition: TP ====
start generating for target cat


100%|██████████| 50/50 [00:12<00:00,  4.04it/s]


start generating for target dog


100%|██████████| 250/250 [01:02<00:00,  4.03it/s]


start generating for target bird


100%|██████████| 250/250 [01:01<00:00,  4.04it/s]


start generating for target turtle


100%|██████████| 250/250 [01:02<00:00,  4.02it/s]


start generating for target insect


100%|██████████| 250/250 [01:01<00:00,  4.04it/s]



==== condition: TA ====
start generating for target cat


100%|██████████| 50/50 [00:12<00:00,  4.01it/s]


start generating for target dog


100%|██████████| 250/250 [01:02<00:00,  4.00it/s]


start generating for target bird


100%|██████████| 250/250 [01:02<00:00,  4.00it/s]


start generating for target turtle


100%|██████████| 250/250 [01:02<00:00,  4.00it/s]


start generating for target insect


100%|██████████| 250/250 [01:02<00:00,  4.01it/s]

display_info saved!





# generate display using blurred objects

In [8]:
# load used display info
import numpy as np
display_info = np.load('../exp-source/display_info.npy', allow_pickle=True)    
print(len(display_info))
display_info[0]

2100


{'display_idx': 0,
 'display_filename': 'TP_cat_d0',
 'condition': 'TP',
 'target_label': 'cat',
 'target_index': 1,
 'clean_objlist': ['../data/imagenet_val_segmented_resized/nontarget/n04330267_ILSVRC2012_val_00023064.JPEG',
  '../data/imagenet_val_segmented_resized/nontarget/n04597913_ILSVRC2012_val_00032112.JPEG',
  '../data/imagenet_val_segmented_resized/cat/n02123045_ILSVRC2012_val_00005358.JPEG',
  '../data/imagenet_val_segmented_resized/nontarget/n07753113_ILSVRC2012_val_00011675.JPEG'],
 'foveated_objlist': ['../data/imagenet_val_segmented_resized_foveated/nontarget/TP_cat_d0-b0-n04330267_ILSVRC2012_val_00023064.JPEG',
  '../data/imagenet_val_segmented_resized_foveated/nontarget/TP_cat_d0-b1-n04597913_ILSVRC2012_val_00032112.JPEG',
  '../data/imagenet_val_segmented_resized_foveated/cat/TP_cat_d0-b2-n02123045_ILSVRC2012_val_00005358.JPEG',
  '../data/imagenet_val_segmented_resized_foveated/nontarget/TP_cat_d0-b3-n07753113_ILSVRC2012_val_00011675.JPEG'],
 'target_idx_in_objlist'

In [10]:
if not os.path.exists(DIR_SAVE/'display_blurred'):
    os.makedirs(DIR_SAVE/'display_blurred')

    
# generate blurred object display
for info in tqdm(display_info):
    blurred_objpathlist = []
    base_array = np.full((min(SCREEN_SIZE), min(SCREEN_SIZE), 3), BACKGROUND_COLOR, dtype=np.uint8)
    for n, objpath in enumerate(info['clean_objlist']):
        blurred_objdir = Path('/'.join(objpath.split('/')[:-1]).replace('imagenet_val_segmented_resized', 'imagenet_val_segmented_resized_blurred'))
        blurred_objfilename = 'blurred-' + objpath.split('/')[-1]
        blurred_objpath = blurred_objdir/blurred_objfilename
        blurred_objpathlist.append(blurred_objpath)

        img_object = Image.open(blurred_objpath)
        xmin, ymin, xmax, ymax = info['bboxlist'][n]
        base_array[ymin:ymax, xmin:xmax] = img_object
        
    # add info to display_info
    info['blurred_objlist'] = blurred_objpathlist

    # save blurred object display
    base_array = Image.fromarray(base_array)
    base_array.save(DIR_SAVE/'display_blurred'/f'{info["display_filename"]}_blurred.png') 



100%|██████████| 2100/2100 [01:22<00:00, 25.51it/s]


In [12]:
display_info[0]

{'display_idx': 0,
 'display_filename': 'TP_cat_d0',
 'condition': 'TP',
 'target_label': 'cat',
 'target_index': 1,
 'clean_objlist': ['../data/imagenet_val_segmented_resized/nontarget/n04330267_ILSVRC2012_val_00023064.JPEG',
  '../data/imagenet_val_segmented_resized/nontarget/n04597913_ILSVRC2012_val_00032112.JPEG',
  '../data/imagenet_val_segmented_resized/cat/n02123045_ILSVRC2012_val_00005358.JPEG',
  '../data/imagenet_val_segmented_resized/nontarget/n07753113_ILSVRC2012_val_00011675.JPEG'],
 'foveated_objlist': ['../data/imagenet_val_segmented_resized_foveated/nontarget/TP_cat_d0-b0-n04330267_ILSVRC2012_val_00023064.JPEG',
  '../data/imagenet_val_segmented_resized_foveated/nontarget/TP_cat_d0-b1-n04597913_ILSVRC2012_val_00032112.JPEG',
  '../data/imagenet_val_segmented_resized_foveated/cat/TP_cat_d0-b2-n02123045_ILSVRC2012_val_00005358.JPEG',
  '../data/imagenet_val_segmented_resized_foveated/nontarget/TP_cat_d0-b3-n07753113_ILSVRC2012_val_00011675.JPEG'],
 'target_idx_in_objlist'

In [13]:
# save the updated display_info, this will overwrite the previous one
np.save(DIR_SAVE/'display_info.npy', display_info, allow_pickle=True)
print('display_info updated!')

display_info saved!
