In [1]:
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm.notebook import tqdm
import random
import glob
import json

In [2]:
RANDOM_SEED=100 # different random seed for all-negative so that we pick diff attribute pairs.
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [3]:
# ZAPPOS_ROOT = '/home/amrith/zappos-50k/'
# ZAPPOS_IMAGES_ROOT = '/home/amrith/zappos-50k/ut-zap50k-images-square'
ZAPPOS_ROOT = '/home/oscarli/projects/meta-analysis-classification/fed_data/zappos/'
ZAPPOS_IMAGES_ROOT = '/home/oscarli/projects/meta-analysis-classification/fed_data/zappos/ut-zap50k-images-square'

In [4]:
df = pd.read_csv(f'{ZAPPOS_ROOT}/ut-zap50k-data/meta-data-bin.csv')
df['HeelHeight.High.heel'] = df['HeelHeight.4in...4.3.4in'] + df['HeelHeight.5in...over']
df['HeelHeight.Short.heel'] = df['HeelHeight.Flat'] + df['HeelHeight.Under.1in'] + df['HeelHeight.1in...1.3.4in']

In [5]:
df.shape

(50025, 154)

In [6]:
print(sorted(df.columns))

['CID', 'Category.Boots', 'Category.Sandals', 'Category.Shoes', 'Category.Slippers', 'Closure.Adjustable', 'Closure.Ankle.Strap', 'Closure.Ankle.Wrap', 'Closure.Belt', 'Closure.Buckle', 'Closure.Bungee', 'Closure.Button.Loop', 'Closure.Elastic.Gore', 'Closure.Hook.and.Loop', 'Closure.Lace.up', 'Closure.Monk.Strap', 'Closure.Pull.on', 'Closure.Sling.Back', 'Closure.Slip.On', 'Closure.Snap', 'Closure.Spat.Strap', 'Closure.T.Strap', 'Closure.Toggle', 'Closure.Zipper', 'Gender.Boys', 'Gender.Girls', 'Gender.Men', 'Gender.Women', 'HeelHeight.1in...1.3.4in', 'HeelHeight.2in...2.3.4in', 'HeelHeight.3in...3.3.4in', 'HeelHeight.4in...4.3.4in', 'HeelHeight.5in...over', 'HeelHeight.Flat', 'HeelHeight.High.heel', 'HeelHeight.Short.heel', 'HeelHeight.Under.1in', 'Insole.EVA', 'Insole.Gel', 'Insole.Hypoallergenic', 'Insole.Latex.Lined', 'Insole.Leather', 'Insole.Memory.Foam', 'Insole.Moisture.Wicking', 'Insole.Orthotic.Friendly', 'Insole.Padded', 'Insole.Polyurethane', 'Insole.Poron', 'Insole.Remova

In [7]:
# Attribute list
import json
partition = json.load(open('zappos_attributes_ID_OOD_partition.json', 'r'))
ID_ATTR = partition['ID']
OOD_ATTR = partition['OOD']
print(len(ID_ATTR))
print(len(OOD_ATTR))

36
30


1. logic to sample attribute list
2. logic to check if attributes are from same category
3. logic to sample images that are positive for the attribute pair

In [8]:
def is_same_category(attrs):
    return len(set([attr.lower().split(".")[0] for attr in attrs])) < len(attrs)
    
def sample_attrs(attributes_list, n_attrs=2):
    return np.random.choice(attributes_list, n_attrs, replace=False)

def get_all_positive_images(df, attrs):
    return df[attrs].apply(sum, axis=1) == len(attrs)

def get_all_negative_images(df, attrs):
    return (1-df[attrs]).apply(sum, axis=1) == len(attrs)

def sample_images(df, selected_images, n_samples):
    return list(np.random.choice(df[selected_images]["CID"].values, n_samples, replace=False))

In [9]:
CID_to_impath = {x.split("/")[-1][:-4].replace(".", "-"):x for x in glob.glob(f"{ZAPPOS_IMAGES_ROOT}/**/*.jpg", recursive=True)}

In [10]:
def sample_dataset(
        ATTR_LIST, N_SAMPLES, 
        POSITIVE_IMAGES_THRESHOLD, NEGATIVE_IMAGES_THRESHOLD,
        N_SUPPORT, N_QUERY,
        verbose=False):
    
    # main logic for sampling
    n_selected = 0
    dataset = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    counter = 0
    import time
    st = time.time()
    p_bar = tqdm(total=N_SAMPLES, initial=1)
    while n_selected < N_SAMPLES:
        counter += 1
        # get attributes
        attrs = sample_attrs(ATTR_LIST)
        if verbose:
            tqdm.write(attrs)
        # if any pair of the sampled attributes are from the same category then reject
        if is_same_category(attrs):
            # print("Dropped -- same category", attrs)
            continue
        # get images which satisfy the attributes
        positive_images = get_all_positive_images(df, attrs)
        negative_images = positive_images.apply(lambda x: not x)
        assert all([(x != y) for x,y in zip(positive_images, negative_images)])
        # if we want to enforce 0s for all attributes then we need to change the logic above
        # to include get_all_negative_images(df, attrs)

        # if the no of images that are positive for the chosen attributes is too less, then reject 
        if (sum(positive_images) < POSITIVE_IMAGES_THRESHOLD) or (sum(negative_images) < NEGATIVE_IMAGES_THRESHOLD) :
            # print("Dropped -- didnt meet threshold", attrs)
            continue

        positive = [CID_to_impath[z] for z in sample_images(df, positive_images, N_SUPPORT+N_QUERY)]
        negative = [CID_to_impath[z] for z in sample_images(df, negative_images, N_SUPPORT+N_QUERY)]

        dataset[n_selected][0]['support'] = negative[:N_SUPPORT]
        dataset[n_selected][0]['query'] = negative[N_SUPPORT:]

        dataset[n_selected][1]['support'] = positive[:N_SUPPORT]
        dataset[n_selected][1]['query'] = positive[N_SUPPORT:]

        dataset[n_selected][0]['attributes'] = "!(" + ";".join([a for a in attrs]) + ")"
        dataset[n_selected][1]['attributes'] = ";".join(attrs)


        n_selected += 1
        p_bar.update(1)
        if verbose:
            tqdm.write(f'Selected {n_selected}/{counter} Time Elapsed {time.time()-st}')
    
    p_bar.close()
    return dataset
        
        
        

In [13]:
POSITIVE_IMAGES_THRESHOLD = 20
NEGATIVE_IMAGES_THRESHOLD = 20
N_SUPPORT = 10
N_QUERY = 10
N_SAMPLES_ID = 11000 # (1000 ID TRAIN, 5000/5000: ID VAL/TEST)
N_SAMPLES_OOD = 20000 # (5000/5000: OD VAL/TEST)

In [14]:
print("--"*10, "ID", "--"*10)
ID_dataset = sample_dataset(
    ID_ATTR,
    N_SAMPLES = N_SAMPLES_ID,
    POSITIVE_IMAGES_THRESHOLD = POSITIVE_IMAGES_THRESHOLD,
    NEGATIVE_IMAGES_THRESHOLD = NEGATIVE_IMAGES_THRESHOLD,
    N_SUPPORT = N_SUPPORT,
    N_QUERY = N_QUERY)

-------------------- ID --------------------


HBox(children=(FloatProgress(value=0.0, max=11000.0), HTML(value='')))

KeyboardInterrupt: 

In [12]:
json.dump(ID_dataset, 
          open(f'zappos-ID-alltrue_vs_anyfalse-nsamp{N_SAMPLES_ID}-ns{N_SUPPORT}-nq{N_QUERY}.json', 'w'))

In [None]:
print("--"*10, "OOD", "--"*10)
OOD_dataset = sample_dataset(
    OOD_ATTR,
    N_SAMPLES = N_SAMPLES_OOD,
    POSITIVE_IMAGES_THRESHOLD = POSITIVE_IMAGES_THRESHOLD,
    NEGATIVE_IMAGES_THRESHOLD = NEGATIVE_IMAGES_THRESHOLD,
    N_SUPPORT = N_SUPPORT,
    N_QUERY = N_QUERY)

In [None]:
json.dump(OOD_dataset, 
          open(f'zappos-OOD-alltrue_vs_anyfalse-nsamp{N_SAMPLES_OOD}-ns{N_SUPPORT}-nq{N_QUERY}.json', 'w'))