# Image Captioning with Transformers

In [None]:
!nvidia-smi

Mon Jun 28 22:41:15 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    25W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!apt install -qq pigz
%pip install -q timm wandb
%pip install -q --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110

The following NEW packages will be installed:
  pigz
0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.
Need to get 57.4 kB of archives.
After this operation, 259 kB of additional disk space will be used.
Selecting previously unselected package pigz.
(Reading database ... 160772 files and directories currently installed.)
Preparing to unpack .../archives/pigz_2.4-1_amd64.deb ...
Unpacking pigz (2.4-1) ...
Setting up pigz (2.4-1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
[K     |████████████████████████████████| 348kB 14.4MB/s 
[K     |████████████████████████████████| 1.8MB 23.6MB/s 
[K     |████████████████████████████████| 133kB 65.2MB/s 
[K     |████████████████████████████████| 102kB 12.5MB/s 
[K     |████████████████████████████████| 174kB 41.3MB/s 
[K     |████████████████████████████████| 71kB 9.8MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
[K     |

In [None]:
!git clone https://github.com/ShivamShrirao/Image-Captioning-Transformers

Cloning into 'Image-Captioning-Transformers'...
remote: Enumerating objects: 70, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 70 (delta 29), reused 57 (delta 16), pack-reused 0[K
Unpacking objects: 100% (70/70), done.


# Download Dataset and Annotations

In [None]:
!mkdir ~/.kaggle/
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d shivamshrirao/coco-trainval2017-320x320

Downloading coco-trainval2017-320x320.zip to /content
100% 3.46G/3.46G [01:03<00:00, 75.8MB/s]
100% 3.46G/3.46G [01:03<00:00, 58.1MB/s]


In [None]:
!unzip -q coco-trainval2017-320x320.zip

In [None]:
# !gdown --id 1-3vdwBlY-CdVultkrFwOhyJTGC5TFUV8

In [None]:
# !pigz -dc coco_trainval2017_320x320.tar.gz | tar xf -

In [None]:
from torchvision.datasets.utils import download_and_extract_archive
DATA_DIR = "datasets/COCO"

In [None]:
download_and_extract_archive("http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
                             download_root=DATA_DIR,
                             remove_finished=True)

Downloading http://images.cocodataset.org/annotations/annotations_trainval2017.zip to datasets/COCO/annotations_trainval2017.zip


HBox(children=(FloatProgress(value=0.0, max=252907541.0), HTML(value='')))


Extracting datasets/COCO/annotations_trainval2017.zip to datasets/COCO


In [None]:
!rm coco-trainval2017-320x320.* datasets/COCO/annotations_trainval2017.zip

In [1]:
%cd /content/Image-Captioning-Transformers

/content/Image-Captioning-Transformers


In [None]:
!wandb agent shivamshrirao/Image_Captioning_Transformer/lfj2msgq

# Import libraries

In [1]:
%cd /content/Image-Captioning-Transformers

/content/Image-Captioning-Transformers


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# TODO: Try pre trained CLIP

In [4]:
import os
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
import timm         # torch image models

In [7]:
plt.rcParams['figure.facecolor'] = 'white'

# Wandb Parameters

In [8]:
import wandb

In [9]:
config_defaults = {
    'BATCH_SIZE'        : 256,
    'd_model'           : 512,
    'dim_feedforward'   : 1024,
    'nheads'            : 8,
    'num_decoder_layers': 6,
    'dp_rate'           : 0.2,
    'encoder'           : 'seresnext50_32x4d',
    'activation'        : 'gelu',
    'max_lr'            : 4e-4,
    'betas'             : (0.9, 0.98),
    'eps'               : 1e-9,
    'seed'              : 62134,
    'use_amp'           : True,
    'use_pe'            : True,
    'log_interval'      : 10,
}
CONFIG = config_defaults

In [10]:
# #hide
# run = wandb.init(id='3vhov6z0', project="Image_Captioning_Transformer", resume='must')
# CONFIG = run.config

In [11]:
run = wandb.init(project="Image_Captioning_Transformer", entity="shivamshrirao", config=config_defaults)
CONFIG = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mshivamshrirao[0m (use `wandb login --relogin` to force relogin)


In [12]:
def seed_everything(seed=33):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True
    
seed_everything(CONFIG['seed'])

In [13]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Read COCO dataset

In [14]:
from imcap.dataset import *

In [15]:
DATA_DIR = "../datasets/COCO/"

In [16]:
train_data = TensorCocoCaptions(root=DATA_DIR+"/train2017/",
                                annFile=DATA_DIR+"/annotations/captions_train2017.json")

val_data = TensorCocoCaptions(root=DATA_DIR+"/val2017/",
                              annFile=DATA_DIR+"/annotations/captions_val2017.json")

loading annotations into memory...
Done (t=1.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


## Tokenizer and Build Vocab

In [17]:
from torchtext.data.utils import get_tokenizer

In [18]:
tokenizer = get_tokenizer('basic_english')

In [19]:
def yield_tokens(cap_data):
    for ann in cap_data.coco.anns.values():
        yield tokenizer(ann['caption'])

In [20]:
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
en_vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=special_symbols, special_first=True)

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = en_vocab(special_symbols)
en_vocab.set_default_index(UNK_IDX)

In [21]:
len(en_vocab)

28940

In [22]:
train_data.fill_token_dict(tokenizer, en_vocab, BOS_IDX, EOS_IDX)
val_data.fill_token_dict(tokenizer, en_vocab, BOS_IDX, EOS_IDX)

100%|██████████| 118287/118287 [00:16<00:00, 7089.15it/s]
100%|██████████| 5000/5000 [00:00<00:00, 6782.23it/s]


## Pretrained Glove Embeddings (not used rn)

In [23]:
# vec.get_vecs_by_tokens(tokens, lower_case_backup=True)

In [24]:
# vec = torchtext.vocab.GloVe('6B', dim=300)
# unk_vec = vec.vectors.mean(dim=0)
# vec.unk_init = lambda x: unk_vec

# Load dataset into batches

In [25]:
from imcap.dataloader import *

In [26]:
nthreads = 2 * len(os.sched_getaffinity(0))
nthreads

8

In [27]:
train_iter = ExternalInputIterator(train_data, CONFIG['BATCH_SIZE'], PAD_IDX)
pipe = ExternalSourcePipeline(batch_size=CONFIG['BATCH_SIZE'], num_threads=nthreads, device_id=0, external_data=train_iter, input_size=input_size)
train_loader = DALIClassificationIterator(pipe, dynamic_shape=True, auto_reset=True, last_batch_padded=True, size=len(train_iter))

val_iter = ExternalInputIterator(val_data, CONFIG['BATCH_SIZE'], PAD_IDX, training=False)
pipe = ExternalSourcePipeline(batch_size=CONFIG['BATCH_SIZE'], num_threads=nthreads, device_id=0, external_data=val_iter, input_size=input_size, training=False)
val_loader = DALIClassificationIterator(pipe, dynamic_shape=True, auto_reset=True, last_batch_padded=True, size=len(val_iter))



# Initialize Model

In [28]:
from imcap.layers import *
from imcap.utils import *

In [29]:
model = CaptionModel(encoder = timm.create_model(CONFIG['encoder'], pretrained=True, num_classes=0, global_pool=''),
                     vocab_size = len(en_vocab),
                     num_decoder_layers = CONFIG['num_decoder_layers'],
                     nheads = CONFIG['nheads'],
                     d_model = CONFIG['d_model'],
                     dim_feedforward = CONFIG['dim_feedforward'],
                     dp_rate = CONFIG['dp_rate'],
                     activation = CONFIG['activation']).to(DEVICE, non_blocking=True)



# Learning Rate Schedule

In [30]:
steps_per_epoch = len(train_loader)

In [31]:
# def lr_schedule(step, d_model=512, warmup_steps=2*steps_per_epoch):
#     # return 1
#     step = max(1,step)
#     arg1 = step ** -0.5
#     arg2 = step * (warmup_steps ** -1.5)
#     return (d_model ** -0.6) * min(arg1, arg2)

In [32]:
# plt.plot(list(map(lr_schedule, range(50*steps_per_epoch))))
# plt.show()

In [33]:
# plt.plot([scheduler.get_last_lr()[0] for _ in range(steps_per_epoch*50) if not scheduler.step()])
# plt.show()

# Loss Function and Optimizer

In [34]:
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['max_lr'],
    betas=CONFIG['betas'], eps=CONFIG['eps']
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CONFIG['max_lr'], total_steps=50*steps_per_epoch, pct_start=0.)
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)

scaler = torch.cuda.amp.GradScaler(enabled=CONFIG['use_amp'])

In [35]:
wandb.watch(model, log=None)

[<wandb.wandb_torch.TorchGraph at 0x7fba30472ed0>]

# Training functions

In [36]:
from torch.cuda import amp

In [37]:
def train_epoch(model, train_loader, optimizer, scaler, scheduler, epoch=1, use_amp=True, log_interval=10):
    model.train()
    model.encoder.eval()
    losses = AverageMeter()
    with tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}") as pbar:
        for idx, batch in pbar:
            img, tgt = batch[0]['data'], batch[0]['label'].transpose(0,1)
            # img = img.to(DEVICE, non_blocking=True)
            # tgt = tgt.to(DEVICE, non_blocking=True)
            
            tgt_inp = tgt[:-1,:]      # give input until before the last word.
            tgt_out = tgt[1:, :]      # predict the last word based on input and already predicted sentence. (auto-regressive)

            tgt_mask, tgt_pad_mask = subsequent_mask(tgt_inp.size(0), DEVICE), padding_mask(tgt_inp, PAD_IDX)

            optimizer.zero_grad(set_to_none=True)
            with amp.autocast(enabled=use_amp):
                logits = model(img, tgt_inp, tgt_mask, tgt_pad_mask)
                loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            losses.update(loss.detach_(), img.size(0))
            del loss, logits, batch, img

            if not idx%log_interval:
                curr_lr = optimizer.param_groups[0]['lr']
                info = {'loss': float(losses.avg), 'lr': curr_lr}
                losses.reset()
                wandb.log(info)
                pbar.set_postfix(info)

    optimizer.zero_grad(set_to_none=True)
    return float(losses.avg)

In [38]:
@torch.no_grad()
def evaluate(model, val_loader, use_amp=True):
    model.eval()
    losses = AverageMeter()
    with tqdm(enumerate(val_loader), total=len(val_loader), desc="Evaluating") as pbar:
        for idx, batch in pbar:
            img, tgt = batch[0]['data'], batch[0]['label'].transpose(0,1)
            # img = img.to(DEVICE, non_blocking=True)
            # tgt = tgt.to(DEVICE, non_blocking=True)

            tgt_inp = tgt[:-1,:]      # give input until before the last word.
            tgt_out = tgt[1:, :]      # predict the last word based on input and already predicted sentence. (auto-regressive)

            tgt_mask, tgt_pad_mask = subsequent_mask(tgt_inp.size(0), DEVICE), padding_mask(tgt_inp, PAD_IDX)
            
            with amp.autocast(enabled=use_amp):
                logits = model(img, tgt_inp, tgt_mask, tgt_pad_mask)
                loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            losses.update(loss.detach_(), img.size(0))
            pbar.set_postfix({'val_loss': float(losses.avg)})
    return float(losses.avg)

# Functions to Make Predictions

In [39]:
@torch.no_grad()
def greedy_decode(model, img, max_len=100, start_symbol=BOS_IDX):
    model.eval()
    img = img.to(DEVICE, non_blocking=True)
    enc_output = model.encode_image(img)
    tgt = torch.ones(1, 1).fill_(start_symbol).long().to(DEVICE, non_blocking=True)
    for i in range(max_len):
        tgt_mask = subsequent_mask(tgt.size(0), DEVICE)
        out = model.decode_text(tgt, enc_output, tgt_mask)
        out = out.transpose(0,1)
        prob = model.generator(out[:,-1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.item()
        tgt = torch.cat([tgt, torch.ones(1, 1).fill_(next_word).long().to(DEVICE)], dim=0)
        if next_word == EOS_IDX:
            break
    return tgt.detach()

@torch.no_grad()
def generate_caption(model, img, tgt_vocab):
    tgt = greedy_decode(model, img, max_len=100, start_symbol=BOS_IDX).flatten()
    return " ".join(tgt_vocab.lookup_tokens(tgt.tolist())).replace("<bos>", "").replace("<eos>", "")

# Begin Training

In [40]:
init_epoch = 1
NUM_EPOCHS = 50

In [41]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [42]:
import glob
val_paths = glob.glob(DATA_DIR+"/val2017/*")

In [43]:
# LR Finder, CLIP, ViT

In [None]:
#collapse-output
for epoch in range(init_epoch, NUM_EPOCHS+1):
    train_loss = train_epoch(model, train_loader, optimizer, scaler, scheduler,
                             epoch, CONFIG['use_amp'], CONFIG['log_interval'])
    # with torch.no_grad():
    val_loss = evaluate(model, val_loader, CONFIG['use_amp'])

    img = Image.open(random.choice(val_paths))
    caps = generate_caption(model, preproc['val'](img)[None,:], en_vocab)
    wandb.log({"val_loss": val_loss, "epoch": epoch, "predictions": wandb.Image(img, caption=caps)})
    print(f"\nEpoch: {epoch}/{NUM_EPOCHS}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}\n")
    gc.collect()
    # if not epoch%10:
    #     save_model(model, optimizer, epoch)

Epoch 1: 100%|██████████| 463/463 [02:22<00:00,  3.24it/s, loss=2.93, lr=0.0004]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.13it/s, val_loss=2.63]
Epoch 2:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 1/50, Train loss: 2.987, Val loss: 2.635



Epoch 2: 100%|██████████| 463/463 [02:20<00:00,  3.29it/s, loss=2.73, lr=0.000398]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.09it/s, val_loss=2.44]
Epoch 3:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 2/50, Train loss: 2.654, Val loss: 2.437



Epoch 3: 100%|██████████| 463/463 [02:19<00:00,  3.33it/s, loss=2.57, lr=0.000396]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.11it/s, val_loss=2.34]
Epoch 4:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 3/50, Train loss: 2.615, Val loss: 2.339



Epoch 4: 100%|██████████| 463/463 [02:21<00:00,  3.28it/s, loss=2.6, lr=0.000394]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.09it/s, val_loss=2.29]
Epoch 5:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 4/50, Train loss: 2.551, Val loss: 2.287



Epoch 5:  54%|█████▍    | 252/463 [01:16<01:02,  3.37it/s, loss=2.56, lr=0.000392]

In [None]:
init_epoch = epoch
init_epoch

In [None]:
# def save_model(model, optimizer, scheduler, epoch=0, path='/content/model.pth'):
#     torch.save({
#                 'projection_head': model.projection_head.state_dict(),
#                 'decoder': model.decoder.state_dict(),
#                 'generator': model.generator.state_dict(),
#                 'optimizer': optimizer.state_dict(),
#                 'scheduler': scheduler.state_dict(),
#                 'epoch': epoch,
#                 }, path)

In [None]:
# def load_model(model, optimizer, scheduler):
#     checkpoint = torch.load('/content/model.pth', map_location=DEVICE)
#     model.projection_head.load_state_dict(checkpoint['projection_head'])
#     model.decoder.load_state_dict(checkpoint['decoder'])
#     model.generator.load_state_dict(checkpoint['generator'])
#     optimizer.load_state_dict(checkpoint['optimizer'])
#     scheduler.load_state_dict(checkpoint['scheduler'])

# Make Predictions

In [None]:
img = Image.open(random.choice(val_paths))
caps = generate_caption(model, preproc['val'](img)[None,:], en_vocab)
# wandb.log({"predictions": wandb.Image(img, caption=caps)})
print(caps)
img

In [None]:
run.finish()