In [None]:
import sys
sys.path += ["../"]

In [None]:
from shared.evaluation import utils
from vsepp.model import VSE
from vsepp.vocab import Vocabulary
import os
import torch
import pickle
from vsepp.data import get_transform, get_loader_single, collate_fn, get_paths
from torch.autograd import Variable
from shared.losses import Triplet, NTXent, SmoothAP
from shared.evaluation.utils import GradStats
from collections import defaultdict
import numpy as np

In [None]:
model_path = '{path to repo}/out/vsepp/out/f30k/paper_experiments/ntxent/0/model_best.pth.tar'

In [None]:
checkpoint = torch.load(model_path)
opt = checkpoint['opt']


# load vocabulary used by the model
with open(os.path.join(opt.vocab_path,
                       '%s_vocab.pkl' % opt.data_name), 'rb') as f:
    vocab = pickle.load(f)
opt.vocab_size = len(vocab)

# construct model
model = VSE(opt)

# load model state
model.load_state_dict(checkpoint['model'])

In [None]:
transform = get_transform(opt.data_name, 'val', opt)
files, ids = get_paths(opt.data_path, opt.data_name, opt.use_restval)
data_loader =   get_loader_single(opt.data_name, 'train',
                                     files['train'],
                                     vocab, transform, ids=ids['train'],
                                     batch_size=opt.batch_size, shuffle=True,
                                     num_workers=opt.workers,
                                     collate_fn=collate_fn, ranking_based=opt.ranking_based,  n_sp=opt.n_sp)

In [None]:
model.val_start()

triplet = Triplet(count_grads=True, margin=opt.margin, max_violation=opt.max_violation)
ntxent = NTXent(tau=opt.tau)
smoothap = SmoothAP()
stats = GradStats()

for i, data in enumerate(data_loader):
    images, captions, lengths, ids = data
    img_emb, cap_emb  = model.forward_emb(images, captions, lengths)
    
    #out = triplet(img_emb, cap_emb)
    
    out = ntxent(img_emb, cap_emb, count_gradients=True,  threshold=0.01)
    
    # i2t, t2i = smoothap(img_emb, cap_emb, count_gradients=True)
    # stats.add_stats(i2t)
    # stats.add_stats(t2i)
    
    stats.add_stats(out)
stats.print_stats()

In [None]:
for key, value in stats.data.items():
    print("%s: %.2f $\pm$  %.2f"%(key, np.mean(value), np.std(value)))

In [None]:
1-s_i2t.diag.mean: 0.56 $\pm$  0.02
1-s_t2i.diag.mean: 0.56 $\pm$  0.02
i2t_low_grad: 0.14 $\pm$  0.01
i2t_high_grad: 0.42 $\pm$  0.02
i2t_n_high_grad: 9.88 $\pm$  0.53
t2i_low_grad: 0.14 $\pm$  0.01
t2i_high_grad: 0.42 $\pm$  0.02
t2i_n_high_grad: 9.65 $\pm$  0.51