# Outfit Creation with Style

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import Counter
import glob
import heapq
import json
import numpy
import numpy as np
import os
import pickle
import pandas as pd
import random
from tqdm import tqdm
import tensorflow as tf
import uuid

# %pylab inline
import matplotlib.pyplot as plt
from PIL import Image

import sys
sys.path.insert(0, "/recsys_data/RecSys/fashion/automl/efficientnetv2")
import effnetv2_model

from data_process import OutfitGen, OutfitGenWithImage, ZalandoOutfitGenWithImage

In [3]:
import anvil.server
import anvil.media
import anvil.mpl_util

anvil.server.connect("D3BJ3YVOSVMXFCEVTGLVI3XH-5GA6BUC4BEGN43RS")

Connecting to wss://anvil.works/uplink
Anvil websocket open
Connected to "Default environment" as SERVER


In [4]:
base_dir = "/recsys_data/RecSys/Zalando_Outfit/female/Outfit_Data"
train_dir = base_dir
image_dir = "/recsys_data/RecSys/Zalando_Outfit/resized_packshot_images_female"
embed_dir = "/recsys_data/RecSys/Zalando_Outfit/female/Outfit_Data/precomputed"
image_embedding_file = os.path.join(embed_dir, "effnet2_zalando.pkl")


In [5]:
with open(image_embedding_file, "rb") as fr:
    image_embedding_dict = pickle.load(fr)
print(f"Loaded {len(image_embedding_dict)} image embeddings")

Loaded 51974 image embeddings


In [6]:
all_files = [image for image in glob.glob(f"{image_dir}/*.*")]
print(len(all_files))

51985


In [7]:
eff2_model = tf.keras.models.Sequential(
                [
                    tf.keras.layers.InputLayer(input_shape=[224, 224, 3]),
                    effnetv2_model.get_model("efficientnetv2-b0", include_top=False),
                ]
            )

Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.


Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.


In [8]:
train_file = "zalando_female_outfit_data_18k_and_new_modified_format_train.json"
valid_file = "zalando_female_outfit_data_18k_and_new_modified_format_val.json"
test_file = "zalando_female_outfit_data_18k_and_new_modified_format_test.json"

train_json = os.path.join(base_dir, "Train", train_file)
train_data = json.load(open(train_json, 'r'))

item2cat = {}
cat2item = {}
seq_lens = []
count = 0
for outfit in train_data:
    items = outfit["item_ids"]
    categories = outfit["high_level_cats"]
    style = outfit["outfit_occasion"]
    for i,c in zip(items, categories):
        if c not in cat2item:
            cat2item[c] = []
        cat2item[c].append(i)
        item2cat[i] = {"category_id": c}
    seq_lens.append(len(items))
    count += 1
print(f"{count} training examples, average {np.mean(seq_lens):.2f} items, max {np.max(seq_lens)} item")

valid_json = os.path.join(base_dir, "Val", valid_file)
valid_data = json.load(open(valid_json, 'r'))

count = 0
seq_lens = []
for outfit in valid_data:
    items = outfit["item_ids"]
    categories = outfit["high_level_cats"]
    for i,c in zip(items, categories):
        cat2item[c].append(i)
        item2cat[i] = {"category_id": c}
    seq_lens.append(len(items))
    count += 1
print(f"{count} validation examples, average {np.mean(seq_lens):.2f} items, max {np.max(seq_lens)} items")

test_json = os.path.join(base_dir, "Test", test_file)
test_data = json.load(open(test_json, 'r'))

count = 0
seq_lens = []
for outfit in test_data:
    items = outfit["item_ids"]
    categories = outfit["high_level_cats"]
    for i,c in zip(items, categories):
        cat2item[c].append(i)
        item2cat[i] = {"category_id": c}
    seq_lens.append(len(items))
    count += 1
print(f"{count} test examples, average {np.mean(seq_lens):.2f} items, max {np.max(seq_lens)} item")

90847 training examples, average 3.89 items, max 8 item
2493 validation examples, average 4.12 items, max 7 items
5809 test examples, average 4.13 items, max 7 item


In [9]:
print(f"Total number of categories: {len(cat2item)}")
cat2item.keys()

Total number of categories: 9


dict_keys(['all-body', 'footwear', 'accessory', 'outerwear', 'jewellery', 'topwear', 'bottomwear', 'bodywear_nightwear_innerwear', 'beachwear_swimwear'])

In [10]:
for c in cat2item:
    print(c, len(cat2item[c]))

all-body 18430
footwear 65766
accessory 32378
outerwear 41939
jewellery 63876
topwear 42805
bottomwear 51935
bodywear_nightwear_innerwear 66597
beachwear_swimwear 4062


In [11]:
category_order = ['bodywear_nightwear_innerwear', 'footwear', 'jewellery', 'bottomwear', 'topwear', 
                  'outerwear', 'accessory', 'all-body', 'beachwear_swimwear']
for c in category_order:
    print(c, len(cat2item[c]))

bodywear_nightwear_innerwear 66597
footwear 65766
jewellery 63876
bottomwear 51935
topwear 42805
outerwear 41939
accessory 32378
all-body 18430
beachwear_swimwear 4062


In [12]:
all_item_categories = set([item2cat[item]["category_id"] for item in item2cat])
label_dict = {}
padding = 1
for ii, k in enumerate(all_item_categories):
    label_dict[k] = ii + padding
label_dict

{'outerwear': 1,
 'footwear': 2,
 'topwear': 3,
 'accessory': 4,
 'jewellery': 5,
 'all-body': 6,
 'bottomwear': 7,
 'bodywear_nightwear_innerwear': 8,
 'beachwear_swimwear': 9}

In [13]:
# create the item category labels based on Polyvore catalogue
import pandas as pd

pv_base_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/polyvore_outfits"
pv_item_file = "polyvore_item_metadata.json"
with open(os.path.join(pv_base_dir, pv_item_file), 'r') as fr:
    pv_items = json.load(fr)
    
# original logic of creating the labels
# taken from ImageDataGen()
X, y = [], []
for item_id in pv_items:
    X.append(item_id)
    y.append(pv_items[item_id]["category_id"])

X_col = "X"
y_col = "y"
df = pd.DataFrame({X_col: X, y_col: y})
categories = df[y_col].unique()
pv_label_dict = {ii: jj for ii, jj in enumerate(categories)} # reversing

In [15]:
pv_id2cat = dict()
for item in pv_items:
    iid, cat = pv_items[item]["category_id"], pv_items[item]["semantic_category"]
    if iid not in pv_id2cat:
        pv_id2cat[iid] = cat

In [14]:
from focal_loss import BinaryFocalLoss

model_path = 'compatibility_zalando_rnn_model_4_only_image'
model = tf.keras.models.load_model('compatibility_zalando_rnn_model_8_only_image')
model.summary()

Model: "rnn"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 8, 1280)]    0                                            
__________________________________________________________________________________________________
tf_op_layer_Sum (TensorFlowOpLa (None, 8)            0           input_1[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_NotEqual (TensorFlo (None, 8)            0           tf_op_layer_Sum[0][0]            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 8)]          0                                            
________________________________________________________________________________________________

In [16]:
category_model = tf.keras.models.load_model(f"finetuned_efficientnet")













































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































In [17]:
def get_product_category(query):
    image = tf.keras.preprocessing.image.load_img(query)
    image_arr = tf.keras.preprocessing.image.img_to_array(image)
    # image_arr = tf.image.resize(image_arr, (224, 224)).numpy()
    image_arr = tf.image.resize_with_pad(image_arr, target_height=224, target_width=224).numpy()
    image_arr /= 255.0
    pred_score = category_model(tf.expand_dims(image_arr, 0))
    pred_label = np.argmax(pred_score, axis=1)[0]
    pred_label = pv_label_dict[pred_label]  # class number as given in the item description
    return pred_label, pv_id2cat[pred_label]

def get_product_category_2(image_arr):
    pred_score = category_model(tf.expand_dims(image_arr, 0))
    pred_label = np.argmax(pred_score, axis=1)[0]  # index of the class
    pred_label = pv_label_dict[pred_label]  # class number as given in the item description
    return pred_label, pv_id2cat[pred_label]


In [18]:
pv2zalando = {'accessories': 'accessory',
              'all-body': 'all-body',
              'bags': 'accessory',
              'bottoms': 'bottomwear',
              'hats': 'accessory',
              'jewellery': 'jewellery',
              'outerwear': 'outerwear',
              'scarves': 'accessory',
              'shoes': 'footwear',
              'sunglasses': 'accessory',
              'tops': 'topwear'}

style_dict = {'casual': '0', 'travel': '1', 'party': '2',
              'athleisure': '3', 'work': '4', 'sporty': '5', 'relax': '6'}

# item_description_dict: dictionary with item-id as key and category in the values, item2cat
# item_category_dict: dictionary with category-id as key and all the items in that category as value, cat2item
def return_top_items(query_item, style_name, image_embedding_dict, item_dict, ignore=None, search_only=None, max_item=8):
    if type(query_item) is not list:
        query_item = [query_item]
    style_idx = style_dict[style_name]
    data_gen = ZalandoOutfitGenWithImage(query_item=query_item,
                                         embed_dir=embed_dir,
                                         text_embed_file="bert_polyvore.pkl",
                                         batch_size=256,
                                         max_len=max_item,
                                         image_embedding_dim=1280,
                                         image_embedding_dict=image_embedding_dict,
                                         item_description=item_dict,
                                         item_category_dict=cat2item,
                                         search_only_categories=search_only,
                                         label_dict=label_dict,
                                         outfit_style=style_idx,
                                         include_item_categories=True,
                                        )
    pbar = tqdm(range(len(data_gen)))
    current_score = []
    # adds one more item to the list of current items (stored in query_item)
    for ii in pbar:
        x, items = data_gen[ii]
        yhat = model(x)
        for cs, item in zip(yhat, items):
            heapq.heappush(current_score, (1-cs, query_item + [str(item)]))  # it's a min-heap
    return current_score

def filter_outfits(outfits, item_dict, max_len):
    # retain only those combinations where all categories are distinct
    count = 0
    filtered = []
    while count < max_len: 
        outfit = heapq.heappop(outfits)
        items = outfit[1]
        categories = [item_dict[item]['category_id'] for item in items]
        if len(set(categories)) == len(categories):
            filtered.append(items)
            count += 1
    return filtered

def get_next_category(cats, global_categories):
    # returns the next category to search from
    for cat in global_categories:
        if cat not in cats:
            return cat
    return None

def create_outfit(query, category, style, model, image_dict, item_dict, max_item=8, beam_length=2):
    """
        query: image embedding or polyvore item-id
        catgeory: category of the image, str
        style: outfit style, str
        model: compatibility model (pretrained)
        image_dict: image to embedding dict
        item_dict: item to item details dict
        max_item: maximum number of items in the outfit
        beam_length: number of alternative outfits
    """
    ## TODO: modify for multiple image inputs 
    # create new item-id and update embedding_dict and category_dict
    if type(query) is not str:
        item_id = str(uuid.uuid4())
        image_dict[item_id] = query
        item_dict[item_id] = {'semantic_category': category, 'category_id': category}
    else:
        item_id = query
    
    # Based on the global category distribution find the next category to search from
    top_level_category_order = ['footwear', 'jewellery', 'bottomwear', 'topwear', 'outerwear', 'accessory', 'all-body']
    next_category = get_next_category([category], top_level_category_order)

    # add the first item - only one run
    first_score = return_top_items(item_id,
                                   style,
                                   image_dict,
                                   item_dict,
                                   ignore=[category], 
                                   search_only=[next_category])
    current_items = filter_outfits(first_score, item_dict, beam_length)
    print(max_item)
    
    for jj in range(2, max_item):
        all_scores = []
        for ii in range(beam_length):
            ignore_categories = [item_dict[item]['category_id'] for item in current_items[ii]]
            next_category = get_next_category(ignore_categories, top_level_category_order)
            scores_ii = return_top_items(current_items[ii],
                                         style,
                                         image_dict, 
                                         item_dict, 
                                         ignore_categories,
                                         search_only=[next_category],
                                        )
            all_scores += scores_ii
        # reconstruct current items - with one new item
        current_items = filter_outfits(all_scores, item_dict, beam_length)
    
    return current_items

In [19]:
@anvil.server.callable
def get_outfit(file, max_item=3, style="casual"):
    with anvil.media.TempFile(file) as filename:
        image = tf.keras.preprocessing.image.load_img(filename)

    image_arr = tf.keras.preprocessing.image.img_to_array(image)
    # image_arr = tf.image.resize(image_arr, (224, 224)).numpy()
    image_arr = tf.image.resize_with_pad(image_arr, target_height=224, target_width=224).numpy()
    image_arr /= 255.0
    image_embed = tf.squeeze(eff2_model(tf.expand_dims(image_arr, 0)))
    image_cat_id, image_cat = get_product_category_2(image_arr)
    image_cat = pv2zalando[image_cat]
    print(image_cat_id, image_cat, max_item, style)
    outfits = create_outfit(query=image_embed,
                            category=image_cat,
                            style=style,
                            model=model,
                            image_dict=image_embedding_dict.copy(), 
                            item_dict=item2cat.copy(), 
                            max_item=int(max_item), 
                            beam_length=2)
    print(outfits)
    outfit_images = []
    for prods in outfits:
        res = []
        for item in prods[1:]:
            filepath = os.path.join(image_dir, item)
            res.append(anvil.media.from_file(filepath, 'image/jpeg'))
        outfit_images.append(res)
    return outfit_images


4 all-body 3 casual
Searching limited to ['footwear']
Original 51975 items, reduced to 2611 items


100%|██████████| 11/11 [00:04<00:00,  2.29it/s]


3
Searching limited to ['jewellery']
Original 51975 items, reduced to 860 items


100%|██████████| 4/4 [00:01<00:00,  2.51it/s]


Searching limited to ['jewellery']
Original 51975 items, reduced to 860 items


100%|██████████| 4/4 [00:01<00:00,  2.59it/s]


[['76b12b23-5908-420f-8de7-500849758d57', 'TT911B01B-B11@9.jpg', 'PI851L0CD-F11@7.1.jpg'], ['76b12b23-5908-420f-8de7-500849758d57', 'TT911B01B-B11@9.jpg', '4SW51L0F2-G11@6.jpg']]
4 all-body 3 travel
Searching limited to ['footwear']
Original 51975 items, reduced to 2611 items


100%|██████████| 11/11 [00:04<00:00,  2.40it/s]


3
Searching limited to ['jewellery']
Original 51975 items, reduced to 860 items


100%|██████████| 4/4 [00:01<00:00,  2.59it/s]


Searching limited to ['jewellery']
Original 51975 items, reduced to 860 items


100%|██████████| 4/4 [00:01<00:00,  2.54it/s]


[['dae2a0dd-be4e-4ebd-b847-b3f7dcefb34f', 'TT911B01B-B11@9.jpg', 'PI851L0CD-F11@7.1.jpg'], ['dae2a0dd-be4e-4ebd-b847-b3f7dcefb34f', 'TT911B01B-B11@9.jpg', '4SW51L0F2-G11@6.jpg']]


In [68]:
def get_outfit_local(filename, style="casual", max_item=3):
    image = tf.keras.preprocessing.image.load_img(filename)

    image_arr = tf.keras.preprocessing.image.img_to_array(image)
    # image_arr = tf.image.resize(image_arr, (224, 224)).numpy()
    image_arr = tf.image.resize_with_pad(image_arr, target_height=224, target_width=224).numpy()
    image_arr /= 255.0
    image_embed = tf.squeeze(eff2_model(tf.expand_dims(image_arr, 0)))
    image_cat_id, image_cat = get_product_category_2(image_arr)
    image_cat = pv2zalando[image_cat]
    print(image_cat_id, image_cat)
    outfits = create_outfit(query=image_embed,
                            category=image_cat,
                            model=model,
                            image_dict=image_embedding_dict.copy(), 
                            item_dict=item2cat.copy(), 
                            max_item=int(max_item), 
                            beam_length=2)
    return outfits


In [None]:
pv_base_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/polyvore_outfits/images"
get_outfit_local(os.path.join(pv_base_dir, "129978068.jpg"))