In [1]:
import platform
import os
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [2]:
from glob import glob 
import pandas as pd
import torch
import numpy as np
import itertools

In [3]:
save_path = f"{DATA_PATH}/Models/InterfaceGAN/Inputs/e4e_00005/"
latents_path = f"{DATA_PATH}/Models/e4e/00005_snapshot_1200/inversions/latents_dict.pt"

# Load in Latents
latents = torch.load(latents_path)

In [4]:
# Import metadata
meta = pd.read_json(f'{DATA_PATH}/Zalando_Germany_Dataset/dresses/metadata/dresses_metadata.json').T.reset_index().rename(columns={'index':'sku'})
meta.original_price = pd.to_numeric(meta.original_price, errors='coerce')
meta.head(3)

Unnamed: 0,sku,name,sku_base,sku_color_code,url,brand,original_price,current_price,brand_url,category,...,fabric,fit,neckline,pattern,collar,length,shape,sleeve_length,thumbnail_url,packshot_url
0,AN621C22S-O11,Jersey dress - brown,AN621C22S,O11,https://en.zalando.de/anna-field-shift-dress-b...,Anna Field,39.99,39.99,https://en.zalando.de/anna-field/,Shift dress,...,Jersey,Slim Fit,,Plain,Standing collar,Calf-length,Body-hugging,Short,https://img01.ztat.net/article/spp-media-p1/fb...,https://img01.ztat.net/article/spp-media-p1/c8...
1,BU321C01G-K11,Jersey dress - marine/bedruckt,BU321C01G,K11,https://en.zalando.de/buffalo-jersey-dress-mar...,Buffalo,39.99,39.99,https://en.zalando.de/buffalo/,Jersey dress,...,Jersey,Regular Fit,Low-cut v-neck,Print,,Knee-length,Fitted,Sleeveless,https://img01.ztat.net/article/spp-media-p1/50...,https://img01.ztat.net/article/spp-media-p1/17...
2,JY121C0TB-A11,JDYCARLA CATHINKA DRESS - Jersey dress - cloud...,JY121C0TB,A11,https://en.zalando.de/jdy-carla-cathinka-dress...,JDY,34.99,34.99,https://en.zalando.de/jacqueline-de-yong/,Jersey dress,...,,Regular Fit,Crew neck,Plain,Standing collar,Knee-length,Flared,Short,https://img01.ztat.net/article/spp-media-p1/20...,https://img01.ztat.net/article/spp-media-p1/20...


In [5]:
thresholds = [100, 200, 300, 400, 500, 600, 700]


summary_stats = {}

for threshold in thresholds:
    # Subset to correct attribute values and create target data
    subset = meta[['sku', 'original_price']].copy()
    subset['label'] = subset.original_price.apply(lambda x: 1 if x > threshold else 0)
    target = np.array(subset.label).reshape(-1, 1)

    # Subset latents
    latents_subset = [latents[sku].squeeze(0) for sku in subset.sku]


    # Save everything
    os.makedirs(f"{save_path}/price/threshold_{threshold}/", exist_ok=True)



    # Split dimensions and save
    for i in range(16):
        latents_subset_dim = torch.stack([elem[i,:] for elem in latents_subset])
        assert latents_subset_dim.shape[0] == target.shape[0]
        np.save(f"{save_path}/price/threshold_{threshold}/latents_dim_{i}.npy", latents_subset_dim)

    np.save(f"{save_path}/price/threshold_{threshold}/target.npy", target)
    subset.to_csv(f'{save_path}/price/threshold_{threshold}/metadata.csv', index=False)

    summary_stats[threshold] = {
        'num_samples': target.shape[0],
        'num_positives': target.sum(),
        'num_negatives': (1 - target).sum()
    }

# Save summary stats
stats = pd.DataFrame(summary_stats).T
stats.to_csv(f"{save_path}/price/summary_stats.csv")

In [6]:
stats

Unnamed: 0,num_samples,num_positives,num_negatives
100,14060,6175,7885
200,14060,2788,11272
300,14060,1336,12724
400,14060,764,13296
500,14060,493,13567
600,14060,339,13721
700,14060,238,13822
