In [1]:
import argparse
import json
from pathlib import Path

import joblib
import numpy as np

In [2]:
DATA_ROOT = Path('/home/bart/Documents/repos/L2M/visual-search-nets/data')

data_gz_fname = DATA_ROOT.joinpath('data_prepd_for_nets/alexnet_train_2_v_5_data.gz')
json_fname = DATA_ROOT.joinpath('visual_search_stimuli/alexnet_train_2_v_5/alexnet_train_2_v_5.json')
stim_abbrev = '2_v_5'

In [3]:
data_gz = joblib.load(data_gz_fname)

In [4]:
with open(json_fname) as fp:
    stim_meta_dict = json.load(fp)

stim_meta_list = []
stim_meta_dict = stim_meta_dict[stim_abbrev]
for set_size, stim_meta_this_set_size in stim_meta_dict.items():
    # *** only using present because we only care about splitting target present condition up ***
    stim_meta_list.extend(stim_meta_this_set_size['present'])

fname_grid_map = {}
for meta_d in stim_meta_list:
    stim_fname_meta = Path(meta_d['filename']).name
    char_grid = np.asarray(meta_d['grid_as_char'])
    fname_grid_map[stim_fname_meta] = char_grid

In [5]:
GRID_SHAPE = (5, 5)
train_mask = np.zeros(GRID_SHAPE).astype(np.int32)
train_mask[:, :3] = 1

In [6]:
SPLITS = ['train', 'val', 'test']

In [7]:
def list2vec(a_list):
    if type(a_list) == np.ndarray:
        return a_list

    elif type(a_list) == list:
        if all([type(item) == list for item in a_list]):
            a_list = [item for sublist in a_list for item in sublist]

        if all([type(item) == str for item in a_list]):
            a_list = np.asarray(a_list)
        
        if all([type(item) == np.ndarray for item in a_list]):
            a_list = np.concatenate(a_list)
        
        return a_list
    else:
        raise TypeError('expected list or numpy array')

Filter training set, keeping only samples where target appears within mask

In [8]:
x_train = list2vec(data_gz['x_train'])
y_train = list2vec(data_gz['y_train'])
set_size_vec_train = list2vec(data_gz['set_size_vec_train'])

splits_new = {
    'train': {
        'x': [],
        'y': [],
        'set_size_vec': [],
    }
}

for fname, target_present, set_size in zip(x_train, y_train, set_size_vec_train):
    if 'present' in fname:
        fname_name = Path(fname).name
        char_grid = fname_grid_map[fname_name]
        if np.any(np.logical_and(char_grid == 't', train_mask)):
            splits_new['train']['x'].append(fname)
            splits_new['train']['set_size_vec'].append(set_size)
            splits_new['train']['y'].append(target_present)
    elif 'absent' in fname:
        splits_new['train']['x'].append(fname)
        splits_new['train']['set_size_vec'].append(set_size)
        splits_new['train']['y'].append(target_present)

Make sure equal number of samples for target present and absent conditions for each set size

In [9]:
set_size_vec = np.asarray(splits_new['train']['set_size_vec'])
set_sizes = np.unique(set_size_vec)
y_train_new_arr = np.asarray(splits_new['train']['y'])
keep_inds = []

for set_size in set_sizes:
    inds_this_set_size_target_present = np.nonzero(
        np.logical_and(set_size_vec == set_size, y_train_new_arr == 1)
    )[0]
    inds_this_set_size_target_absent = np.nonzero(
        np.logical_and(set_size_vec == set_size, y_train_new_arr == 0)
    )[0]
    num_present = inds_this_set_size_target_present.shape[0]
    num_absent = inds_this_set_size_target_absent.shape[0]
    if num_present < num_absent:
        inds_this_set_size_target_absent = inds_this_set_size_target_absent[:num_present]
    elif num_present > num_absent:
        inds_this_set_size_target_present = inds_this_set_size_target_present[:num_absent]
    else:
        pass
    keep_inds.extend(inds_this_set_size_target_absent.tolist())
    keep_inds.extend(inds_this_set_size_target_present.tolist())

keep_inds = np.asarray(keep_inds)
for name in ['x', 'set_size_vec', 'y']:
    as_arr = np.asarray(splits_new['train'][name])
    splits_new['train'][name] = as_arr[keep_inds].tolist()

In [11]:
len(splits_new['train']['x'])

62382

In [12]:
if data_gz['shard_train']:
    shard_size = data_gz['shard_size']
    # get floor to figure out num samples per shard for each set size,
    # and then we'll throw any leftovers into the last (num_shards + 1) shard
    num_shards = int(np.floor(len(splits_new['train']['x']) / shard_size))
    set_sizes, set_size_samples_per_shard = np.unique(data_gz['set_size_vec_train'][0], return_counts=True)

In [14]:
set_size_samples_per_shard

array([ 382,  832, 1694, 3492])

In [15]:
x_train = np.asarray(splits_new['train']['x'])
y_train = np.asarray(splits_new['train']['y'])
set_size_vec_train = np.asarray(splits_new['train']['set_size_vec'])
for_sharding = {int(set_size): {} for set_size in
                set_sizes}  # will add 'present' and 'absent' keys in next loop below
for set_size in set_sizes:
    set_size_inds = np.nonzero(set_size_vec_train == set_size)[0]
    set_size_present_inds = np.nonzero(y_train[set_size_inds] == 1)[0]
    set_size_absent_inds = np.nonzero(y_train[set_size_inds] == 1)[0]
    for_sharding[int(set_size)]['present'] = x_train[set_size_present_inds].tolist()
    for_sharding[int(set_size)]['absent'] = x_train[set_size_absent_inds].tolist()

In [17]:
for set_size, num_samples in zip(set_sizes, set_size_samples_per_shard):
    is_odd = num_samples % 2
    if is_odd:
        coin_flip = np.random.choice([0, 1])
        if coin_flip:
            n_present = int(np.ceil(num_samples / 2))
            n_absent = num_samples - n_present
        else:
            n_absent = int(np.ceil(num_samples / 2))
            n_present = num_samples - n_absent
    else:
        n_present = n_absent = int(num_samples / 2)
    set_size = int(set_size)
    total_present = len(for_sharding[set_size]['present'])
    for_sharding[set_size]['present'] = [
        for_sharding[set_size]['present'][i:i + n_present] for i in range(0, total_present, n_present)]
    total_absent = len(for_sharding[set_size]['absent'])
    for_sharding[set_size]['absent'] = [
        for_sharding[set_size]['absent'][i:i + n_absent] for i in range(0, total_absent, n_absent)]

In [21]:
x_sharded = []
y_sharded = []
set_size_sharded = []
num_shards_now = set(
    len(for_sharding[set_size][target_cond])
    for set_size in set_sizes
    for target_cond in ['present', 'absent']
)
if len(num_shards_now) != 1:
    raise ValueError(f'inconsistent number of shards: {num_shards_now}')
else:
    num_shards_now = num_shards_now.pop()

In [23]:
for shard_ind in range(num_shards_now):
    x_shard = []
    y_shard = []
    set_size_shard = []
    for set_size in set_sizes:
        set_size = int(set_size)
        x_present = for_sharding[set_size]['present'][shard_ind]
        x_shard.extend(x_present)
        y_shard.extend([1 for el in x_present])
        set_size_shard.extend([set_size for el in x_present])

        x_absent = for_sharding[set_size]['absent'][shard_ind]
        x_shard.extend(x_absent)
        y_shard.extend([0 for el in x_absent])
        set_size_shard.extend([set_size for el in x_absent])

    y_shard = np.asarray(y_shard)
    set_size_shard = np.asarray(set_size_shard)
    x_sharded.append(x_shard)
    y_sharded.append(y_shard)
    set_size_sharded.append(set_size_shard)

6400

In [None]:





    splits_new['train']['x'] = x_sharded
    splits_new['train']['y'] = y_sharded
    splits_new['train']['set_size_vec'] = set_size_sharded
else:  # if shard_train is not True
    # keep x as a list but
    splits_new['train']['y'] = np.asarray(splits_new['train']['y'])
    splits_new['train']['set_size_vec'] = np.asarray(splits_new['train']['set_size_vec'])

In [17]:
len(for_sharding[2]['present'])

0

In [14]:
out_dict = {}

for split in SPLITS:
        if split == 'train':
            for name, a_list in splits_new['train'].items():
                out_dict[f'{name}_{split}'] = splits_new[split][name]
        elif split == 'val' or split == 'test':
            for name in ['x', 'y', 'set_size_vec']:
                out_dict[f'{name}_{split}'] = data_gz[f'{name}_{split}']

In [27]:
ALSO_ADD = ['set_sizes_by_stim_type', 'shard_train', 'shard_size']

In [28]:
for other_key in ALSO_ADD:
    out_dict[other_key] = data_gz[other_key]

In [25]:
NEW_DATAGZ_NAME = '/home/bart/Documents/repos/L2M/visual-search-nets/data/expt_13/data_prepd_for_nets/alexnet_train_test_target_split_RVvGV_data.gz'

In [26]:
joblib.dump(out_dict, NEW_DATAGZ_NAME)

['/home/bart/Documents/repos/L2M/visual-search-nets/data/expt_13/data_prepd_for_nets/alexnet_train_test_target_split_RVvGV_data.gz']