In [64]:
import os
import nltk
import pickle
import numpy as np
from PIL import Image
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence
from collections import Counter
from torch.utils.data import Dataset,DataLoader
import re

In [65]:
# nltk.download()

In [68]:
class Vocab(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.w2i = {}
        self.i2w = {}
        self.index = 0

    def __call__(self, token):
        return self.w2i.get(token, self.w2i.get('<unk>', 0))

    def __len__(self):
        return len(self.w2i)

    def add_token(self, token):
        if token not in self.w2i:
            self.w2i[token] = self.index
            self.i2w[self.index] = token
            self.index += 1

    @staticmethod
    def build_vocabulary(file_path, threshold):
        image_caption = {}
        counter = Counter()

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.startswith('image'):
                    parts = line.split('.jpg')
                    if len(parts) < 2:
                        continue

                    caption = parts[1].lower()
                    caption = re.sub(r'[^a-z\s]', '', caption).strip()

                    image_id = parts[0]
                    if image_id not in image_caption:
                        image_caption[image_id] = [caption]
                    else:
                        image_caption[image_id].append(caption)

                    tokens = nltk.tokenize.word_tokenize(caption)
                    counter.update(tokens)

        # Filter words based on frequency threshold
        tokens = [token for token, cnt in counter.items() if cnt >= threshold]

        # Create a vocab wrapper and add special tokens
        vocab = Vocab()
        vocab.add_token('<pad>')
        vocab.add_token('<start>')
        vocab.add_token('<end>')
        vocab.add_token('<unk>')

        # Add the remaining words to vocabulary
        for token in tokens:
            vocab.add_token(token)

        return vocab

# Build vocabulary
file_path = "C:\\Users\\naman\\Downloads\\flickr30k_images\\captions.txt"
vocab = Vocab.build_vocabulary(file_path, threshold=4)

# Save vocabulary
vocab_path = '.\\data_dir\\vocabulary.pkl'
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)

print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'".format(vocab_path))


Total vocabulary size: 8577
Saved the vocabulary wrapper to '.\data_dir\vocabulary.pkl'


In [69]:
transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])