In [None]:
# # download images and annotations to the data directory
# !wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip -P ./data_dir/ -P ./data_dir/
# !wget http://images.cocodataset.org/zips/train2014.zip -P ./data_dir/
# !wget http://images.cocodataset.org/zips/val2014.zip -P ./data_dir/
# # extract zipped images and annotations and remove the zip files
# !unzip ./data_dir/annotations_trainval2014.zip -d ./data_dir/
# !rm ./data_dir/annotations_trainval2014.zip
# !unzip ./data_dir/train2014.zip -d ./data_dir/
# !rm ./data_dir/train2014.zip
# !unzip ./data_dir/val2014.zip -d ./data_dir/
# !rm ./data_dir/val2014.zip

# Preprocessing caption data

In [1]:
import nltk
from pycocotools.coco import COCO
from collections import Counter, OrderedDict
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms
from torch.nn.utils.rnn import pack_padded_sequence

In [2]:
nltk.download("punkt")

[nltk_data] Downloading package punkt to
[nltk_data]     /home/silly_ronny/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
TRAIN_PATH = "data_dir/annotations/captions_train2014.json"
VAL_PATH = "data_dir/annotations/captions_val2014.json"
TRAIN_IMAGE_PATH = "data_dir/train2014"
VAL_IMAGE_PATH = "data_dir/val2014"

In [4]:
import json
with open(TRAIN_PATH, 'r') as f:
            # annotations = json.load(f)['annotations']
            image_map = {img['id']: img['file_name'] for img in json.load(f)['images']}

In [5]:
from torch.utils.data import Dataset
import json
import os
from PIL import Image

class CocoCaptionDataset(Dataset):
    def __init__(self, annotation_file, image_dir, transforms=None, tokenizer=None):
        self.image_dir = image_dir
        self.transforms = transforms
        self.tokenizer = tokenizer
        
        # Load and parse annotations
        with open(annotation_file, 'r') as f:
            out = json.load(f)
        
        self.annotations = out['annotations']
        self.image_map = {img['id']: img['file_name'] for img in out['images']}
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, index):
        annotations = self.annotations[index]
        caption = annotations['caption']
        image_id = annotations['image_id']
        image_path = os.path.join(self.image_dir, self.image_map[image_id])
        
        try:
            # Load the image
            image = Image.open(image_path)
            
            # Ensure the image is in RGB format
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Apply transformations if specified
            if self.transforms:
                image = self.transforms(image)
            
            # Tokenize the caption if tokenizer is provided
            if self.tokenizer:
                caption = self.tokenizer(caption)
            
            return image, caption
        
        except Exception as e:
            print(f"Error processing file {image_path}: {e}")
            # Optional: Return a placeholder image and tokenized fallback caption
            placeholder_image = Image.new('RGB', (224, 224))  # Example: Blank image
            placeholder_caption = "<UNK>" if not self.tokenizer else self.tokenizer("<UNK>")
            return placeholder_image, placeholder_caption


In [6]:
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
weights = EfficientNet_V2_S_Weights.DEFAULT



train_transforms = weights.transforms()
val_transforms = weights.transforms()

tokenizer = nltk.word_tokenize

train_dataset = CocoCaptionDataset(TRAIN_PATH, TRAIN_IMAGE_PATH,
                                   transforms=train_transforms, tokenizer=tokenizer)
val_dataset = CocoCaptionDataset(VAL_PATH, VAL_IMAGE_PATH,
                                   transforms=val_transforms, tokenizer=tokenizer)

In [7]:
from typing import List
import string
import unicodedata
from tqdm import tqdm

class Vocab:
    def __init__(self):
        self.word2idx = {"<unk>": 0, "<pad>": 1}
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}
        self.n = 2
    
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.n
            self.idx2word[self.n] = word
            self.n = self.n+1
    
    def build_with_sentence(self, sentence, tokenizer, tokenized=False):
        if tokenized:
            for word in self.process_sentence(sentence):
                self.add_word(word)
        else:
            for word in self.process_sentence(tokenizer(sentence)):
                self.add_word(word)
            
    
    def build_vocab(self, list_of_sentences: List[str]):
        for sentence in tqdm(list_of_sentences):
            self.build_with_sentence(sentence=sentence)
            

    def process_sentence(self, tokens):
        """
        Normalize tokens by:
        - Converting to lowercase
        - Removing punctuation
        - Stripping whitespace
        - Normalizing Unicode characters to their ASCII equivalents
        """
        normalized_tokens = []
        for token in tokens:
            # Normalize Unicode characters to ASCII
            token = unicodedata.normalize('NFKD', token).encode('ascii', 'ignore').decode('utf-8')
            # Convert to lowercase
            token = token.lower()
            # Remove punctuation and strip whitespace
            token = token.strip(string.punctuation).strip()
            # Add to the result list if not empty
            if token:
                normalized_tokens.append(token)
        return normalized_tokens
        
    def __getitem__(self, word):
        return self.word2idx.get(word, 0)

In [None]:
# vocab = Vocab()
# for _, caption in tqdm(train_dataset):
#     vocab.build_with_sentence(caption, tokenizer=tokenizer, tokenized=True)

In [12]:
# for _, caption in tqdm(val_dataset):
#     vocab.build_with_sentence(caption, tokenizer, True)

100%|██████████| 202654/202654 [19:50<00:00, 170.20it/s]


In [9]:
import joblib
# joblib.dump(vocab, "vocab.z")
vocab = joblib.load("vocab_v2.z")

In [10]:
vocab.word2idx

{'<unk>': 1,
 '<pad>': 0,
 'a': 2,
 'very': 3,
 'clean': 4,
 'and': 5,
 'well': 6,
 'decorated': 7,
 'empty': 8,
 'bathroom': 9,
 'panoramic': 10,
 'view': 11,
 'of': 12,
 'kitchen': 13,
 'all': 14,
 'its': 15,
 'appliances': 16,
 'blue': 17,
 'white': 18,
 'with': 19,
 'butterfly': 20,
 'themed': 21,
 'wall': 22,
 'tiles': 23,
 'photo': 24,
 'dining': 25,
 'room': 26,
 'graffiti-ed': 27,
 'stop': 28,
 'sign': 29,
 'across': 30,
 'the': 31,
 'street': 32,
 'from': 33,
 'red': 34,
 'car': 35,
 'vandalized': 36,
 'beetle': 37,
 'on': 38,
 'road': 39,
 'border': 40,
 'butterflies': 41,
 'paint': 42,
 'walls': 43,
 'above': 44,
 'it': 45,
 'an': 46,
 'angled': 47,
 'beautifully': 48,
 'two': 49,
 'people': 50,
 'are': 51,
 'walking': 52,
 'down': 53,
 'beach': 54,
 'sink': 55,
 'toilet': 56,
 'inside': 57,
 'small': 58,
 'black': 59,
 'square': 60,
 'tile': 61,
 'floor': 62,
 'that': 63,
 'needs': 64,
 'repairs': 65,
 'vanity': 66,
 'contains': 67,
 'sinks': 68,
 'towel': 69,
 'for': 70,
 

In [11]:
# joblib.dump(vocab, "vocab_v2.z")

In [49]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
def collate_fn(batch):
    images, captions = zip(*batch)
    tokens = [torch.tensor([vocab[token] for token in caption])
              for caption in captions]
    lengths = [len(caption) for caption in captions]
    captions = pad_sequence(tokens, batch_first=True)
    # lengths = captions.shape[-2]
    # captions = pack_padded_sequence(input=captions, batch_first=True, lengths=lengths, enforce_sorted=False)
    return torch.stack(images), captions, lengths
    
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

In [51]:
for images, captions, lens in train_dataloader:
    print(images.shape, captions, lens)
    break


torch.Size([4, 3, 384, 384]) tensor([[   0,   12, 1538,   38,  258,   12,   31, 3108,   19,  701,    5,    2,
          617,    0,    0],
        [   0,  339, 5553,  578,  412,  330,   38,  258,   12,   31,  301,   12,
            2,  899,    0],
        [   0, 1053,  170,   78,    2,  443,   12,  279,   19, 2499,    0,    0,
            0,    0,    0],
        [   0,  169,  128,  927,   31,  174,  267,   12,   50,    0,    0,    0,
            0,    0,    0]]) [13, 15, 11, 9]


In [28]:
import torch
import torch.nn as nn

class VisionEncoder(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.model = efficientnet_v2_s(weights=weights)
        self.model.classifier = nn.Sequential(nn.Linear(1280, embedding_size))
        self.batch_norm = nn.BatchNorm1d(embedding_size, momentum=0.01)
    
    def forward(self, input_images):
        with torch.no_grad():
            features = self.model(input_images)
            features = self.batch_norm(features)
        
        return features

In [29]:
VisionEncoder(768)(torch.randn(4, 3, 384, 384)).shape

torch.Size([4, 768])

In [30]:
sum([p.numel() for p in VisionEncoder(768).parameters()])

21162832

In [62]:
class TextDecoder(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocabulary_size, num_layers, max_seq_len=20):
        super().__init__()
        self.embedding = nn.Embedding(vocabulary_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True)
        self.linear_layer = nn.Linear(hidden_size, vocabulary_size)
        self.max_seq_len = max_seq_len
    
    def forward(self, input_features, capts, lens):
        embeddings = self.embedding(capts)
        features = torch.cat([input_features.unsqueeze(1), embeddings], dim=1)
        features = pack_padded_sequence(features, lens, batch_first=True, enforce_sorted=False)
        hidden, _ = self.rnn(features)
        return self.linear_layer(hidden[0])
    
    def sample(self, input_features, lstm_states):
        sampled_indices = []
        lstm_inputs = input_features.unsqueeze(1)
        for i in range(self.max_seq_len):
            hidden_variables, lstm_states = self.lstm_layer(lstm_inputs, lstm_states)          # hiddens: (batch_size, 1, hidden_size)
            model_outputs = self.linear_layer(hidden_variables.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted_outputs = model_outputs.max(1)                        # predicted: (batch_size)
            sampled_indices.append(predicted_outputs)
            lstm_inputs = self.embedding_layer(predicted_outputs)                       # inputs: (batch_size, embed_size)
            lstm_inputs = lstm_inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_indices = torch.stack(sampled_indices, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_indices
        

In [72]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
for image, captions, lens in train_dataloader:
    print(TextDecoder(768, 512, len(vocab.word2idx), 1)(torch.randn((4, 768)), captions, lens).shape)
    break

torch.Size([39, 29645])


# Training

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [73]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
encoder_model = VisionEncoder(256).to(device)
decoder_model = TextDecoder(256, 512, len(vocab.word2idx), 1).to(device)

criterion = nn.CrossEntropyLoss()
parameters = list(decoder_model.parameters()) + \
    list(encoder_model.model.classifier.parameters()) + \
    list(encoder_model.batch_norm.parameters())

optimizer = torch.optim.Adam(parameters, lr=0.001)

In [None]:
total_num_steps = len(train_dataloader)
import numpy as np

for epoch in range(1):
    for i, (imgs, capts, lens) in enumerate(train_dataloader):
        imgs = imgs.to(device)
        capts = capts.to(device)
        tgts = pack_padded_sequence(capts, lens, batch_first=True, enforce_sorted=False)
        optimizer.zero_grad()
        
        feats = encoder_model(imgs)
        output = decoder_model(feats, capts, lens)

        loss = criterion(output, tgts.data)
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, 5, i, total_num_steps, loss.item(), np.exp(loss.item()))) 
 
        # Save the model checkpoints
        if (i+1) % 1000 == 0:
            torch.save(decoder_model.state_dict(), os.path.join(
                'models_dir/', 'decoder-{}-{}.pt'.format(epoch+1, i+1)))
            torch.save(encoder_model.state_dict(), os.path.join(
                'models_dir/', 'encoder-{}-{}.pt'.format(epoch+1, i+1)))

Epoch [0/5], Step [0/103529], Loss: 10.3047, Perplexity: 29874.0555
Epoch [0/5], Step [10/103529], Loss: 9.5898, Perplexity: 14614.5124
Epoch [0/5], Step [20/103529], Loss: 6.7709, Perplexity: 872.0854
Epoch [0/5], Step [30/103529], Loss: 6.0673, Perplexity: 431.5164
Epoch [0/5], Step [40/103529], Loss: 6.4110, Perplexity: 608.4881
Epoch [0/5], Step [50/103529], Loss: 6.7764, Perplexity: 876.9159
Epoch [0/5], Step [60/103529], Loss: 6.0911, Perplexity: 441.9282
Epoch [0/5], Step [70/103529], Loss: 6.3808, Perplexity: 590.3780
Epoch [0/5], Step [80/103529], Loss: 5.5497, Perplexity: 257.1693
Epoch [0/5], Step [90/103529], Loss: 6.0465, Perplexity: 422.6142
Epoch [0/5], Step [100/103529], Loss: 5.5821, Perplexity: 265.6391
Epoch [0/5], Step [110/103529], Loss: 5.7238, Perplexity: 306.0673
Epoch [0/5], Step [120/103529], Loss: 5.9305, Perplexity: 376.3278
Epoch [0/5], Step [130/103529], Loss: 6.4169, Perplexity: 612.0779
Epoch [0/5], Step [140/103529], Loss: 6.5135, Perplexity: 674.2010
E