https://github.com/Amir-Hofo

--------

# 00_Arguments

In [13]:
system= "colab"
grid_search= True
seed= False

wandb_enable= False
if wandb_enable:
  wandb_arg_name= input('Please input the WandB argument name:')

batch_size= 128
max_seq_length= 20

embed_size= 300
num_layers= 2
hidden_size= 500
dropout_embd= 0.5
dropout_rnn= 0.5


lr= 0.1
wd= 1e-4
momentum= 0.9
clip= 0.25

-----------

# 01_Library

## insatll

In [14]:
try:
    import torchtext
except ImportError:
    ! pip install -q torchtext==0.17.0
    import torchtext

try:
    import torchvision
except:
    ! pip uninstall -q -y torchvision
    ! pip install -q torchvision==0.17.0
    import torchvision

! pip install -q torchmetrics tqdm wandb torcheval

In [15]:
# !pip uninstall -q -y torch torchvision torchtext
# !pip install -q torch=2.2.2 torchtext==0.17.2 torchvision==0.17.2

## import

In [16]:
import os
import urllib.request
import zipfile
from collections import Counter
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import wandb
import tqdm
import torchmetrics as tm
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import optim
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, random_split

from torchvision import transforms
from torchvision.datasets import VisionDataset
from torchvision.models import resnet50, ResNet50_Weights

from PIL import Image

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

try:
    from torcheval.metrics import BLEUScore
except:
    import torcheval
    from bleu import BLEUScore

In [17]:
! python --version
for lib in [torch, torchtext, torchvision]:
  print(lib.__name__, '-->', lib.__version__)

Python 3.11.11
torch --> 2.2.0+cu121
torchtext --> 0.17.0+cpu
torchvision --> 0.17.0+cu121


-------

# 02_Utils

## system

In [18]:
if system== "local":
    project_path= r"./"
    dataset_path= './dataset/'

elif system== "colab":
    root_path= '/content/'
    project_path= r"/content/drive/MyDrive/Catalist/2_image captioning/"
    dataset_path= os.path.join(project_path, r'dataset/')

else:
  raise ValueError("Invalid system")

## device

In [19]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'
pin_memory= (device == 'cuda')
device

'cuda'

## drive mount

In [20]:
if system== "colab":
    from google.colab import drive
    drive.mount('/content/drive')

Mounted at /content/drive


## number of params fn

In [21]:
def num_trainable_params(model):
  nums= sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

## avragemeter

In [22]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## set seed

In [23]:
def set_seed(seed):
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
      torch.cuda.manual_seed(seed)

--------------

# 03_Data

## download dataset

In [24]:
image_link= "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
caption_link= "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"

files = {
    "Flickr8k_Dataset.zip": image_link,
    "Flickr8k_text.zip": caption_link
}

for filename, url in files.items():
  if not os.path.exists(os.path.join(dataset_path, filename)):
    urllib.request.urlretrieve(url, os.path.join(dataset_path, filename))
    print(f"{filename} has been downloaded.")
  else:
    print(f"{filename} already exists.")

Flickr8k_Dataset.zip already exists.
Flickr8k_text.zip already exists.


## unzip

In [25]:
files= ["Flickr8k_Dataset.zip", "Flickr8k_text.zip"]
data_path= os.path.join(root_path, "dataset/")
os.makedirs(data_path, exist_ok= True)
for file in files:
    with zipfile.ZipFile(os.path.join(dataset_path, file), 'r') as zip_ref:
        zip_ref.extractall(data_path)
    print(f"{file} extraction is complete.")

Flickr8k_Dataset.zip extraction is complete.
Flickr8k_text.zip extraction is complete.


## custom dataset

### pytorch Flickr30k class

In [26]:
# import glob
# import os
# from collections import defaultdict
# from html.parser import HTMLParser
# from pathlib import Path
# from typing import Any, Callable, Dict, List, Optional, Tuple, Union

# from .folder import default_loader
# from .vision import VisionDataset


# class Flickr30k(VisionDataset):
#     """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.

#     Args:
#         root (str or ``pathlib.Path``): Root directory where images are downloaded to.
#         ann_file (string): Path to annotation file.
#         transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
#             and returns a transformed version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform that takes in the
#             target and transforms it.
#         loader (callable, optional): A function to load an image given its path.
#             By default, it uses PIL as its image loader, but users could also pass in
#             ``torchvision.io.decode_image`` for decoding image data into tensors directly.
#     """

#     def __init__(
#         self,
#         root: str,
#         ann_file: str,
#         transform: Optional[Callable] = None,
#         target_transform: Optional[Callable] = None,
#         loader: Callable[[str], Any] = default_loader,
#     ) -> None:
#         super().__init__(root, transform=transform, target_transform=target_transform)
#         self.ann_file = os.path.expanduser(ann_file)

#         # Read annotations and store in a dict
#         self.annotations = defaultdict(list)
#         with open(self.ann_file) as fh:
#             for line in fh:
#                 img_id, caption = line.strip().split("\t")
#                 self.annotations[img_id[:-2]].append(caption)

#         self.ids = list(sorted(self.annotations.keys()))
#         self.loader = loader

#     def __getitem__(self, index: int) -> Tuple[Any, Any]:
#         """
#         Args:
#             index (int): Index

#         Returns:
#             tuple: Tuple (image, target). target is a list of captions for the image.
#         """
#         img_id = self.ids[index]

#         # Image
#         filename = os.path.join(self.root, img_id)
#         img = self.loader(filename)
#         if self.transform is not None:
#             img = self.transform(img)

#         # Captions
#         target = self.annotations[img_id]
#         if self.target_transform is not None:
#             target = self.target_transform(target)

#         return img, target

#     def __len__(self) -> int:
#         return len(self.ids)

### custom Flickr dataset

In [27]:
class Flickr8k(VisionDataset):
    """
    Args:
        root (string): Root directory where images are downloaded to.
        ann_file (string): Path to annotation file.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.PILToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self,
                 root: str,
                 ann_file: str,
                 split_file: str,
                 train: bool,
                 transform: Optional[Callable]= None,
                 target_transform: Optional[Callable]= None):
        super().__init__(root, transform= transform,
                         target_transform= target_transform)
        self.ann_file= os.path.expanduser(ann_file)
        self.train= train

        # Read {train/dev/test} files
        with open(split_file) as f:
            self.split_samples= f.read().strip().split("\n")

        # Read annotations and store in a dict
        self.ids, self.captions= [], []
        with open(self.ann_file) as fh:
            for line in fh:
                img_id, caption= line.strip().split("\t")
                if img_id[:-2] in self.split_samples:
                    self.ids.append(img_id[:-2])
                    self.captions.append(caption)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is a list of captions for the image.
        """
        img_id= self.ids[index]

        # Image
        filename= os.path.join(self.root, img_id)
        img_raw= Image.open(filename).convert("RGB")
        if self.transform is not None:
            img= self.transform(img_raw)

        # Captions
        caption= self.captions[index]
        if self.target_transform is not None:
            target= self.target_transform(caption)

        if self.train:
            return img, target
        else:
          return img, img_raw, caption

    def __len__(self) -> int:
        return len(self.ids)

### caption transform

In [28]:
class CaptionTransform:

    def __init__(self, caption_file):
        captions= self._load_captions(caption_file)

        self.tokenizer= get_tokenizer('basic_english')
        self.vocab= build_vocab_from_iterator(map(self.tokenizer, captions),
                                              specials= ['<pad>', '<unk>', '<sos>', '<eos>'])
        self.vocab.set_default_index(self.vocab['<unk>'])
        torch.save(self.vocab, 'vocab.pt')

    def __call__(self, caption):
        indices= self.vocab(self.tokenizer(caption))
        indices= self.vocab(['<sos>']) + indices + self.vocab(['<eos>'])
        target= torch.LongTensor(indices)
        return target

    def __repr__(self):
        return f"""CaptionTransform([
          _load_captions(),
          toknizer('basic_english'),
          vocab(vocab_size={len(self.vocab)}) ])
          """

    def _load_captions(self, caption_file):
        captions= []
        with open(caption_file) as f:
            for line in f:
                _, caption= line.strip().split("\t")
                captions.append(caption)
        return captions

In [29]:
caption_transform= CaptionTransform(os.path.join(data_path, 'Flickr8k.token.txt'))

## transform

In [30]:
train_transform= transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ])

eval_transform= transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ])

## dataset

In [31]:
split_file= lambda phase: f'{data_path}Flickr_8k.{phase}Images.txt'

train_set= Flickr8k(os.path.join(data_path, 'Flicker8k_Dataset'),
                    os.path.join(data_path, 'Flickr8k.token.txt'),
                    split_file('train'), True,
                    train_transform, caption_transform)

valid_set= Flickr8k(os.path.join(data_path, 'Flicker8k_Dataset'),
                    os.path.join(data_path, 'Flickr8k.token.txt'),
                    split_file('dev'), True,
                    eval_transform, caption_transform)

test_set= Flickr8k(os.path.join(data_path, 'Flicker8k_Dataset'),
                   os.path.join(data_path, 'Flickr8k.token.txt'),
                   split_file('test'), False,
                   eval_transform, caption_transform)

len(train_set), len(valid_set), len(test_set)

(30000, 5000, 5000)

## dataloader

In [32]:
def collate_fn(batch):
  if len(batch[0]) == 2:
      x_batch, y_batch= zip(*batch)
      x_batch= torch.stack(x_batch)
      y_batch= pad_sequence(y_batch, batch_first= True,
                            padding_value= caption_transform.vocab['<pad>'])
      return x_batch, y_batch
  else:
    x_batch, x_raw, captions= zip(*batch)
    x_batch= torch.stack(x_batch)
    return x_batch, x_raw, captions

In [33]:
train_loader= DataLoader(train_set, batch_size= batch_size,
                         shuffle= True, collate_fn= collate_fn,
                         pin_memory= pin_memory)
valid_loader= DataLoader(valid_set, batch_size= batch_size*2,
                         collate_fn= collate_fn, pin_memory= pin_memory)
test_loader= DataLoader(test_set, batch_size= batch_size*2,
                        collate_fn= collate_fn, pin_memory= pin_memory)

print("train batch size:",train_loader.batch_size,
     ", num of batch:", len(train_loader))
print("valid batch size:",valid_loader.batch_size,
     ", num of batch:", len(valid_loader))
print("Test batch size:",test_loader.batch_size,
     ", num of batch:", len(test_loader))

train batch size: 128 , num of batch: 235
valid batch size: 256 , num of batch: 20
Test batch size: 256 , num of batch: 20


-----------

# 04_Model

## encoder

In [34]:
class EncoderCNN(nn.Module):

  def __init__(self, embed_size):
    super(EncoderCNN, self).__init__()
    self.resnet= resnet50(weights= ResNet50_Weights.IMAGENET1K_V2)
    self.resnet.requires_grad_(False)
    feature_size= self.resnet.fc.in_features

    self.resnet.fc= nn.Identity()
    self.linear= nn.Linear(feature_size, embed_size)
    self.bn= nn.BatchNorm1d(embed_size)

  def forward(self, images):
    self.resnet.eval()
    with torch.no_grad():
      features= self.resnet(images)
    features= self.bn(self.linear(features))
    return features

## decoder

In [35]:
class DecoderRNN(nn.Module):

  def __init__(self, embed_size, hidden_size, vocab_size, num_layers,
               dropout_embd, dropout_rnn, max_seq_length= 20):
    super(DecoderRNN, self).__init__()
    self.embedding= nn.Embedding(vocab_size, embed_size,
                                 padding_idx= caption_transform.vocab['<pad>'])
    self.dropout_embd= nn.Dropout(dropout_embd)

    self.lstm= nn.LSTM(embed_size, hidden_size, num_layers,
                       dropout= dropout_rnn, batch_first= True)

    self.linear= nn.Linear(hidden_size, vocab_size)
    self.max_seq_length= max_seq_length
    self.init_weights()

  def init_weights(self):
      self.embedding.weight.data.uniform_(-0.1, 0.1)
      self.linear.bias.data.fill_(0)
      self.linear.weight.data.uniform_(-0.1, 0.1)

  def forward(self, features, captions):
    embeddings= self.dropout_embd(self.embedding(captions[:, :-1]))
    inputs= torch.cat((features.unsqueeze(1), embeddings), dim= 1)
    outputs, _= self.lstm(inputs)
    outputs= self.linear(outputs)
    return outputs

  def generate(self, features, captions):
    if len(captions) > 0:
        embeddings= self.dropout_embd(self.embedding(captions))
        inputs= torch.cat((features.unsqueeze(1), embeddings), dim= 1)
    else:
        inputs= features.unsqueeze(1)

    outputs, _= self.lstm(inputs)
    outputs= self.linear(outputs)
    return outputs

## custom model

In [36]:
class ImageCaptioning(nn.Module):

  def __init__(self, embed_size, hidden_size, vocab_size, num_layers,
               dropout_embd, dropout_rnn, max_seq_length= 20):
    super(ImageCaptioning, self).__init__()
    self.encoder= EncoderCNN(embed_size)
    self.decoder= DecoderRNN(embed_size, hidden_size, vocab_size, num_layers,
                             dropout_embd, dropout_rnn, max_seq_length)

  def forward(self, images, captions):
    features= self.encoder(images)
    outputs= self.decoder(features, captions)
    return outputs

  def generate(self, images, captions):
    features= self.encoder(images)
    outputs= self.decoder.generate(features, captions)
    return outputs

## configuration

In [37]:
loss_fn= nn.CrossEntropyLoss(ignore_index= caption_transform.vocab['<pad>'])

----------------

# Train one epoch

In [38]:
def train_one_epoch(model, train_loader, loss_fn, optimizer, metric= None, epoch= None):
  model.train()
  loss_train= AverageMeter()
  if metric: metric.reset()

  with tqdm.tqdm(train_loader, unit= 'batch') as tepoch:
    for inputs, targets in tepoch:
      if epoch:
        tepoch.set_description(f'Epoch {epoch}')

      inputs, targets= inputs.to(device), targets.to(device)
      outputs= model(inputs, targets)
      loss= loss_fn(outputs.reshape(-1, outputs.shape[-1]), targets.flatten())

      nn.utils.clip_grad.clip_grad_norm_(model.parameters(), max_norm= clip)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      loss_train.update(loss.item(), n= len(targets))
      if metric:
        metric.update(outputs, targets)
        metric_train_val= metric.compute().item()
      else:
        metric_train_val= None

      tepoch.set_postfix(loss= loss_train.avg, metric= metric_train_val)

    return model, loss_train.avg, metric_train_val

In [39]:
def evaluate(model, test_loader, loss_fn, metric= None):
  model.eval()
  loss_eval= AverageMeter()
  if metric: metric.reset()

  with torch.inference_mode():
    for inputs, targets in test_loader:
      inputs, targets= inputs.to(device), targets.to(device)
      outputs= model(inputs, targets)
      loss= loss_fn(outputs.reshape(-1, outputs.shape[-1]), targets.flatten())
      loss_eval.update(loss.item(), n= len(targets))
      if metric: metric(outputs, targets)

  return loss_eval.avg, metric.compute().item() if metric else None

---------

# 05_Experiments before the main training

## base loss

In [40]:
model= ImageCaptioning(embed_size, hidden_size, len(caption_transform.vocab),
                       num_layers, dropout_embd, dropout_rnn, max_seq_length).to(device)

loss_base, _= evaluate(model, valid_loader, loss_fn)
print(f'{loss_base:.2f}')

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 151MB/s]


9.097713442993165


## overfit on subset of data

In [41]:
num_epochs= 20
mini_train_size= 1000

optimizer= torch.optim.SGD(model.parameters(), lr= lr, momentum= momentum)
_, mini_train_dataset= random_split(train_set, (len(train_set)- mini_train_size,
                                                mini_train_size))
mini_train_loader= DataLoader(mini_train_dataset, 20, collate_fn= collate_fn)

for epoch in range(num_epochs):
  model, _, _= train_one_epoch(model, mini_train_loader, loss_fn, optimizer, None, epoch)

del mini_train_dataset, mini_train_loader

100%|██████████| 50/50 [00:12<00:00,  4.05batch/s, loss=7.43, metric=None]
Epoch 1: 100%|██████████| 50/50 [00:13<00:00,  3.83batch/s, loss=5.27, metric=None]
Epoch 2: 100%|██████████| 50/50 [00:09<00:00,  5.10batch/s, loss=4.81, metric=None]
Epoch 3: 100%|██████████| 50/50 [00:09<00:00,  5.05batch/s, loss=4.64, metric=None]
Epoch 4: 100%|██████████| 50/50 [00:10<00:00,  4.88batch/s, loss=4.55, metric=None]
Epoch 5: 100%|██████████| 50/50 [00:10<00:00,  4.94batch/s, loss=4.47, metric=None]
Epoch 6: 100%|██████████| 50/50 [00:09<00:00,  5.08batch/s, loss=4.39, metric=None]
Epoch 7: 100%|██████████| 50/50 [00:09<00:00,  5.19batch/s, loss=4.24, metric=None]
Epoch 8: 100%|██████████| 50/50 [00:10<00:00,  4.86batch/s, loss=4.11, metric=None]
Epoch 9: 100%|██████████| 50/50 [00:10<00:00,  4.95batch/s, loss=4.02, metric=None]
Epoch 10: 100%|██████████| 50/50 [00:09<00:00,  5.10batch/s, loss=3.91, metric=None]
Epoch 11: 100%|██████████| 50/50 [00:10<00:00,  4.73batch/s, loss=3.79, metric=None]

KeyboardInterrupt: 

## rough grid

In [42]:
num_epochs= 1
loss_grid= loss_base

for lr in [0.9, 0.5, 0.125, 0.005]:
  print(f'LR={lr}')

  model= ImageCaptioning(embed_size, hidden_size, len(caption_transform.vocab),
                         num_layers, dropout_embd, dropout_rnn, max_seq_length).to(device)
  # model= torch.load('model.pt')
  optimizer = optim.SGD(model.parameters(), lr= lr,
                        weight_decay= wd, momentum= momentum)

  for epoch in range(num_epochs):
    model, loss, _ = train_one_epoch(model, train_loader, loss_fn, optimizer, None, epoch+1)
  if loss< loss_grid:
    best_lr= lr
    loss_grid= loss
    print(f'best loss is: {loss_grid} with lr: {best_lr}')
  print()

LR=0.9


Epoch 1: 100%|██████████| 235/235 [04:43<00:00,  1.21s/batch, loss=4.43, metric=None]


best loss is: 4.433724770100912 with lr: 0.9

LR=0.5


Epoch 1: 100%|██████████| 235/235 [04:43<00:00,  1.21s/batch, loss=4.63, metric=None]



LR=0.125


Epoch 1: 100%|██████████| 235/235 [04:44<00:00,  1.21s/batch, loss=5.31, metric=None]



LR=0.005


Epoch 1: 100%|██████████| 235/235 [04:43<00:00,  1.21s/batch, loss=8.36, metric=None]







## grid search

In [44]:
num_epochs= 2
lr= best_lr

for wd in [1e-4, 0]:
  print(f'LR={lr}, WD={wd}')

  model= ImageCaptioning(embed_size, hidden_size, len(caption_transform.vocab),
                         num_layers, dropout_embd, dropout_rnn, max_seq_length).to(device)
  # model= torch.load('model.pt')
  optimizer = optim.SGD(model.parameters(), lr= lr,
                        weight_decay= wd, momentum= momentum)

  for epoch in range(num_epochs):
    model, _, _ = train_one_epoch(model, train_loader, loss_fn, optimizer, None, epoch+1)
  loss_valid, _= evaluate(model, valid_loader, loss_fn, None)
  print(f'Valid: Loss= {loss_valid:.4}')
  print()

LR=0.9, WD=0.0001


Epoch 1:  67%|██████▋   | 158/235 [03:12<01:33,  1.22s/batch, loss=4.79, metric=None]


KeyboardInterrupt: 

------------

# 06_Training