In [1]:
import sys

In [2]:
%load_ext autoreload

In [3]:
%autoreload 1

In [4]:
import argparse
import time
import torch
import torchvision.transforms as transforms
import numpy as np
import os

from misc.dataset import CocoCaptionsRV, Multi30k
from misc.evaluation import eval_recall, eval_recall5
from misc.model import joint_embedding
from misc.utils import collate_fn_padded, collate_fn_cap_index
from torch.utils.data import DataLoader
import torch.utils.data as data

os.environ["CUDA_VISIBLE_DEVICES"]="3"

## Parameters

In [5]:
normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

prepro_val = transforms.Compose([
        transforms.Resize((400, 400)),
        transforms.ToTensor(),
        normalize,
    ])

In [6]:
class arguments:
    def __init__(self, dict):
        self.dict = '/data/m.portaz/'+dict

In [7]:
batch_size = 156

In [8]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

## Test Evaluation

In [8]:
def cosine_sim(A, B):
    """
        Return similarity of each image with each caption
        One line of the output matrix correspond to one image
        Each row correspond to one caption
    """
    img_norm = np.linalg.norm(A, axis=1)
    caps_norm = np.linalg.norm(B, axis=1)
    scores = np.dot(A, B.T)
    norms = np.dot(np.expand_dims(img_norm, 1),np.expand_dims(caps_norm.T, 1).T)
    scores = (scores / norms)
    return scores

In [9]:
def multilingual_recall(imgs, caps, indices, ks=[1,5,10]):
    """
        Compute multingual recall
    """
    imgs = np.vstack(imgs)
    caps = np.vstack(caps)

    scores = -cosine_sim(imgs, caps)
    ranks = np.argsort(np.argsort(scores))
    # scores represent all the similarity between each images and each captions
    recall = {k:0 for k in ks}
    nb_imgs, nb_caps = ranks.shape
    for i in range(nb_imgs):
        for k in ks:
            for j in range(nb_caps):
                if indices[j] == i and ranks[i][j] < k: #if the caption correspond to the image and is ranked less than k
                    recall[k] += 1
                
    scores = np.transpose(scores)
    ranks_caps = np.argsort(np.argsort(scores))
    recall_caps = {k:0 for k in ks}
    print("Caption ranks :", ranks_caps)
    nb_caps, nb_imgs = ranks.shape
    for i in range(nb_caps):
        for k in ks:
            if ranks[indices[i]][i] < k:
                recall_caps[k] += 1
    
    return [recall[k] / imgs.shape[0]*100 for k in ks], [recall_caps[k] / ranks_caps.shape[0]*100 for k in ks]

# Models to evaluate
## with their corresponding dictionnaries

In [10]:
models = []

#english only models
models.append(["best_sota_coco.pth.tar", 
                {'en':"wiki.en.bin"}
              ]) # state of the art model

models.append(["best_w2vec.pth.tar", 
               {'en':"w2vec_model_vec.en.vec"}
              ]) # word2vec model
               
models.append(["best_bivec_coco.pth.tar", 
               {'en':"bivec_model_vec.en-fr.en.vec"}
               ]) # bivec on coco only
               
models.append(["best_correct_en.pth.tar", 
               {'en':"wiki.multi.en.vec"}
              ]) # muse on english


#multilingual models
models.append(["best_bivec_enfr.pth.tar",
               {'en':"bivec_model_vec.en-fr.en.vec",
                'fr':"bivec_model_vec.en-fr.fr.vec"}
            ])

models.append(["best_bivec_de.pth.tar", 
               {'en':"bivec_model_vec.de-en.en.vec",
                'de':"bivec_model_vec.de-en.de.vec"}
              ])
                

models.append(["best_correct_enfr.pth.tar", 
               {'en':"wiki.multi.en.vec",
                'fr':"wiki.multi.fr.vec",
                'fr':"wiki.multi.de.vec",
                'fr':"wiki.multi.cs.vec"}
              ])
                

models.append(["best_correct_full_enfrde.pth.tar",
               {'en':"wiki.multi.en.vec",
                'fr':"wiki.multi.fr.vec",
                'fr':"wiki.multi.de.vec",
                'fr':"wiki.multi.cs.vec"}
              ])
                
models.append(["best_correct_full_cs.pth.tar", 
               {'en':"wiki.multi.en.vec",
                'fr':"wiki.multi.fr.vec",
                'fr':"wiki.multi.de.vec",
                'fr':"wiki.multi.cs.vec"}
              ])

# Evaluation on english

In [18]:
for model, dic in models:
    aa = arguments(dic['en'])
    coco_dataset = CocoCaptionsRV(aa, sset="val", transform=prepro_val)

    coco_dataset_loader = DataLoader(coco_dataset, batch_size=batch_size, shuffle=False,
                                num_workers=6, collate_fn=collate_fn_padded, pin_memory=True)
    
    # load model
    checkpoint = torch.load("weights/"+model, map_location=lambda storage, loc: storage)
    join_emb = joint_embedding(checkpoint['args_dict']).cuda()
    join_emb.load_state_dict(checkpoint["state_dict"])
    join_emb = torch.nn.DataParallel(join_emb.cuda().eval())
    
    
    imgs_enc = list()
    caps_enc = list()

    for i, (imgs, caps, lengths) in enumerate(coco_dataset_loader, 0):
        print("%2.2f" % (i*batch_size/len(coco_dataset)*100), "\%", end="\r")
        input_imgs, input_caps = imgs.cuda(), caps.cuda()
        with torch.no_grad():
            output_imgs, output_caps = join_emb(input_imgs, input_caps, lengths)

        imgs_enc.append(output_imgs.cpu().data.numpy())
        caps_enc.append(output_caps.cpu().data.numpy())
    

    print(model, eval_recall5(imgs_enc, caps_enc))

best_sota_coco.pth.tar [array([66.08, 90.7 , 96.2 ]), array([54.124, 85.748, 92.928]), 0.0, 0.0]
best_w2vec.pth.tar [array([63.48, 89.48, 95.64]), array([51.868, 84.308, 92.484]), 0.0, 0.0]
best_bivec_coco.pth.tar [array([65.58, 90.52, 96.1 ]), array([55.572, 86.924, 93.856]), 0.0, 0.0]
best_correct_en.pth.tar [array([63.1 , 89.58, 95.56]), array([51.872, 84.708, 92.824]), 0.0, 0.0]
best_bivec_enfr.pth.tar [array([67.78, 91.58, 96.92]), array([56.092, 87.22 , 94.028]), 0.0, 0.0]
best_bivec_de.pth.tar [array([67.04, 91.38, 96.68]), array([54.66 , 86.524, 93.548]), 0.0, 0.0]
best_correct_enfr.pth.tar [array([63.88, 89.2 , 95.24]), array([52.252, 84.716, 92.74 ]), 0.0, 0.0]
best_correct_full_enfrde.pth.tar [array([62.4 , 89.18, 95.16]), array([51.172, 84.092, 92.224]), 0.0, 0.2]
best_correct_full_cs.pth.tar [array([63.28, 88.3 , 94.6 ]), array([50.444, 83.388, 91.804]), 0.0, 0.4]


# Multilang evaluation

In [36]:
%aimport misc.dataset

In [37]:
from torch.nn.utils.rnn import pad_sequence

def collate_sentences(data):
    #print("Collating :", data)
    sta = np.stack(data[1])
    print(sta)
    return pad_sequence(data[0], batch_first=True), sta
    
#collate_sentences = lambda x:(pad_sequence(x[0], batch_first=True), np.stack(x[1]))

In [38]:
def collate_fn_cap_index(data):
    captions, indices = zip(*data)
    captions = pad_sequence(captions, batch_first=True)
    indices = np.stack(indices)
    return captions, indices


In [87]:
def multilingual_recall(imgs, caps, indices, ks=[1,5,10]):
    """
        Compute multingual recall
    """
    imgs = np.vstack(imgs)
    caps = np.vstack(caps)
    indices = np.concatenate(indices)
    
    scores = -cosine_sim(imgs, caps)
    ranks = np.argsort(np.argsort(scores))
    # scores represent all the similarity between each images and each captions
    recall = {k:0 for k in ks}
    nb_imgs, nb_caps = ranks.shape

    for i in range(nb_imgs):
        for k in ks:
            for c in range(nb_caps):
                if indices[c] == i and ranks[i][c] < k: #if caption correspond to image and is ranked less than k
                        recall[k] += 1
                        break
                        
    scores = -cosine_sim(caps, imgs)
    ranks_caps = np.argsort(np.argsort(scores))
    recall_caps = {k:0 for k in ks}

    nb_caps, nb_imgs = ranks_caps.shape
    for c in range(nb_caps):
        for k in ks:
            if ranks_caps[c][indices[c]] < k:
                recall_caps[k] += 1
    
    return [recall[k] / nb_imgs*100 for k in recall], [recall_caps[k] / nb_caps*100 for k in ks], np.median(ranks), np.median(ranks_caps)



In [89]:
for model, dic in models:
    # load model
    checkpoint = torch.load("weights/"+model, map_location=lambda storage, loc: storage)
    join_emb = joint_embedding(checkpoint['args_dict']).cuda()
    join_emb.load_state_dict(checkpoint["state_dict"])
    join_emb = torch.nn.DataParallel(join_emb.cuda().eval())
    
    
    image_dataset = misc.dataset.ImageDataset("data/image_splits/test_2016_flickr.txt",
                            "/data/datasets/flickr30k_images",
                            transform=prepro_val
                            )
    imgs_enc = []
    
    image_data_loader = DataLoader(image_dataset, batch_size=batch_size, shuffle=False,
                                num_workers=6, pin_memory=True)
    
    
    for i, imgs in enumerate(image_data_loader):
        print("%2.2f" % (i*batch_size/len(image_dataset)*100), "\%", end="\r")
        with torch.no_grad():
            output_imgs, _ = join_emb(imgs, None, None)

        imgs_enc.append(output_imgs.cpu().data.numpy())
    
    indices=[]
    caps_enc = []
    langs = []
    for lang in dic:
        langs.append(lang)
        lang_dataset = misc.dataset.CaptionDataset("data/tok/test_2016_flickr.lc.norm.tok."+lang,
                                                  '/data/m.portaz/'+dic[lang])
        caption_loader = DataLoader(lang_dataset, batch_size=batch_size, shuffle=False,
                                num_workers=6, collate_fn=collate_fn_cap_index, pin_memory=True)
        
        
        for i, (caps, inds) in enumerate(caption_loader):
            print("%2.2f" % (i*batch_size/len(lang_dataset)*100), "\%", end="\r")
            #caps = caps.cuda()
            with torch.no_grad():
                _, output_caps = join_emb(None, caps, None)
            caps_enc.append(output_caps.cpu().data.numpy())
            indices.append(inds)
        print(model, lang, multilingual_recall(imgs_enc,caps_enc, indices))
    print(model, langs ,multilingual_recall(imgs_enc,caps_enc, indices))


best_sota_coco.pth.tar en ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)
best_sota_coco.pth.tar ['en'] ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)
best_w2vec.pth.tar en ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)
best_w2vec.pth.tar ['en'] ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)
best_bivec_coco.pth.tar en ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)
best_bivec_coco.pth.tar ['en'] ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)
best_correct_en.pth.tar en ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)
best_correct_en.pth.tar ['en'] ([0.1, 0.5, 1.0], [0.1, 0.5, 1.0], 499.5, 499.5)


KeyboardInterrupt: 