# Setup

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

import glob
import os
import json
import time
import string
import re

from torch import nn
from torch import Tensor
from PIL import Image
from tqdm import tqdm

import torchvision.transforms as transforms
#from torchvision.transforms import Compose, Resize, ToTensor
from torchvision.models import swin_t, Swin_T_Weights
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from nltk.translate.bleu_score import corpus_bleu

In [2]:
token_path = "/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr8k.token.txt"
train_images_path = '/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr_8k.trainImages.txt'
test_images_path = '/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr_8k.testImages.txt'
val_images_path = '/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr_8k.devImages.txt'

images_path = '/content/drive/MyDrive/dataset_captioning/Flicker8k_Dataset/'

test_path ='/content/drive/MyDrive/dataset_captioning/test_image/'
checkpoint_path = '/content/drive/MyDrive/Colab Notebooks/Checkpoints/'
run_path = '/content/drive/MyDrive/Colab Notebooks/runs/'

# Class Declaration

## Model

In [3]:
class EncoderSwin(nn.Module):
    def __init__(self, embed_size = 100):
        super(EncoderSwin, self).__init__()

        swin = swin_t(Swin_T_Weights.IMAGENET1K_V1)
        self.swin = torch.nn.Sequential(*(list(swin.children())[:-1]))

    def forward(self, images):

        img_features = self.swin(images)

        return img_features

class ResidualBlock(nn.Module):
    """Represents 1D version of the residual block: https://arxiv.org/abs/1512.03385"""

    def __init__(self, input_dim):
        """Initializes the module."""
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.LeakyReLU(),
            nn.Linear(input_dim, input_dim),
        )

    def forward(self, x):
        """Performs forward pass of the module."""
        skip_connection = x
        x = self.block(x)
        x = skip_connection + x
        return x


class Normalize(nn.Module):
    def __init__(self, eps=1e-5):
        super(Normalize, self).__init__()
        self.register_buffer("eps", torch.Tensor([eps]))

    def forward(self, x, dim=-1):
        norm = x.norm(2, dim=dim).unsqueeze(-1)
        x = self.eps * (x / norm)
        return x

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class CaptionDecoder(nn.Module):
    def __init__(self):
        super().__init__()

        ## Configs ##
        decoder_layers = 6
        attention_heads = 16
        d_model = 512
        ff_dim = 1024
        dropout = 0.5
        embedding_dim = 100
        img_feature_dim = 768
        vocab_size = 401
        embedding_path = '/content/drive/MyDrive/Colab Notebooks/embedding/w2v-embeddings.txt'

        ## Embeddings ##
        word_embeddings = torch.Tensor(np.loadtxt(embedding_path))
        self.embedding_layer = nn.Embedding.from_pretrained(
            word_embeddings,
            freeze=True,
            padding_idx=0
        )

        ## Layers ##
        self.entry_mapping_words = nn.Linear(embedding_dim, d_model)
        self.entry_mapping_img = nn.Linear(img_feature_dim, d_model)

        self.res_block = ResidualBlock(d_model)

        self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout, max_len = 64)
        dec_layer = TransformerDecoderLayer(
            d_model=d_model,
            nhead=attention_heads,
            dim_feedforward=ff_dim,
            dropout=dropout
        )
        self.decoder = TransformerDecoder(dec_layer, num_layers = decoder_layers)
        self.classifier = nn.Linear(d_model, vocab_size)

    def forward(self, x, image_features, tgt_padding_mask=None, tgt_mask=None):
        ## Process Image ##
        image_features = self.entry_mapping_img(image_features)
        if (image_features.dim() == 2):
            image_features = image_features.unsqueeze(0)
        #image_features = image_features.permute(1,0,2)
        image_features = F.leaky_relu(image_features)

        ## Process Caption ##
        # Embedding
        #with torch.no_grad():
        #    outputs = self.emb_model(x)
        #x = outputs.last_hidden_state
        #x = torch.squeeze(x, dim=0) #This one stays disabled

        x = self.embedding_layer(x)
        x = self.entry_mapping_words(x)
        x = F.leaky_relu(x)

        x = self.res_block(x)
        x = F.leaky_relu(x)

        x = x.permute(1,0,2)
        x = self.positional_encoding(x)

        ## Decode Image and Caption ##
        x = self.decoder(
            tgt=x,
            memory=image_features,
            tgt_key_padding_mask=tgt_padding_mask,
            tgt_mask=tgt_mask
        )

        x = x.permute(1,0,2)

        x = self.classifier(x)
        return x

## Dataloader

In [4]:
class Flickr8KDataset(Dataset):
    def __init__(self, path_list, training=True):
        # Read tokens, split lines
        with open(path_list) as g:
            train_list = [line.replace("\n", "") for line in g.readlines()]
        with open(token_path, "r") as f:
            self._data = []
            for line in f.readlines() :
                if (line.split("#")[0] in train_list) :
                    self._data.append(line.replace("\n",""))

        self._training = training
        self._inference_captions = self._group_captions(self._data)

        # Tokens
        self._pad_idx = 0
        self._start_idx = 1
        self._end_idx = 2
        self._unk_idx = 3
        self._pad_token = '<pad>'
        self._start_token = '<start>'
        self._end_token = '<end>'
        self._unk_token = '<unk>'

        # Load the vocabulary mappings
        word2idx_path = '/content/drive/MyDrive/Colab Notebooks/embedding/w2v-word2idx.json'
        with open(word2idx_path, "r", encoding="utf8") as f:
            self._word2idx = json.load(f)
        self._idx2word = {str(idx): word for word, idx in self._word2idx.items()}

        # Create (X,Y) pairs
        self._data = self._create_input_label_mappings(self._data)

        self.image_dir = images_path

        # For image preprocessing
        self._preproc = self._construct_image_transform(224)

        self._max_len = 32
        self._dataset_size = len(self._data)

    def _construct_image_transform(self, image_size):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        preprocessing = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])

        return preprocessing

    def _create_input_label_mappings(self, data):
        # Creates (image, description) pairs.
        processed_data = []

        for line in data:
            tokens = line.split()
            # Seperate image and caption
            img_name, caption_words = tokens[0].split("#")[0], tokens[1:]

            pair = (img_name, caption_words)
            processed_data.append(pair)

        return processed_data

    def _load_and_prepare_image(self, image_name):
        # Image preprocessing
        image_path = os.path.join(self.image_dir, image_name)
        img_pil = Image.open(image_path).convert("RGB")
        image_tensor = self._preproc(img_pil)
        #image_tensor = image_tensor.unsqueeze(0)
        return image_tensor

    def _group_captions(self, data):
        table = str.maketrans('', '', string.punctuation)
        grouped_captions = {}

        for line in data:
            tokens = line.split()
            if len(line) > 2:
                image_id, image_desc = tokens[0].split('#')[0], tokens[1:]

                image_desc = [token.strip().lower().translate(table) for token in image_desc]

                if image_id not in grouped_captions:
                    grouped_captions[image_id] = []
                grouped_captions[image_id].append(image_desc)

        return grouped_captions

    def _load_and_process_images(self, image_dir, image_names):
        image_paths = [os.path.join(image_dir, fname) for fname in image_names]
        image_raws = [Image.open(path) for path in image_paths]

        image_tensors = [self._preproc(img) for img in image_raws]
        #image_tensors = [img.unsqueeze(0) for img in image_tensors]

        image_processed = {img_name: img_tensor for img_name, img_tensor in zip(image_names, image_tensors)}

        return image_processed

    def inference_batch(self, batch_size):
        caption_data_items = list(self._inference_captions.items())

        num_batches = len(caption_data_items) // batch_size
        for idx in range(num_batches):
            caption_samples = caption_data_items[idx * batch_size: (idx + 1) * batch_size]
            batch_imgs = []
            batch_captions = []

            # Increase index for the next batch
            idx += batch_size

            # Create a mini batch data
            for image_name, captions in caption_samples:
                batch_captions.append(captions)
                batch_imgs.append(self._load_and_prepare_image(image_name))

            # Batch image tensors
            batch_imgs = torch.stack(batch_imgs, dim=0)
            #if batch_size == 1:
            #    batch_imgs = batch_imgs.unsqueeze(0)

            yield batch_imgs, batch_captions

    def __len__(self):
        return self._dataset_size

    def __getitem__(self, index):
        table = str.maketrans('', '', string.punctuation)

        image_id, tokens = self._data[index]

        # Load and preprocess image
        image_tensor = self._load_and_prepare_image(image_id)
        # preprocess caption and add tokens
        tokens = [token.strip().lower().translate(table) for token in tokens]
        tokens = [self._start_token] + tokens + [self._end_token]

        # Create input and target tokens
        input_tokens = tokens[:-1].copy()
        tgt_tokens = tokens[1:].copy()

        # previously disabled
        sample_size = len(input_tokens)
        padding_size = self._max_len - sample_size

        if padding_size > 0:
            padding_vec = [self._pad_token for _ in range(padding_size)]
            input_tokens += padding_vec.copy()
            tgt_tokens += padding_vec.copy()

        input_tokens = [self._word2idx.get(token, self._unk_idx) for token in input_tokens]
        tgt_tokens = [self._word2idx.get(token, self._unk_idx) for token in tgt_tokens]

        # Tokens to Tensor
        input_tokens = torch.tensor(input_tokens).long()
        tgt_tokens = torch.tensor(tgt_tokens).long()

        # Create padding masks for captions
        tgt_padding_mask = torch.ones([self._max_len, ])
        tgt_padding_mask[:sample_size] = 0.0
        tgt_padding_mask = tgt_padding_mask.bool()

        return image_tensor, input_tokens, tgt_tokens, tgt_padding_mask



## Utils

In [5]:
def set_up_causal_mask(seq_len, device):
    """Defines the triangular mask used in transformers.
        This mask prevents decoder from attending the tokens after the current one.
    """
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)).to(device)
    mask.requires_grad = False
    return mask

def log_gradient_norm(model, writer, step, mode, norm_type=2):
    """Writes model param's gradients norm to tensorboard"""
    total_norm = 0
    for p in model.parameters():
        if p.requires_grad:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    writer.add_scalar(f"Gradient/{mode}", total_norm, step)

def save_checkpoint(name, encoder, decoder, enc_optimizer, dec_optimizer, start_time, epoch):
    """Saves specified model checkpoint."""
    #target_dir = os.path.join(checkpoint_path, str(start_time) + f'_{name}')
    target_dir = str(start_time) + f'_{name}'
    os.makedirs(target_dir, exist_ok=True)

    PATH = os.path.join(target_dir, f"{name}_{epoch}.pth")

    torch.save({
            'epoch': epoch,
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'enc_optimizer': enc_optimizer.state_dict(),
            'dec_optimizer': dec_optimizer.state_dict(),
    }, PATH)

    print("Model saved.")

## Evaluate

In [6]:
def inference(decoder, img_features, start_idx, end_idx, pad_idx, idx2word, batch_size, max_len, device):
    # Input words [<start>, <pad>, ...] + padding mask [False, ..., True]
    x_words = torch.Tensor([start_idx] + [pad_idx] * (max_len - 1)).to(device).long()
    x_words = x_words.repeat(batch_size, 1)
    padd_mask = torch.Tensor([True] * max_len).to(device).bool()
    padd_mask = padd_mask.repeat(batch_size, 1)

    # Flag for each image
    is_decoded = [False] * batch_size
    generated_captions = []
    for _ in range(batch_size):
        generated_captions.append([])

    for i in range(max_len - 1):
        # Update padding masks
        padd_mask[:, i] = False

        # Prediction for next word
        y_pred_prob = decoder(x_words, img_features, padd_mask)
        y_pred_prob = y_pred_prob[torch.arange(batch_size), [i] * batch_size].clone()
        y_pred = y_pred_prob.argmax(-1)

        # Add the generated word to generated_captions
        for batch_idx in range(batch_size):
            if is_decoded[batch_idx]:
                continue
            generated_captions[batch_idx].append(idx2word[str(y_pred[batch_idx].item())])
            if y_pred[batch_idx] == end_idx:
                is_decoded[batch_idx] = True

        if np.all(is_decoded):
            break

        if i < (max_len - 1):
            # Update the input tokens for the next iteration
            x_words[torch.arange(batch_size), [i+1] * batch_size] = y_pred.view(-1)

    # Add end token to unfinished caption
    for batch_idx in range(batch_size):
        if not is_decoded[batch_idx]:
            generated_captions[batch_idx].append(idx2word[str(end_idx)])

    # Clean the EOS symbol
    for caption in generated_captions:
        caption.remove("<end>")

    return generated_captions



def evaluate(dataset, encoder, decoder, device):
    batch_size = 4
    max_len = 32
    bleu_w = {
        "bleu-1": [1.0],
        "bleu-2": [0.5, 0.5],
        "bleu-3": [0.333, 0.333, 0.333],
        "bleu-4": [0.25, 0.25, 0.25, 0.25]
    }

    idx2word = dataset._idx2word
    start_idx = dataset._start_idx
    end_idx = dataset._end_idx
    pad_idx = dataset._pad_idx

    references = []
    predictions = []

    for x_img, y_caption in dataset.inference_batch(batch_size):
        x_img = x_img.to(device)

        # Extract image features
        with torch.no_grad():
            img_features = encoder(x_img)

        pred_captions = inference(decoder, img_features, start_idx, end_idx, pad_idx, idx2word, batch_size, max_len, device)
        references += y_caption
        predictions += pred_captions

    # Evaluate BLEU
    bleu_1 = corpus_bleu(references, predictions, weights=bleu_w["bleu-1"]) * 100
    bleu_2 = corpus_bleu(references, predictions, weights=bleu_w["bleu-2"]) * 100
    bleu_3 = corpus_bleu(references, predictions, weights=bleu_w["bleu-3"]) * 100
    bleu_4 = corpus_bleu(references, predictions, weights=bleu_w["bleu-4"]) * 100
    bleu = [bleu_1, bleu_2, bleu_3, bleu_4]

    return bleu


## Trainer

In [7]:
def train(device, writer, model_name, checkpoint=None) :

    torch.manual_seed(2023)
    np.random.seed(2023)

    ## Encoder ##
    encoder = EncoderSwin()
    encoder = encoder.to(device)
    if checkpoint!=None:
        encoder.load_state_dict(checkpoint['encoder'])
    encoder.train()

    ## Decoder ##
    decoder = CaptionDecoder()
    decoder = decoder.to(device)
    if checkpoint!=None:
        decoder.load_state_dict(checkpoint['decoder'])
    decoder.train()

    ## Config ##
    train_config = {
        "epochs": 50,
        "warmup_steps": 0,
        "learning_rate": 5e-6,
        "l2_penalty": 1e-2,
        "gradient_clipping": 2.0,
        "save_period": 5,
        "eval_period": 5
    }
    train_hyperparams = {
        "batch_size" : 4,
        "shuffle" : True
    }
    early_stopping = 10
    epochs_since_improvement = 0
    min_loss = 100

    # Create dataloader
    train_set = Flickr8KDataset(train_images_path, training=True)
    val_set = Flickr8KDataset(val_images_path, training=False)
    train_loader = DataLoader(train_set, **train_hyperparams)
    val_loader = DataLoader(val_set, **train_hyperparams)

    causal_mask = set_up_causal_mask(32, device)

    # Optimizer
    enc_optimizer = torch.optim.AdamW(
        encoder.parameters(),
        lr = train_config["learning_rate"],
        weight_decay = train_config["l2_penalty"]
    )
    dec_optimizer = torch.optim.AdamW(
        decoder.parameters(),
        lr = train_config["learning_rate"],
        weight_decay = train_config["l2_penalty"]
    )
    if checkpoint!=None:
        enc_optimizer.load_state_dict(checkpoint['enc_optimizer'])
        dec_optimizer.load_state_dict(checkpoint['dec_optimizer'])

    # LR Scheduler
    #enc_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(enc_optimizer, 'min')
    #dec_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(dec_optimizer, 'min')

    # Loss function
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=0)

    ## Start Train ##
    start_time = time.strftime("%b-%d_%H-%M-%S")
    train_step = 0

    # Load epoch checkpoint
    if checkpoint!=None:
        load_epoch = checkpoint['epoch']
    else:
        load_epoch = 0

    for epoch in range(load_epoch+1, train_config["epochs"]+1):

        encoder.train()
        decoder.train()
        train_loss = []
        val_loss = []

        with tqdm(train_loader) as tepoch:
            for x_img, x_words, y, tgt_padding_mask in tepoch:
                tepoch.set_description(f"Epoch {epoch}")

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()
                train_step += 1

                # Move tensor to device
                x_img, x_words = x_img.to(device), x_words.to(device)
                y = y.to(device)
                tgt_padding_mask = tgt_padding_mask.to(device)

                # Extract image features
                img_features = encoder(x_img)

                # Prediction from decoder
                y_pred = decoder(x_words, img_features, tgt_padding_mask, causal_mask)
                tgt_padding_mask = torch.logical_not(tgt_padding_mask)
                y_pred = y_pred[tgt_padding_mask]

                y = y[tgt_padding_mask]

                # Calculate loss
                loss = loss_fn(y_pred, y.long())
                # Backpropagation
                loss.backward()

                # Log gradient
                torch.nn.utils.clip_grad_norm_(encoder.parameters(), train_config["gradient_clipping"])
                torch.nn.utils.clip_grad_norm_(decoder.parameters(), train_config["gradient_clipping"])

                # Update weights
                enc_optimizer.step()
                dec_optimizer.step()

                # Log loss
                train_loss.append(loss.item())
                avg_train_loss = sum(train_loss) / len(train_loss)
                tepoch.set_postfix(loss=avg_train_loss)

            writer.add_scalar("Train/Epoch-Loss", avg_train_loss, epoch)

        # Validation
        encoder.eval()
        decoder.eval()

        with tqdm(val_loader) as vepoch:

            for x_img, x_words, y, tgt_padding_mask in vepoch:
                vepoch.set_description(f"Validation ")
                # Move tensor to device
                x_img, x_words = x_img.to(device), x_words.to(device)
                y = y.to(device)
                tgt_padding_mask = tgt_padding_mask.to(device)

                # Extract image features
                with torch.no_grad():
                    # Extract image features
                    img_features = encoder(x_img)

                    # Prediction from decoder
                    y_pred = decoder(x_words, img_features, tgt_padding_mask, causal_mask)
                    tgt_padding_mask = torch.logical_not(tgt_padding_mask)
                    y_pred = y_pred[tgt_padding_mask]

                    y = y[tgt_padding_mask]

                    # Calculate loss
                    loss = loss_fn(y_pred, y.long())

                val_loss.append(loss.item())
                avg_val_loss = sum(val_loss) / len(val_loss)
                vepoch.set_postfix(loss=avg_val_loss)

            writer.add_scalar("Valid/Epoch-Loss", avg_val_loss, epoch)

        #enc_scheduler.step(avg_val_epoch_loss)
        #dec_scheduler.step(avg_val_epoch_loss)

        # Early Stopping
        '''
        if min_loss <= avg_val_epoch_loss:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0
            min_loss = min(min_loss, avg_val_epoch_loss)
        if epochs_since_improvement == early_stopping:
            save_checkpoint(model_name, encoder, decoder, enc_optimizer, dec_optimizer, start_time, epoch)
            break
        '''


        # Save model state
        if (epoch) % train_config['save_period'] == 0:
            save_checkpoint(model_name, encoder, decoder, enc_optimizer, dec_optimizer, start_time, epoch)

        # Evaluate model performance
        if (epoch) % train_config["eval_period"] == 0:
            with torch.no_grad():
                encoder.eval()
                decoder.eval()

                # Evaluate model performance on subsets
                train_bleu = evaluate(train_set, encoder, decoder, device)
                val_bleu = evaluate(val_set, encoder, decoder, device)

                print('Train BLEU', train_bleu)
                print('Valid BLEU', val_bleu)

                # Log the evaluated BLEU score
                for i, tr_b in enumerate(train_bleu):
                    writer.add_scalar(f"Train/BLEU-{i+1}", tr_b, epoch)
                for i, tv_b in enumerate(val_bleu):
                    writer.add_scalar(f"Valid/BLEU-{i+1}", tv_b, epoch)

                encoder.train()
                decoder.train()

    writer.flush()

# Main

## Tensorboard

In [17]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [18]:
#%tensorboard --logdir='/content/drive/MyDrive/Colab Notebooks/runs/'
%tensorboard

ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
2023-11-07 11:18:27.972868: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 11:18:27.972943: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 11:18:27.972974: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Error: A logdir or db must be specified. For example `tensorboard --logdir mylogdir` or `tensorboard --db sqlite:~/.tensorboard.db`. Run `tensorboard --helpfull` for details and examples.

In [None]:
!kill

## Train

In [11]:
start_time = time.strftime("%b-%d_%H-%M-%S")
model_name = 'swin-trans-lr56-wd2'
run_dir = os.path.join(run_path, str(start_time) + f'_{model_name}')

writer = SummaryWriter()
use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")
print("Running on", device)

#PATH = '/content/Nov-04_04-06-34_swin-trans-finetune/swin-trans-finetune.pth'
#checkpoint = torch.load(PATH)

train(device, writer, model_name)

Running on cuda


Downloading: "https://download.pytorch.org/models/swin_t-704ceda3.pth" to /root/.cache/torch/hub/checkpoints/swin_t-704ceda3.pth
100%|██████████| 108M/108M [00:00<00:00, 141MB/s] 
Epoch 1: 100%|██████████| 871/871 [06:36<00:00,  2.20it/s, loss=4.85]
Validation : 100%|██████████| 93/93 [00:32<00:00,  2.87it/s, loss=4.58]
Epoch 2: 100%|██████████| 871/871 [02:15<00:00,  6.41it/s, loss=4.36]
Validation : 100%|██████████| 93/93 [00:09<00:00,  9.66it/s, loss=4.43]
Epoch 3: 100%|██████████| 871/871 [02:15<00:00,  6.44it/s, loss=4.18]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.16it/s, loss=4.23]
Epoch 4: 100%|██████████| 871/871 [02:15<00:00,  6.44it/s, loss=4]
Validation : 100%|██████████| 93/93 [00:07<00:00, 13.00it/s, loss=4.17]
Epoch 5: 100%|██████████| 871/871 [02:17<00:00,  6.33it/s, loss=3.81]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.40it/s, loss=4.06]


Model saved.


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


Train BLEU [22.158504937426805, 4.088006633307689e-153, 3.732489607799368e-204, 5.552609865984194e-230]
Valid BLEU [20.155851277058407, 4.0337558982386097e-153, 3.784419724859241e-204, 5.70642150360679e-230]


Epoch 6: 100%|██████████| 871/871 [02:15<00:00,  6.43it/s, loss=3.65]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.74it/s, loss=3.99]
Epoch 7: 100%|██████████| 871/871 [02:16<00:00,  6.40it/s, loss=3.54]
Validation : 100%|██████████| 93/93 [00:08<00:00, 11.18it/s, loss=3.94]
Epoch 8: 100%|██████████| 871/871 [02:16<00:00,  6.40it/s, loss=3.44]
Validation : 100%|██████████| 93/93 [00:06<00:00, 13.35it/s, loss=3.82]
Epoch 9: 100%|██████████| 871/871 [02:15<00:00,  6.42it/s, loss=3.38]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.36it/s, loss=3.83]
Epoch 10: 100%|██████████| 871/871 [02:16<00:00,  6.39it/s, loss=3.3]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.30it/s, loss=3.82]


Model saved.
Train BLEU [29.265367426576148, 4.7050794422899815e-153, 4.102936082337701e-204, 5.96586539553077e-230]
Valid BLEU [22.17143640476425, 4.2306388774311403e-153, 3.90645712016477e-204, 5.844024377529853e-230]


Epoch 11: 100%|██████████| 871/871 [02:16<00:00,  6.37it/s, loss=3.24]
Validation : 100%|██████████| 93/93 [00:06<00:00, 13.76it/s, loss=3.78]
Epoch 12: 100%|██████████| 871/871 [02:16<00:00,  6.38it/s, loss=3.18]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.68it/s, loss=3.75]
Epoch 13: 100%|██████████| 871/871 [02:17<00:00,  6.35it/s, loss=3.14]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.31it/s, loss=3.79]
Epoch 14: 100%|██████████| 871/871 [02:14<00:00,  6.45it/s, loss=3.09]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.11it/s, loss=3.75]
Epoch 15: 100%|██████████| 871/871 [02:16<00:00,  6.39it/s, loss=3.04]
Validation : 100%|██████████| 93/93 [00:06<00:00, 13.58it/s, loss=3.74]


Model saved.
Train BLEU [32.290096900067475, 3.11276287426979, 2.483260705680467e-102, 1.5526929932839672e-153]
Valid BLEU [21.163643840911327, 4.1333698091149876e-153, 3.8464078648728565e-204, 5.776452046050107e-230]


Epoch 16: 100%|██████████| 871/871 [02:14<00:00,  6.48it/s, loss=3.01]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.11it/s, loss=3.73]
Epoch 17: 100%|██████████| 871/871 [02:14<00:00,  6.49it/s, loss=2.99]
Validation : 100%|██████████| 93/93 [00:07<00:00, 13.08it/s, loss=3.77]
Epoch 18: 100%|██████████| 871/871 [02:15<00:00,  6.44it/s, loss=2.94]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.98it/s, loss=3.83]
Epoch 19: 100%|██████████| 871/871 [02:14<00:00,  6.50it/s, loss=2.92]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.47it/s, loss=3.78]
Epoch 20: 100%|██████████| 871/871 [02:14<00:00,  6.49it/s, loss=2.89]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.41it/s, loss=3.86]


Model saved.
Train BLEU [36.10310968115946, 5.23934383360648, 3.61238611553975e-102, 2.100717395267331e-153]
Valid BLEU [21.667540122837785, 4.1822871308732315e-153, 3.876665381271931e-204, 5.810532887640725e-230]


Epoch 21: 100%|██████████| 871/871 [02:13<00:00,  6.52it/s, loss=2.86]
Validation : 100%|██████████| 93/93 [00:06<00:00, 15.02it/s, loss=3.85]
Epoch 22: 100%|██████████| 871/871 [02:13<00:00,  6.52it/s, loss=2.83]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.68it/s, loss=3.87]
Epoch 23: 100%|██████████| 871/871 [02:15<00:00,  6.42it/s, loss=2.81]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.82it/s, loss=3.88]
Epoch 24: 100%|██████████| 871/871 [02:13<00:00,  6.53it/s, loss=2.78]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.92it/s, loss=3.82]
Epoch 25: 100%|██████████| 871/871 [02:15<00:00,  6.44it/s, loss=2.76]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.25it/s, loss=3.87]


Model saved.
Train BLEU [48.61283011050161, 15.140226964724226, 8.068093626394063e-102, 4.12800298454336e-153]
Valid BLEU [24.042089922115874, 4.781709294114566e-153, 4.4768431948084944e-204, 6.74355098927916e-230]


Epoch 26: 100%|██████████| 871/871 [02:13<00:00,  6.55it/s, loss=2.75]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.49it/s, loss=3.87]
Epoch 27: 100%|██████████| 871/871 [02:12<00:00,  6.55it/s, loss=2.72]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.38it/s, loss=3.95]
Epoch 28: 100%|██████████| 871/871 [02:14<00:00,  6.47it/s, loss=2.7]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.63it/s, loss=3.9]
Epoch 29: 100%|██████████| 871/871 [02:12<00:00,  6.55it/s, loss=2.69]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.53it/s, loss=3.99]
Epoch 30: 100%|██████████| 871/871 [02:12<00:00,  6.57it/s, loss=2.66]
Validation : 100%|██████████| 93/93 [00:08<00:00, 10.63it/s, loss=3.93]


Model saved.
Train BLEU [65.54101334839211, 30.599774674705056, 4.496492543228723, 1.0788245068637742e-76]
Valid BLEU [31.025835362637018, 8.117415247118956, 5.340524387248961e-102, 3.034189765567951e-153]


Epoch 31: 100%|██████████| 871/871 [02:13<00:00,  6.53it/s, loss=2.64]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.19it/s, loss=3.98]
Epoch 32: 100%|██████████| 871/871 [02:13<00:00,  6.53it/s, loss=2.63]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.05it/s, loss=4.02]
Epoch 33: 100%|██████████| 871/871 [02:14<00:00,  6.48it/s, loss=2.61]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.90it/s, loss=4.01]
Epoch 34: 100%|██████████| 871/871 [02:11<00:00,  6.60it/s, loss=2.59]
Validation : 100%|██████████| 93/93 [00:07<00:00, 13.20it/s, loss=4]
Epoch 35: 100%|██████████| 871/871 [02:12<00:00,  6.57it/s, loss=2.58]
Validation : 100%|██████████| 93/93 [00:07<00:00, 13.13it/s, loss=4.01]


Model saved.
Train BLEU [77.32727106196523, 52.391821202711306, 12.372719477663216, 2.4151942851845872e-76]
Valid BLEU [38.39884507789223, 21.575470528003958, 1.1331117735197271e-101, 5.755694466947957e-153]


Epoch 36: 100%|██████████| 871/871 [02:13<00:00,  6.54it/s, loss=2.56]
Validation : 100%|██████████| 93/93 [00:06<00:00, 13.50it/s, loss=4.02]
Epoch 37: 100%|██████████| 871/871 [02:13<00:00,  6.55it/s, loss=2.56]
Validation : 100%|██████████| 93/93 [00:06<00:00, 13.74it/s, loss=4.02]
Epoch 38: 100%|██████████| 871/871 [02:14<00:00,  6.46it/s, loss=2.54]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.89it/s, loss=4.04]
Epoch 39: 100%|██████████| 871/871 [02:12<00:00,  6.58it/s, loss=2.53]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.99it/s, loss=4.12]
Epoch 40: 100%|██████████| 871/871 [02:12<00:00,  6.58it/s, loss=2.51]
Validation : 100%|██████████| 93/93 [00:06<00:00, 13.44it/s, loss=4.04]


Model saved.
Train BLEU [81.18963308150377, 58.75509740412117, 29.034590695537588, 4.6713455612081645e-76]
Valid BLEU [40.328478885324415, 22.395831644399866, 8.215735686880324, 1.7643260223123518e-76]


Epoch 41: 100%|██████████| 871/871 [02:14<00:00,  6.49it/s, loss=2.49]
Validation : 100%|██████████| 93/93 [00:06<00:00, 15.11it/s, loss=4.06]
Epoch 42: 100%|██████████| 871/871 [02:12<00:00,  6.56it/s, loss=2.48]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.86it/s, loss=4.08]
Epoch 43: 100%|██████████| 871/871 [02:14<00:00,  6.48it/s, loss=2.48]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.96it/s, loss=4.1]
Epoch 44: 100%|██████████| 871/871 [02:12<00:00,  6.57it/s, loss=2.47]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.71it/s, loss=4.12]
Epoch 45: 100%|██████████| 871/871 [02:12<00:00,  6.59it/s, loss=2.45]
Validation : 100%|██████████| 93/93 [00:06<00:00, 15.00it/s, loss=4.09]


Model saved.
Train BLEU [81.43551568073697, 62.42729856066337, 37.79219909991227, 5.677670222389189e-76]
Valid BLEU [37.60362611045187, 22.21795536269732, 11.767079879738889, 2.3083422177148206e-76]


Epoch 46: 100%|██████████| 871/871 [02:17<00:00,  6.36it/s, loss=2.43]
Validation : 100%|██████████| 93/93 [00:07<00:00, 11.98it/s, loss=4.15]
Epoch 47: 100%|██████████| 871/871 [02:17<00:00,  6.32it/s, loss=2.41]
Validation : 100%|██████████| 93/93 [00:06<00:00, 13.99it/s, loss=4.12]
Epoch 48: 100%|██████████| 871/871 [02:16<00:00,  6.37it/s, loss=2.41]
Validation : 100%|██████████| 93/93 [00:06<00:00, 14.80it/s, loss=4.18]
Epoch 49: 100%|██████████| 871/871 [02:17<00:00,  6.32it/s, loss=2.41]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.50it/s, loss=4.1]
Epoch 50: 100%|██████████| 871/871 [02:13<00:00,  6.53it/s, loss=2.39]
Validation : 100%|██████████| 93/93 [00:07<00:00, 12.48it/s, loss=4.12]


Model saved.
Train BLEU [82.3556458995109, 66.5217156362642, 43.87715391160611, 6.379897961538142e-76]
Valid BLEU [40.52766428452431, 24.872957734923993, 14.174125227944568, 2.693863712590783e-76]


In [13]:
import shutil
shutil.copy("/content/Nov-07_09-10-25_swin-trans-lr56-wd2/swin-trans-lr56-wd2_25.pth", "/content/drive/MyDrive/")
#shutil.copy('/content/runs/Nov04_04-06-13_435e95dac81f/events.out.tfevents.1699070773.435e95dac81f.249.0', "/content/drive/MyDrive/")

'/content/drive/MyDrive/swin-trans-lr56-wd2_25.pth'

# Evaluate Test

## Inference Test Function

In [None]:
def inference_test(decoder, img_features, start_idx, end_idx, pad_idx, idx2word, batch_size, max_len, device):
    # Input words [<start>, <pad>, ...] + padding mask [False, ..., True]
    x_words = torch.Tensor([start_idx] + [pad_idx] * (max_len - 1)).to(device).long()
    x_words = x_words.repeat(batch_size, 1)
    padd_mask = torch.Tensor([True] * max_len).to(device).bool()
    padd_mask = padd_mask.repeat(batch_size, 1)

    # Flag for each image
    is_decoded = [False] * batch_size
    generated_captions = []
    for _ in range(batch_size):
        generated_captions.append([])

    for i in range(max_len - 1):
        # Update padding masks
        padd_mask[:, i] = False

        # Prediction for next word
        y_pred_prob = decoder(x_words, img_features, padd_mask)
        y_pred_prob = y_pred_prob[torch.arange(batch_size), [i] * batch_size].clone()
        y_pred = y_pred_prob.argmax(-1)

        # Add the generated word to generated_captions
        for batch_idx in range(batch_size):
            if is_decoded[batch_idx]:
                continue
            generated_captions[batch_idx].append(idx2word[str(y_pred[batch_idx].item())])
            if y_pred[batch_idx] == end_idx:
                is_decoded[batch_idx] = True

        if np.all(is_decoded):
            break

        if i < (max_len - 1):
            # Update the input tokens for the next iteration
            x_words[torch.arange(batch_size), [i+1] * batch_size] = y_pred.view(-1)

    # Add end token to unfinished caption
    for batch_idx in range(batch_size):
        if not is_decoded[batch_idx]:
            generated_captions[batch_idx].append(idx2word[str(end_idx)])

    # Clean the EOS symbol
    for caption in generated_captions:
        caption.remove("<end>")

    return generated_captions



def evaluate_test(dataset, encoder, decoder, device):
    batch_size = 4
    max_len = 64
    bleu_w = {
        "bleu-1": [1.0],
        "bleu-2": [0.5, 0.5],
        "bleu-3": [0.333, 0.333, 0.333],
        "bleu-4": [0.25, 0.25, 0.25, 0.25]
    }

    idx2word = dataset._idx2word
    start_idx = dataset._start_idx
    end_idx = dataset._end_idx
    pad_idx = dataset._pad_idx

    references = []
    predictions = []

    print("Evaluating...")
    for x_img, y_caption in dataset.inference_batch(batch_size):
        x_img = x_img.to(device)

        # Extract image features
        with torch.no_grad():
            img_features = encoder._process_input(x_img)
            batch_class_token = encoder.class_token.expand(img_features.shape[0], -1, -1)
            img_features = torch.cat([batch_class_token, img_features], dim=1)
            img_features = encoder.encoder(img_features)
            img_features = img_features[:, 0]
            #img_features = img_features.unsqueeze(0)

        pred_captions = inference_test(decoder, img_features, start_idx, end_idx, pad_idx, idx2word, batch_size, max_len, device)
        references += y_caption
        predictions += pred_captions

    # Evaluate BLEU
    bleu_1 = corpus_bleu(references, predictions, weights=bleu_w["bleu-1"]) * 100
    bleu_2 = corpus_bleu(references, predictions, weights=bleu_w["bleu-2"]) * 100
    bleu_3 = corpus_bleu(references, predictions, weights=bleu_w["bleu-3"]) * 100
    bleu_4 = corpus_bleu(references, predictions, weights=bleu_w["bleu-4"]) * 100
    bleu = [bleu_1, bleu_2, bleu_3, bleu_4]

    return bleu

In [None]:
import matplotlib.pyplot as plt
import itertools

def caption_test(dataset, encoder, decoder, device, index):
    batch_size = 1
    max_len = 64

    idx2word = dataset._idx2word
    start_idx = dataset._start_idx
    end_idx = dataset._end_idx
    pad_idx = dataset._pad_idx

    x_img, y_caption = next(itertools.islice(dataset.inference_batch(batch_size), index, None))

    im_prev = x_img.squeeze().permute(1, 2, 0).float()
    plt.imshow(im_prev)
    plt.show()
    x_img = x_img.to(device)

    # Extract image features
    with torch.no_grad():
        img_features = encoder._process_input(x_img)
        batch_class_token = encoder.class_token.expand(img_features.shape[0], -1, -1)
        img_features = torch.cat([batch_class_token, img_features], dim=1)
        img_features = encoder.encoder(img_features)
        img_features = img_features[:, 0]
        #img_features = img_features.unsqueeze(0)

    pred_captions = inference_test(decoder, img_features, start_idx, end_idx, pad_idx, idx2word, batch_size, max_len, device)

    print('Reference:')
    for ref in y_caption[0]:
        print(ref)
    print('Prediction: ', pred_captions)




In [None]:
device = torch.device("cuda")
test_set = Flickr8KDataset(test_images_path, training=False)
cp_path = '/content/drive/MyDrive/Colab Notebooks/Checkpoints/Sep-15_05-45-51_pretrained-vit-w2v/model_90.pth'
checkpoint = torch.load(cp_path)

encoder = vit_b_16()
encoder = encoder.to(device)
encoder.load_state_dict(checkpoint['encoder'])
for name, param in encoder.named_parameters():
    if re.match('^heads', name) : param.requires_grad = False
encoder.eval()

decoder = CaptionDecoder()
decoder = decoder.to(device)
decoder.load_state_dict(checkpoint['decoder'])
decoder.eval()


## Result

In [None]:
#test_bleu = evaluate_test(test_set, encoder, decoder, device)
#print(test_bleu)

caption_test(test_set, encoder, decoder, device, 6)