In [1]:
from types import SimpleNamespace
import torch
import os
import json
from torchvision import transforms
import Resnet_18
from polyvore_outfits import TripletImageLoader
from tripletnet import Tripletnet
from type_specific_network import TypeSpecificNet
import pickle
import torch.nn.functional as F

gpu = 1

global args
args = SimpleNamespace(
    batch_size = 256,
    seed = 1,
    cuda = torch.device(f"cuda:{gpu}"),
    dim_embed = 64,
    use_fc = True,
    datadir="/home/fteotini/thesis",
    margin=0.3,
    resume="/home/fteotini/thesis/type_aware/runs/nondis/model_best.pth.tar",
    polyvore_split="nondisjoint",
    rand_typespaces=False,
    num_rand_embed=4,
    learned=False,
    prein=False,
    l2_embed=False,
    learned_metric=False
)

torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

In [2]:
fn = os.path.join(args.datadir, "polyvore_outfits", "polyvore_item_metadata.json")
with open(fn, "r") as fn_file:
    meta_data = json.load(fn_file)

transform = transforms.Compose([
    transforms.Resize(112),
    transforms.CenterCrop(112),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = TripletImageLoader(
    args,
    "test",
    meta_data,
    transform=transform
)

In [3]:
model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)
csn_model = TypeSpecificNet(args, model, len(dataset.typespaces))

criterion = torch.nn.MarginRankingLoss(margin=args.margin)
tnet = Tripletnet(args, csn_model, 6000, criterion).cuda(args.cuda)

In [4]:
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint["epoch"]
best_acc = checkpoint["best_prec1"]
tnet.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [5]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size,shuffle= False,num_workers=20,pin_memory=True)

In [6]:
tnet.eval()
embeddings = []

with torch.no_grad():
    for batch_idx, images in enumerate(dataloader):
        images = images.cuda(args.cuda)
        embeddings.append(tnet.embeddingnet(images))

    embeddings = torch.cat(embeddings)

In [7]:
import pickle

with open("sip_data_nondisjoint.pkl",'rb') as f:
    sip_data = pickle.load(f)

In [8]:
results = []

for o in sip_data:
    questions = o['src']
    pos = o['ground_truth']
    negs = o['negs']
    for i in range(len(questions)):
        seed = questions[:i + 1]
        answers = [pos[i], *negs[i]]
        scores = torch.zeros(len(answers), dtype=torch.float)
        for idx, answer in enumerate(answers):
            answer_type = dataset.im2type[answer]
            score = 0.0
            for s in seed:
                s_type = dataset.im2type[s]
                condition = dataset.get_typespace(s_type, answer_type)
                embed_ans = embeddings[dataset.im2index[answer]][condition].unsqueeze(0)
                embed_s = embeddings[dataset.im2index[s]][condition].unsqueeze(0)
                
                score += F.pairwise_distance(embed_ans, embed_s, 2)

            scores[idx] = score.squeeze()
        
        results.append(torch.argmin(scores) == 0)

In [9]:
torch.cat(list(map(lambda i: i.view(1),results))).sum() /len(results)

tensor(0.4196)