https://github.com/Amir-Hofo

--------

# 00_Arguments

In [1]:
system= "colab"
grid_search= True
seed= False

wandb_enable= False
if wandb_enable:
  wandb_arg_name= input('Please input the WandB argument name:')

batch_size= None
seq_len= None

embedding_dim= None
num_layers= None
hidden_dim= None
weight_drop= None

lr= None
wd= None
momentum= None
clip= None

-----------

# 01_Library

## insatll

In [3]:
! python --version
import torch
for lib in [torch, torchvision]:
  print(lib.__name__, '-->', lib.__version__)

Python 3.11.11
torch --> 2.6.0+cu124
torchvision --> 0.21.0+cu124


In [1]:
try:
    import torchtext
except ImportError:
    ! pip install -q torchtext==0.17.0
    import torchtext

try:
    import torchvision
except (ImportError, OSError):
    ! pip uninstall -q -y torchvision
    ! pip install -q torchvision==0.17.0
    import torchvision

! pip install -q torchmetrics tqdm wandb

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m755.5/755.5 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m80.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m69.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

AttributeError: module 'torch.library' has no attribute 'register_fake'

## import

In [None]:
import os
import urllib.request
import zipfile
from collections import Counter
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import wandb
import tqdm
import torchmetrics as tm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset, random_split

from torchvision import transforms
from torchvision.datasets import VisionDataset

from PIL import Image

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [None]:
! python --version
for lib in [torch, torchtext, torchvision]:
  print(lib.__name__, '-->', lib.__version__)

-------

# 02_Utils

## system

In [None]:
if system== "local":
    project_path= r"./"
    dataset_path= './dataset/'

elif system== "colab":
    root_path= '/content/'
    project_path= r"/content/drive/MyDrive/Catalist/2_image captioning/"
    dataset_path= os.path.join(project_path, r'dataset/')

else:
  raise ValueError("Invalid system")

## device

In [None]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'
device

## drive mount

In [None]:
if system== "colab":
    from google.colab import drive
    drive.mount('/content/drive')

## number of params fn

In [None]:
def num_trainable_params(model):
  nums= sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

## avragemeter

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## set seed

In [None]:
def set_seed(seed):
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
      torch.cuda.manual_seed(seed)

--------------

# 03_Data

## download dataset

In [None]:
image_link= "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
caption_link= "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"

files = {
    "Flickr8k_Dataset.zip": image_link,
    "Flickr8k_text.zip": caption_link
}

for filename, url in files.items():
  if not os.path.exists(os.path.join(dataset_path, filename)):
    urllib.request.urlretrieve(url, os.path.join(dataset_path, filename))
    print(f"{filename} has been downloaded.")
  else:
    print(f"{filename} already exists.")

## unzip

In [None]:
files= ["Flickr8k_Dataset.zip", "Flickr8k_text.zip"]
data_path= os.path.join(root_path, "dataset")
os.makedirs(data_path, exist_ok= True)
for file in files:
    with zipfile.ZipFile(os.path.join(dataset_path, file), 'r') as zip_ref:
        zip_ref.extractall(data_path)
    print(f"{file} extraction is complete.")

## custom dataset

### pytorch Flickr30k class

In [None]:
# import glob
# import os
# from collections import defaultdict
# from html.parser import HTMLParser
# from pathlib import Path
# from typing import Any, Callable, Dict, List, Optional, Tuple, Union

# from .folder import default_loader
# from .vision import VisionDataset


# class Flickr30k(VisionDataset):
#     """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.

#     Args:
#         root (str or ``pathlib.Path``): Root directory where images are downloaded to.
#         ann_file (string): Path to annotation file.
#         transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
#             and returns a transformed version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform that takes in the
#             target and transforms it.
#         loader (callable, optional): A function to load an image given its path.
#             By default, it uses PIL as its image loader, but users could also pass in
#             ``torchvision.io.decode_image`` for decoding image data into tensors directly.
#     """

#     def __init__(
#         self,
#         root: str,
#         ann_file: str,
#         transform: Optional[Callable] = None,
#         target_transform: Optional[Callable] = None,
#         loader: Callable[[str], Any] = default_loader,
#     ) -> None:
#         super().__init__(root, transform=transform, target_transform=target_transform)
#         self.ann_file = os.path.expanduser(ann_file)

#         # Read annotations and store in a dict
#         self.annotations = defaultdict(list)
#         with open(self.ann_file) as fh:
#             for line in fh:
#                 img_id, caption = line.strip().split("\t")
#                 self.annotations[img_id[:-2]].append(caption)

#         self.ids = list(sorted(self.annotations.keys()))
#         self.loader = loader

#     def __getitem__(self, index: int) -> Tuple[Any, Any]:
#         """
#         Args:
#             index (int): Index

#         Returns:
#             tuple: Tuple (image, target). target is a list of captions for the image.
#         """
#         img_id = self.ids[index]

#         # Image
#         filename = os.path.join(self.root, img_id)
#         img = self.loader(filename)
#         if self.transform is not None:
#             img = self.transform(img)

#         # Captions
#         target = self.annotations[img_id]
#         if self.target_transform is not None:
#             target = self.target_transform(target)

#         return img, target

#     def __len__(self) -> int:
#         return len(self.ids)

### custom Flickr dataset

In [None]:
class Flickr8k(VisionDataset):
    """
    Args:
        root (string): Root directory where images are downloaded to.
        ann_file (string): Path to annotation file.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.PILToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self,
                 root: str,
                 ann_file: str,
                 split_file: str,
                 train: bool,
                 transform: Optional[Callable]= None,
                 target_transform: Optional[Callable]= None)
        super().__init__(root, transform= transform,
                         target_transform= target_transform)
        self.ann_file= os.path.expanduser(ann_file)
        self.train= train

        # Read {train/dev/test} files
        with open(split_file) as f:
            self.split_samples= f.read().strip().split("\n")

        # Read annotations and store in a dict
        self.ids, self.captions= [], []
        with open(self.ann_file) as fh:
            for line in fh:
                img_id, caption= line.strip().split("\t")
                if img_id[:-2] in self.split_samples:
                    self.ids.append(img_id[:-2])
                    self.captions.append(caption)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is a list of captions for the image.
        """
        img_id= self.ids[index]

        # Image
        filename= os.path.join(self.root, img_id)
        img_raw= Image.open(filename).convert("RGB")
        if self.transform is not None:
            img= self.transform(img_raw)

        # Captions
        caption= self.captions[index]
        if self.target_transform is not None:
            target= self.target_transform(caption)

        if self.train:
            return img, target
        else:
          return img, img_raw, caption

    def __len__(self) -> int:
        return len(self.ids)

### caption transform

In [None]:
class CaptionTransform:

    def __init__(self, caption_file):
        captions= self._load_captions(caption_file)

        self.tokenizer= get_tokenizer('basic_english')
        self.vocab= build_vocab_from_iterator(map(self.tokenizer, captions),
                                              specials= ['<pad>', '<unk>', '<sos>', '<eos>'])
        self.vocab.set_default_index(self.vocab['<unk>'])
        torch.save(self.vocab, 'vocab.pt')

    def __call__(self, caption):
        indices= self.vocab(self.tokenizer(caption))
        indices= self.vocab(['<sos>']) + indices + self.vocab(['<eos>'])
        target= torch.LongTensor(indices)
        return target

    def __repr__(self):
        return f"""CaptionTransform([
          _load_captions(),
          toknizer('basic_english'),
          vocab(vocab_size={len(self.vocab)}) ])
          """

    def _load_captions(self, caption_file):
        captions= []
        with open(caption_file) as f:
            for line in f:
                _, caption= line.strip().split("\t")
                captions.append(caption)
        return captions

In [None]:
caption_transform= CaptionTransform(os.path.join(data_path, 'Flickr8k.token.txt'))

## transform

In [None]:
train_transform= transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ])

eval_transform= transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ])

## dataset

In [None]:
split_file= lambda phase: f'{data_path}Flickr_8k.{phase}Images.txt'

train_set= Flickr8k(root, ann_file, split_file('train'),
                    True, train_transform, caption_transform)
valid_set= Flickr8k(root, ann_file, split_file('dev'),
                    True, eval_transform, caption_transform)
test_set= Flickr8k(root, ann_file, split_file('test'),
                   False, eval_transform, caption_transform)

len(train_set), len(valid_set), len(test_set)