In [1]:
import torch
from PIL import Image
from tqdm.auto import tqdm
import numpy as np
import os
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
import json

data_dir = "data/polyvore/"
out_dir = "data/Re-PolyVore/all_together/"
with open(f"{data_dir}train_no_dup.json", "r", encoding="utf-8") as f:
    outfits = json.load(f)

In [3]:
outfits_dict = {}
for i in range(len(outfits)):
    outfits_dict.update({outfits[i]["set_id"]: {"ind": i}})


for item in tqdm(os.listdir(out_dir)):
    try:
        outfit, suff = item.split("_", 1)
        num = int(suff.split(".")[0])
        if outfits_dict.get(outfit) is not None:
            item_dict = outfits[outfits_dict[outfit]["ind"]]["items"][num-1]
            if outfits_dict[outfit].get("items") is None:
                outfits_dict[outfit].update({"items": [item_dict | {"path": out_dir+item}]})
            else:
                outfits_dict[outfit]["items"].append(item_dict | {"path": out_dir+item})
    except: continue

  0%|          | 0/126928 [00:00<?, ?it/s]

In [4]:
# with open("data/Re-PolyVore_encoded_with_OutfitTransformer.json", "w", encoding="utf-8") as f:
#     json.dump(outfits_dict, f, indent=4)

In [5]:
weights_path = "OutfitTransformer/checkpoints/2_0.923.pth"
outfit_transformer_weights = torch.load(weights_path)

In [6]:
from OutfitTransformer.model.encoder import ItemEncoder
from OutfitTransformer.model.model import OutfitTransformer
from transformers import AutoTokenizer

model = ItemEncoder(embedding_dim=128).to(device)
model.load_state_dict(outfit_transformer_weights['encoder_state_dict'])
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-albert-small-v2')



In [7]:
with open("data/Re-PolyVore_encoded_with_OutfitTransformer.json", "r", encoding="utf-8") as f:
    outfits_dict = json.load(f)

In [8]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
img_size = 224

def load_img(path):
    transform = A.Compose([A.Resize(img_size, img_size),
                           A.Normalize(),
                           ToTensorV2()])
    
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = transform(image=img)['image']
    return img


model.eval()
with torch.no_grad(): 
    
    # with open("data/Re-PolyVore_encoded_with_OutfitTransformer.json", "r", encoding="utf-8") as f:
    #     outfits_dict = json.load(f)
        
    for ind, (id, outfit) in enumerate(tqdm(outfits_dict.items())):
        if outfit.get("processed") is None:
            if outfit.get("items") is not None:
                batch_size = len(outfit["items"])
                
                img_paths = [item["path"] for item in outfit["items"]]
                images = torch.stack([load_img(path) for path in img_paths], 0)

                input_ids, _, attention_mask, *_ = tokenizer([item['name'] for item in outfit["items"]], 
                                      max_length=16, 
                                      padding='max_length', 
                                      truncation=True, 
                                      return_tensors='pt').values()
                embeds = model(images.to(device), input_ids.to(device), attention_mask.to(device))
                
                for item, e in zip(outfit["items"], embeds):
                    item["embedding"] = e.cpu().tolist()
                    
                outfit["processed"] = "True"
                
        if (ind + 1) % 500 == 0:  
            with open("data/Re-PolyVore_encoded_with_OutfitTransformer.json", "w", encoding="utf-8") as f:
                json.dump(outfits_dict, f, indent=4)
                
    with open("data/Re-PolyVore_encoded_with_OutfitTransformer.json", "w", encoding="utf-8") as f:
        json.dump(outfits_dict, f, indent=4)

  0%|          | 0/17316 [00:00<?, ?it/s]