# Enhance Item Features by Graph Neural Network

    - use the image embedding vector (from EfficientNet) and text embedding vectors (from BERT)
    - use Graph Attention Network for subsequent feature transformation
    - use outfit definition as edges (two items are connected if they are part of the same outfit), use only the training data
    - use category-id as node label

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import json
import math
import numpy as np
import pickle
import random
import time
from tqdm import tqdm
from datetime import datetime

import torch
import torch.utils.data as torch_data
import torch.nn as nn
from torch import Tensor
# from torch_geometric.nn import GCNConv

from tensorflow.random import set_seed
from numpy.random import seed
import shap

%matplotlib inline

In [3]:
from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GATConv, GCSConv, GlobalAvgPool
from spektral.transforms import LayerPreprocess
from spektral.data import Dataset, DisjointLoader, Graph
from spektral.transforms.normalize_adj import NormalizeAdj

In [8]:
from spektral.datasets.citation import Citation
from spektral.layers import GCNConv
from spektral.models.gcn import GCN
from spektral.transforms import LayerPreprocess

In [9]:
dataset = Citation("cora", normalize_x=True, transforms=[LayerPreprocess(GCNConv)])

Downloading cora dataset.
Pre-processing node features


Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.


In [10]:
dir(dataset)

['__add__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'a',
 'apply',
 'available_datasets',
 'download',
 'dtype',
 'filter',
 'graphs',
 'map',
 'mask_te',
 'mask_tr',
 'mask_va',
 'n_edge_features',
 'n_graphs',
 'n_labels',
 'n_node_features',
 'n_nodes',
 'name',
 'normalize_x',
 'path',
 'random_split',
 'read',
 'signature',
 'suffixes',
 'url']

In [13]:
dataset.n_graphs, dataset.n_nodes, dataset.n_node_features, dataset.n_labels

(1, 2708, 1433, 7)

In [14]:
dataset.graphs

[Graph(n_nodes=2708, n_node_features=1433, n_edge_features=None, n_labels=7)]

## Create node and edge matrix from Outfit data

In [3]:
base_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/polyvore_outfits"
data_type = "nondisjoint" # "disjoint"
train_dir = os.path.join(base_dir, data_type)
image_dir = os.path.join(base_dir, "images")
train_json = "train.json"
valid_json = "valid.json"
test_json = "test.json"

train_file = "compatibility_train.txt"
valid_file = "compatibility_valid.txt"
test_file = "compatibility_test.txt"
item_file = "polyvore_item_metadata.json"
outfit_file = "polyvore_outfit_titles.json"

In [4]:
with open(os.path.join(train_dir, train_json), 'r') as fr:
    train_pos = json.load(fr)
    
with open(os.path.join(train_dir, valid_json), 'r') as fr:
    valid_pos = json.load(fr)
    
with open(os.path.join(train_dir, test_json), 'r') as fr:
    test_pos = json.load(fr)
    
with open(os.path.join(base_dir, item_file), 'r') as fr:
    pv_items = json.load(fr)
    
with open(os.path.join(base_dir, outfit_file), 'r') as fr:
    pv_outfits = json.load(fr)


In [5]:
with open(os.path.join(train_dir, train_file), 'r') as fr:
    train_X, train_y = [], []
    for line in fr:
        elems = line.strip().split()
        train_y.append(elems[0])
        train_X.append(elems[1:])

with open(os.path.join(train_dir, valid_file), 'r') as fr:
    valid_X, valid_y = [], []
    for line in fr:
        elems = line.strip().split()
        valid_y.append(elems[0])
        valid_X.append(elems[1:])

with open(os.path.join(train_dir, test_file), 'r') as fr:
    test_X, test_y = [], []
    for line in fr:
        elems = line.strip().split()
        test_y.append(elems[0])
        test_X.append(elems[1:])


In [6]:
print(f"{len(train_pos)} outfits in train, {len(valid_pos)} outfits in validation and {len(test_pos)} outfits in the test data")

53306 outfits in train, 5000 outfits in validation and 10000 outfits in the test data


## Create a dict that maps to the original item-id

In [7]:
item_dict = {}
for ii, outfit in enumerate(train_pos):
    items = outfit['items']
    mapped = train_X[ii]
    item_dict.update({jj:kk['item_id'] for jj, kk in zip(mapped, items)})
print("Train data:", len(item_dict))

for ii, outfit in enumerate(valid_pos):
    items = outfit['items']
    mapped = valid_X[ii]
    item_dict.update({jj:kk['item_id'] for jj, kk in zip(mapped, items)})
print("Train & Validation data:", len(item_dict))

for ii, outfit in enumerate(test_pos):
    items = outfit['items']
    mapped = test_X[ii]
    item_dict.update({jj:kk['item_id'] for jj, kk in zip(mapped, items)})
print("Train, validation and test data", len(item_dict))

Train data: 284767
Train & Validation data: 311548
Train, validation and test data 365054


In [8]:
train_set = set()
for outfit in train_pos:
    items = [x['item_id'] for x in outfit['items']]
    train_set |= set(items)
print(f"Total {len(train_set)} items in the train data")

valid_set = set()
for outfit in valid_pos:
    items = [x['item_id'] for x in outfit['items']]
    valid_set |= set(items)
print(f"Total {len(valid_set)} items in the valid data")
print(f"{len(valid_set.intersection(train_set))} common items between train and validation set")

test_set = set()
for outfit in test_pos:
    items = [x['item_id'] for x in outfit['items']]
    test_set |= set(items)
print(f"Total {len(test_set)} items in the test data")
print(f"{len(test_set.intersection(train_set))} common items between train and test set")

Total 204679 items in the train data
Total 25132 items in the valid data
9356 common items between train and validation set
Total 47854 items in the test data
16655 common items between train and test set


In [10]:
embed_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/precomputed"
with open(os.path.join(embed_dir, "effnet2_polyvore.pkl"), "rb") as fr:
    image_embedding = pickle.load(fr)
    
with open(os.path.join(embed_dir, "bert_polyvore.pkl"), "rb") as fr:
    text_embedding = pickle.load(fr)
  

In [11]:
len(image_embedding)

261057

In [12]:
all_item_categories = set([pv_items[item]['category_id'] for item in pv_items])
len(all_item_categories)

153

In [13]:
label_renum_dict = {}
for ii, k in enumerate(all_item_categories):
    label_renum_dict[k] = ii

node file has all the items (written as outfit-i_j) for the i-th outfit and j-th item

In [14]:
node_file, edge_file = f"nodes_{data_type}.txt", f"edges_{data_type}.txt"
fn = open(os.path.join(train_dir, node_file), 'w')

item_number_dict = {}  # converts item number to integers
count = 0
for item in tqdm(item_dict):
    item_number_dict[item] = count
    item_name = item_dict[item]
    x_img = image_embedding[item_name].numpy().tolist()
#     x_txt = text_embedding[item_name].tolist()
    label = pv_items[item_name]['category_id']
    label = label_renum_dict[label]
    
    out = [count] + x_img + [label]
    out = "\t".join([str(x) for x in out]) + "\n"
    fn.write(out)
    count += 1

fn.close()

100%|██████████| 365054/365054 [07:25<00:00, 819.52it/s]


In [11]:
train_X[0]

['199244701_1',
 '199244701_2',
 '199244701_3',
 '199244701_4',
 '199244701_5',
 '199244701_6']

## Create edges based on Outfit Items

    - if there are N items in an outfit then create N(N-1)/2 edges

In [15]:
import itertools

for comb in itertools.combinations([1,2,3, 4], 2):
    print(comb)

(1, 2)
(1, 3)
(1, 4)
(2, 3)
(2, 4)
(3, 4)


In [16]:
edge_file_train = f"edges_{data_type}_train.txt"
fw = open(os.path.join(train_dir, edge_file_train), 'w')
count = 0
for ii in tqdm(range(len(train_pos))):
    items = [item_number_dict[k] for k in train_X[ii]]
    for comb in itertools.combinations(items, 2):
        src, tgt = comb[1], comb[0]
        fw.write("\t".join([str(tgt), str(src)]) + "\n")
        count += 1
fw.close()
print(f"Total {count} edges written")

100%|██████████| 53306/53306 [00:00<00:00, 54380.45it/s]

Total 686851 edges written





In [17]:
edge_file_valid = f"edges_{data_type}_valid.txt"
fw = open(os.path.join(train_dir, edge_file_valid), 'w')
count = 0
for ii in tqdm(range(len(valid_pos))):
    items = [item_number_dict[k] for k in valid_X[ii]]
    for comb in itertools.combinations(items, 2):
        src, tgt = comb[1], comb[0]
        fw.write("\t".join([str(tgt), str(src)]) + "\n")
        count += 1
fw.close()
print(f"Total {count} edges written")

100%|██████████| 5000/5000 [00:00<00:00, 54245.76it/s]

Total 64925 edges written





In [18]:
edge_file_test = f"edges_{data_type}_test.txt"
fw = open(os.path.join(train_dir, edge_file_test), 'w')
count = 0
for ii in tqdm(range(len(test_pos))):
    items = [item_number_dict[k] for k in test_X[ii]]
    for comb in itertools.combinations(items, 2):
        src, tgt = comb[1], comb[0]
        fw.write("\t".join([str(tgt), str(src)]) + "\n")
        count += 1
fw.close()
print(f"Total {count} edges written")

100%|██████████| 10000/10000 [00:00<00:00, 53702.28it/s]

Total 129589 edges written





## Load GraphSage Embeddings

In [19]:
import pickle

with open(os.path.join(embed_dir, "graphsage_polyvore_nondisjoint.pkl"), "rb") as fr:
    gs_embed = pickle.load(fr)
gs_embed.shape

(365054, 256)

In [20]:
# from graphsage number to outfit-item
# example, id2item[365053] = '209553625_5'
# and then, item_dict['209553625_5'] = '172852191' (original item-id)
id2item = {}
for key, value in item_number_dict.items():
    id2item[value] = key


In [21]:
key, value

('209553625_5', 365053)

In [22]:
item_dict[key]

'172852191'

Create a dictionary of item embeddings

In [23]:
graphsage_dict = {}
for ii in range(gs_embed.shape[0]):
    item = item_dict[id2item[ii]]
    graphsage_dict[item] = gs_embed[ii]

with open(os.path.join(embed_dir, "graphsage_dict_polyvore_nondisjoint.pkl"), "wb") as output_file:
    pickle.dump(graphsage_dict, output_file)

In [4]:
import networkx as nx
import pandas as pd
import numpy as np
import os
import pickle
import random

import stellargraph as sg
from stellargraph import StellarGraph
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import GraphSAGELinkGenerator
from stellargraph.mapper import GraphSAGENodeGenerator
from stellargraph.layer import GraphSAGE, link_classification
from stellargraph.data import UniformRandomWalk
from stellargraph.data import UnsupervisedSampler
from sklearn.model_selection import train_test_split

from tensorflow import keras
from sklearn import preprocessing, feature_extraction, model_selection
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.metrics import accuracy_score

from stellargraph import globalvar

from stellargraph import datasets
from IPython.display import display, HTML

In [5]:
edges = pd.read_csv(
    os.path.join(train_dir, 'edges.txt'),
    sep="\t",  # tab-separated
    header=None,  # no heading row
    names=["target", "source"],  # set our own names for the columns
)[["target", "source"]]
edges

Unnamed: 0,target,source
0,1,0
1,2,1
2,3,2
3,4,3
4,5,4
...,...,...
68935,85929,85928
68936,85930,85929
68937,85931,85930
68938,85932,85931


In [6]:
num_features = 1280
feature_names = [f"X{i}" for i in range(num_features)]

raw_content = pd.read_csv(
    os.path.join(train_dir, 'nodes.txt'),
    sep="\t",  # tab-separated
    header=None,  # no heading row
    names=["id", *feature_names, "label"],  # set our own names for the columns
)
raw_content

Unnamed: 0,id,X0,X1,X2,X3,X4,X5,X6,X7,X8,...,X1271,X1272,X1273,X1274,X1275,X1276,X1277,X1278,X1279,label
0,0,-0.062523,0.421099,-0.124768,-0.097375,-0.216441,0.169326,0.057159,-0.034298,0.431150,...,-0.011671,0.050762,-0.204316,1.638582,-0.008674,-0.030689,0.126104,-0.081300,-0.033844,148
1,1,-0.113803,0.108486,0.939266,0.255427,0.002270,-0.083422,-0.054675,0.147102,-0.133714,...,-0.070046,-0.167655,-0.092253,-0.040769,-0.103261,0.121185,-0.103870,0.386899,-0.035554,33
2,2,1.412980,-0.208100,0.713474,0.632122,-0.150961,0.043015,-0.116357,-0.071905,0.387462,...,-0.148249,0.384419,0.134001,0.348973,-0.111628,0.299475,-0.156931,-0.149539,-0.045462,83
3,3,0.550926,-0.192544,-0.174773,-0.017669,-0.130582,-0.190628,-0.139337,-0.100167,0.050681,...,-0.074378,-0.112425,0.293040,0.790908,-0.222358,-0.167452,-0.106885,-0.157713,0.116547,69
4,4,0.804715,-0.180687,-0.110375,1.339518,-0.122575,0.015235,-0.144777,-0.089144,-0.178694,...,-0.133782,-0.067391,-0.144996,0.961823,0.545530,1.131453,-0.156405,-0.135552,-0.169803,84
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85930,85930,-0.113226,-0.118045,-0.222307,-0.168248,-0.160915,-0.168635,-0.141841,0.052695,-0.166531,...,-0.086500,-0.179051,0.120281,-0.042114,-0.130227,-0.138519,-0.002534,0.526933,0.029293,32
85931,85931,-0.051178,-0.035326,-0.167132,-0.085977,-0.180994,1.791017,-0.138490,-0.042086,-0.166632,...,-0.140853,0.012947,-0.164001,0.382978,-0.088951,-0.089313,2.180066,-0.143095,-0.067517,144
85932,85932,0.469863,0.312989,0.850873,0.119212,-0.089325,-0.099193,-0.133834,1.077684,0.220273,...,-0.061910,-0.124991,-0.149757,1.123994,0.141367,0.225086,-0.125461,0.035168,-0.145158,75
85933,85933,0.302072,-0.124257,-0.031438,0.091247,0.318073,0.046275,0.037056,-0.079065,-0.092398,...,0.023317,-0.092640,-0.128483,0.627211,0.042906,0.749422,-0.143885,0.043744,-0.059547,46


In [83]:
# # filter for rare classes - causes problem
# from collections import Counter

# node_subjects = raw_content["label"]
# retain_classes = [ii[0] for ii in Counter(node_subjects).most_common(100)]

# raw_content = raw_content[raw_content.label.isin(retain_classes)]
# raw_content.shape

In [87]:
content_str_subject = raw_content.set_index("id")
content_no_subject = content_str_subject.drop(columns="label")
G = StellarGraph({"items": content_no_subject}, {"related": edges})
labels = content_str_subject["label"].copy()
print(G.info())

StellarGraph: Undirected multigraph
 Nodes: 85935, Edges: 68940

 Node types:
  items: [85935]
    Features: float32 vector, length 1280
    Edge types: items-related->items

 Edge types:
    items-related->items: [68940]
        Weights: all 1 (default)
        Features: none


In [75]:
from stellargraph.mapper import FullBatchNodeGenerator
from stellargraph.layer import GAT
from tensorflow.keras import layers, optimizers, losses, metrics, Model
from sklearn import preprocessing, feature_extraction, model_selection

In [None]:
train_subjects, test_subjects = model_selection.train_test_split(
    node_subjects, train_size=60000, test_size=None, stratify=node_subjects
)
val_subjects, test_subjects = model_selection.train_test_split(
    test_subjects, train_size=10000, test_size=None, stratify=test_subjects
)

In [88]:
generator = FullBatchNodeGenerator(G, method="gat")

In [66]:
dataset = datasets.Cora()
display(HTML(dataset.description))
Gc, node_subjects = dataset.load()

In [73]:
print(Gc.info())

StellarGraph: Undirected multigraph
 Nodes: 2708, Edges: 5429

 Node types:
  paper: [2708]
    Features: float32 vector, length 1433
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [5429]
        Weights: all 1 (default)
        Features: none
