In [2]:
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
from data.dataset import AP_Dataset, APeval_Dataset, SA_Dataset, SAeval_Dataset
import evaluation
from evaluation import PTBTokenizer, Cider
# from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
from models.grid_m2_rst import  Transformer, TransformerEncoder, MeshedDecoder, ScaledDotProductAttention
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
import argparse, os, pickle
from tqdm import tqdm
import numpy as np
import itertools
import multiprocessing
from shutil import copyfile
import h5py
from utils import text_progress2, text_progress
import pandas as pd
import json

random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)

In [3]:
def evaluate_metrics(model, dataloader, text_field, mode="multiple", is_sample=False, beam_size=5, top_k=5, top_p=0.8):
    import itertools
    print(dataloader)
    print(mode)
    model.eval()
    gen = {}
    gts = {}
    with tqdm(desc='evalulateion metrics', unit='it', total=len(dataloader)) as pbar:
        for it, batch in enumerate(iter(dataloader)):
            images = batch['roi_feat']
            caps_gt = batch['cap']
            images = images.to(device)
            with torch.no_grad():
#                 beam_size = 5
                out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], beam_size, out_size=1, is_sample=is_sample, top_k=5, top_p=0.8)
#                 if decode == "beam_search":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1, is_sample=False)
#                 elif decode == "top-k_sampling":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 1, out_size=1, is_sample=True)
            caps_gen = text_field.decode(out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]
                if mode == "multiple":
                    gts['%d_%d' % (it, i)] = gts_i
                elif mode == "single":
                    gts['%d_%d' % (it, i)] = [gts_i[0]]
            pbar.update()

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen, spice=False)
    return scores

In [4]:
device = torch.device('cuda')

In [5]:
# Model
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                       remove_punctuation=True, nopoints=False)
text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))

encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': 40})
decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## 1. artpedia

In [6]:
train_myidx = np.load('../Dataset/artpedia/train_myidx.npy')
val_myidx = np.load('../Dataset/artpedia/val_myidx.npy')
test_myidx = np.load('../Dataset/artpedia/test_myidx.npy')

ap_train_dataset = h5py.File("../Dataset/artpedia/ap_train_grid.hdf5", "r")
ap_val_dataset = h5py.File("../Dataset/artpedia/ap_val_grid.hdf5", "r")
ap_test_dataset = h5py.File("../Dataset/artpedia/ap_test_grid.hdf5", "r")
print("loading data: done!!!")

loading data: done!!!


In [8]:
# artpedia dataset
dict_artpedia_test = APeval_Dataset(ap_test_dataset, test_myidx, text_field, max_detections=49, feature_type="grid", lower=True, remove_punctuation=True, tokenize='spacy')

# artpedia, dataloader
dict_artpedia_test_data_loader = TorchDataLoader(dict_artpedia_test, batch_size=50, collate_fn=lambda x: text_progress2(x))

#### 1.1 artpedia, one caption for evaluation, beam search

#### 1.2 artpedia, multiple captions for evaluation, beam search

In [1]:
# *** grid_m2rst
# without fine-tuning, off-the-shelf model
data = torch.load('saved_models_grid_m2rst/grid_m2_rst_best.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

In [2]:
# *** grid_m2rst
# without fine-tuning, off-the-shelf model
data = torch.load('saved_models_grid_m2rst/grid_m2_rst_last.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

In [3]:
# *** grid_m2rst
# without fine-tuning, off-the-shelf model
data = torch.load('saved_models_grid_m2rst/grid_m2_rst_epoch41.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

In [4]:
# fine-tune on artpedia, one image with one caption for training
data = torch.load('saved_models/artpedia_finetune_singlecap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

In [5]:
# *** grid_m2rst
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models/grid_m2_tr_last_18epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

In [6]:
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models/artpedia_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

In [7]:
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

In [8]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

In [10]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_grid_m2rst_apft/grid_m2_tr_last_18epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7f2055ff5588>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it]


{'BLEU': [0.3609881069135,
  0.17902517899338843,
  0.09118977721377369,
  0.050181107544762536],
 'METEOR': 0.09133971685517936,
 'ROUGE': 0.23398104425547456,
 'CIDEr': 0.09330209254997349}

## 2. semart

In [6]:
sa_test_csv = pd.read_csv("../Dataset/SemArt/prediction_csvs/semart_test_prediction.csv")
sa_test_csv = sa_test_csv[sa_test_csv['predictioin']==0]
test_roi_feats = h5py.File("../Dataset/SemArt/sa_test_grid.hdf5", "r")
test_img_names = np.unique(sa_test_csv['img_name'].to_numpy())
test_img_caps_map = json.load(open('../Dataset/SemArt/test_img_caps_map.json'))

dict_semart_test = SAeval_Dataset(sa_test_csv, test_img_names, test_img_caps_map, test_roi_feats, text_field, max_detections=49, lower=True, remove_punctuation=True, tokenize='spacy')
dict_semart_test_data_loader = TorchDataLoader(dict_semart_test, batch_size=50,
                                  collate_fn=lambda x: text_progress2(x))

#### 2.1 semart, multiple captions for evaluation

In [6]:
# off-the-shelf model
data = torch.load('saved_models/saved_models_grid_m2rst/grid_m2_rst_last.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f70735d40b8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:22<00:00,  1.73s/it]


{'BLEU': [0.1238723982265556,
  0.04377896610833391,
  0.017749577060783853,
  0.008296285553807443],
 'METEOR': 0.03853641775444528,
 'ROUGE': 0.1403943502673313,
 'CIDEr': 0.026411970140918662}

In [7]:
# off-the-shelf model
data = torch.load('saved_models/saved_models_grid_m2rst/grid_m2_rst_best.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f70735d40b8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:22<00:00,  1.71s/it]


{'BLEU': [0.12809294283327557,
  0.045397722109473716,
  0.01791234797525151,
  0.007650384719347912],
 'METEOR': 0.039360563038510016,
 'ROUGE': 0.142495980976305,
 'CIDEr': 0.020916017474374143}

In [8]:
# off-the-shelf model
data = torch.load('saved_models/saved_models_grid_m2rst/grid_m2_rst_epoch41.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f70735d40b8>
multiple


evalulateion metrics: 100%|██████████| 13/13 [00:22<00:00,  1.70s/it]


{'BLEU': [0.13904005551409623,
  0.04929343466036093,
  0.019266916799502324,
  0.008159874882971943],
 'METEOR': 0.03898631586267387,
 'ROUGE': 0.13827194208741386,
 'CIDEr': 0.02001249028494681}

In [9]:
# fine-tune on SemArt    remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_last_17epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [10]:
# fine-tune on SemArt          remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_best_13epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [11]:
# fine-tune on SemArt             remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_last_15epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [12]:
# fine-tune on SemArt            remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_last_16epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [13]:
# fine-tune on SemArt    remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_last_17epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [14]:
# fine-tune on SemArt          remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_best_13epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [15]:
# fine-tune on SemArt             remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_last_15epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [16]:
# fine-tune on SemArt            remove <unk>
data = torch.load('saved_models/saved_models_saft_grid_m2rst/sa_gridm2rst_sa_last_16epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')