# Image Captioning with Transformers

In [1]:
!nvidia-smi

Sun Jun 27 19:26:32 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!apt install -qq pigz
%pip install -q timm wandb
%pip install -q --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110

The following NEW packages will be installed:
  pigz
0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.
Need to get 57.4 kB of archives.
After this operation, 259 kB of additional disk space will be used.
Selecting previously unselected package pigz.
(Reading database ... 160772 files and directories currently installed.)
Preparing to unpack .../archives/pigz_2.4-1_amd64.deb ...
Unpacking pigz (2.4-1) ...
Setting up pigz (2.4-1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
[K     |████████████████████████████████| 348kB 29.6MB/s 
[K     |████████████████████████████████| 1.8MB 40.2MB/s 
[K     |████████████████████████████████| 133kB 45.9MB/s 
[K     |████████████████████████████████| 174kB 52.3MB/s 
[K     |████████████████████████████████| 102kB 12.4MB/s 
[K     |████████████████████████████████| 71kB 12.0MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
[K     

In [3]:
!git clone https://github.com/ShivamShrirao/Image-Captioning-Transformers

Cloning into 'Image-Captioning-Transformers'...
remote: Enumerating objects: 49, done.[K
remote: Counting objects: 100% (49/49), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 49 (delta 18), reused 41 (delta 10), pack-reused 0[K
Unpacking objects: 100% (49/49), done.


# Download Dataset and Annotations

In [4]:
!mkdir ~/.kaggle/
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [5]:
!kaggle datasets download -d shivamshrirao/coco-trainval2017-320x320

Downloading coco-trainval2017-320x320.zip to /content
100% 3.46G/3.46G [01:26<00:00, 40.9MB/s]



In [6]:
!unzip -q coco-trainval2017-320x320.zip

In [7]:
# !gdown --id 1-3vdwBlY-CdVultkrFwOhyJTGC5TFUV8

In [8]:
# !pigz -dc coco_trainval2017_320x320.tar.gz | tar xf -

In [9]:
from torchvision.datasets.utils import download_and_extract_archive
DATA_DIR = "datasets/COCO"

In [10]:
download_and_extract_archive("http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
                             download_root=DATA_DIR,
                             remove_finished=True)

Downloading http://images.cocodataset.org/annotations/annotations_trainval2017.zip to datasets/COCO/annotations_trainval2017.zip


HBox(children=(FloatProgress(value=0.0, max=252907541.0), HTML(value='')))


Extracting datasets/COCO/annotations_trainval2017.zip to datasets/COCO


In [11]:
!rm coco-trainval2017-320x320.* datasets/COCO/annotations_trainval2017.zip

# Import libraries

In [1]:
%cd Image-Captioning-Transformers

/content/Image-Captioning-Transformers


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# TODO: Try pre trained CLIP

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

In [6]:
import timm         # torch image models

In [7]:
plt.rcParams['figure.facecolor'] = 'white'

# Wandb Parameters

In [8]:
import wandb

In [9]:
config_defaults = {
    'BATCH_SIZE'        : 256,
    'd_model'           : 768,
    'dim_feedforward'   : 1536,
    'nheads'            : 12,
    'num_decoder_layers': 4,
    'dp_rate'           : 0.1,
    'encoder'           : 'seresnext50_32x4d',
    'activation'        : 'gelu',
    'max_lr'            : 5e-4,
    'betas'             : (0.9, 0.98),
    'eps'               : 1e-9,
    'seed'              : 62134,
    'use_amp'           : True,
    'use_pe'            : True,
    'log_interval'      : 10,
}
CONFIG = config_defaults

In [10]:
# #hide
# run = wandb.init(id='3vhov6z0', project="Image_Captioning_Transformer", resume='must')
# CONFIG = run.config

In [11]:
run = wandb.init(project="Image_Captioning_Transformer", entity="shivamshrirao", config=config_defaults)
CONFIG = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mshivamshrirao[0m (use `wandb login --relogin` to force relogin)


In [12]:
def seed_everything(seed=33):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True
    
seed_everything(CONFIG['seed'])

In [13]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Read COCO dataset

In [14]:
from imcap.dataset import *

In [15]:
DATA_DIR = "../datasets/COCO/"

In [16]:
train_data = TensorCocoCaptions(root=DATA_DIR+"/train2017/",
                                annFile=DATA_DIR+"/annotations/captions_train2017.json")

val_data = TensorCocoCaptions(root=DATA_DIR+"/val2017/",
                              annFile=DATA_DIR+"/annotations/captions_val2017.json")

loading annotations into memory...
Done (t=1.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


## Tokenizer and Build Vocab

In [17]:
from torchtext.data.utils import get_tokenizer

In [18]:
tokenizer = get_tokenizer('basic_english')

In [19]:
def yield_tokens(cap_data):
    for ann in cap_data.coco.anns.values():
        yield tokenizer(ann['caption'])

In [20]:
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
en_vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=special_symbols, special_first=True)

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = en_vocab(special_symbols)
en_vocab.set_default_index(UNK_IDX)

In [21]:
len(en_vocab)

28940

In [22]:
train_data.fill_token_dict(tokenizer, en_vocab, BOS_IDX, EOS_IDX)
val_data.fill_token_dict(tokenizer, en_vocab, BOS_IDX, EOS_IDX)

100%|██████████| 118287/118287 [00:18<00:00, 6517.93it/s]
100%|██████████| 5000/5000 [00:00<00:00, 6739.17it/s]


## Pretrained Glove Embeddings (not used rn)

In [23]:
# vec.get_vecs_by_tokens(tokens, lower_case_backup=True)

In [24]:
# vec = torchtext.vocab.GloVe('6B', dim=300)
# unk_vec = vec.vectors.mean(dim=0)
# vec.unk_init = lambda x: unk_vec

# Load dataset into batches

In [25]:
from imcap.dataloader import *

In [26]:
train_iter = ExternalInputIterator(train_data, CONFIG['BATCH_SIZE'], PAD_IDX)
pipe = ExternalSourcePipeline(batch_size=CONFIG['BATCH_SIZE'], num_threads=4, device_id=0, external_data=train_iter, input_size=input_size)
train_loader = DALIClassificationIterator(pipe, dynamic_shape=True, auto_reset=True, last_batch_padded=True, size=len(train_iter))

val_iter = ExternalInputIterator(val_data, CONFIG['BATCH_SIZE'], PAD_IDX, training=False)
pipe = ExternalSourcePipeline(batch_size=CONFIG['BATCH_SIZE'], num_threads=4, device_id=0, external_data=val_iter, input_size=input_size, training=False)
val_loader = DALIClassificationIterator(pipe, dynamic_shape=True, auto_reset=True, last_batch_padded=True, size=len(val_iter))



# Initialize Model

In [27]:
from imcap.layers import *
from imcap.utils import *

In [28]:
model = CaptionModel(encoder = timm.create_model(CONFIG['encoder'], pretrained=True, num_classes=0, global_pool=''),
                     vocab_size = len(en_vocab),
                     num_decoder_layers = CONFIG['num_decoder_layers'],
                     nheads = CONFIG['nheads'],
                     d_model = CONFIG['d_model'],
                     dim_feedforward = CONFIG['dim_feedforward'],
                     dp_rate = CONFIG['dp_rate'],
                     activation = CONFIG['activation']).to(DEVICE, non_blocking=True)



# Learning Rate Schedule

In [29]:
steps_per_epoch = len(train_loader)

In [30]:
# def lr_schedule(step, d_model=512, warmup_steps=2*steps_per_epoch):
#     # return 1
#     step = max(1,step)
#     arg1 = step ** -0.5
#     arg2 = step * (warmup_steps ** -1.5)
#     return (d_model ** -0.6) * min(arg1, arg2)

In [31]:
# plt.plot(list(map(lr_schedule, range(steps_per_epoch*50))))
# plt.show()

In [32]:
# plt.plot([scheduler.get_last_lr()[0] for _ in range(steps_per_epoch*50) if not scheduler.step()])
# plt.show()

# Loss Function and Optimizer

In [33]:
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['max_lr'], betas=CONFIG['betas'], eps=CONFIG['eps']
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CONFIG['max_lr'], total_steps=50*steps_per_epoch, pct_start=0.01, final_div_factor=0.31)

scaler = torch.cuda.amp.GradScaler(enabled=CONFIG['use_amp'])

In [34]:
wandb.watch(model, log=None)

[<wandb.wandb_torch.TorchGraph at 0x7f39524c1a10>]

# Training functions

In [35]:
from torch.cuda import amp

In [36]:
def train_epoch(model, train_loader, optimizer, scaler, scheduler, epoch=1, use_amp=True, log_interval=10):
    model.train()
    model.encoder.eval()
    losses = 0
    with tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}") as pbar:
        for idx, batch in pbar:
            img, tgt = batch[0]['data'], batch[0]['label'].transpose(0,1)
            # img = img.to(DEVICE, non_blocking=True)
            # tgt = tgt.to(DEVICE, non_blocking=True)
            
            tgt_inp = tgt[:-1,:]      # give input until before the last word.
            tgt_out = tgt[1:, :]      # predict the last word based on input and already predicted sentence. (auto-regressive)

            tgt_mask, tgt_pad_mask = subsequent_mask(tgt_inp.size(0), DEVICE), padding_mask(tgt_inp, PAD_IDX)

            optimizer.zero_grad(set_to_none=True)
            with amp.autocast(enabled=use_amp):
                logits = model(img, tgt_inp, tgt_mask, tgt_pad_mask)
                loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            losses+= loss.detach_()
            # del loss, logits, batch, img

            if not idx%log_interval:
                curr_lr = optimizer.param_groups[0]['lr']
                losses = float(losses)
                info = {'loss': losses/(idx+1), 'lr': curr_lr}
                wandb.log(info)
                pbar.set_postfix(info)

    optimizer.zero_grad(set_to_none=True)
    return float(losses)/len(train_loader)

In [37]:
@torch.no_grad()
def evaluate(model, val_loader, use_amp=True):
    model.eval()
    losses = 0
    with tqdm(enumerate(val_loader), total=len(val_loader), desc="Evaluating") as pbar:
        for idx, batch in pbar:
            img, tgt = batch[0]['data'], batch[0]['label'].transpose(0,1)
            # img = img.to(DEVICE, non_blocking=True)
            # tgt = tgt.to(DEVICE, non_blocking=True)

            tgt_inp = tgt[:-1,:]      # give input until before the last word.
            tgt_out = tgt[1:, :]      # predict the last word based on input and already predicted sentence. (auto-regressive)

            tgt_mask, tgt_pad_mask = subsequent_mask(tgt_inp.size(0), DEVICE), padding_mask(tgt_inp, PAD_IDX)
            
            with amp.autocast(enabled=use_amp):
                logits = model(img, tgt_inp, tgt_mask, tgt_pad_mask)
                loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            losses+= float(loss.detach_())
            pbar.set_postfix({'val_loss': losses/(idx+1)})
    return float(losses)/len(val_loader)

# Functions to Make Predictions

In [38]:
@torch.no_grad()
def greedy_decode(model, img, max_len=100, start_symbol=BOS_IDX):
    model.eval()
    img = img.to(DEVICE, non_blocking=True)
    enc_output = model.encode_image(img)
    tgt = torch.ones(1, 1).fill_(start_symbol).long().to(DEVICE, non_blocking=True)
    for i in range(max_len):
        tgt_mask = subsequent_mask(tgt.size(0), DEVICE)
        out = model.decode_text(tgt, enc_output, tgt_mask)
        out = out.transpose(0,1)
        prob = model.generator(out[:,-1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.item()
        tgt = torch.cat([tgt, torch.ones(1, 1).fill_(next_word).long().to(DEVICE)], dim=0)
        if next_word == EOS_IDX:
            break
    return tgt.detach()

@torch.no_grad()
def generate_caption(model, img, tgt_vocab):
    tgt = greedy_decode(model, img, max_len=100, start_symbol=BOS_IDX).flatten()
    return " ".join(tgt_vocab.lookup_tokens(tgt.tolist())).replace("<bos>", "").replace("<eos>", "")

# Begin Training

In [39]:
init_epoch = 1
NUM_EPOCHS = 50

In [40]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [41]:
import glob
val_paths = glob.glob("../datasets/COCO/val2017/*")

In [None]:
#collapse-output
for epoch in range(init_epoch, NUM_EPOCHS+1):
    train_loss = train_epoch(model, train_loader, optimizer, scaler, scheduler,
                             epoch, CONFIG['use_amp'], CONFIG['log_interval'])
    # with torch.no_grad():
    val_loss = evaluate(model, val_loader, CONFIG['use_amp'])

    img = Image.open(random.choice(val_paths))
    caps = generate_caption(model, preproc['val'](img)[None,:], en_vocab)
    wandb.log({"val_loss": val_loss, "epoch": epoch, "examples": wandb.Image(img, caption=caps)})
    print(f"\nEpoch: {epoch}/{NUM_EPOCHS}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}\n")
    gc.collect()
    # if not epoch%10:
    #     save_model(model, optimizer, epoch)

Epoch 1: 100%|██████████| 463/463 [07:59<00:00,  1.04s/it, loss=3.92, lr=0.0005]
Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s, val_loss=2.54]
Epoch 2:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 1/50, Train loss: 3.917, Val loss: 2.538



Epoch 2: 100%|██████████| 463/463 [07:59<00:00,  1.04s/it, loss=2.68, lr=0.000499]
Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.31it/s, val_loss=2.36]
Epoch 3:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 2/50, Train loss: 2.678, Val loss: 2.357



Epoch 3: 100%|██████████| 463/463 [07:54<00:00,  1.03s/it, loss=2.54, lr=0.000497]
Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.31it/s, val_loss=2.29]
Epoch 4:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 3/50, Train loss: 2.541, Val loss: 2.287



Epoch 4:  18%|█▊        | 82/463 [01:26<06:45,  1.06s/it, loss=2.48, lr=0.000497]

In [None]:
init_epoch = epoch
init_epoch

In [None]:
# def save_model(model, optimizer, epoch):
#     torch.save({
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'epoch': epoch,
#                 }, '/content/model.pth')

# Make Predictions

In [None]:
img = Image.open(random.choice(val_paths))
caps = generate_caption(model, preproc['val'](img)[None,:], en_vocab)
# wandb.log({"examples": wandb.Image(img, caption=caps)})
print(caps)
img

In [None]:
run.finish()