# Image Captioning with Transformers

In [1]:
!nvidia-smi

Wed Jul  7 00:26:22 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 470.42.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| N/A   45C    P8    N/A /  N/A |      3MiB /  2004MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!apt install -qq pigz
%pip install -q timm wandb
%pip install -q --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110

The following NEW packages will be installed:
  pigz
0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.
Need to get 57.4 kB of archives.
After this operation, 259 kB of additional disk space will be used.
Selecting previously unselected package pigz.
(Reading database ... 160772 files and directories currently installed.)
Preparing to unpack .../archives/pigz_2.4-1_amd64.deb ...
Unpacking pigz (2.4-1) ...
Setting up pigz (2.4-1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
[K     |████████████████████████████████| 348kB 14.4MB/s 
[K     |████████████████████████████████| 1.8MB 23.6MB/s 
[K     |████████████████████████████████| 133kB 65.2MB/s 
[K     |████████████████████████████████| 102kB 12.5MB/s 
[K     |████████████████████████████████| 174kB 41.3MB/s 
[K     |████████████████████████████████| 71kB 9.8MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
[K     |

In [None]:
!git clone https://github.com/ShivamShrirao/Image-Captioning-Transformers

Cloning into 'Image-Captioning-Transformers'...
remote: Enumerating objects: 70, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 70 (delta 29), reused 57 (delta 16), pack-reused 0[K
Unpacking objects: 100% (70/70), done.


# Download Dataset and Annotations

In [None]:
!mkdir ~/.kaggle/
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d shivamshrirao/coco-trainval2017-320x320

Downloading coco-trainval2017-320x320.zip to /content
100% 3.46G/3.46G [01:03<00:00, 75.8MB/s]
100% 3.46G/3.46G [01:03<00:00, 58.1MB/s]


In [None]:
!unzip -q coco-trainval2017-320x320.zip

In [None]:
# !gdown --id 1-3vdwBlY-CdVultkrFwOhyJTGC5TFUV8

In [None]:
# !pigz -dc coco_trainval2017_320x320.tar.gz | tar xf -

In [None]:
from torchvision.datasets.utils import download_and_extract_archive
DATA_DIR = "datasets/COCO"

In [None]:
download_and_extract_archive("http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
                             download_root=DATA_DIR,
                             remove_finished=True)

Downloading http://images.cocodataset.org/annotations/annotations_trainval2017.zip to datasets/COCO/annotations_trainval2017.zip


HBox(children=(FloatProgress(value=0.0, max=252907541.0), HTML(value='')))


Extracting datasets/COCO/annotations_trainval2017.zip to datasets/COCO


In [None]:
!rm coco-trainval2017-320x320.* datasets/COCO/annotations_trainval2017.zip

In [None]:
%cd /content/Image-Captioning-Transformers

/content/Image-Captioning-Transformers


In [None]:
!wandb agent shivamshrirao/Image_Captioning_Transformer/lfj2msgq

# Import libraries

In [None]:
%cd /content/Image-Captioning-Transformers

/content/Image-Captioning-Transformers


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# TODO: Try pre trained CLIP

In [None]:
import os
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import timm         # torch image models

In [None]:
plt.rcParams['figure.facecolor'] = 'white'

# Wandb Parameters

In [None]:
import wandb

In [None]:
config_defaults = {
    'BATCH_SIZE'        : 256,
    'd_model'           : 512,
    'dim_feedforward'   : 2048,
    'nheads'            : 8,
    'num_decoder_layers': 6,
    'dp_rate'           : 0.2,
    'encoder'           : 'seresnext50_32x4d',
    'activation'        : 'gelu',
    'max_lr'            : 3e-4,
    'betas'             : (0.9, 0.98),
    'eps'               : 1e-9,
    'seed'              : 62134,
    'use_amp'           : True,
    'use_pe'            : True,
    'log_interval'      : 10,
}
CONFIG = config_defaults

In [None]:
# #hide
# run = wandb.init(id='19sqz0by', project="Image_Captioning_Transformer", resume='must')
# CONFIG = run.config

In [None]:
run = wandb.init(project="Image_Captioning_Transformer", entity="shivamshrirao", config=config_defaults)
CONFIG = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mshivamshrirao[0m (use `wandb login --relogin` to force relogin)


In [None]:
def seed_everything(seed=33):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True
    
seed_everything(CONFIG['seed'])

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Read COCO dataset

In [None]:
from imcap.dataset import *

In [None]:
DATA_DIR = "../datasets/COCO/"

In [None]:
train_data = TensorCocoCaptions(root=DATA_DIR+"/train2017/",
                                annFile=DATA_DIR+"/annotations/captions_train2017.json")

val_data = TensorCocoCaptions(root=DATA_DIR+"/val2017/",
                              annFile=DATA_DIR+"/annotations/captions_val2017.json")

loading annotations into memory...
Done (t=1.06s)
creating index...
index created!
loading annotations into memory...
Done (t=0.23s)
creating index...
index created!


## Tokenizer and Build Vocab

In [None]:
from torchtext.data.utils import get_tokenizer

In [None]:
tokenizer = get_tokenizer('basic_english')

In [None]:
def yield_tokens(cap_data):
    for ann in cap_data.coco.anns.values():
        yield tokenizer(ann['caption'])

In [None]:
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
en_vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=special_symbols, special_first=True)

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = en_vocab(special_symbols)
en_vocab.set_default_index(UNK_IDX)

In [None]:
len(en_vocab)

28940

In [None]:
train_data.fill_token_dict(tokenizer, en_vocab, BOS_IDX, EOS_IDX)
val_data.fill_token_dict(tokenizer, en_vocab, BOS_IDX, EOS_IDX)

100%|██████████| 118287/118287 [00:14<00:00, 7931.20it/s]
100%|██████████| 5000/5000 [00:00<00:00, 9689.40it/s]


## Pretrained Glove Embeddings (not used rn)

In [None]:
# vec.get_vecs_by_tokens(tokens, lower_case_backup=True)

In [None]:
# vec = torchtext.vocab.GloVe('6B', dim=300)
# unk_vec = vec.vectors.mean(dim=0)
# vec.unk_init = lambda x: unk_vec

# Load dataset into batches

In [None]:
from imcap.dataloader import *

In [None]:
nthreads = 2 * len(os.sched_getaffinity(0))
nthreads

8

In [None]:
train_iter = ExternalInputIterator(train_data, CONFIG['BATCH_SIZE'], PAD_IDX)
pipe = ExternalSourcePipeline(batch_size=CONFIG['BATCH_SIZE'], num_threads=nthreads, device_id=0, external_data=train_iter, input_size=input_size)
train_loader = DALIClassificationIterator(pipe, dynamic_shape=True, auto_reset=True, last_batch_padded=True, size=len(train_iter))

val_iter = ExternalInputIterator(val_data, CONFIG['BATCH_SIZE'], PAD_IDX, training=False)
pipe = ExternalSourcePipeline(batch_size=CONFIG['BATCH_SIZE'], num_threads=nthreads, device_id=0, external_data=val_iter, input_size=input_size, training=False)
val_loader = DALIClassificationIterator(pipe, dynamic_shape=True, auto_reset=True, last_batch_padded=True, size=len(val_iter))



# Initialize Model

In [None]:
from imcap.layers import *
from imcap.utils import *

In [None]:
model = CaptionModel(encoder = timm.create_model(CONFIG['encoder'], pretrained=True, num_classes=0, global_pool=''),
                     vocab_size = len(en_vocab),
                     num_decoder_layers = CONFIG['num_decoder_layers'],
                     nheads = CONFIG['nheads'],
                     d_model = CONFIG['d_model'],
                     dim_feedforward = CONFIG['dim_feedforward'],
                     dp_rate = CONFIG['dp_rate'],
                     activation = CONFIG['activation']).to(DEVICE, non_blocking=True)



# Learning Rate Schedule

In [None]:
steps_per_epoch = len(train_loader)

In [None]:
# def lr_schedule(step, d_model=512, warmup_steps=2*steps_per_epoch):
#     # return 1
#     step = max(1,step)
#     arg1 = step ** -0.5
#     arg2 = step * (warmup_steps ** -1.5)
#     return (d_model ** -0.6) * min(arg1, arg2)

In [None]:
# plt.plot(list(map(lr_schedule, range(50*steps_per_epoch))))
# plt.show()

In [None]:
# plt.plot([scheduler.get_last_lr()[0] for _ in range(steps_per_epoch*50) if not scheduler.step()])
# plt.show()

# Loss Function and Optimizer

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['max_lr'],
    betas=CONFIG['betas'], eps=CONFIG['eps']
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CONFIG['max_lr'], total_steps=50*steps_per_epoch, pct_start=0.)
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)

scaler = torch.cuda.amp.GradScaler(enabled=CONFIG['use_amp'])

In [None]:
wandb.watch(model, log=None)

[<wandb.wandb_torch.TorchGraph at 0x7fdbe460c410>]

# Training functions

In [None]:
from torch.cuda import amp

In [None]:
def train_epoch(model, train_loader, optimizer, scaler, scheduler, epoch=1, use_amp=True, log_interval=10):
    model.train()
    model.encoder.eval()
    losses = AverageMeter()
    with tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}") as pbar:
        for idx, batch in pbar:
            img, tgt = batch[0]['data'], batch[0]['label'].transpose(0,1)
            # img = img.to(DEVICE, non_blocking=True)
            # tgt = tgt.to(DEVICE, non_blocking=True)
            
            tgt_inp = tgt[:-1,:]      # give input until before the last word.
            tgt_out = tgt[1:, :]      # predict the last word based on input and already predicted sentence. (auto-regressive)

            tgt_mask, tgt_pad_mask = subsequent_mask(tgt_inp.size(0), DEVICE), padding_mask(tgt_inp, PAD_IDX)

            optimizer.zero_grad(set_to_none=True)
            with amp.autocast(enabled=use_amp):
                logits = model(img, tgt_inp, tgt_mask, tgt_pad_mask)
                loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            losses.update(loss.detach_(), img.size(0))
            del loss, logits, batch, img

            if not idx%log_interval:
                curr_lr = optimizer.param_groups[0]['lr']
                info = {'loss': float(losses.avg), 'lr': curr_lr}
                wandb.log(info)
                pbar.set_postfix(info)

    optimizer.zero_grad(set_to_none=True)
    return float(losses.avg)

In [None]:
@torch.no_grad()
def evaluate(model, val_loader, use_amp=True):
    model.eval()
    losses = AverageMeter()
    with tqdm(enumerate(val_loader), total=len(val_loader), desc="Evaluating") as pbar:
        for idx, batch in pbar:
            img, tgt = batch[0]['data'], batch[0]['label'].transpose(0,1)
            # img = img.to(DEVICE, non_blocking=True)
            # tgt = tgt.to(DEVICE, non_blocking=True)

            tgt_inp = tgt[:-1,:]      # give input until before the last word.
            tgt_out = tgt[1:, :]      # predict the last word based on input and already predicted sentence. (auto-regressive)

            tgt_mask, tgt_pad_mask = subsequent_mask(tgt_inp.size(0), DEVICE), padding_mask(tgt_inp, PAD_IDX)
            
            with amp.autocast(enabled=use_amp):
                logits = model(img, tgt_inp, tgt_mask, tgt_pad_mask)
                loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            losses.update(loss.detach_(), img.size(0))
            pbar.set_postfix({'val_loss': float(losses.avg)})
    return float(losses.avg)

# Functions to Make Predictions

In [None]:
@torch.no_grad()
def greedy_decode(model, img, max_len=100, start_symbol=BOS_IDX):
    model.eval()
    img = img.to(DEVICE, non_blocking=True)
    enc_output = model.encode_image(img)
    tgt = torch.ones(1, 1).fill_(start_symbol).long().to(DEVICE, non_blocking=True)
    for i in range(max_len):
        tgt_mask = subsequent_mask(tgt.size(0), DEVICE)
        out = model.decode_text(tgt, enc_output, tgt_mask)
        out = out.transpose(0,1)
        prob = model.generator(out[:,-1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.item()
        tgt = torch.cat([tgt, torch.ones(1, 1).fill_(next_word).long().to(DEVICE)], dim=0)
        if next_word == EOS_IDX:
            break
    return tgt.detach()

@torch.no_grad()
def generate_caption(model, img, tgt_vocab):
    tgt = greedy_decode(model, img, max_len=100, start_symbol=BOS_IDX).flatten()
    return " ".join(tgt_vocab.lookup_tokens(tgt.tolist())).replace("<bos>", "").replace("<eos>", "")

# Begin Training

In [None]:
init_epoch = 1
NUM_EPOCHS = 50

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import glob
val_paths = glob.glob(DATA_DIR+"/val2017/*")

In [None]:
# LR Finder, CLIP, ViT

In [None]:
#collapse-output
for epoch in range(init_epoch, NUM_EPOCHS+1):
    train_loss = train_epoch(model, train_loader, optimizer, scaler, scheduler,
                             epoch, CONFIG['use_amp'], CONFIG['log_interval'])
    # with torch.no_grad():
    val_loss = evaluate(model, val_loader, CONFIG['use_amp'])

    img = Image.open(random.choice(val_paths))
    caps = generate_caption(model, preproc['val'](img)[None,:], en_vocab)
    wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch, "predictions": wandb.Image(img, caption=caps)})
    print(f"\nEpoch: {epoch}/{NUM_EPOCHS}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}\n")
    gc.collect()
    # if not epoch%10:
    #     save_model(model, optimizer, epoch)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Epoch 1: 100%|██████████| 463/463 [02:26<00:00,  3.17it/s, loss=2.94, lr=0.0003]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.09it/s, val_loss=2.65]
Epoch 2:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 1/50, Train loss: 3.011, Val loss: 2.646



Epoch 2: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.71, lr=0.000299]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.44]
Epoch 3:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 2/50, Train loss: 2.654, Val loss: 2.436



Epoch 3: 100%|██████████| 463/463 [02:22<00:00,  3.25it/s, loss=2.57, lr=0.000297]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=2.34]
Epoch 4:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 3/50, Train loss: 2.613, Val loss: 2.339



Epoch 4: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.59, lr=0.000295]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.28]
Epoch 5:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 4/50, Train loss: 2.546, Val loss: 2.282



Epoch 5: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.46, lr=0.000293]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.24]
Epoch 6:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 5/50, Train loss: 2.470, Val loss: 2.235



Epoch 6: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.42, lr=0.000289]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=2.21]
Epoch 7:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 6/50, Train loss: 2.533, Val loss: 2.207



Epoch 7: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.43, lr=0.000286]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.04it/s, val_loss=2.17]
Epoch 8:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 7/50, Train loss: 2.410, Val loss: 2.174



Epoch 8: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.39, lr=0.000281]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=2.16]
Epoch 9:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 8/50, Train loss: 2.392, Val loss: 2.155



Epoch 9: 100%|██████████| 463/463 [02:23<00:00,  3.23it/s, loss=2.35, lr=0.000277]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.08it/s, val_loss=2.13]
Epoch 10:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 9/50, Train loss: 2.347, Val loss: 2.134



Epoch 10: 100%|██████████| 463/463 [02:23<00:00,  3.24it/s, loss=2.37, lr=0.000271]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.12]
Epoch 11:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 10/50, Train loss: 2.399, Val loss: 2.118



Epoch 11: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.33, lr=0.000266]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.07it/s, val_loss=2.1]
Epoch 12:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 11/50, Train loss: 2.324, Val loss: 2.104



Epoch 12: 100%|██████████| 463/463 [02:23<00:00,  3.23it/s, loss=2.32, lr=0.000259]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=2.1]
Epoch 13:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 12/50, Train loss: 2.270, Val loss: 2.100



Epoch 13: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.31, lr=0.000253]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.09]
Epoch 14:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 13/50, Train loss: 2.274, Val loss: 2.088



Epoch 14: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.29, lr=0.000246]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.07]
Epoch 15:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 14/50, Train loss: 2.306, Val loss: 2.074



Epoch 15: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.28, lr=0.000238]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.08it/s, val_loss=2.06]
Epoch 16:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 15/50, Train loss: 2.282, Val loss: 2.063



Epoch 16: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.26, lr=0.00023]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=2.05]
Epoch 17:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 16/50, Train loss: 2.286, Val loss: 2.051



Epoch 17: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.22, lr=0.000222]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=2.04]
Epoch 18:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 17/50, Train loss: 2.267, Val loss: 2.044



Epoch 18: 100%|██████████| 463/463 [02:23<00:00,  3.23it/s, loss=2.22, lr=0.000214]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.08it/s, val_loss=2.03]
Epoch 19:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 18/50, Train loss: 2.187, Val loss: 2.034



Epoch 19: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.23, lr=0.000205]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.03]
Epoch 20:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 19/50, Train loss: 2.221, Val loss: 2.031



Epoch 20: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.19, lr=0.000196]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=2.02]
Epoch 21:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 20/50, Train loss: 2.164, Val loss: 2.023



Epoch 21: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.22, lr=0.000187]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.04it/s, val_loss=2.02]
Epoch 22:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 21/50, Train loss: 2.227, Val loss: 2.018



Epoch 22: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.16, lr=0.000178]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=2.01]
Epoch 23:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 22/50, Train loss: 2.146, Val loss: 2.014



Epoch 23: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.16, lr=0.000169]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.07it/s, val_loss=2.01]
Epoch 24:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 23/50, Train loss: 2.177, Val loss: 2.007



Epoch 24: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.17, lr=0.000159]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.04it/s, val_loss=2.01]
Epoch 25:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 24/50, Train loss: 2.074, Val loss: 2.006



Epoch 25: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.16, lr=0.00015]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=2]
Epoch 26:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 25/50, Train loss: 2.168, Val loss: 1.996



Epoch 26: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.13, lr=0.000141]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.07it/s, val_loss=1.99]
Epoch 27:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 26/50, Train loss: 2.117, Val loss: 1.992



Epoch 27: 100%|██████████| 463/463 [02:23<00:00,  3.24it/s, loss=2.13, lr=0.000131]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.04it/s, val_loss=1.99]
Epoch 28:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 27/50, Train loss: 2.130, Val loss: 1.987



Epoch 28: 100%|██████████| 463/463 [02:23<00:00,  3.23it/s, loss=2.12, lr=0.000122]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=1.98]
Epoch 29:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 28/50, Train loss: 2.156, Val loss: 1.977



Epoch 29: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.11, lr=0.000113]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.02it/s, val_loss=1.98]
Epoch 30:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 29/50, Train loss: 2.084, Val loss: 1.977



Epoch 30: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.12, lr=0.000104]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.04it/s, val_loss=1.97]
Epoch 31:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 30/50, Train loss: 2.146, Val loss: 1.974



Epoch 31: 100%|██████████| 463/463 [02:23<00:00,  3.23it/s, loss=2.1, lr=9.48e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=1.97]
Epoch 32:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 31/50, Train loss: 2.081, Val loss: 1.969



Epoch 32: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.09, lr=8.62e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.01it/s, val_loss=1.97]
Epoch 33:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 32/50, Train loss: 2.141, Val loss: 1.966



Epoch 33: 100%|██████████| 463/463 [02:25<00:00,  3.19it/s, loss=2.08, lr=7.78e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.07it/s, val_loss=1.96]
Epoch 34:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 33/50, Train loss: 2.080, Val loss: 1.960



Epoch 34: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.11, lr=6.96e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.02it/s, val_loss=1.96]
Epoch 35:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 34/50, Train loss: 2.058, Val loss: 1.961



Epoch 35: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.07, lr=6.18e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.06it/s, val_loss=1.96]
Epoch 36:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 35/50, Train loss: 2.123, Val loss: 1.959



Epoch 36: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.09, lr=5.44e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=1.95]
Epoch 37:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 36/50, Train loss: 2.100, Val loss: 1.955



Epoch 37: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.04, lr=4.73e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=1.95]
Epoch 38:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 37/50, Train loss: 2.004, Val loss: 1.952



Epoch 38: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.1, lr=4.07e-5]
Evaluating: 100%|██████████| 20/20 [00:05<00:00,  4.00it/s, val_loss=1.95]
Epoch 39:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 38/50, Train loss: 2.021, Val loss: 1.949



Epoch 39: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.05, lr=3.44e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=1.95]
Epoch 40:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 39/50, Train loss: 2.024, Val loss: 1.949



Epoch 40: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.08, lr=2.87e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.02it/s, val_loss=1.95]
Epoch 41:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 40/50, Train loss: 2.051, Val loss: 1.948



Epoch 41: 100%|██████████| 463/463 [02:25<00:00,  3.19it/s, loss=2.03, lr=2.34e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.05it/s, val_loss=1.95]
Epoch 42:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 41/50, Train loss: 2.060, Val loss: 1.947



Epoch 42: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.04, lr=1.86e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.01it/s, val_loss=1.95]
Epoch 43:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 42/50, Train loss: 2.085, Val loss: 1.945



Epoch 43: 100%|██████████| 463/463 [02:25<00:00,  3.19it/s, loss=2.03, lr=1.43e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.04it/s, val_loss=1.94]
Epoch 44:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 43/50, Train loss: 1.964, Val loss: 1.944



Epoch 44: 100%|██████████| 463/463 [02:23<00:00,  3.23it/s, loss=2.02, lr=1.05e-5]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.02it/s, val_loss=1.94]
Epoch 45:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 44/50, Train loss: 2.027, Val loss: 1.942



Epoch 45: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.02, lr=7.35e-6]
Evaluating: 100%|██████████| 20/20 [00:05<00:00,  4.00it/s, val_loss=1.94]
Epoch 46:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 45/50, Train loss: 2.001, Val loss: 1.941



Epoch 46: 100%|██████████| 463/463 [02:24<00:00,  3.21it/s, loss=2.04, lr=4.72e-6]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=1.94]
Epoch 47:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 46/50, Train loss: 2.046, Val loss: 1.940



Epoch 47: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.05, lr=2.66e-6]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.01it/s, val_loss=1.94]
Epoch 48:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 47/50, Train loss: 2.065, Val loss: 1.940



Epoch 48: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.07, lr=1.19e-6]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=1.94]
Epoch 49:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 48/50, Train loss: 2.128, Val loss: 1.940



Epoch 49: 100%|██████████| 463/463 [02:24<00:00,  3.20it/s, loss=2.03, lr=2.98e-7]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.03it/s, val_loss=1.94]
Epoch 50:   0%|          | 0/463 [00:00<?, ?it/s]


Epoch: 49/50, Train loss: 2.147, Val loss: 1.940



Epoch 50: 100%|██████████| 463/463 [02:23<00:00,  3.22it/s, loss=2.04, lr=1.2e-9]
Evaluating: 100%|██████████| 20/20 [00:04<00:00,  4.04it/s, val_loss=1.94]



Epoch: 50/50, Train loss: 2.032, Val loss: 1.940



In [None]:
init_epoch = epoch
init_epoch

50

In [None]:
def save_model(model, optimizer, scheduler, epoch=0, path='/content/model.pth'):
    torch.save({
                'projection_head': model.projection_head.state_dict(),
                'decoder': model.decoder.state_dict(),
                'generator': model.generator.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch,
                }, path)

In [None]:
def load_model(model, optimizer, scheduler, path='/content/model.pth'):
    checkpoint = torch.load(path, map_location=DEVICE)
    model.projection_head.load_state_dict(checkpoint['projection_head'])
    model.decoder.load_state_dict(checkpoint['decoder'])
    model.generator.load_state_dict(checkpoint['generator'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])

# Make Predictions

In [None]:
# TODO: Plot attention on images.

In [None]:
img = Image.open(random.choice(val_paths))
caps = generate_caption(model, preproc['val'](img)[None,:], en_vocab)
# wandb.log({"predictions": wandb.Image(img, caption=caps)})
print(caps)
img

In [None]:
for i in range(25):
    img = Image.open(random.choice(val_paths))
    caps = generate_caption(model, preproc['val'](img)[None,:], en_vocab)
    wandb.log({"predictions": wandb.Image(img, caption=caps)})

In [None]:
save_model(model, optimizer, scheduler, epoch=epoch, path=run.dir+'/model.pth')

In [None]:
run.finish()