In [None]:
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
from data.dataset import AP_Dataset, APeval_Dataset, SA_Dataset, SAeval_Dataset
import evaluation
from evaluation import PTBTokenizer, Cider
from models.m2 import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
# from models.grid_m2 import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
import argparse, os, pickle
from tqdm import tqdm
import numpy as np
import itertools
import multiprocessing
from shutil import copyfile
import h5py
from utils import text_progress2, text_progress
import pandas as pd
import json

random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)

In [None]:
def evaluate_metrics(model, dataloader, text_field, mode="multiple", is_sample=False, beam_size=5, top_k=5, top_p=0.8):
    import itertools
    print(dataloader)
    print(mode)
    model.eval()
    gen = {}
    gts = {}
    with tqdm(desc='evalulateion metrics', unit='it', total=len(dataloader)) as pbar:
        for it, batch in enumerate(iter(dataloader)):
            images = batch['roi_feat']
            caps_gt = batch['cap']
            images = images.to(device)
            with torch.no_grad():
#                 beam_size = 5
                out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], beam_size, out_size=1, is_sample=is_sample, top_k=5, top_p=0.8)
#                 if decode == "beam_search":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1, is_sample=False)
#                 elif decode == "top-k_sampling":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 1, out_size=1, is_sample=True)
            caps_gen = text_field.decode(out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]
                if mode == "multiple":
                    gts['%d_%d' % (it, i)] = gts_i
                elif mode == "single":
                    gts['%d_%d' % (it, i)] = [gts_i[0]]
            pbar.update()

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen, spice=False)
    return scores

In [None]:
device = torch.device('cuda')

In [None]:
# Model
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                       remove_punctuation=True, nopoints=False)
text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))

encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
                                 attention_module_kwargs={'m': 40})
decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])

model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

## 1. artpedia

In [None]:
train_myidx = np.load('../Dataset/artpedia/train_myidx.npy')
val_myidx = np.load('../Dataset/artpedia/val_myidx.npy')
test_myidx = np.load('../Dataset/artpedia/test_myidx.npy')

ap_train_dataset = h5py.File("../Dataset/artpedia/artpedia_train2.hdf5", "r")
ap_val_dataset = h5py.File("../Dataset/artpedia/artpedia_val2.hdf5", "r")
ap_test_dataset = h5py.File("../Dataset/artpedia/artpedia_test2.hdf5", "r")
print("loading data: done!!!")

In [None]:
# artpedia dataset
dict_artpedia_test = APeval_Dataset(ap_test_dataset, test_myidx, text_field, max_detections=50, lower=True, remove_punctuation=True, tokenize='spacy')

# artpedia, dataloader
dict_artpedia_test_data_loader = TorchDataLoader(dict_artpedia_test, batch_size=50, collate_fn=lambda x: text_progress2(x))

#### 1.1 artpedia, one caption for evaluation, beam search

In [None]:
# scratch model
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode = "single")

In [None]:
# without fine-tuning, off-the-shelf model
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])

evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

In [None]:
# fine-tune on artpedia, one image with one caption for training
data = torch.load('saved_models/artpedia_finetune_singlecap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

In [None]:
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models/artpedia_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

In [None]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

In [None]:
# beam search.   remove <unk> token
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single')

In [None]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
# generation by top-k sampling    k=5
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single', is_sample=False)

In [None]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
# generation by top-k sampling    k=10
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single', decode="top-k_sampling")

In [None]:
# fine-tune on artpedia, one image with multiple captions for training   + shuffle
# generation by top-k sampling    k=10
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='single', is_sample=True, top_k=10)

#### 1.2 artpedia, multiple captions for evaluation, beam search

In [None]:
# scratch model
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

In [None]:
# without fine-tuning, off-the-shelf model
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

In [None]:
# fine-tune on artpedia, one image with one caption for training
data = torch.load('saved_models/artpedia_finetune_singlecap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

In [None]:
# fine-tune on artpedia, one image with multiple captions for training
data = torch.load('saved_models/artpedia_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

In [None]:
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

In [None]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

## 2. semart

In [None]:
sa_test_csv = pd.read_csv("../Dataset/SemArt/prediction_csvs/semart_test_prediction.csv")
sa_test_csv = sa_test_csv[sa_test_csv['predictioin']==0]
test_roi_feats = h5py.File("../Dataset/SemArt/sa_test_roi.hdf5", "r")
test_img_names = np.unique(sa_test_csv['img_name'].to_numpy())
test_img_caps_map = json.load(open('../Dataset/SemArt/test_img_caps_map.json'))

dict_semart_test = SAeval_Dataset(sa_test_csv, test_img_names, test_img_caps_map, test_roi_feats, text_field, max_detections=50, lower=True, remove_punctuation=True, tokenize='spacy')
dict_semart_test_data_loader = TorchDataLoader(dict_semart_test, batch_size=10,
                                  collate_fn=lambda x: text_progress2(x))

#### 2.1 semart, multiple captions for evaluation

In [None]:
# scratch model
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [None]:
# off-the-shelf model
data = torch.load('saved_models/saved_models_region_m2/meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [None]:
# fine-tune on SemArt
data = torch.load('saved_models/semart_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [None]:
# fine-tune on SemArt   shuffle
data = torch.load('saved_models/saved_models_saft_region_m2/semart_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

In [None]:
# fine-tune on SemArt   shuffle
# remove <unk>
data = torch.load('saved_models/saved_models_saft_region_m2/semart_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

#### 2.2 semart, single captions for evaluation

In [None]:
# scratch model
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')

In [None]:
# off-the-shelf model
data = torch.load('meshed_memory_transformer.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')

In [None]:
# fine-tune on SemArt
data = torch.load('saved_models/semart_finetune_mulcap.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')

In [None]:
# fine-tune on SemArt   shuffle
data = torch.load('saved_models/semart_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='single')