# Few-Shot Learning 

Apply Few-shot learning paradigm for outfit compatibility learning from few examples. 

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import json
import math
import numpy as np
import pickle
import random
import time
from tqdm import tqdm
from datetime import datetime
from prettytable import PrettyTable

import torch
import torch.utils.data as torch_data
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

from tensorflow.random import set_seed
from numpy.random import seed
from numpy.random import default_rng

import shap

%matplotlib inline

In [3]:
base_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/polyvore_outfits"
data_type = "nondisjoint" # "nondisjoint", "disjoint"
train_dir = os.path.join(base_dir, data_type)
image_dir = os.path.join(base_dir, "images")
embed_dir = "/recsys_data/RecSys/fashion/polyvore-dataset/precomputed"
train_json = "train.json"
valid_json = "valid.json"
test_json = "test.json"

train_file = "compatibility_train.txt"
valid_file = "compatibility_valid.txt"
test_file = "compatibility_test.txt"
item_file = "polyvore_item_metadata.json"
outfit_file = "polyvore_outfit_titles.json"

model_type = "rnn" #"set-transformer"
include_text = True
batch_size = 32

Read all the required files

In [4]:
with open(os.path.join(train_dir, train_json), 'r') as fr:
    train_pos = json.load(fr)
    
with open(os.path.join(train_dir, valid_json), 'r') as fr:
    valid_pos = json.load(fr)
    
with open(os.path.join(train_dir, test_json), 'r') as fr:
    test_pos = json.load(fr)
    
with open(os.path.join(base_dir, item_file), 'r') as fr:
    pv_items = json.load(fr)
    
with open(os.path.join(base_dir, outfit_file), 'r') as fr:
    pv_outfits = json.load(fr)

print(f"Total {len(train_pos)}, {len(valid_pos)}, {len(test_pos)} outfits in train, validation and test split, respectively")

Total 53306, 5000, 10000 outfits in train, validation and test split, respectively


In [5]:
with open(os.path.join(train_dir, train_file), 'r') as fr:
    train_X, train_y = [], []
    for line in fr:
        elems = line.strip().split()
        train_y.append(elems[0])
        train_X.append(elems[1:])

with open(os.path.join(train_dir, valid_file), 'r') as fr:
    valid_X, valid_y = [], []
    for line in fr:
        elems = line.strip().split()
        valid_y.append(elems[0])
        valid_X.append(elems[1:])

with open(os.path.join(train_dir, test_file), 'r') as fr:
    test_X, test_y = [], []
    for line in fr:
        elems = line.strip().split()
        test_y.append(elems[0])
        test_X.append(elems[1:])

print(f"Total {len(train_X)}, {len(valid_X)}, {len(test_X)} examples in train, validation and test split, respectively")

Total 106612, 10000, 20000 examples in train, validation and test split, respectively


In [6]:
item_dict = {}
for ii, outfit in enumerate(train_pos):
    items = outfit['items']
    mapped = train_X[ii]
    item_dict.update({jj:kk['item_id'] for jj, kk in zip(mapped, items)})
print(len(item_dict))

for ii, outfit in enumerate(valid_pos):
    items = outfit['items']
    mapped = valid_X[ii]
    item_dict.update({jj:kk['item_id'] for jj, kk in zip(mapped, items)})
print(len(item_dict))

for ii, outfit in enumerate(test_pos):
    items = outfit['items']
    mapped = test_X[ii]
    item_dict.update({jj:kk['item_id'] for jj, kk in zip(mapped, items)})
print(len(item_dict))

284767
311548
365054


In [7]:
model_type = "transformer" # "set-transformer", "rnn"
include_text = True
use_graphsage = False
batch_size = 32
max_seq_len = 8
d_model_rnn = 512
image_data_type = "embedding"  # "original", "embedding", "both"
include_item_categories = True
image_encoder = "resnet18"  # "resnet50", "vgg16", "inception"

if use_graphsage:
    image_embedding_dim, image_embedding_file = (50, os.path.join(embed_dir, "graphsage_dict2_polyvore.pkl"))
#         image_embedding_dim, image_embedding_file = (256, os.path.join(embed_dir, "graphsage_dict2_polyvore_nondisjoint.pkl"))
else:
    image_embedding_dim, image_embedding_file = (1280, os.path.join(embed_dir, "effnet_tuned_polyvore.pkl"))
#         image_embedding_dim, image_embedding_file = (256, os.path.join(embed_dir, "triplet_polyvore_image.pkl"))
    
text_embedding_dim, text_embedding_file = (768, os.path.join(embed_dir, "bert_polyvore.pkl"))
num_support, num_query = 100, 100
num_episodes = 1
num_train_epochs = 20000
num_test_samples = 1000
learning_rate = 1e-05

In [8]:
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params


In [9]:
from few_shot_models import SimpleProtoTypeModel, proto_loss, get_prototypes
from torch.optim import AdamW
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
print("Device:", device)

# Model only returns the embeddings
model = SimpleProtoTypeModel(num_layers=1,
                             d_model=64,
                             num_heads=4,
                             dff=32,
                             rate=0.0,
                             image_data_type=image_data_type,
                             include_text=include_text,
                             include_item_categories=include_item_categories,
                             num_categories=154,
                             embedding_activation="linear",
                             encoder_activation="relu",
                             embedding_dim=64,
                             max_seq_len=max_seq_len,
                             num_classes=2,
                             use_rnn=False,
                             device=device,
                            )
model.to(device)
count_parameters(model)

Device: cuda
+--------------------------------------------+------------+
|                  Modules                   | Parameters |
+--------------------------------------------+------------+
|          image_projector.0.weight          |   81920    |
|           image_projector.0.bias           |     64     |
|          text_projector.0.weight           |   49152    |
|           text_projector.0.bias            |     64     |
|          category_embedder.weight          |    9856    |
| encoder.layers.0.self_attn.in_proj_weight  |   110592   |
|  encoder.layers.0.self_attn.in_proj_bias   |    576     |
| encoder.layers.0.self_attn.out_proj.weight |   36864    |
|  encoder.layers.0.self_attn.out_proj.bias  |    192     |
|      encoder.layers.0.linear1.weight       |    6144    |
|       encoder.layers.0.linear1.bias        |     32     |
|      encoder.layers.0.linear2.weight       |    6144    |
|       encoder.layers.0.linear2.bias        |    192     |
|       encoder.layers.0.no

401312

Select n_support outfits for learning two prototypes (0 and 1 class) and evaluate on n_query outfits

In [10]:
def get_balanced_samples(x, y, num_samples):
    """
    make sure that both the classes are present in the samples
    """
    index_0 = [ii for ii in range(len(y)) if y[ii]=='0']
    index_1 = [ii for ii in range(len(y)) if y[ii]=='1']
    sample_0 = rng.choice(index_0, size=num_samples//2, replace=False)
    sample_1 = rng.choice(index_1, size=num_samples//2, replace=False)
    indices = []
    # interleave so that both query and support have the same distribution
    for ii in range(len(sample_0)):
        indices.append(sample_0[ii])
        indices.append(sample_1[ii])
    return indices

In [11]:
from utils_torch import CustomDataset

rng = default_rng()
sample_indices = get_balanced_samples(valid_X, valid_y, num_test_samples)
tst_x = [valid_X[ii] for ii in sample_indices]
tst_y = [valid_y[ii] for ii in sample_indices]

valid_set = CustomDataset(tst_x, 
                          tst_y, 
                          item_dict, 
                          pv_items, 
                          image_dir=image_dir, 
                          batch_size=batch_size,
                          max_len=max_seq_len,
                          only_image=not include_text,
                          image_embedding_dim=image_embedding_dim,
                          image_embedding_file=image_embedding_file,
                          text_embedding_file=text_embedding_file,
                          number_items_in_batch=150,
                          variable_length_input=True,
                          text_embedding_dim=text_embedding_dim,
                          include_item_categories=include_item_categories,
                          image_data=image_data_type,
                          input_size=(3, 224, 224),
                         )

eval_sampler = SequentialSampler(valid_set)
eval_dataloader = DataLoader(valid_set,
                             sampler=eval_sampler,
                             batch_size=256)    

In [12]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
global_step = 0
num_samples = num_support + num_query

for tn in range(num_episodes):
    sample_indices = get_balanced_samples(train_X, train_y, num_samples)
    tr_x = [train_X[ii] for ii in sample_indices]
    tr_y = [train_y[ii] for ii in sample_indices]

    train_set = CustomDataset(tr_x, tr_y,
                              item_dict, 
                              pv_items, 
                              image_dir=image_dir, 
                              batch_size=num_samples,
                              max_len=max_seq_len,
                              only_image=not include_text,
                              image_embedding_dim=image_embedding_dim,
                              image_embedding_file=image_embedding_file,
                              text_embedding_file=text_embedding_file,
                              number_items_in_batch=150,
                              variable_length_input=True,
                              text_embedding_dim=text_embedding_dim,
                              include_item_categories=include_item_categories,
                              image_data=image_data_type,
                              input_size=(3, 224, 224),
                             )
    
    all_sampler = RandomSampler(train_set)
    all_dataloader = DataLoader(train_set,
                                sampler=all_sampler,
                                batch_size=num_samples)

    batch_x, batch_y = next(iter(all_dataloader))

    model.train()
    # train for the current task
    pbar = tqdm(range(num_train_epochs))
#     pbar = tqdm(range(1000))
    losses = []
    for epoch in pbar:
        
        batch_d = [x.to(device) for x in batch_x]
        outputs = model(batch_d)
        model.zero_grad()
        loss = proto_loss(outputs, batch_y.to(device), num_support)
        pbar.set_description("Loss %g" % loss.item())
        losses.append(loss.item())
        
        # backward pass to get the gradients
        loss.backward()

        # update
        optimizer.step()
        optimizer.zero_grad()
        global_step += 1
    
    print(f"final loss: {loss.item():.3f}")

# get the prototypes from the support set
train_outputs = model(batch_d)
prototypes, classes = get_prototypes(train_outputs, batch_y, num_support)
print(prototypes.shape, classes)

Loss 0.332929: 100%|██████████| 20000/20000 [12:18<00:00, 27.07it/s]


final loss: 0.333
torch.Size([2, 64]) tensor([0., 1.])


In [None]:
plt.plot(losses)

In [13]:
from few_shot_models import evaluate

evaluate(model, eval_dataloader, prototypes, device)

Evaluating: 100%|██████████| 4/4 [00:00<00:00,  5.83it/s]


{'loss': 2.5297014117240906,
 'precision': 0.5099099099099099,
 'recall': 0.566,
 'f1': 0.5364928909952607,
 'auc': 0.504934}

In [14]:
valid_set = CustomDataset(valid_X, 
                          valid_y, 
                          item_dict, 
                          pv_items, 
                          image_dir=image_dir, 
                          batch_size=batch_size,
                          max_len=max_seq_len,
                          only_image=not include_text,
                          image_embedding_dim=image_embedding_dim,
                          image_embedding_file=image_embedding_file,
                          text_embedding_file=text_embedding_file,
                          number_items_in_batch=150,
                          variable_length_input=True,
                          text_embedding_dim=text_embedding_dim,
                          include_item_categories=include_item_categories,
                          image_data=image_data_type,
                          input_size=(3, 224, 224),
                         )

eval_sampler = SequentialSampler(valid_set)
eval_dataloader = DataLoader(valid_set,
                             sampler=eval_sampler,
                             batch_size=256)

evaluate(model, eval_dataloader, prototypes, device)

Evaluating: 100%|██████████| 40/40 [00:06<00:00,  5.91it/s]


{'loss': 2.6250036537647246,
 'precision': 0.5038532110091744,
 'recall': 0.5492,
 'f1': 0.5255502392344498,
 'auc': 0.4996244}