# Test Locally

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import heapq
import json
import logging
import math
import numpy as np
import os
import pickle
import pandas as pd
import tensorflow as tf

In [9]:
class OutfitGen(tf.keras.utils.Sequence):
    """
    This generator creates sample outfits using a query item (list)
    and adding one item at a time from a list of common items. The 
    outfit thus created is evaluated by an existing model and the 
    outfit score is utilized further for creating the best outfit.
    """

    def __init__(
        self,
        image_embedding_dict,
        text_embedding_dict,
        batch_size,
        max_len,
        image_embedding_dim,
        query_item,
        max_search_item,
        shuffle=False,
    ):

        self.image_embedding_dict = image_embedding_dict
        self.text_embedding_dict = text_embedding_dict

        common_items = set(self.image_embedding_dict.keys()).intersection(
            self.text_embedding_dict.keys())

        # working with such large number of items is causing timeout error
        # hence downsampling the searchable catalogue of items
        rng = np.random.default_rng()
        if max_search_item < 1:
            max_search_item = int(len(common_items) * max_search_item)
            common_items = list(common_items)
            sample_indices = rng.choice(
                len(common_items), size=max_search_item, replace=False)
            sampled_items = [common_items[ii] for ii in sample_indices]

        elif max_search_item > 1 and max_search_item < len(common_items):
            common_items = list(common_items)
            sample_indices = rng.choice(
                len(common_items), size=int(max_search_item), replace=False)
            sampled_items = [common_items[ii] for ii in sample_indices]

        else:
            sampled_items = common_items

        if type(query_item) is not list:
            query_item = [query_item]

        X, y = [], []
        for item in sampled_items:
            if item not in query_item:
                X.append([item] + query_item)
                y.append(item)

        self.X_col = "X"
        self.y_col = "y"
        self.df = pd.DataFrame({self.X_col: X, self.y_col: y})
        self.batch_size = batch_size
        self.max_len = max_len
        self.image_embedding_dim = image_embedding_dim
        self.text_embedding_dim = 768
        self.shuffle = shuffle
        self.n = len(self.df)

    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)

    def get_texts(self, item_id):
        return self.text_embedding_dict[item_id]

    def get_image(self, item_id):
        return self.image_embedding_dict[item_id]

    def __get_input(self, example):
        data = []
        items = [x for x in example[: self.max_len]]
        for item in items:
            image = self.get_image(item)
            text = self.get_texts(item)
            data.append((text, image))

        text_data = [x[0] for x in data]
        image_data = [x[1] for x in data]
        zero_elem_image = np.zeros(
            self.image_embedding_dim)  # np.zeros((1, 1280))
        zero_elem_text = np.zeros(self.text_embedding_dim)
        zeros_image = [zero_elem_image for _ in range(
            self.max_len - len(data))]
        zeros_text = [zero_elem_text for _ in range(
            self.max_len - len(data))]

        return (zeros_image + image_data, zeros_text + text_data)

    def __get_output(self, label):
        return self.label_dict[label]

    def __get_data(self, batches):
        # Generates data containing batch_size samples
        x_batch = batches["X"].tolist()
        y_batch = batches["y"].tolist()

        combined = [self.__get_input(x) for x in x_batch]
        X_batch = (
            np.asarray([x[0] for x in combined]),
            np.asarray([x[1] for x in combined]),
            # mask,
        )
        y_batch = np.asarray([int(y) for y in y_batch])

        return X_batch, y_batch

    def __getitem__(self, index):
        batches = self.df[index *
                          self.batch_size: (index + 1) * self.batch_size]
        X, y = self.__get_data(batches)
        return X, y

    def __len__(self):
        return math.ceil(self.n / self.batch_size)
        # return self.n // self.batch_size


def return_top_items(query_item, max_search_item, max_item=8):
    if type(query_item) is not list:
        query_item = [query_item]
    data_gen = OutfitGen(image_embedding_dict=image_embedding_dict,
                         text_embedding_dict=text_embedding_dict,
                         batch_size=256,
                         max_len=max_item,
                         image_embedding_dim=1280,
                         query_item=query_item,
                         max_search_item=max_search_item,
                         )
    # pbar = tqdm(range(len(data_gen)))
    pbar = range(len(data_gen))
    current_score = []
    for ii in pbar:
        x, items = data_gen[ii]
        yhat = model(x)
        for cs, item in zip(yhat, items):
            heapq.heappush(current_score, (1-cs, query_item +
                           [str(item)]))  # it's a min-heap
    return current_score


def filter_outfits(outfits, max_len):
    count = 0
    filtered = []
    while count < max_len:
        outfit = heapq.heappop(outfits)
        items = outfit[1]
        categories = [pv_items[item]['semantic_category'] for item in items]
        if len(set(categories)) == len(categories):
            filtered.append(items)
            count += 1
    return filtered


def create_outfit(query, max_search_item=1000, max_item=8, beam_length=10):

    if type(query) is not list:
        query = [query]

    # add the first item - only one run
    first_score = return_top_items(query, max_search_item)
    current_items = filter_outfits(first_score, beam_length)
    # plot_current_outfits(current_items, figsize=(10, beam_length * 2))
    
    print(0, current_items)

    for jj in range(len(query)+1, max_item):
        all_scores = []
        for ii in range(beam_length):
            scores_ii = return_top_items(
                current_items[ii], max_search_item)
            all_scores += scores_ii
        # reconstruct current items - with one new item
        current_items = filter_outfits(all_scores, beam_length)
        print(jj, current_items)
        # plot_current_outfits(current_items, figsize=(10 + jj, beam_length * 2))

    return current_items


def init():
    # print("This is init")
    """
    This function is called when the container is initialized/started, typically after create/update of the deployment.
    You can write the logic here to perform init operations like caching the model in memory
    """
    global model, image_embedding_dict, text_embedding_dict, pv_items
    # AZUREML_MODEL_DIR is an environment variable created during deployment.
    # It is the path to the model folder (./azureml-models/$MODEL_NAME/$VERSION)
    # model_path = os.path.join(os.getenv("AZUREML_MODEL_DIR"), "compatibility-rnn_")

    # Example when the model is a file, and the deployment contains multiple models
    # https://docs.microsoft.com/en-us/azure/machine-learning/v1/how-to-deploy-advanced-entry-script
#     first_model_name = 'compatibility-rnn'
#     first_model_version = '1'
#     first_model_path = os.path.join(os.getenv(
#         'AZUREML_MODEL_DIR'), first_model_name, first_model_version, 'model')
    data_type = "nondisjoint"
    model_type = "rnn"
    max_seq_len = 8
    first_model_path = f"compatibility_{data_type}_{model_type}_model_{max_seq_len}"

#     second_model_name = 'image-embedding'
#     second_model_version = '2'
#     image_embedding_file = os.path.join(os.getenv(
#         'AZUREML_MODEL_DIR'), second_model_name, second_model_version, 'effnet_tuned_polyvore.pkl')
    embed_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/precomputed"
    image_embedding_file = os.path.join(embed_dir, "effnet_tuned_polyvore.pkl")
    
#     third_model_name = 'text-embedding'
#     third_model_version = '2'
#     text_embedding_file = os.path.join(os.getenv(
#         'AZUREML_MODEL_DIR'), third_model_name, third_model_version, 'bert_polyvore.pkl')
    text_embedding_file = os.path.join(embed_dir, "bert_polyvore.pkl")

#     fourth_model_name = 'item-dict'
#     fourth_model_version = '1'
#     item_dict_file = os.path.join(os.getenv(
#         'AZUREML_MODEL_DIR'), fourth_model_name, fourth_model_version, 'items_polyvore.pkl')
    item_dict_file = os.path.join(embed_dir, "items_polyvore.pkl")

    # does not work
    # model_path = Model.get_model_path(model_name='compatibility-rnn')
    # image_path = Model.get_model_path(model_name='image-embedding')
    # text_path = Model.get_model_path(model_name='text-embedding')

    # model_path = os.getenv("AZUREML_MODEL_DIR")
    # Load a Tensorflow model from the model folder (there is no single model file)
    model = tf.keras.models.load_model(first_model_path)
    logging.info("Loaded the compatibility model")

    with open(image_embedding_file, "rb") as fr:
        image_embedding_dict = pickle.load(fr)
    logging.info(f"Loaded {len(image_embedding_dict)} image embeddings")

    with open(text_embedding_file, "rb") as fr:
        text_embedding_dict = pickle.load(fr)
    print(f"Loaded {len(text_embedding_dict)} text embeddings")

    with open(item_dict_file, "rb") as fr:
        pv_items = pickle.load(fr)
    print(f"Loaded {len(pv_items)} items ...")

    logging.info("Init complete")


def run(data):
    try:
        inp = json.loads(data)
        outfits = create_outfit(query=inp["query"],
                                max_search_item=inp["max_search_item"],
                                max_item=int(inp["max_item"]),
                                beam_length=int(inp["beam_length"]))

        logging.info(f"received data {inp}")
        return f"Generated outfit: {outfits}"
    except Exception as e:
        result = str(e)
        return result


In [10]:
data = """{
    "query": "132621870",
    "max_search_item": 100000,
    "max_item": 5,
    "beam_length": 2
}"""

In [11]:
init()
run(data)

Loaded 251008 text embeddings
Loaded 251008 items ...
0 [['132621870', '212822356'], ['132621870', '151036990']]
2 [['132621870', '212822356', '184154812'], ['132621870', '212822356', '203338138']]
3 [['132621870', '212822356', '184154812', '200172677'], ['132621870', '212822356', '184154812', '185030615']]
4 [['132621870', '212822356', '184154812', '200172677', '182992370'], ['132621870', '212822356', '184154812', '200172677', '193960201']]


"Generated outfit: [['132621870', '212822356', '184154812', '200172677', '182992370'], ['132621870', '212822356', '184154812', '200172677', '193960201']]"