### Installing Necessary Packages

In [1]:
# Installing the COCO API
! pip install -U 'git+https://github.com/leimao/cocoapi.git#subdirectory=PythonAPI' > /dev/null

# For pretrained DeiT Transformer
! pip install timm > /dev/null

  Running command git clone --filter=blob:none --quiet https://github.com/leimao/cocoapi.git /tmp/pip-req-build-semk52up


### Necessary Packages

In [2]:
import torch
import os
import random
import numpy as np 
import nltk
import pickle
import timm
import time
import torch.nn.functional as F
from tqdm.notebook import tqdm
from pycocotools.coco import COCO
from collections import Counter
from torchvision.transforms import Compose,ToTensor,RandomResizedCrop,RandomHorizontalFlip,Resize
from torch.utils.data import Dataset,DataLoader
from typing import Optional,Callable
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch import nn
from torch import optim
from timm.models.layers import trunc_normal_
from torchmetrics import Metric
from transformers import get_linear_schedule_with_warmup

 ### Global

In [3]:
class GLOBAL:
    
    working_dir = "/kaggle/working"
    input_dir = "/kaggle/input"
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    max_len = 100
    img_size = 384
    num_bins = img_size
    num_classes = 91
    word_threshold = 20
    num_patches = 576
    
    feature_extractor_name = "deit3_small_patch16_384.fb_in22k_ft_in1k"
    dataset_root = os.path.join(input_dir, "coco-2017-dataset", "coco2017")
    caption_path = os.path.join(dataset_root, 'annotations', 'captions_train2017.json')
    vocab_path = os.path.join(working_dir, "vocab.pkl")
    weights_dir = os.path.join(working_dir, "weights")
    
    lr = 10e-4
    weight_decay = 10e-4
    start_epoch = 0
    batch_size = 16
    epochs = 2
    features_dim = 256
    
    num_workers = 2

In [4]:
class FLAGS:
    
    build_vocab = False
    train = True

In [5]:
if not os.path.exists(GLOBAL.weights_dir):
    os.mkdir(GLOBAL.weights_dir)

### Reproducibility

In [6]:
SEED = 50

torch.cuda.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False

### Vocabulary

In [7]:
class Vocab:
    
    def __init__(self):
        
        self.str_2_idx = {}
        self.idx_2_str = {}
        self.idx = 0
        
        self.unk_token = '<unk>'
        
        self.add_word(self.unk_token)
                
    def add_word(self, word : str):
        
        if word not in self.str_2_idx:
            self.str_2_idx[word] = self.idx
            self.idx_2_str[self.idx] = word
            self.idx += 1
            
    def __call__(self, word : str):
        
        if word not in self.str_2_idx:
            return self.str_2_idx[self.unk_token]
        
        return self.str_2_idx[word]
    
    def get_word(self, index : int):
        
        if index not in self.idx_2_str:
            return self.unk_token
        
        return self.idx_2_str[index]
    
    def __len__(self):
        return len(self.str_2_idx)
        

In [8]:
def build_vocab(json : str, word_threshold : int):
    
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    
    for i,id in tqdm(enumerate(ids), total=len(ids)):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower()) 
        counter.update(tokens)
        
    words = [word for word, cnt in counter.items() if cnt >= word_threshold]
    
    vocab = Vocab()

    # Add the words to the vocabulary.
    for i, word in tqdm(enumerate(words), total=len(words)):
        vocab.add_word(word)
        
    return vocab

In [9]:
def load_or_build_vocab(
    build : bool,
    caption_path : str,
    word_threshold : int,
    vocab_path : str,
    verbose : bool = True
) -> Vocab:
    
    if build:
        
        if verbose:
            print(f"-- Building vocabulary --")
    
        vocab = build_vocab(caption_path, word_threshold)

        with open(vocab_path, 'wb') as f:
            pickle.dump(vocab, f)
            
        if verbose:
            print(f"-- Vocabulary size = {len(vocab)} --")
            print(f"-- Vocabulary saved to = {vocab_path} --")
            
        return vocab
                    
    if verbose:
        print(f"-- Loading vocabulary from path : {vocab_path} --")
            
    vocab = None
        
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)
            
    return vocab

In [10]:
!ls

state.db  vocab.pkl  weights


In [11]:
vocab = load_or_build_vocab(FLAGS.build_vocab, GLOBAL.caption_path, GLOBAL.word_threshold, GLOBAL.vocab_path)

-- Loading vocabulary from path : /kaggle/working/vocab.pkl --


### Tokenization

In [12]:
class Tokenizer:
    
    def __init__(self, num_classes: int, 
        num_bins: int, 
        width: int, 
        height: int, 
        max_len : int
    ):
        self.num_classes = num_classes
        self.num_bins = num_bins
        self.width = width
        self.height = height
        self.max_len = max_len
        self.max_len_obj = int((self.max_len - 2) / 5)
        self.BOS_code = num_classes + num_bins 
        self.EOS_code = self.BOS_code + 1
        self.PAD_code = self.EOS_code + 1
        self.text_id_shift = 550
        self.vocab_size = 6000 

In [13]:
tokenizer = Tokenizer(
    num_classes=GLOBAL.num_classes, 
    num_bins=GLOBAL.num_bins,
    width=GLOBAL.img_size, 
    height=GLOBAL.img_size, 
    max_len=GLOBAL.max_len
)

### Preprocessing

In [14]:
def create_transfroms(size : int) -> tuple[Compose,Compose]:
    
    train_transfroms = Compose([
        RandomHorizontalFlip(p=0.5),
        RandomResizedCrop(size=size),
        ToTensor()
    ])
    
    val_transfroms = Compose([
        RandomResizedCrop(size=size),
        ToTensor()
    ])
    
    return train_transfroms,val_transfroms

In [15]:
class CaptionPreprocessor:
    
    def __init__(self,vocab : Vocab, tokenizer : Tokenizer):
        self.vocab = vocab
        self.tokenizer = tokenizer
    
    def __call__(self, caption : str):
        
        tokenizer = self.tokenizer
        vocab = self.vocab
    
        tokens = nltk.tokenize.word_tokenize(caption.lower()) 
        tokens = [tokenizer.BOS_code] + [vocab(token) + tokenizer.text_id_shift for token in tokens] + [tokenizer.EOS_code]
        return torch.tensor(tokens).type(torch.long)

### Data Loading

In [16]:
class COCODataset(Dataset):
    
    def __init__(self,
        root : str,
        split : str,
        img_transfroms : Optional[Callable] = None,
        caption_transfroms : Optional[Callable] = None
    ):
        
        self.root = root
        self.annot_filename = os.path.join(root, "annotations", f"captions_{split}2017.json")
        self.img_root = os.path.join(root, f"{split}2017")
        
        if not os.path.exists(self.root):
            raise Exception(f"{self.root} doesn't exist.")
            
        if not os.path.exists(self.annot_filename):
            raise Exception(f"{self.annot_filename} doesn't exist.")
            
        if not os.path.exists(self.img_root):
            raise Exception(f"{self.img_root} doesn't exist.")
            
        self.coco = COCO(self.annot_filename)
        self.ids = list(self.coco.anns.keys())
        self.caption_transfroms = caption_transfroms
        self.img_transfroms = img_transfroms
    
    def __getitem__(self, index : int):
        
        caption_id = self.ids[index]
        caption = self.coco.anns[caption_id]['caption']
        
        img_id = self.coco.anns[caption_id]['image_id']
        path = self.coco.loadImgs(img_id)[0]['file_name']
        path = os.path.join(self.img_root, path)
        
        img = Image.open(path).convert('RGB')
        
        if self.img_transfroms is not None:
            img = self.img_transfroms(img)
            
        if self.caption_transfroms is not None:
            caption = self.caption_transfroms(caption)
            
        return img,caption
    
    def __len__(self) -> int:
        return len(self.ids)

In [17]:
class Collate:
    
    def __init__(self,pad_idx : int,max_len : Optional[int]):
        self.pad_idx = pad_idx
        self.max_len = max_len
    
    def __call__(self, batch : list[tuple[torch.Tensor,torch.Tensor]]):
        
        imgs = torch.stack([row[0] for row in batch])
        captions = [row[1] for row in batch]
        
        captions = pad_sequence(captions, padding_value=self.pad_idx,batch_first=True)
        
        if self.max_len is not None:
            pad = torch.ones(captions.size(0), self.max_len - captions.size(1)).fill_(self.pad_idx).type(torch.long)
            captions = torch.cat([captions, pad], dim=1)
            
        return imgs,captions

In [18]:
def create_dataloaders(
    dataset_root : str,
    img_size : int,
    batch_size : int,
    num_workers : int,
    max_len : int,
    vocab : Vocab,
    tokenizer : Tokenizer
) -> tuple[DataLoader,DataLoader]:
    
    train_transfroms, val_transfroms = create_transfroms(img_size)
    
    train_caption_transfroms = CaptionPreprocessor(vocab, tokenizer)
    val_caption_transfroms = CaptionPreprocessor(vocab, tokenizer)
    
    train_collate = Collate(tokenizer.PAD_code,max_len)
    val_collate = Collate(tokenizer.PAD_code,max_len)

    train_data = COCODataset(dataset_root,split='train',img_transfroms=train_transfroms, caption_transfroms=train_caption_transfroms)
    val_data = COCODataset(dataset_root,split='val',img_transfroms=val_transfroms, caption_transfroms=val_caption_transfroms)
    
    train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=True,collate_fn=train_collate)
    val_loader = DataLoader(dataset=val_data,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=True,collate_fn=val_collate)
    
    return train_loader,val_loader

### Architecture

In [19]:
class Encoder(nn.Module):
        
    def __init__(self, model_name, out_dim : int, pretrained : bool = True):
        super().__init__()
        
        self.out_dim = out_dim

        self.model = timm.create_model(model_name, num_classes=0, 
                                           global_pool='', pretrained=pretrained)
        self.bottleneck = nn.AdaptiveAvgPool1d(out_dim)

    def forward(self, x):
        features = self.model(x)
        return self.bottleneck(features[:, 1:])

In [20]:
class PosEmbeddings(nn.Module):
    
    def __init__(self, max_len : int,dim : int):
        super().__init__()
        
        self.max_len = max_len
        self.dim = dim
        
        self.weight = nn.Parameter(torch.randn(1, max_len, dim) * .02)
        
        self.init_weights()
        
    def forward(self, x):
        return x + self.weight
    
    def init_weights(self):
        trunc_normal_(self.weight, std=.02)

In [21]:
class Mask(nn.Module):
    
    def __init__(self, 
        device : torch.device,
        pad_idx : int
    ):
        super().__init__()
        
        self.device = device
        self.pad_idx = pad_idx
        
    def forward(self, target):
        
        target_len = target.shape[1]
        
        mask = torch.ones(size=(target_len, target_len), device=self.device)
        mask = torch.triu(mask)
        mask = mask == 1
        mask = mask.transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
                
        target_padding_mask = (target == self.pad_idx)
        
        return mask, target_padding_mask

In [22]:
class Decoder(nn.Module):
    
    def __init__(self,
        vocab_size : int, 
        encoder_length : int, 
        dim : int, 
        max_len : int,
        num_heads : int, 
        num_layers : int,
        device : torch.device,
        pad_idx : int
    ):
        super().__init__()
        
        self.mask = Mask(device,pad_idx)
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size,embedding_dim=dim)
        
        self.decoder_pos_embed = PosEmbeddings(max_len-1, dim)
        self.decoder_pos_drop = nn.Dropout(p=0.05)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=dim, nhead=num_heads)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output = nn.Linear(dim, vocab_size)
        
        self.encoder_pos_embed = PosEmbeddings(encoder_length, dim)
        self.encoder_pos_drop = nn.Dropout(p=0.05)
        
        self.init_weights()
    
    def init_weights(self):
        for name, param in self.named_parameters():
            if not 'encoder_pos_embed' in name and 'decoder_pos_embed' not in name and param.dim() > 1:
                nn.init.xavier_uniform_(param)
                
    def forward(self, encoder_out, target):
                
        target_mask, target_padding_mask = self.mask(target)
        
        target_embedding = self.embedding(target)
        target_embedding = self.decoder_pos_embed(target_embedding)
        target_embedding = self.decoder_pos_drop(target_embedding)
        
        encoder_out = self.encoder_pos_embed(encoder_out)
        encoder_out = self.encoder_pos_drop(encoder_out)
        
        encoder_out = encoder_out.transpose(0, 1)
        target_embedding = target_embedding.transpose(0, 1)
        
        preds = self.decoder(
            memory=encoder_out,
            tgt=target_embedding,
            tgt_mask=target_mask,
            tgt_key_padding_mask=target_padding_mask.float()
        )
        
        preds = preds.transpose(0, 1)
        
        outputs = self.output(preds)
        
        return outputs

In [23]:
class EncoderDecoder(nn.Module):
    
    def __init__(self,
        feature_extractor_name : str,
        features_dim : int,
        vocab_size : int,
        encoder_length : int,
        num_heads : int,
        num_layers : int,
        max_len : int,
        device : torch.device,
        pad_idx : int
    ):
        super().__init__()
        
        self.feature_extractor_name = feature_extractor_name
        self.features_dim = features_dim
        self.vocab_size = vocab_size
        self.encoder_length = encoder_length
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.max_len = max_len
        self.device = device
        self.pad_idx = pad_idx
        
        self.encoder = Encoder(
            model_name=self.feature_extractor_name,
            out_dim=features_dim,
            pretrained=True
        )
        
        self.decoder = Decoder(
            vocab_size=self.vocab_size,
            encoder_length=self.encoder_length,
            dim=self.features_dim,
            max_len=self.max_len,
            num_heads=self.num_heads,
            num_layers=self.num_layers,
            device=self.device,
            pad_idx=self.pad_idx
        )
    
    def forward(self, image, target):
        encoder_out = self.encoder(image)
        preds = self.decoder(encoder_out, target)
        return preds

In [24]:
class Seq2SeqCrossEntropyLoss(nn.CrossEntropyLoss):
    
    def forward(self, y_hat : torch.Tensor, y : torch.Tensor) -> torch.Tensor:
        
        y_hat = y_hat.reshape(-1, y_hat.size(-1))
        y = y.reshape(-1)
        
        return super().forward(y_hat, y)

### Training

In [25]:
def read_weights_folder(path : str):
    
    folders = os.listdir(path)
    folders = filter(lambda f : f.endswith('.pt'),folders)
    folders = map(lambda f : os.path.join(path, f), folders)
    folders = sorted(folders)
    num_epochs = len(folders)
    
    last_weights = None
    
    if num_epochs != 0:
        last_weights = torch.load(folders[-1])
        
    return last_weights,num_epochs

In [26]:
weights, num_epochs = read_weights_folder(GLOBAL.weights_dir)
print(num_epochs,weights != None)

0 False


In [27]:
train_loader,val_loader = create_dataloaders(
    dataset_root=GLOBAL.dataset_root,
    img_size=GLOBAL.img_size,
    batch_size=GLOBAL.batch_size,
    num_workers=GLOBAL.num_workers,
    max_len=GLOBAL.max_len,
    vocab=vocab,
    tokenizer=tokenizer
)

loading annotations into memory...
Done (t=2.14s)
creating index...
index created!
loading annotations into memory...
Done (t=0.11s)
creating index...
index created!


In [28]:
x, y = next(iter(train_loader))
x, y = x.to(GLOBAL.device), y.to(GLOBAL.device)

In [29]:
model = EncoderDecoder(
    feature_extractor_name = GLOBAL.feature_extractor_name,
    features_dim = GLOBAL.features_dim,
    vocab_size = tokenizer.vocab_size,
    encoder_length = GLOBAL.num_patches,
    num_heads = 8,
    num_layers = 6,
    max_len = GLOBAL.max_len,
    device = GLOBAL.device,
    pad_idx = tokenizer.PAD_code
).to(GLOBAL.device)

model.safetensors:   0%|          | 0.00/88.8M [00:00<?, ?B/s]

In [30]:
if weights is not None:
    model.load_state_dict(weights)

In [31]:
num_training_steps = GLOBAL.epochs * len(train_loader)
num_warmup_steps = int(0.05 * num_training_steps)

In [32]:
optimizer = optim.AdamW(params=model.parameters(), lr=GLOBAL.lr, weight_decay=GLOBAL.weight_decay)
criterion = Seq2SeqCrossEntropyLoss(ignore_index=tokenizer.PAD_code).to(GLOBAL.device)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
        num_training_steps=num_training_steps,
        num_warmup_steps=num_warmup_steps,
        last_epoch=num_epochs-1
)

In [33]:
class Trainer:
    
    def __init__(self,
        model : nn.Module,
        criterion : nn.Module,
        optimizer : optim.Optimizer,
        device : torch.device,
        weights_folder : str,
        num_epochs : int,
        lr_scheduler : Optional[optim.lr_scheduler.LRScheduler] = None,        
    ):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.lr_scheduler = lr_scheduler
        self.history = self.init_history()
        self.weights_folder = weights_folder
        self.num_epochs = num_epochs
        
    def init_history(self) -> dict:
        
        metrics = ['loss','epoch','time']
        
        history = {}
        
        for split in ['train','val']:
            
            history[split] = {}
                    
            for metric in metrics:
                history[split][metric] = []
                
        return history
        
    def train_on_batch(self, train_batch : tuple[torch.Tensor,torch.Tensor]) -> tuple[torch.Tensor,torch.Tensor]:
            
        ### put the data in the appropriate device
        x, y = train_batch
        x, y = x.to(self.device), y.to(self.device)
        
        y = y[:,:-1]
        
        ### forward pass
        y_hat = self.model(x, y)
        
        ### loss
        loss = self.criterion(y_hat, y)
        
        ### zero the gradients (they accumelate by default)
        self.optimizer.zero_grad()
        
        ### backward step
        loss.backward()
        
        ### update the weights
        self.optimizer.step()
        
        ### update learning rate
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        
        return y_hat,loss
    
    def train_step(self, data_loader : DataLoader, epoch_num : int) -> None:
        
        ### put the model in training mode
        self.model.train()
        
        t_object = tqdm(data_loader, total=len(data_loader))
        
        running_loss = 0.0
        
        tic = time.time()
        
        for train_batch in t_object:
            
            y_hat, loss = self.train_on_batch(train_batch)
            
            current_loss = loss.item()
            running_loss += current_loss
            
            t_object.set_description(f"Epoch {epoch_num+1} : loss = {current_loss}")
        
        toc = time.time()
        
        running_loss = running_loss / len(data_loader)
        epoch_time = toc - tic
        
        self.history["train"]["loss"].append(running_loss)
        self.history["train"]["epoch"].append(epoch_num)
        self.history["train"]["time"].append(epoch_time)
    
    def val_on_batch(self, train_batch : tuple[torch.Tensor,torch.Tensor]) -> tuple[torch.Tensor,torch.Tensor]:
        
        ### put the data in the appropriate device
        x, y = train_batch
        x, y = x.to(self.device), y.to(self.device)
        
        y_input = y[:,:-1]
        y_expected = y[:,1:]
        
        ### forward pass
        y_hat = self.model(x, y_input)
        
        ### loss
        loss = self.criterion(y_hat, y_expected)
        
        return y_hat,loss
    
    def val_step(self, data_loader : DataLoader, epoch_num : int) -> None:
        
        ### put the model in evaluation mode
        self.model.eval()
        
        with torch.inference_mode():
            
            t_object = tqdm(data_loader, total=len(data_loader))

            running_loss = 0.0

            tic = time.time()

            for val_batch in t_object:

                y_hat, loss = self.val_on_batch(val_batch)
                running_loss += loss.item()

            toc = time.time()

            running_loss = running_loss / len(data_loader)
            epoch_time = toc - tic

            self.history["val"]["loss"].append(running_loss)
            self.history["val"]["epoch"].append(epoch_num)
            self.history["val"]["time"].append(epoch_time)
        
    def train(self,
        train_loader : DataLoader,
        val_loader : DataLoader,
        epochs : int = 1
    ) -> None:
        
        for epoch in range(epochs):
            
            self.train_step(train_loader, epoch)
            self.val_step(val_loader, epoch)
            
            last_train_loss = self.history["train"]["loss"][-1]
            last_val_loss = self.history["val"]["loss"][-1]
            
            path = os.path.join(self.weights_folder, f"epoch_{self.num_epochs+epoch}.pt")
            torch.save(model.state_dict(), f=path)
            
            print(f"model saved to = {path}")
            print(f"Epoch = {epoch+1} : train_loss = {last_train_loss},train_loss = {last_val_loss}")

In [34]:
trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    device=GLOBAL.device,
    num_epochs=num_epochs,
    weights_folder=GLOBAL.weights_dir
)

In [None]:
if FLAGS.train:    
    trainer.train(train_loader, val_loader)

  0%|          | 0/36985 [00:00<?, ?it/s]

In [None]:
os.listdir(GLOBAL.weights_dir)

### Inference

In [None]:
def top_k_top_p_filtering(
    logits: torch.Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> torch.Tensor:
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

In [None]:
def generate_caption(
    model : nn.Module,
    image : Image.Image,
    tokenizer : Tokenizer,
    device : torch.device,
    size : int,
    max_len : int = 50,
    top_k : int = 0,
    top_p : float = 1,
):
    
    ### preprocessing the image
    preprocessor = Compose([
        Resize((size,size)),
        ToTensor()
    ])
    
    x = preprocessor(image).unsqueeze(0).to(device)
        
    batch_preds = torch.ones(x.size(0), 1).fill_(tokenizer.BOS_code).long().to(device)
    confs = []
    
    if top_k != 0 or top_p != 1:
        sample = lambda preds: torch.softmax(preds, dim=-1).multinomial(num_samples=1).view(-1, 1)
    else:
        sample = lambda preds: torch.softmax(preds, dim=-1).argmax(dim=-1).view(-1, 1)
    
    ### prediction
    model.eval()
    
    with torch.no_grad():
        
        for i in tqdm(range(max_len)):
            
            encoder_out = model.encoder(x)

            length = batch_preds.size(1)
            padding = torch.ones(batch_preds.size(0), max_len-length).fill_(tokenizer.PAD_code).long().to(device)
            tgt = torch.cat([batch_preds, padding], dim=1)

            preds = model.decoder(encoder_out, tgt)
            preds = preds[:, length-1, :]

            preds = top_k_top_p_filtering(preds, top_k=top_k, top_p=top_p)

            preds = sample(preds)
            batch_preds = torch.cat([batch_preds, preds], dim=1)
        
    return batch_preds.cpu()
        

In [None]:
def postprocess_caption(batch_preds : torch.Tensor, tokenizer : Tokenizer, vocab : Vocab):
    
    batch_preds[:,-1] = tokenizer.EOS_code 
    
    EOS_idxs = (batch_preds == tokenizer.EOS_code).float().argmax(dim=-1)
    
    captions = []
    
    for i, EOS_idx in enumerate(EOS_idxs.tolist()):
        
        if EOS_idx == 0 or EOS_idx == 1:
            
            captions.append(None)
            continue
            
        caption = []
        
        for word in batch_preds[i][1:EOS_idx]:
            caption.append(vocab.get_word(word.item()-tokenizer.text_id_shift))
        captions.append(caption)
        
    return captions

In [None]:
def generate_k_captions(
    model : nn.Module,
    image : Image.Image,
    tokenizer : Tokenizer,
    device : torch.device,
    size : int,
    vocab : Vocab,
    max_len : int,
    top_k : int = 0,
    top_p : float = 1,
    k = 5
):
    
    captions = []
    
    for i in range(k):
    
        ids = generate_caption(model, img, tokenizer, GLOBAL.device, GLOBAL.img_size,GLOBAL.max_len-1, 20, 0.65)
        caption = postprocess_caption(ids, tokenizer, vocab)[0]
        caption[0] = caption[0].title()
        caption = [word for word in caption if word.lower() != vocab.unk_token]
        caption = ' '.join(caption)
        captions.append(caption)
        
    return captions

In [None]:
path = "/kaggle/input/coco-2017-dataset/coco2017/test2017/000000000001.jpg"
img = Image.open(path)

In [None]:
img

In [None]:
generate_k_captions(model, img, tokenizer, GLOBAL.device, GLOBAL.img_size, vocab, GLOBAL.max_len, 20, 0.65)