In [1]:
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
from data.dataset import AP_Dataset, APeval_Dataset, SA_Dataset, SAeval_Dataset
import evaluation
from evaluation import PTBTokenizer, Cider
from models.transformer import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttentionMemory, ScaledDotProductAttention
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
import argparse, os, pickle
from tqdm import tqdm
import numpy as np
import itertools
import multiprocessing
from shutil import copyfile
import h5py
from utils import text_progress2, text_progress
import pandas as pd
import json

random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def evaluate_metrics(model, dataloader, text_field, mode="multiple", is_sample=False, beam_size=5, top_k=5, top_p=0.8):
    import itertools
    print(dataloader)
    print(mode)
    model.eval()
    gen = {}
    gts = {}
    with tqdm(desc='evalulateion metrics', unit='it', total=len(dataloader)) as pbar:
        for it, batch in enumerate(iter(dataloader)):
            images = batch['roi_feat']
            caps_gt = batch['cap']
            images = images.to(device)
            with torch.no_grad():
#                 beam_size = 5
                out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], beam_size, out_size=1, is_sample=is_sample, top_k=5, top_p=0.8)
#                 if decode == "beam_search":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1, is_sample=False)
#                 elif decode == "top-k_sampling":
#                     out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 1, out_size=1, is_sample=True)
            caps_gen = text_field.decode(out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]
                if mode == "multiple":
                    gts['%d_%d' % (it, i)] = gts_i
                elif mode == "single":
                    gts['%d_%d' % (it, i)] = [gts_i[0]]
            pbar.update()

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen, spice=False)
    return scores

In [3]:
device = torch.device('cuda')

In [4]:
# Model
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                       remove_punctuation=True, nopoints=False)
text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))

encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': 40})
decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

## 1. artpedia

In [9]:
train_myidx = np.load('../Dataset/artpedia/train_myidx.npy')
val_myidx = np.load('../Dataset/artpedia/val_myidx.npy')
test_myidx = np.load('../Dataset/artpedia/test_myidx.npy')

ap_train_dataset = h5py.File("../Dataset/artpedia/ap_train_region.hdf5", "r")
ap_val_dataset = h5py.File("../Dataset/artpedia/ap_val_region.hdf5", "r")
ap_test_dataset = h5py.File("../Dataset/artpedia/ap_test_region.hdf5", "r")
print("loading data: done!!!")

loading data: done!!!


In [10]:
# artpedia dataset
dict_artpedia_test = APeval_Dataset(ap_test_dataset, test_myidx, text_field, max_detections=50, feature_type="detector", lower=True, remove_punctuation=True, tokenize='spacy')

# artpedia, dataloader
dict_artpedia_test_data_loader = TorchDataLoader(dict_artpedia_test, batch_size=50, collate_fn=lambda x: text_progress2(x))

#### 1.2 artpedia, multiple captions for evaluation, beam search

In [8]:
# *** region stdtr
# scratch model
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7ff2d81fafd0>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:04<00:00,  1.48it/s]


{'BLEU': [0.00524318430199016,
  2.836729807696152e-11,
  5.072564320885483e-14,
  2.1766323754403583e-15],
 'METEOR': 0.015595644337142557,
 'ROUGE': 0.004175462859047491,
 'CIDEr': 0.0009228591392197886}

In [9]:
# *** region stdtr
# without fine-tuning, off-the-shelf model
data = torch.load('saved_models_region_std/region_std_best.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7ff2d81fafd0>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:02<00:00,  2.42it/s]


{'BLEU': [0.15064964497653152,
  0.05254520054357977,
  0.012394490245029106,
  7.835026833744943e-07],
 'METEOR': 0.034218845998585906,
 'ROUGE': 0.1514563893431952,
 'CIDEr': 0.007221904742579137}

In [10]:
# ** std transformer
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_region_std_apft/region_std_apft_last_17epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7ff2d81fafd0>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:02<00:00,  2.42it/s]


{'BLEU': [0.257237276372721,
  0.1211514887443507,
  0.05465593843534876,
  0.0254420322759611],
 'METEOR': 0.06266429495902702,
 'ROUGE': 0.21395755142357425,
 'CIDEr': 0.02933599737380237}

In [44]:
# ** std transformer
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_region_std_apft/region_std_apft_last_17epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f3fe39d1e48>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:07<00:00,  1.01s/it]


{'BLEU': [0.33371795247170943,
  0.15761846241713176,
  0.07185717110044607,
  0.03310101342775185],
 'METEOR': 0.08320607069392956,
 'ROUGE': 0.22028913774749376,
 'CIDEr': 0.057944515408255085}

In [41]:
file_path = "../Dataset/artpedia/artpedia_region_features/1420.npz"
img = np.load(file_path)
print("--------------")
print(file_path[:-4])
#     img = processor(torch.tensor(img)).unsqueeze(0)
img = torch.tensor([img['x']]).to(device)

--------------
../Dataset/artpedia/artpedia_region_features/1420


In [42]:
out, _ = model.beam_search(img, 20, text_field.vocab.stoi['<eos>'], 1, out_size=1, is_sample=False)

In [43]:
caps_gen = text_field.decode(out, join_words=False)[0]
caps_gen = ' '.join([k for k, g in itertools.groupby(caps_gen)])
caps_gen

'it shows a nude woman with a large mirror'

In [13]:
# 太好了，隐藏
# ** std transformer
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models_region_std_apft/region_std_apft_last_21epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7ff2d81fafd0>
multiple


evalulateion metrics: 100%|██████████| 7/7 [00:02<00:00,  2.43it/s]


{'BLEU': [0.25325047864758254,
  0.1174482210374481,
  0.051749822530998496,
  0.02653207663796888],
 'METEOR': 0.06306414530109579,
 'ROUGE': 0.2156961450358094,
 'CIDEr': 0.037402463240044696}

In [22]:
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field)

<torch.utils.data.dataloader.DataLoader object at 0x7fd22eadd0f0>
multiple


evalulateion metrics: 100%|██████████| 329/329 [01:03<00:00,  5.18it/s]


{'BLEU': [0.23290164145038766,
  0.11628312648636245,
  0.052702077361140964,
  0.02689678159915748],
 'ROUGE': 0.21914811563939154,
 'CIDEr': 0.03148247854508848}

In [8]:
# beam search   remove unk
# fine-tune on artpedia, one image with multiple captions for training      shuffle
data = torch.load('saved_models/artpedia_finetune_mulcap_shuffle.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_artpedia_test_data_loader, text_field, is_sample=False)

<torch.utils.data.dataloader.DataLoader object at 0x7f49562804e0>
multiple


evalulateion metrics: 100%|██████████| 329/329 [01:05<00:00,  5.00it/s]


{'BLEU': [0.33960716243253547,
  0.16180978856929548,
  0.0747107343165195,
  0.035472046837868663],
 'ROUGE': 0.22451221642237176,
 'CIDEr': 0.06145132500410184}

## 2. semart

In [5]:
sa_test_csv = pd.read_csv("../Dataset/SemArt/prediction_csvs/semart_test_prediction.csv")
sa_test_csv = sa_test_csv[sa_test_csv['predictioin']==0]
test_roi_feats = h5py.File("../Dataset/SemArt/sa_test_roi.hdf5", "r")
test_img_names = np.unique(sa_test_csv['img_name'].to_numpy())
test_img_caps_map = json.load(open('../Dataset/SemArt/test_img_caps_map.json'))

dict_semart_test = SAeval_Dataset(sa_test_csv, test_img_names, test_img_caps_map, test_roi_feats, text_field, max_detections=50, lower=True, remove_punctuation=True, tokenize='spacy')
dict_semart_test_data_loader = TorchDataLoader(dict_semart_test, batch_size=10,
                                  collate_fn=lambda x: text_progress2(x))

#### 2.1 semart, multiple captions for evaluation

In [8]:
# off-the-shelf model   ***
data = torch.load('saved_models/saved_models_region_std/region_std_best.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fc12d21fdd8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:13<00:00,  4.44it/s]


{'BLEU': [0.10140459597308837,
  0.03169078168216852,
  0.007083860979364486,
  4.123190171219849e-07],
 'METEOR': 0.030852755825836443,
 'ROUGE': 0.13096962696319228,
 'CIDEr': 0.004329042746890157}

In [11]:
# fine-tune on SemArt shuffle ***
data = torch.load('saved_models/saved_models_saft_region_std/sa_regionstd_sa_last_15epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fc12d21fdd8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:09<00:00,  6.67it/s]


{'BLEU': [0.1837182108351129,
  0.09057445922520133,
  0.04742767887883237,
  0.028502222867494566],
 'METEOR': 0.06295136197989658,
 'ROUGE': 0.21248080847379314,
 'CIDEr': 0.06793464689087193}

In [12]:
# fine-tune on SemArt shuffle ***
data = torch.load('saved_models/saved_models_saft_region_std/sa_regionstd_sa_last_17epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fc12d21fdd8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:09<00:00,  6.71it/s]


{'BLEU': [0.19369142251464125,
  0.09451746759851255,
  0.05058007094099744,
  0.03200999646136212],
 'METEOR': 0.06409278986758879,
 'ROUGE': 0.20886380803394036,
 'CIDEr': 0.08181295820994731}

In [14]:
# fine-tune on SemArt shuffle ***
data = torch.load('saved_models/saved_models_saft_region_std/sa_regionstd_sa_best_9epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7fc12d21fdd8>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:09<00:00,  6.66it/s]


{'BLEU': [0.1844762618161031,
  0.09035194568173129,
  0.047122917335089914,
  0.028681345432281263],
 'METEOR': 0.06284682555449271,
 'ROUGE': 0.21586449108361624,
 'CIDEr': 0.06075225033169698}

In [7]:
# fine-tune on SemArt shuffle ***
# remove <unk>
data = torch.load('saved_models/saved_models_saft_region_std/sa_regionstd_sa_best_11epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f2b77f55a90>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:39<00:00,  1.57it/s]


{'BLEU': [0.2973848290118472,
  0.14532965320560803,
  0.06940698793346775,
  0.039141015893592895],
 'METEOR': 0.08113096148741637,
 'ROUGE': 0.23387698234426282,
 'CIDEr': 0.07548107119948975}

In [9]:
# fine-tune on SemArt shuffle ***
# remove <unk>
data = torch.load('saved_models/saved_models_saft_region_std/sa_regionstd_sa_last_17epoch.pth')
model.load_state_dict(data['state_dict'])
evaluate_metrics(model, dict_semart_test_data_loader, text_field, mode='multiple')

<torch.utils.data.dataloader.DataLoader object at 0x7f3f4b485438>
multiple


evalulateion metrics: 100%|██████████| 62/62 [00:14<00:00,  4.35it/s]


{'BLEU': [0.2998353728530462,
  0.1468716489820476,
  0.07456873559596515,
  0.04307386701642498],
 'METEOR': 0.08242227112688057,
 'ROUGE': 0.2251958917282386,
 'CIDEr': 0.10360548719242094}