# Learn Graph Sage Embedding

GraphSage embeddings work beautifully on the training dataset but not on the validation data and there is no reason it'll work on the test dataset. However, for each item we have all the features (text, image, category etc.) and we can use them to learn the GraphSage embeddings. The training dataset for this model will come from the items present in the original training dataset and the model will be evaluated on the items present *only* in the validation and test dataset. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
from collections import Counter
from PIL import Image
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import EarlyStopping

from tqdm import tqdm
import pickle
import sys

%pylab inline
import matplotlib.pyplot as plt

Populating the interactive namespace from numpy and matplotlib


## Load all the Data

In [3]:
base_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/polyvore_outfits"
train_dir = os.path.join(base_dir, "disjoint")
image_dir = os.path.join(base_dir, "images")
train_json = "train.json"
valid_json = "valid.json"
test_json = "test.json"

train_file = "compatibility_train.txt"
valid_file = "compatibility_valid.txt"
test_file = "compatibility_test.txt"
item_file = "polyvore_item_metadata.json"
outfit_file = "polyvore_outfit_titles.json"

In [4]:
with open(os.path.join(train_dir, train_json), 'r') as fr:
    train_pos = json.load(fr)
    
with open(os.path.join(train_dir, valid_json), 'r') as fr:
    valid_pos = json.load(fr)
    
with open(os.path.join(train_dir, test_json), 'r') as fr:
    test_pos = json.load(fr)
    
with open(os.path.join(base_dir, item_file), 'r') as fr:
    pv_items = json.load(fr)
    
with open(os.path.join(base_dir, outfit_file), 'r') as fr:
    pv_outfits = json.load(fr)

with open(os.path.join(train_dir, train_file), 'r') as fr:
    train_X, train_y = [], []
    for line in fr:
        elems = line.strip().split()
        train_y.append(elems[0])
        train_X.append(elems[1:])

with open(os.path.join(train_dir, valid_file), 'r') as fr:
    valid_X, valid_y = [], []
    for line in fr:
        elems = line.strip().split()
        valid_y.append(elems[0])
        valid_X.append(elems[1:])

with open(os.path.join(train_dir, test_file), 'r') as fr:
    test_X, test_y = [], []
    for line in fr:
        elems = line.strip().split()
        test_y.append(elems[0])
        test_X.append(elems[1:])


In [5]:
train_set = set()
for outfit in train_pos:
    items = [x['item_id'] for x in outfit['items']]
    train_set |= set(items)
print(f"Total {len(train_set)} items in the train data")

valid_set = set()
for outfit in valid_pos:
    items = [x['item_id'] for x in outfit['items']]
    valid_set |= set(items)
print(f"Total {len(valid_set)} items in the valid data")
print(f"{len(valid_set.intersection(train_set))} common items between train and validation set")

test_set = set()
for outfit in test_pos:
    items = [x['item_id'] for x in outfit['items']]
    test_set |= set(items)
print(f"Total {len(test_set)} items in the test data")
print(f"{len(test_set.intersection(train_set))} common items between train and test set")

Total 71967 items in the train data
Total 14657 items in the valid data
3781 common items between train and validation set
Total 70035 items in the test data
84 common items between train and test set


In [6]:
all_item_categories = set([pv_items[item]['category_id'] for item in pv_items])
len(all_item_categories)

153

In [7]:
label_renum_dict = {}
for ii, k in enumerate(all_item_categories):
    label_renum_dict[k] = ii

## Load all the embeddings

In [8]:
with open("effnet2_polyvore.pkl", "rb") as fr:
    image_embedding = pickle.load(fr)
    
with open("bert_polyvore.pkl", "rb") as fr:
    text_embedding = pickle.load(fr)

with open("graphsage_dict_polyvore.pkl", "rb") as fr:
    graphsage_embedding = pickle.load(fr)


In [9]:
train_X1, train_X2 = [], []
train_Y = []
train_item_list = []
count = 0
for item in tqdm(train_set):
    train_item_list.append(item)
    train_X1.append(image_embedding[item].numpy())
    train_X2.append(label_renum_dict[pv_items[item]['category_id']])
    train_Y.append(graphsage_embedding[item])
    count += 1
train_X1 = np.array(train_X1)
train_X2 = np.array(train_X2)
train_Y = np.array(train_Y)

100%|██████████| 71967/71967 [00:00<00:00, 92491.62it/s] 


## Build an Embedding Mapping Model

In [14]:
image_dim = 1280
hidden_dim = 256
out_dim = 50

in1 = tf.keras.layers.Input(shape=(image_dim))
in2 = tf.keras.layers.Input(shape=(1))
x2 = tf.keras.layers.Embedding(153, 100)(in2)
x2 = tf.squeeze(x2, -2)
x3 = tf.keras.layers.concatenate([in1, x2], axis=-1)
out = tf.keras.layers.Dense(hidden_dim, activation="relu")(x3)
out = tf.keras.layers.Dense(out_dim, activation="linear")(out)
model = tf.keras.models.Model(inputs=[in1, in2], outputs=out)
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 1, 100)       15300       input_6[0][0]                    
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None, 1280)]       0                                            
__________________________________________________________________________________________________
tf_op_layer_Squeeze_2 (TensorFl [(None, 100)]        0           embedding_2[0][0]                
____________________________________________________________________________________________

In [18]:
learning_rate = 1.0e-04
batch_size = 128
epochs = 50
patience = 10

opt = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(loss='mse', optimizer=opt, metrics=["mse", "mae"])  # 'adam'
callback = EarlyStopping(
                        monitor="val_accuracy",
                        min_delta=0,
                        patience=patience,
                        verbose=0,
                        mode="auto",
                        baseline=None,
                        restore_best_weights=True,
                    )
tic = time.time()
history = model.fit([train_X1, train_X2], train_Y, 
                    epochs=epochs, 
                    batch_size=batch_size,
                    steps_per_epoch=math.ceil(train_X1.shape[0]/batch_size),
#                     validation_data=(val_X, val_y),
#                     callbacks=[callback, model_checkpoint_callback],
                    verbose=1)
time.time() - tic

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


107.41560554504395

In [19]:
valid_X1, valid_X2 = [], []
valid_item_list = []
for item in valid_set:
    valid_item_list.append(item)
    valid_X1.append(image_embedding[item].numpy())
    valid_X2.append(label_renum_dict[pv_items[item]['category_id']])
valid_X1 = np.array(valid_X1)
valid_X2 = np.array(valid_X2)

In [20]:
valid_Y = model([valid_X1, valid_X2])

In [21]:
valid_Y.shape

TensorShape([14657, 50])

## Write the Model Prediction

 - Keep the embedding of the training items same as before
 - Update only the new items present in the validation set

In [22]:
new_graphsage_dict = {}
for item in train_set:
    new_graphsage_dict[item] = graphsage_embedding[item]

for item in valid_set:
    if item not in train_set:
        jj = valid_item_list.index(item)
        new_graphsage_dict[item] = valid_Y[jj]
        
with open("graphsage_dict2_polyvore.pkl", "wb") as output_file:
    pickle.dump(new_graphsage_dict, output_file)