In [1]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

from tqdm.notebook import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split

In [2]:
from evaluate import load

wer_metric = load("wer")

2025-04-22 12:19:09.792550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745306349.808983    6417 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745306349.814074    6417 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1745306349.827436    6417 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745306349.827450    6417 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745306349.827452    6417 computation_placer.cc:177] computation placer alr

In [3]:
batch_size = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
data = json.load(open('annotations/captions_train2014.json'))

In [5]:
len(data['annotations'])

414113

In [6]:
id2filename = {i['id']:i['file_name'] for i in data['images']}

In [7]:
df = pd.DataFrame(data['annotations'])

In [8]:
df['file_name'] = df['image_id'].map(id2filename)

In [9]:
df

Unnamed: 0,image_id,id,caption,file_name
0,318556,48,A very clean and well decorated empty bathroom,COCO_train2014_000000318556.jpg
1,116100,67,A panoramic view of a kitchen and all of its a...,COCO_train2014_000000116100.jpg
2,318556,126,A blue and white bathroom with butterfly theme...,COCO_train2014_000000318556.jpg
3,116100,148,A panoramic photo of a kitchen and dining room,COCO_train2014_000000116100.jpg
4,379340,173,A graffiti-ed stop sign across the street from...,COCO_train2014_000000379340.jpg
...,...,...,...,...
414108,133071,829655,a slice of bread is covered with a sour cream ...,COCO_train2014_000000133071.jpg
414109,410182,829658,A long plate hold some fries with some sliders...,COCO_train2014_000000410182.jpg
414110,180285,829665,Two women sit and pose with stuffed animals.,COCO_train2014_000000180285.jpg
414111,133071,829693,White Plate with a lot of guacamole and an ext...,COCO_train2014_000000133071.jpg


In [10]:
words = sorted(list(set(' '.join(df['caption']).split())))
print(len(words))

44535


In [11]:
vocabulary = ["[PAD]"] + words
print(len(vocabulary))
idx2word = {k:v for k,v in enumerate(vocabulary, start=0)}
word2idx = {v:k for k,v in idx2word.items()}

44536


In [12]:
train_df, test_df = train_test_split(df, test_size=0.1)
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [13]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, root_dir, df, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.df['file_name'][idx]
        text = self.df['caption'][idx]
        image = Image.open(self.root_dir + file_name).convert("RGB")
        image = self.transform(image)
        return image, text
    
    def transform(self, image):
        
        transform_ops = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform_ops(image)

In [14]:
train_dataset = ImageCaptioningDataset(root_dir='train2014/',
                           df=train_df)
eval_dataset = ImageCaptioningDataset(root_dir='train2014/',
                           df=test_df)
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=10, shuffle=True)
test_loader = DataLoader(eval_dataset, batch_size=batch_size, num_workers=10, shuffle=False)
print(len(train_loader), len(test_loader))

5824 648


In [15]:
num_words = len(word2idx)
print(num_words)
rnn_hidden_size = 256

44536


In [16]:
resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

In [17]:
class CRNN(nn.Module):
    
    def __init__(self, num_words, rnn_hidden_size=256, dropout=0.1):
        
        super(CRNN, self).__init__()
        self.num_cwords = num_words
        self.rnn_hidden_size = rnn_hidden_size
        self.dropout = dropout
        
        resnet_modules = list(resnet.children())[:-3]
        self.cnn = nn.Sequential(
            *resnet_modules
        )
        
        self.linear1 = nn.Linear(14336, rnn_hidden_size, bias=False)
        
        self.rnn1 = nn.GRU(input_size=rnn_hidden_size, 
                            hidden_size=rnn_hidden_size,
                            bidirectional=True, 
                            batch_first=True)
        self.rnn2 = nn.GRU(input_size=rnn_hidden_size*2, 
                            hidden_size=rnn_hidden_size*2,
                            bidirectional=True, 
                            batch_first=True)
        self.linear2 = nn.Linear(self.rnn_hidden_size*4, num_words)
        
    def forward(self, batch):
        
        batch = self.cnn(batch)
        
        batch = batch.permute(0, 3, 1, 2) # [batch_size, width, channels, height]
         
        batch_size = batch.size(0)
        width = batch.size(1)
        batch = batch.view(batch_size, width, -1) # [batch_size, T==width, num_features==channels*height]
        
        batch = self.linear1(batch)
        
        batch, hidden = self.rnn1(batch)
        
        batch, hidden = self.rnn2(batch)
        
        batch = self.linear2(batch)
        
        batch = batch.permute(1, 0, 2)
        return batch

In [18]:
crnn = CRNN(num_words, rnn_hidden_size=rnn_hidden_size)
crnn = crnn.to(device)

In [19]:
def encode_text(text):
    
    text_batch_targets = [word2idx[c] for c in text.split()][:14]
    text_batch_targets = text_batch_targets + [0] * (14 - len(text_batch_targets))
    text_batch_targets = torch.LongTensor(text_batch_targets)
    
    return text_batch_targets.unsqueeze(0)

In [20]:
def decode_predictions(text_batch_logits):

    text_batch_tokens = text_batch_logits.argmax(2) # [T, batch_size]
    text_batch_tokens = text_batch_tokens.numpy().T # [batch_size, T]

    text_batch_tokens_new = []
    for text_tokens in text_batch_tokens:
        text = [idx2word[idx] for idx in text_tokens if idx != 0]
        text = " ".join(text)
        text_batch_tokens_new.append(text)

    return text_batch_tokens_new

In [21]:
num_epochs = 25
lr = 1e-3
clip_norm = 5

criterion = nn.CTCLoss(blank=0)
optimizer = optim.AdamW(crnn.parameters(), lr=lr)

In [22]:
def compute_loss(text_batch, text_batch_logits):
    """
    text_batch: list of strings of length equal to batch size
    text_batch_logits: Tensor of size([T, batch_size, num_classes])
    """

    text_batch_targets = torch.cat([encode_text(text) for text in text_batch]).to(device)
    target_lengths = [int((i > 0).sum()) for i in text_batch_targets]
    
    loss = criterion(
        nn.functional.log_softmax(text_batch_logits, dim=2), 
        text_batch_targets, 
        input_lengths=[14]*len(target_lengths), 
        target_lengths=target_lengths
    )
    return loss

In [23]:
scaler = torch.amp.GradScaler('cuda', enabled = True)

epoch_losses = []
iteration_losses = []
num_updates_epochs = []
for epoch in tqdm(range(1, num_epochs+1)):
    
    crnn.train()
    
    epoch_loss_list = [] 
    num_updates_epoch = 0
    for image_batch, text_batch in tqdm(train_loader, leave=False):
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda', enabled = True):
            text_batch_logits = crnn(image_batch.to(device))
            loss = compute_loss(text_batch, text_batch_logits)
            
        iteration_loss = loss.item()
        
        if iteration_loss == float('inf'):
            continue
          
        epoch_loss_list.append(iteration_loss)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(crnn.parameters(), clip_norm)
        scaler.step(optimizer)

        scaler.update()

    crnn.eval()
    
    pred_str = []
    label_str = []
    for image_batch, text_batch in tqdm(test_loader, leave=False):
        with torch.amp.autocast('cuda', enabled = True):
            text_batch_logits = crnn(image_batch.to(device))
            
        pred_text_batch = decode_predictions(text_batch_logits.cpu())
        
        pred_str += pred_text_batch
        label_str += text_batch
        
    epoch_loss = np.mean(epoch_loss_list)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    print()
    print(f"Epoch:{epoch}    Loss:{epoch_loss}   WER:{wer}")
    print()
    for p, l in zip(pred_str[:10], label_str[:10]):
        print(l, '->', p)

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:1    Loss:5.1446688430858085   WER:0.8169326445937329

Men dressed as bears are riding a motorcycle -> A motorcycle of a street.
Two baby giraffes are standing in a grassy field.  -> A grass in a field.
Two women are sitting on a bench talking -> A man of standing on a horse.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A sitting on a street.
some elephants in some tall green grass and some trees -> A in a field.
A young blonde girl holding up 2 cell phones. -> A man is front of a phone.
A bicycle chained to two poles mounted to the ground  -> A man motorcycle on a street.
Man in red and black wet suit riding surfboard. -> A man is surfboard on a surfboard.
A big pair of scissors sticking in something by a paper. -> A next on a it.
a single crane in mid-flight along a tree scape -> A grass in a field.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:2    Loss:4.696300023512225   WER:0.804455399798553

Men dressed as bears are riding a motorcycle -> A man is motorcycle on a motorcycle.
Two baby giraffes are standing in a grassy field.  -> A giraffe in grass in a field.
Two women are sitting on a bench talking -> A man of standing on a other.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A top on a it.
some elephants in some tall green grass and some trees -> A herd of field in a field.
A young blonde girl holding up 2 cell phones. -> A man is front in a refrigerator.
A bicycle chained to two poles mounted to the ground  -> A motorcycle parked on a road.
Man in red and black wet suit riding surfboard. -> A man in surfboard on the surfboard.
A big pair of scissors sticking in something by a paper. -> A close next of a it.
a single crane in mid-flight along a tree scape -> A bird grass in a field.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:3    Loss:4.511418543368255   WER:0.797899128602715

Men dressed as bears are riding a motorcycle -> A man motorcycle on a motorcycle.
Two baby giraffes are standing in a grassy field.  -> A giraffes in in a field.
Two women are sitting on a bench talking -> A man and next on a street.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A contents items of a table.
some elephants in some tall green grass and some trees -> A of elephants in field.
A young blonde girl holding up 2 cell phones. -> A man in room in a room.
A bicycle chained to two poles mounted to the ground  -> A bicycle parked on a hydrant.
Man in red and black wet suit riding surfboard. -> A man riding a surfboard on the water.
A big pair of scissors sticking in something by a paper. -> A on next on a table.
a single crane in mid-flight along a tree scape -> A bird grass on a grass.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:4    Loss:4.366688799645777   WER:0.7955381317168282

Men dressed as bears are riding a motorcycle -> A next on luggage.
Two baby giraffes are standing in a grassy field.  -> Two giraffes in in a field.
Two women are sitting on a bench talking -> A man next in a phone.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A next on table.
some elephants in some tall green grass and some trees -> A herd of elephants in field.
A young blonde girl holding up 2 cell phones. -> A woman girl front in a computer.
A bicycle chained to two poles mounted to the ground  -> A next on a meter.
Man in red and black wet suit riding surfboard. -> A man is surfboard in the surfing.
A big pair of scissors sticking in something by a paper. -> A pair scissors on a surface.
a single crane in mid-flight along a tree scape -> A giraffe in grass in a grass.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:5    Loss:4.202155543639228   WER:0.7870274540968609

Men dressed as bears are riding a motorcycle -> A man motorcycle on a motorcycle.
Two baby giraffes are standing in a grassy field.  -> A giraffe in in a field.
Two women are sitting on a bench talking -> A man and talking on a street
Supplies and tools are arranged on the floor with a hard hat and bag. -> A next on a it.
some elephants in some tall green grass and some trees -> A of elephants in a field.
A young blonde girl holding up 2 cell phones. -> A man in glass in a mirror.
A bicycle chained to two poles mounted to the ground  -> A red next on a meter.
Man in red and black wet suit riding surfboard. -> A man man a surfboard in a water.
A big pair of scissors sticking in something by a paper. -> A pair on top on a table.
a single crane in mid-flight along a tree scape -> A bird bird a bird on a branch.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:6    Loss:4.089978411819732   WER:0.7947110897548444

Men dressed as bears are riding a motorcycle -> A woman motorcycle of a motorcycle.
Two baby giraffes are standing in a grassy field.  -> Two giraffes in a field.
Two women are sitting on a bench talking -> A man and talking on a together.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A items next on floor.
some elephants in some tall green grass and some trees -> A herd of elephants in grass.
A young blonde girl holding up 2 cell phones. -> A woman in front in a phone.
A bicycle chained to two poles mounted to the ground  -> A bicycle bicycle next of a it.
Man in red and black wet suit riding surfboard. -> A woman woman surfboard in a ocean.
A big pair of scissors sticking in something by a paper. -> A pair scissors on top on a table.
a single crane in mid-flight along a tree scape -> A bird bird flight in a tree.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:7    Loss:4.000036643678143   WER:0.7850014323073085

Men dressed as bears are riding a motorcycle -> A man of a motorcycle on a motorcycle.
Two baby giraffes are standing in a grassy field.  -> A giraffe giraffe in a in a field.
Two women are sitting on a bench talking -> A man and talking on a street.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A next on a assignment
some elephants in some tall green grass and some trees -> A herd of in field.
A young blonde girl holding up 2 cell phones. -> A woman is talking on her phone.
A bicycle chained to two poles mounted to the ground  -> A sitting on a meter.
Man in red and black wet suit riding surfboard. -> A man on a surfboard in a water
A big pair of scissors sticking in something by a paper. -> A pair next on table.
a single crane in mid-flight along a tree scape -> A bird bird flight in the field.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:8    Loss:3.9242680596172748   WER:0.7891320217710711

Men dressed as bears are riding a motorcycle -> A in taxes
Two baby giraffes are standing in a grassy field.  -> Two giraffes in in a field.
Two women are sitting on a bench talking -> A man of on a street.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A hardhat electronics on a assignment
some elephants in some tall green grass and some trees -> A of elephants in grass.
A young blonde girl holding up 2 cell phones. -> A woman is kitchen in a kitchen
A bicycle chained to two poles mounted to the ground  -> A chained locked locked on a benches.
Man in red and black wet suit riding surfboard. -> A woman in surfboard on the wave.
A big pair of scissors sticking in something by a paper. -> A pair scissors on sitting on a paper.
a single crane in mid-flight along a tree scape -> A bird bird in day.


  0%|          | 0/5824 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:13    Loss:3.708645371386441   WER:0.787096759289206

Men dressed as bears are riding a motorcycle -> A costumes with on a taxes
Two baby giraffes are standing in a grassy field.  -> A giraffe on a a field.
Two women are sitting on a bench talking -> A man talking on a phone.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A items items with neatly on assignment
some elephants in some tall green grass and some trees -> A tide of elephants in in field.
A young blonde girl holding up 2 cell phones. -> A woman holding front in a computer.
A bicycle chained to two poles mounted to the ground  -> A bicycle chained tied on a museum.
Man in red and black wet suit riding surfboard. -> A man on on on a wave on the wave.
A big pair of scissors sticking in something by a paper. -> A hate on life!"
a single crane in mid-flight along a tree scape -> A giraffe flight flapping the day.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:14    Loss:3.68446057299988   WER:0.790665976694974

Men dressed as bears are riding a motorcycle -> A on motorcycle on a motorcycles.
Two baby giraffes are standing in a grassy field.  -> Two giraffes giraffe in in field.
Two women are sitting on a bench talking -> A woman and skirt, talking on a phone.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A hardhat items on items on a it.
some elephants in some tall green grass and some trees -> A in a field.
A young blonde girl holding up 2 cell phones. -> A woman in broken in a phone.
A bicycle chained to two poles mounted to the ground  -> A bicycle chained chained chained sitting in a building
Man in red and black wet suit riding surfboard. -> A man man on a wave in the ocean.
A big pair of scissors sticking in something by a paper. -> A pair on a next on a table.
a single crane in mid-flight along a tree scape -> A bird bird flight flight in day.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:15    Loss:3.659258283229712   WER:0.7816655423824352

Men dressed as bears are riding a motorcycle -> A on on a bikes.
Two baby giraffes are standing in a grassy field.  -> Two giraffes in in a field.
Two women are sitting on a bench talking -> A man and skirt, on a street.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A devices items and out on assignment
some elephants in some tall green grass and some trees -> A herd of elephants elephants field.
A young blonde girl holding up 2 cell phones. -> A woman woman front in a keyboard.
A bicycle chained to two poles mounted to the ground  -> A bicycle bicycle next on a benches.
Man in red and black wet suit riding surfboard. -> A man on on a surfboard on a wave
A big pair of scissors sticking in something by a paper. -> A pair on next on a table.
a single crane in mid-flight along a tree scape -> A bird Pelican flight in a day.


  0%|          | 0/5824 [00:00<?, ?it/s]

  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:16    Loss:3.6449137307778963   WER:0.7905389171756748

Men dressed as bears are riding a motorcycle -> A of bear-costumed bear-costumed motorcycles "no tax"
Two baby giraffes are standing in a grassy field.  -> Two giraffes standing in in a field.
Two women are sitting on a bench talking -> A couple and talking on a conversation.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A devices devices on neatly on assignment
some elephants in some tall green grass and some trees -> A of elephants walking in the grass.
A young blonde girl holding up 2 cell phones. -> A woman holding cellphone in keyboard.
A bicycle chained to two poles mounted to the ground  -> A display of a building.
Man in red and black wet suit riding surfboard. -> A boy boy board on the surfing.
A big pair of scissors sticking in something by a paper. -> A pair of next of a paper.
a single crane in mid-flight along a tree scape -> A Pelican Pelican flight in the cloudy day.


  0%|          | 0/5824 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  0%|          | 0/648 [00:00<?, ?it/s]


Epoch:20    Loss:3.5778672087458925   WER:0.7851977970189526

Men dressed as bears are riding a motorcycle -> A people of motorcycle on a taxes
Two baby giraffes are standing in a grassy field.  -> Two giraffes in field.
Two women are sitting on a bench talking -> A man and on talking on a cell phone.
Supplies and tools are arranged on the floor with a hard hat and bag. -> A devices items items on top on a assignment
some elephants in some tall green grass and some trees -> A tide elephants walking in in a grass.
A young blonde girl holding up 2 cell phones. -> A woman is phones in a phone.
A bicycle chained to two poles mounted to the ground  -> A locked sitting on a toothbrushes.
Man in red and black wet suit riding surfboard. -> A man is surfboard on a surfing.
A big pair of scissors sticking in something by a paper. -> A of a next on a scissors.
a single crane in mid-flight along a tree scape -> A bird bird flight in the day.


  0%|          | 0/5824 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

