In [5]:
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
from data.dataset import AP_Dataset, APeval_Dataset, SA_Dataset, SAeval_Dataset
import evaluation
from evaluation import PTBTokenizer, Cider
from models.m2 import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
# from models.grid_m2 import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
import argparse, os, pickle
from tqdm import tqdm
import numpy as np
import itertools
import multiprocessing
from shutil import copyfile
import h5py
from utils import text_progress2, text_progress
import pandas as pd
import json

random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)

In [16]:
def evaluate_metrics(model, dataloader, text_field, mode="multiple", is_sample=False, beam_size=5, top_k=5, top_p=0.8):
    import itertools
    print(dataloader)
    print(mode)
    model.eval()
    gen = {}
    gts = {}
    with tqdm(desc='evalulateion metrics', unit='it', total=len(dataloader)) as pbar:
        for it, batch in enumerate(iter(dataloader)):
            images = batch['roi_feat']
            caps_gt = batch['cap']
            images = images.to(device)
            with torch.no_grad():
#                 beam_size = 5
                out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], beam_size, out_size=1, is_sample=is_sample, top_k=5, top_p=0.8)
#                 if decode == "beam_search":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1, is_sample=False)
#                 elif decode == "top-k_sampling":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 1, out_size=1, is_sample=True)
            caps_gen = text_field.decode(out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]
                if mode == "multiple":
                    gts['%d_%d' % (it, i)] = gts_i
                elif mode == "single":
                    gts['%d_%d' % (it, i)] = [gts_i[0]]
            pbar.update()

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen, spice=False)
    return scores

In [7]:
device = torch.device('cuda')

In [8]:
# Model
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                       remove_punctuation=True, nopoints=False)
text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))

In [9]:
encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
                                 attention_module_kwargs={'m': 40})

In [10]:
decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])

In [11]:
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder)

In [12]:
# Model
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                       remove_punctuation=True, nopoints=False)
text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))

encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
                                 attention_module_kwargs={'m': 40})
decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])

model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

## 1. artpedia

In [13]:
train_myidx = np.load('../Dataset/artpedia/train_myidx.npy')
val_myidx = np.load('../Dataset/artpedia/val_myidx.npy')
test_myidx = np.load('../Dataset/artpedia/test_myidx.npy')

ap_train_dataset = h5py.File("../Dataset/artpedia/artpedia_train2.hdf5", "r")
ap_val_dataset = h5py.File("../Dataset/artpedia/artpedia_val2.hdf5", "r")
ap_test_dataset = h5py.File("../Dataset/artpedia/artpedia_test2.hdf5", "r")
print("loading data: done!!!")

loading data: done!!!


In [14]:
# artpedia dataset
dict_artpedia_test = APeval_Dataset(ap_test_dataset, test_myidx, text_field, max_detections=50, lower=True, remove_punctuation=True, tokenize='spacy')

# artpedia, dataloader
dict_artpedia_test_data_loader = TorchDataLoader(dict_artpedia_test, batch_size=50, collate_fn=lambda x: text_progress2(x))

#### 1.1 artpedia, one caption for evaluation, beam search

In [11]:
# scratch model
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode = "single")

<torch.utils.data.dataloader.DataLoader object at 0x7f1660140860>
single


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.81it/s]


{'BLEU': [0.0001236055972225921,
  4.035351405369087e-12,
  1.3192351808657412e-14,
  7.681117691112694e-16],
 'ROUGE': 0.00017328068630514444,
 'CIDEr': 2.058933461283359e-06}

In [14]:
# without fine-tuning, off-the-shelf model
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])

evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7f1660140860>
single


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.86it/s]


{'BLEU': [0.06869365934607014,
  0.02408617516886492,
  0.009203178196801148,
  0.004077820732889246],
 'ROUGE': 0.1265856198220996,
 'CIDEr': 0.035073845779796387}

In [9]:
# fine-tune on artpedia, one image with one caption for training
data = torch.load('saved_models/artpedia_finetune_singlecap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5db3dcc0>
single


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.85it/s]


{'BLEU': [0.10770007890173203,
  0.042905798133398614,
  0.021158617232954074,
  0.011994791859808066],
 'ROUGE': 0.16777807810792697,
 'CIDEr': 0.05152375363409006}

In [10]:
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models/artpedia_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5db3dcc0>
single


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.95it/s]


{'BLEU': [0.10001649470859435,
  0.04153007097875002,
  0.020453615889150106,
  0.010804099773927376],
 'ROUGE': 0.16959301187547615,
 'CIDEr': 0.04574429996778054}

In [12]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5db3dcc0>
single


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.90it/s]


{'BLEU': [0.11439772302455756,
  0.049840209225196434,
  0.023219958693446556,
  0.012087941730175456],
 'ROUGE': 0.17763442833720808,
 'CIDEr': 0.04937305896140482}

In [10]:
# beam search.   remove <unk> token
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7f49562804e0>
single


evalulateion metrics: 100%|██████████| 329/329 [01:05<00:00,  5.05it/s]


{'BLEU': [0.20352391671520823,
  0.08086088054241472,
  0.03645516117706604,
  0.017655400779776953],
 'ROUGE': 0.18542789079533953,
 'CIDEr': 0.07592433158198265}

In [19]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
# generation by top-k sampling    k=5
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single', is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fd22eadd0f0>
single


evalulateion metrics: 100%|██████████| 329/329 [01:00<00:00,  5.42it/s]


{'BLEU': [0.09333995593432404,
  0.038172705149673934,
  0.01817147245345931,
  0.009833963777214086],
 'ROUGE': 0.16580195687810742,
 'CIDEr': 0.038267756192650784}

In [7]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
# generation by top-k sampling    k=10
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single', decode="top-k_sampling")

<torch.utils.data.dataloader.DataLoader object at 0x7f70e047ff28>
single


evalulateion metrics: 100%|██████████| 329/329 [01:02<00:00,  5.26it/s]


{'BLEU': [0.13215944908390478,
  0.04599203906981929,
  0.015877252152370982,
  0.005857381895548809],
 'ROUGE': 0.1465819237593196,
 'CIDEr': 0.030852426204639675}

In [18]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
# generation by top-k sampling    k=10
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single', is_sample=True, top_k=10)

<torch.utils.data.dataloader.DataLoader object at 0x7fd22eadd0f0>
single


evalulateion metrics: 100%|██████████| 329/329 [01:02<00:00,  5.23it/s]


{'BLEU': [0.1336230293647883,
  0.0478548036942355,
  0.015488464819317658,
  0.005749882420519693],
 'ROUGE': 0.1460779368214854,
 'CIDEr': 0.027999531132453754}

#### 1.2 artpedia, multiple captions for evaluation, beam search

In [12]:
# scratch model
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f1660140860>
multiple


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.89it/s]


{'BLEU': [0.0012174750562701938,
  1.4052652310529447e-11,
  3.248509699697838e-14,
  1.5904823867729776e-15],
 'ROUGE': 0.0014741008774548938,
 'CIDEr': 0.00020410413642154462}

In [13]:
# without fine-tuning, off-the-shelf model
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f1660140860>
multiple


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.96it/s]


{'BLEU': [0.17192803131078616,
  0.07112076525833422,
  0.028359206677696855,
  0.012836546820123028],
 'ROUGE': 0.16342891844085322,
 'CIDEr': 0.027290209662590086}

In [15]:
# fine-tune on artpedia, one image with one caption for training
data = torch.load('saved_models/artpedia_finetune_singlecap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f1660140860>
multiple


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.94it/s]


{'BLEU': [0.2263770796898909,
  0.10189771635005827,
  0.04723688221258288,
  0.024825597404249923],
 'ROUGE': 0.2044362212159059,
 'CIDEr': 0.03947330064255181}

In [16]:
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models/artpedia_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

<torch.utils.data.dataloader.DataLoader object at 0x7f1660140860>
multiple


evalulateion metrics: 100%|██████████| 66/66 [00:13<00:00,  4.95it/s]


{'BLEU': [0.21496088729295892,
  0.10277365798269977,
  0.04822575058710956,
  0.024778542489085762],
 'ROUGE': 0.2091605244824866,
 'CIDEr': 0.03343695862148488}

In [22]:
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

<torch.utils.data.dataloader.DataLoader object at 0x7fd22eadd0f0>
multiple


evalulateion metrics: 100%|██████████| 329/329 [01:03<00:00,  5.18it/s]


{'BLEU': [0.23290164145038766,
  0.11628312648636245,
  0.052702077361140964,
  0.02689678159915748],
 'ROUGE': 0.21914811563939154,
 'CIDEr': 0.03148247854508848}

In [8]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7f49562804e0>
multiple


evalulateion metrics: 100%|██████████| 329/329 [01:05<00:00,  5.00it/s]


{'BLEU': [0.33960716243253547,
  0.16180978856929548,
  0.0747107343165195,
  0.035472046837868663],
 'ROUGE': 0.22451221642237176,
 'CIDEr': 0.06145132500410184}

## 2. semart

In [16]:
sa_test_csv = pd.read_csv("../Dataset/SemArt/prediction_csvs/semart_test_prediction.csv")
sa_test_csv = sa_test_csv[sa_test_csv['predictioin']==0]
test_roi_feats = h5py.File("../Dataset/SemArt/sa_test_roi.hdf5", "r")
test_img_names = np.unique(sa_test_csv['img_name'].to_numpy())
test_img_caps_map = json.load(open('../Dataset/SemArt/test_img_caps_map.json'))

dict_semart_test = SAeval_Dataset(sa_test_csv, test_img_names, test_img_caps_map, test_roi_feats, text_field, max_detections=50, lower=True, remove_punctuation=True, tokenize='spacy')
dict_semart_test_data_loader = TorchDataLoader(dict_semart_test, batch_size=10,
                                  collate_fn=lambda x: text_progress2(x))

#### 2.1 semart, multiple captions for evaluation

In [21]:
# scratch model
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:15<00:00,  4.00it/s]


{'BLEU': [0.0018378225188191279,
  1.1993886603098599e-11,
  2.291862529361248e-14,
  1.0199533973959484e-15],
 'ROUGE': 0.0020635112957342753,
 'CIDEr': 0.0003285390689160264}

In [18]:
# off-the-shelf model
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:15<00:00,  4.06it/s]


{'BLEU': [0.11661231133981016,
  0.04033414573061459,
  0.015035969453543152,
  0.006403400474171556],
 'ROUGE': 0.14120449334790547,
 'CIDEr': 0.017936319393811125}

In [19]:
# fine-tune on SemArt
data = torch.load('saved_models/semart_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:15<00:00,  4.04it/s]


{'BLEU': [0.1943147719540551,
  0.09214780595267547,
  0.04630569693062932,
  0.02707906791756692],
 'ROUGE': 0.21633784481124718,
 'CIDEr': 0.057332406384715213}

In [26]:
# fine-tune on SemArt   shuffle
data = torch.load('saved_models/semart_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:15<00:00,  4.07it/s]


{'BLEU': [0.19465401469523047,
  0.09556125592848534,
  0.047773953419375625,
  0.028291138877741047],
 'ROUGE': 0.2129567245656028,
 'CIDEr': 0.05563681242748405}

#### 2.2 semart, single captions for evaluation

In [25]:
# scratch model
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
single


evalulateion metrics: 100%|██████████| 62/62 [00:29<00:00,  2.12it/s]


{'BLEU': [0.0003956737196290604,
  4.901831183185136e-12,
  1.1638071419966104e-14,
  5.789889675096696e-16],
 'ROUGE': 0.0004424504289318356,
 'CIDEr': 3.8845498212968275e-05}

In [22]:
# off-the-shelf model
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
single


evalulateion metrics: 100%|██████████| 62/62 [00:25<00:00,  2.41it/s]


{'BLEU': [0.056891952408472155,
  0.01754322385391535,
  0.00642865973036798,
  0.002323683418409561],
 'ROUGE': 0.1162408217732889,
 'CIDEr': 0.024121582330860605}

In [23]:
# fine-tune on SemArt
data = torch.load('saved_models/semart_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
single


evalulateion metrics: 100%|██████████| 62/62 [00:29<00:00,  2.11it/s]


{'BLEU': [0.11726016251663722,
  0.0530377317106873,
  0.02759006235809311,
  0.01663845349934508],
 'ROUGE': 0.18710581654278827,
 'CIDEr': 0.07783775341498904}

In [28]:
# fine-tune on SemArt   shuffle
data = torch.load('saved_models/semart_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')

<torch.utils.data.dataloader.DataLoader object at 0x7faf5de78cf8>
single


evalulateion metrics: 100%|██████████| 62/62 [00:15<00:00,  4.03it/s]


{'BLEU': [0.12102475253065365,
  0.05635952468793378,
  0.02944552781687917,
  0.01802454846235869],
 'ROUGE': 0.18495195667006736,
 'CIDEr': 0.07142283831893503}