In [1]:
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
from data.dataset import AP_Dataset, APeval_Dataset, SA_Dataset, SAeval_Dataset
import evaluation
from evaluation import PTBTokenizer, Cider
from models.transformer import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttentionMemory, ScaledDotProductAttention
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
import argparse, os, pickle
from tqdm import tqdm
import numpy as np
import itertools
import multiprocessing
from shutil import copyfile
import h5py
from utils import text_progress2, text_progress
import pandas as pd
import json

random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def evaluate_metrics(model, dataloader, text_field, mode="multiple", is_sample=False, beam_size=5, top_k=5, top_p=0.8):
    import itertools
    print(dataloader)
    print(mode)
    model.eval()
    gen = {}
    gts = {}
    with tqdm(desc='evalulateion metrics', unit='it', total=len(dataloader)) as pbar:
        for it, batch in enumerate(iter(dataloader)):
            images = batch['roi_feat']
            caps_gt = batch['cap']
            images = images.to(device)
            with torch.no_grad():
#                 beam_size = 5
                out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], beam_size, out_size=1, is_sample=is_sample, top_k=5, top_p=0.8)
#                 if decode == "beam_search":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1, is_sample=False)
#                 elif decode == "top-k_sampling":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 1, out_size=1, is_sample=True)
            caps_gen = text_field.decode(out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]
                if mode == "multiple":
                    gts['%d_%d' % (it, i)] = gts_i
                elif mode == "single":
                    gts['%d_%d' % (it, i)] = [gts_i[0]]
            pbar.update()

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen, spice=False)
    return scores

In [3]:
device = torch.device('cuda')

In [4]:
# Model
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                       remove_punctuation=True, nopoints=False)
text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))

encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': 40})
decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

## 1. artpedia

In [6]:
train_myidx = np.load('../Dataset/artpedia/train_myidx.npy')
val_myidx = np.load('../Dataset/artpedia/val_myidx.npy')
test_myidx = np.load('../Dataset/artpedia/test_myidx.npy')

ap_train_dataset = h5py.File("../Dataset/artpedia/ap_train_grid.hdf5", "r")
ap_val_dataset = h5py.File("../Dataset/artpedia/ap_val_grid.hdf5", "r")
ap_test_dataset = h5py.File("../Dataset/artpedia/ap_test_grid.hdf5", "r")
print("loading data: done!!!")

loading data: done!!!


In [8]:
# artpedia dataset
dict_artpedia_test = APeval_Dataset(ap_test_dataset, test_myidx, text_field, max_detections=50, feature_type="grid", lower=True, remove_punctuation=True, tokenize='spacy')

# artpedia, dataloader
dict_artpedia_test_data_loader = TorchDataLoader(dict_artpedia_test, batch_size=50, collate_fn=lambda x: text_progress2(x))

#### 1.2 artpedia, multiple captions for evaluation, beam search

In [9]:
# ** grid_std
# scratch model
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f8e6c422080>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.33it/s]


{'BLEU': [0.004326973082402499,
  2.5634695273945842e-11,
  4.723374467606041e-14,
  2.0568927010706106e-15],
 'METEOR': 0.014165827505820452,
 'ROUGE': 0.0039192435773511315,
 'CIDEr': 0.0012401598617022148}

In [10]:
# ** grid_std
# without fine-tuning, off-the-shelf model
data = torch.load('saved_models_grid_std/grid_std_last.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7ff735e80e10>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:02<00:00,  2.43it/s]


{'BLEU': [0.17158520112747436,
  0.06986202299331647,
  0.0290571500011994,
  0.014471267217663474],
 'METEOR': 0.04434568811794901,
 'ROUGE': 0.16056835045320964,
 'CIDEr': 0.026423129336359794}

In [9]:
# ** grid_std
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models_grid_std_apft/grid_std_last_21epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

<torch.utils.data.dataloader.DataLoader object at 0x7f5b73ae26a0>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:03<00:00,  2.14it/s]


{'BLEU': [0.23651595136113693,
  0.11535301254704178,
  0.05517624751153536,
  0.02861699999162214],
 'METEOR': 0.0652642460583719,
 'ROUGE': 0.21837367224144824,
 'CIDEr': 0.04142939673584087}

In [11]:
# ** grid_std
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models_grid_std_apft/grid_std_last_15epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

<torch.utils.data.dataloader.DataLoader object at 0x7f5b73ae26a0>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:02<00:00,  2.38it/s]


{'BLEU': [0.24092640079647745,
  0.11738046072365003,
  0.05299045307871724,
  0.026597708072234535],
 'METEOR': 0.06354250291789304,
 'ROUGE': 0.22068494763047783,
 'CIDEr': 0.029918216239498464}

In [13]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_std_apft/grid_std_last_21epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fa019543a20>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.29it/s]


{'BLEU': [0.3571358661176454,
  0.17361371226509703,
  0.08674862267140716,
  0.04570633002247616],
 'METEOR': 0.08824295864639446,
 'ROUGE': 0.23070234570577688,
 'CIDEr': 0.07891834094960062}

In [14]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_std_apft/grid_std_last_20epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fa019543a20>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.31it/s]


{'BLEU': [0.35125560858510335,
  0.17422489071457506,
  0.08999686413731837,
  0.050536320558352665],
 'METEOR': 0.08718024689117586,
 'ROUGE': 0.22868627601219982,
 'CIDEr': 0.08440374822549629}

In [15]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_std_apft/grid_std_last_19epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fa019543a20>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.31it/s]


{'BLEU': [0.35318592451741593,
  0.16834040979707066,
  0.0825267968983749,
  0.04409835817926199],
 'METEOR': 0.08707575268205435,
 'ROUGE': 0.22855139750142697,
 'CIDEr': 0.07518868139920898}

In [16]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_std_apft/grid_std_last_18epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fa019543a20>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.32it/s]


{'BLEU': [0.3558035810957888,
  0.172262817766436,
  0.0834531145284142,
  0.04260480886226038],
 'METEOR': 0.08943341300919651,
 'ROUGE': 0.23107114942411341,
 'CIDEr': 0.07631653715059386}

In [17]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_std_apft/grid_std_last_17epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fa019543a20>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.31it/s]


{'BLEU': [0.358945790118638,
  0.17721325558301393,
  0.08763283161252743,
  0.045814453923764605],
 'METEOR': 0.08998031260008607,
 'ROUGE': 0.23837964005557438,
 'CIDEr': 0.0797819271020616}

In [18]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_std_apft/grid_std_last_16epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fa019543a20>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.30it/s]


{'BLEU': [0.3613522362075712,
  0.17697943534387128,
  0.08730867362185374,
  0.045059511288135526],
 'METEOR': 0.08840195978654111,
 'ROUGE': 0.23528738908625133,
 'CIDEr': 0.07305138761212397}

In [9]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_std_apft/grid_std_last_15epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7fb07a201e10>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:05<00:00,  1.26it/s]


{'BLEU': [0.3557578858129252,
  0.17372540776847317,
  0.08281240839698314,
  0.04205395742564302],
 'METEOR': 0.08502436153130276,
 'ROUGE': 0.22745517479015018,
 'CIDEr': 0.06538858828989183}

## 2. semart

In [5]:
sa_test_csv = pd.read_csv("../Dataset/SemArt/prediction_csvs/semart_test_prediction.csv")
sa_test_csv = sa_test_csv[sa_test_csv['predictioin']==0]
test_roi_feats = h5py.File("../Dataset/SemArt/sa_test_grid.hdf5", "r")
test_img_names = np.unique(sa_test_csv['img_name'].to_numpy())
test_img_caps_map = json.load(open('../Dataset/SemArt/test_img_caps_map.json'))

dict_semart_test = SAeval_Dataset(sa_test_csv, test_img_names, test_img_caps_map, test_roi_feats, text_field, max_detections=50, lower=True, remove_punctuation=True, tokenize='spacy')
dict_semart_test_data_loader = TorchDataLoader(dict_semart_test, batch_size=50,
                                  collate_fn=lambda x: text_progress2(x))

#### 2.1 semart, multiple captions for evaluation

In [13]:
# ** grid_std
# off-the-shelf model
data = torch.load('saved_models/saved_models_grid_std/grid_std_best.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fad61774a90>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


{'BLEU': [0.11517470621353412,
  0.04139207503728256,
  0.015460279432342003,
  0.006791966587035756],
 'METEOR': 0.03834290863524735,
 'ROUGE': 0.1425161425820527,
 'CIDEr': 0.021504803105371926}

In [6]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_best_8epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fad61774a90>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:05<00:00,  2.21it/s]


{'BLEU': [0.18229221027173154,
  0.08893411516107362,
  0.0437774002995205,
  0.023555833187067023],
 'METEOR': 0.06352536010162067,
 'ROUGE': 0.21915605324102436,
 'CIDEr': 0.056163765645843036}

In [7]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_best_12epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fad61774a90>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:05<00:00,  2.30it/s]


{'BLEU': [0.18404202912788434,
  0.09013336043349846,
  0.04414499284640818,
  0.025480432973197693],
 'METEOR': 0.06063164718709266,
 'ROUGE': 0.21732314625355967,
 'CIDEr': 0.05959392235369736}

In [9]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_last_13epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fad61774a90>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:05<00:00,  2.29it/s]


{'BLEU': [0.16751313931206727,
  0.07710970825414262,
  0.04014469720747225,
  0.02311959274827015],
 'METEOR': 0.06006067955775578,
 'ROUGE': 0.20740895623879047,
 'CIDEr': 0.060322444207458324}

In [10]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_last_15epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fad61774a90>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:05<00:00,  2.31it/s]


{'BLEU': [0.18965665304637247,
  0.0950157436056729,
  0.0487429384232812,
  0.02924166179150639],
 'METEOR': 0.06308580523219724,
 'ROUGE': 0.2197763846461518,
 'CIDEr': 0.06585826591430592}

In [11]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_last_16epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fad61774a90>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:05<00:00,  2.30it/s]


{'BLEU': [0.1833499354514984,
  0.08673186328744482,
  0.04476226133275599,
  0.026790109186059896],
 'METEOR': 0.06085520970294295,
 'ROUGE': 0.20789636541300216,
 'CIDEr': 0.05170080322979307}

In [6]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
# remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_best_8epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f182347eac8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:10<00:00,  1.21it/s]


{'BLEU': [0.3082742231174583,
  0.1542147940548198,
  0.07660736613936082,
  0.04396027473140647],
 'METEOR': 0.08305229783226224,
 'ROUGE': 0.23925052119544996,
 'CIDEr': 0.093623867165398}

In [7]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
# remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_last_16epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f182347eac8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:10<00:00,  1.27it/s]


{'BLEU': [0.3079446270590496,
  0.15621797284084335,
  0.07590924176624102,
  0.04210847081680922],
 'METEOR': 0.08285630576417538,
 'ROUGE': 0.23537459462805665,
 'CIDEr': 0.09591957649036643}

In [8]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
# remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_last_15epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f182347eac8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:10<00:00,  1.26it/s]


{'BLEU': [0.3108348492098281,
  0.1597667352405077,
  0.07847689921097227,
  0.044307602565190174],
 'METEOR': 0.08386580424106972,
 'ROUGE': 0.23696186001034636,
 'CIDEr': 0.1022901572648685}

In [9]:
# ** grid_std   
# fine-tune on SemArt   shuffle    
# remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_last_13epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f182347eac8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:10<00:00,  1.25it/s]


{'BLEU': [0.31381365138630946,
  0.15588294292885566,
  0.07716917178811405,
  0.04413896570001189],
 'METEOR': 0.08435054141732126,
 'ROUGE': 0.23767617003418728,
 'CIDEr': 0.10130022960584317}

In [11]:
# ** grid_std   
# fine-tune on SemArt   shuffle  
# remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_std/sa_gridtr_sa_best_12epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f182347eac8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:10<00:00,  1.26it/s]


{'BLEU': [0.30906015361656763,
  0.1566134762300169,
  0.07779981208717075,
  0.04461520730988741],
 'METEOR': 0.0839683366102684,
 'ROUGE': 0.23399971999165048,
 'CIDEr': 0.09828388234275036}