In [1]:
%pip install nltk
%pip install pycocoevalcap

Collecting pycocoevalcap
  Downloading pycocoevalcap-1.2-py3-none-any.whl (104.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.3/104.3 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pycocoevalcap
Successfully installed pycocoevalcap-1.2


In [4]:
import os
import nltk
import json
import torch
import urllib
import pickle
import zipfile
import numpy as np
from PIL import Image
from collections import Counter
from pycocotools.coco import COCO
from model import Encoder, Decoder
from torchvision import transforms
from collections import defaultdict
from pycocoevalcap.eval import COCOEvalCap
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
os.makedirs('opt' , exist_ok=True)
os.chdir( '/content/opt' )
!git clone 'https://github.com/cocodataset/cocoapi.git'

Cloning into 'cocoapi'...
remote: Enumerating objects: 975, done.[K
remote: Total 975 (delta 0), reused 0 (delta 0), pack-reused 975[K
Receiving objects: 100% (975/975), 11.72 MiB | 16.98 MiB/s, done.
Resolving deltas: 100% (576/576), done.


In [5]:
os.chdir('/content/opt/cocoapi')
annotations_trainval2014 = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'
urllib.request.urlretrieve(annotations_trainval2014 , filename = 'annotations_trainval2014.zip' )

('annotations_trainval2014.zip', <http.client.HTTPMessage at 0x7b82a7394f40>)

In [6]:
with zipfile.ZipFile('annotations_trainval2014.zip' , 'r') as zip_ref:
  zip_ref.extractall( '/content/opt/cocoapi'  )

try:
  os.remove( 'annotations_trainval2014.zip' )
  print('zip removed')
except:
  None

zip removed


In [7]:
os.chdir('/content/opt/cocoapi')

val2014 = 'http://images.cocodataset.org/zips/val2014.zip'

urllib.request.urlretrieve(val2014, 'val2014')

('val2014', <http.client.HTTPMessage at 0x7b82a7395000>)

In [8]:
os.chdir('/content/opt/cocoapi')
with zipfile.ZipFile( 'val2014' , 'r' ) as zip_ref:
  zip_ref.extractall( 'images' )

try:
  os.remove( 'val2014' )
  print('zip removed')
except:
  None

zip removed


In [15]:
class Vocabulary(object):

    def __init__(self,
        vocab_threshold,
        vocab_file='./vocab.pkl',
        start_word="<start>",
        end_word="<end>",
        unk_word="<unk>",
        annotations_file='../cocoapi/annotations/captions_train2014.json',
        vocab_from_file=False):
        """Initialize the vocabulary.
        Args:
          vocab_threshold: Minimum word count threshold.
          vocab_file: File containing the vocabulary.
          start_word: Special word denoting sentence start.
          end_word: Special word denoting sentence end.
          unk_word: Special word denoting unknown words.
          annotations_file: Path for train annotation file.
          vocab_from_file: If False, create vocab from scratch & override any existing vocab_file
                           If True, load vocab from from existing vocab_file, if it exists
        """
        self.vocab_threshold = vocab_threshold
        self.vocab_file = vocab_file
        self.start_word = start_word
        self.end_word = end_word
        self.unk_word = unk_word
        self.annotations_file = annotations_file
        self.vocab_from_file = vocab_from_file
        self.get_vocab()

    def get_vocab(self):
        """Load the vocabulary from file OR build the vocabulary from scratch."""
        if os.path.exists(self.vocab_file) & self.vocab_from_file:
            with open(self.vocab_file, 'rb') as f:
                vocab = pickle.load(f)
                self.word2idx = vocab.word2idx
                self.idx2word = vocab.idx2word
            print('Vocabulary successfully loaded from vocab.pkl file!')
        else:
            self.build_vocab()
            with open(self.vocab_file, 'wb') as f:
                pickle.dump(self, f)

    def build_vocab(self):
        """Populate the dictionaries for converting tokens to integers (and vice-versa)."""
        self.init_vocab()
        self.add_word(self.start_word)
        self.add_word(self.end_word)
        self.add_word(self.unk_word)
        self.add_captions()

    def init_vocab(self):
        """Initialize the dictionaries for converting tokens to integers (and vice-versa)."""
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        """Add a token to the vocabulary."""
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def add_captions(self):
        """Loop over training captions and add all tokens to the vocabulary that meet or exceed the threshold."""
        coco = COCO(self.annotations_file)
        counter = Counter()
        ids = coco.anns.keys()
        for i, id in enumerate(ids):
            caption = str(coco.anns[id]['caption'])
            tokens = nltk.tokenize.word_tokenize(caption.lower())
            counter.update(tokens)

            if i % 100000 == 0:
                print("[%d/%d] Tokenizing captions..." % (i, len(ids)))

        words = [word for word, cnt in counter.items() if cnt >= self.vocab_threshold]

        for i, word in enumerate(words):
            self.add_word(word)

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx[self.unk_word]
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

In [16]:
with open('/content/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [17]:
# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_file = 'encoder-3.pkl'
decoder_file = 'decoder-3.pkl'
embed_size = 256
hidden_size = 512
vocab_size = len(vocab)
encoder = Encoder(embed_size).eval().to(device)
decoder = Decoder(embed_size, hidden_size, vocab_size).eval().to(device)
encoder.load_state_dict(torch.load(os.path.join('/content/models', encoder_file), map_location=device))
decoder.load_state_dict(torch.load(os.path.join('/content/models', decoder_file), map_location=device))

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 81.9MB/s]


<All keys matched successfully>

In [18]:
# Image transformation
transform_image = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [19]:
def clean_sentence(output):
    cleaned_list = []
    for index in output:
        if  (index == 1) :
            continue
        cleaned_list.append(vocab.idx2word[index])
    cleaned_list = cleaned_list[1:-1] # Discard <start> and <end>

    sentence = ' '.join(cleaned_list) # Convert list of string to
    sentence = sentence.capitalize()
    return sentence

In [25]:
def generate_captions(encoder, decoder, image_path, vocab):
    image = Image.open(image_path).convert('RGB')
    image = transform_image(image).unsqueeze(0).to(device)

    features = encoder(image).unsqueeze(1)
    output = decoder.sample(features)
    sentence = clean_sentence(output)

    return sentence

In [21]:
def generate_all_captions(encoder, decoder, image_dir, dataset, vocab):
    results = []
    for img_id in dataset.imgs:
        img_path = os.path.join(image_dir, dataset.imgs[img_id]['file_name'])
        caption = generate_captions(encoder, decoder, img_path, vocab)
        results.append({'image_id': img_id, 'caption': caption})
    return results

In [22]:
# Load the COCO validation annotations
os.chdir('/content/opt/cocoapi/annotations')
annFile = 'instances_val2014.json'
coco = COCO(annFile)

loading annotations into memory...
Done (t=9.04s)
creating index...
index created!


In [23]:
# Assuming image_dir points to the folder containing validation images
image_dir = '/content/opt/cocoapi/images/val2014'

In [26]:
# Generate captions for the validation images
generated_captions = generate_all_captions(encoder, decoder, image_dir, coco, vocab)

In [36]:
del generated_captions

In [29]:
# Save the generated captions to a JSON file
with open('/content/generated_captions.json', 'w') as f:
    json.dump(generated_captions, f)

In [30]:
# Load the reference captions
coco_caps = COCO('/content/opt/cocoapi/annotations/captions_val2014.json')

loading annotations into memory...
Done (t=1.00s)
creating index...
index created!


In [31]:
# Prepare for evaluation
coco_res = coco_caps.loadRes('/content/generated_captions.json')

Loading and preparing results...
DONE (t=0.10s)
creating index...
index created!


In [32]:
# Create a COCOEval object by passing the reference and generated captions
cocoEval = COCOEvalCap(coco_caps, coco_res)

In [33]:
# Evaluate for the specified image IDs
cocoEval.params['image_id'] = coco_res.getImgIds()

In [34]:
# Evaluate the results for all metrics
cocoEval.evaluate()

tokenization...
setting up scorers...
Downloading stanford-corenlp-3.6.0 for SPICE ...
Progress: 384.5M / 384.5M (100.0%)
Extracting stanford-corenlp-3.6.0 ...
Done.
computing Bleu score...
{'testlen': 406541, 'reflen': 399519, 'guess': [406541, 366037, 325533, 285029], 'correct': [245374, 104370, 38398, 14396]}
ratio: 1.017576135302699
Bleu_1: 0.604
Bleu_2: 0.415
Bleu_3: 0.273
Bleu_4: 0.179
computing METEOR score...
METEOR: 0.188
computing Rouge score...
ROUGE_L: 0.445
computing CIDEr score...
CIDEr: 0.526
computing SPICE score...


CalledProcessError: Command '['java', '-jar', '-Xmx8G', 'spice-1.0.jar', '/usr/local/lib/python3.10/dist-packages/pycocoevalcap/spice/tmp/tmprwbvv0ki', '-cache', '/usr/local/lib/python3.10/dist-packages/pycocoevalcap/spice/cache', '-out', '/usr/local/lib/python3.10/dist-packages/pycocoevalcap/spice/tmp/tmpfmp51nmu', '-subset', '-silent']' died with <Signals.SIGKILL: 9>.

In [35]:
# Print the evaluation scores for each metric
for metric, score in cocoEval.eval.items():
    print(f'{metric}: {score:.3f}')

Bleu_1: 0.604
Bleu_2: 0.415
Bleu_3: 0.273
Bleu_4: 0.179
METEOR: 0.188
ROUGE_L: 0.445
CIDEr: 0.526
