In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random

import os
import json
import time

from tqdm.auto import tqdm
from typing import Dict, List, Tuple

import torch

import config

from tokenizer import ByteLevelBPE

import importlib

from model.CPTR import CPTR
from model.helpers import *

from dataset.loader import DatasetLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
importlib.reload(config)
torch.cuda.empty_cache()

In [3]:
batch_size = config.BATCH_SIZE

H = config.IMG_HEIGHT
W = config.IMG_WIDTH
P = config.PATCH_SIZE
D_IMG = config.IMG_EMBEDDING_DIM

# The data will get truncated/padded to this length AFTER tokenization
L = config.MAX_TEXT_SEQUENCE_LENGTH
D_TEXT = config.TEXT_EMBEDDING_DIM
VOCAB_SIZE = config.TEXT_VOCAB_SIZE
DROPOUT_DEC = config.DECODER_DROPOUT_PROB

In [4]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42)

## Build Dataset

In [5]:
data_loader = DatasetLoader(dataset_type=config.DATASET, batch_size_train=batch_size, batch_size_test=1, shuffle_test=True)
data_loader.load_data()

train_dataloader = data_loader.get_train_dataloader()
test_dataloader = data_loader.get_test_dataloader()

Loading COCO dataset...


## Initialize Tokenizer

In [6]:
special_tokens = [config.SpecialTokens.PAD, config.SpecialTokens.BOS, config.SpecialTokens.EOS]
tokenizer = ByteLevelBPE(special_tokens=special_tokens)
tokenizer.load(folder=config.TOKENIZER_DATA_PATH, filename_prefix=config.TOKENIZER_FILENAME_PREFIX)
pad_idx = tokenizer.get_padding_token_id()

## Configure Model

In [7]:
transformer = CPTR(num_patches=(H//P)*(W//P),
                   img_emb_dim=D_IMG,
                   patch_size=P,
                   text_emb_dim=D_TEXT,
                   d_model=D_TEXT,
                   max_text_seq_len=L,
                   vocab_size=VOCAB_SIZE,
                   pad_idx=pad_idx,
                   verbose=False).to(device)


## Training code

In [8]:
# TODO: debug training and testing code
# TODO: train to overfit on new dataset
def train_step(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer,
               device: torch.device) -> Tuple[float, float]:
    """Trains a PyTorch model for a single epoch.

    Turns a target PyTorch model to training mode and then
    runs through all of the required training steps (forward
    pass, loss calculation, optimizer step).

    Args:
    model: A PyTorch model to be trained.
    dataloader: A DataLoader instance for the model to be trained on.
    loss_fn: A PyTorch loss function to minimize.
    optimizer: A PyTorch optimizer to help minimize the loss function.
    device: A target device to compute on (e.g. "cuda" or "cpu").

    Returns:
    A tuple of training loss and training accuracy metrics.
    In the form (train_loss, train_accuracy). For example:

    (0.1112, 0.8743)
    """
    # Put model in train mode
    model.train()
    optimizer.zero_grad()

    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0
    
    pad_idx=tokenizer.get_padding_token_id()

    # Loop through data loader data batches
    for batch in dataloader:
        images = batch["pixel_values"].to(device)
        texts = batch["description"]

        tokens = torch.tensor(
            [tokenizer.encode(t, max_seq_length=L) for t in texts],
            device=device
        )

        targets = tokens[:, 1:]
        decoder_inputs = tokens[:, :-1]
        # print('Decoder inputs:', tokenizer.decode(decoder_inputs[0].cpu().numpy()))
        T = decoder_inputs.size(1)
        # mask = get_attention_mask(decoder_inputs=decoder_inputs, pad_idx=tokenizer.get_padding_token_id(), seq_len=T, device=device)  # [1, 1, max_seq_len, max_seq_len]
        # mask = None # debug
        attention_mask = get_causal_mask(T, device=device)
        padding_mask = get_padding_mask(decoder_inputs, pad_idx, device=device)

        logits = model(images=images, text_tokens=decoder_inputs, attn_mask=attention_mask, pad_mask=padding_mask)

        B, T, V = logits.shape
        loss = loss_fn(
            logits.reshape(-1, V),
            targets.reshape(-1)
        )
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            preds = logits.argmax(dim=-1)
            pad_idx = tokenizer.get_padding_token_id()
            non_pad = targets != pad_idx
            acc = ((preds == targets) & non_pad).sum() / non_pad.sum()
            train_acc += acc.item()
        
        # print produced tokens for debugging
        # decoded_preds = [tokenizer.decode(preds[i].cpu().numpy()) for i in range(preds.size(0))]
        # decoded_targets = [tokenizer.decode(targets[i].cpu().numpy()) for i in range(targets.size(0))]
        # for dp, dt in zip(decoded_preds, decoded_targets):
        #     print(f"Predicted: {dp} | Target: {dt}")

    # Adjust metrics to get average loss and accuracy per batch 
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

def test_step(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module,
              device: torch.device) -> Tuple[float, float]:
    """Tests a PyTorch model for a single epoch.

    Turns a target PyTorch model to "eval" mode and then performs
    a forward pass on a testing dataset.

    Args:
    model: A PyTorch model to be tested.
    dataloader: A DataLoader instance for the model to be tested on.
    loss_fn: A PyTorch loss function to calculate loss on the test data.
    device: A target device to compute on (e.g. "cuda" or "cpu").

    Returns:
    A tuple of testing loss and testing accuracy metrics.
    In the form (test_loss, test_accuracy). For example:

    (0.0223, 0.8985)
    """
    # Put model in eval mode
    model.eval()

    # Setup test loss and test accuracy values
    test_loss, test_acc = 0, 0
    
    pad_idx=tokenizer.get_padding_token_id()

    # Turn on inference context manager
    with torch.inference_mode():
        # Loop through DataLoader batches
        for batch in dataloader:
          images = batch["pixel_values"].to(device)
          texts = batch["description"]

          tokens = torch.tensor(
              [tokenizer.encode(t, max_seq_length=L) for t in texts],
              device=device
          )

          targets = tokens[:, 1:]
          decoder_inputs = tokens[:, :-1]
          T = decoder_inputs.size(1)
        #   mask = get_attention_mask(decoder_inputs=decoder_inputs, pad_idx=tokenizer.get_padding_token_id(), seq_len=T, device=device)  # [1, 1, max_seq_len, max_seq_len]
        #   mask = None # debug
          attention_mask = get_causal_mask(T, device=device)
          padding_mask = get_padding_mask(decoder_inputs, pad_idx, device=device)
          logits = model(images=images, text_tokens=decoder_inputs, attn_mask=attention_mask, pad_mask=padding_mask)

          B, T, V = logits.shape
          loss = loss_fn(
              logits.reshape(-1, V),
              targets.reshape(-1)
          )
          test_loss += loss.item()

          preds = logits.argmax(dim=-1)
          pad_idx = tokenizer.get_padding_token_id()
          non_pad = targets != pad_idx
          acc = ((preds == targets) & non_pad).sum() / non_pad.sum()
          test_acc += acc.item()
          
          # print produced tokens for debugging
          decoded_preds = [tokenizer.decode(preds[i].cpu().numpy()) for i in range(preds.size(0))]
          decoded_targets = [tokenizer.decode(targets[i].cpu().numpy()) for i in range(targets.size(0))]
          for dp, dt in zip(decoded_preds, decoded_targets):
              print(f"Predicted: {dp} | Target: {dt}")

    # Adjust metrics to get average loss and accuracy per batch 
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc

def train(model: torch.nn.Module, 
          train_dataloader: torch.utils.data.DataLoader, 
          test_dataloader: torch.utils.data.DataLoader, 
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device) -> Dict[str, List]:
    """Trains and tests a PyTorch model.

    Passes a target PyTorch models through train_step() and test_step()
    functions for a number of epochs, training and testing the model
    in the same epoch loop.

    Calculates, prints and stores evaluation metrics throughout.

    Args:
    model: A PyTorch model to be trained and tested.
    train_dataloader: A DataLoader instance for the model to be trained on.
    test_dataloader: A DataLoader instance for the model to be tested on.
    optimizer: A PyTorch optimizer to help minimize the loss function.
    loss_fn: A PyTorch loss function to calculate loss on both datasets.
    epochs: An integer indicating how many epochs to train for.
    device: A target device to compute on (e.g. "cuda" or "cpu").

    Returns:
    A dictionary of training and testing loss as well as training and
    testing accuracy metrics. Each metric has a value in a list for 
    each epoch.
    In the form: {train_loss: [...],
              train_acc: [...],
              test_loss: [...],
              test_acc: [...]} 
    For example if training for epochs=2: 
             {train_loss: [2.0616, 1.0537],
              train_acc: [0.3945, 0.3945],
              test_loss: [1.2641, 1.5706],
              test_acc: [0.3400, 0.2973]} 
    """
    # Create empty results dictionary
    results = {"train_loss": [],
               "train_acc": [],
               "test_loss": [],
               "test_acc": []
    }
    
    # Make sure model on target device
    model.to(device)

    # Loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                          dataloader=train_dataloader,
                                          loss_fn=loss_fn,
                                          optimizer=optimizer,
                                          device=device)
        test_loss, test_acc = test_step(model=model,
          dataloader=test_dataloader,
          loss_fn=loss_fn,
          device=device)

        # Print out what's happening
        print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
        )

        # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

    # Return the filled results at the end of the epochs
    return results

In [None]:
optimizer = torch.optim.Adam(params=transformer.parameters(),
                             lr=config.LR,
                             betas=(0.9, 0.999), # default values but also mentioned in ViT paper section 4.1 (Training & Fine-tuning)
                             weight_decay=config.WEIGHT_DECAY)

# Setup the loss function for multi-class classification
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=config.LABEL_SMOOTHING) # ignore the padding token for loss calculation

x = enumerate(train_dataloader)
print(next(x))
# Train the model and save the training results to a dictionary
results = train(model=transformer,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=config.NUM_EPOCHS,
                       device=device)

(0, {'pixel_values': tensor([[[[0.7098, 0.7137, 0.7294,  ..., 0.8510, 0.8588, 0.8549],
          [0.7098, 0.7137, 0.7294,  ..., 0.8549, 0.8588, 0.8549],
          [0.7059, 0.7098, 0.7255,  ..., 0.8549, 0.8588, 0.8549],
          ...,
          [0.1529, 0.0549, 0.0039,  ..., 0.6510, 0.6510, 0.6353],
          [0.1882, 0.0706, 0.0078,  ..., 0.6471, 0.6549, 0.6471],
          [0.2157, 0.0902, 0.0118,  ..., 0.6431, 0.6431, 0.6510]],

         [[0.6941, 0.6980, 0.7216,  ..., 0.8667, 0.8745, 0.8706],
          [0.6941, 0.6980, 0.7216,  ..., 0.8706, 0.8745, 0.8706],
          [0.6902, 0.6941, 0.7176,  ..., 0.8706, 0.8745, 0.8706],
          ...,
          [0.1725, 0.0510, 0.0078,  ..., 0.5059, 0.5059, 0.4902],
          [0.2000, 0.0784, 0.0078,  ..., 0.4980, 0.5059, 0.4980],
          [0.2235, 0.1020, 0.0157,  ..., 0.4941, 0.4941, 0.5020]],

         [[0.6471, 0.6510, 0.6706,  ..., 0.8706, 0.8784, 0.8745],
          [0.6471, 0.6510, 0.6706,  ..., 0.8745, 0.8784, 0.8745],
          [0.6431, 0.

  0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
# export training results to results directory
results_dir = 'results'
os.makedirs(results_dir, exist_ok=True)
results_path = os.path.join(results_dir, 'training_results.json')
with open(results_path, 'w') as f:
    json.dump(results, f)

time_string = time.strftime("%Y%m%d-%H%M%S")
# export model weights
model_path = os.path.join(results_dir, f'cptr_model_{time_string}.pth')
torch.save(transformer.state_dict(), model_path)

## Test caption generation

In [None]:
@torch.no_grad()
def generate_caption(model, image, tokenizer, max_length=config.MAX_TEXT_SEQUENCE_LENGTH, device='cuda'):
    model.eval()
    with torch.inference_mode():
        # 1. Encode the image once (K, V)
        # image shape: [1, 3, H, W]
        img_features = model.forward_images(image.to(device))

        # 2. Start with the [BOS] token
    # current_tokens shape: [1, 1]
    bos_idx = tokenizer.token_to_id(config.SpecialTokens.BOS)
    eos_idx = tokenizer.token_to_id(config.SpecialTokens.EOS)
    current_tokens = torch.tensor([[bos_idx]], device=device)
    
    generated_ids = [bos_idx]

    with torch.inference_mode():
        for i in range(max_length):
            # 3. Generate the causal mask for the current sequence length
            # Standard triangular mask: (L, L)
            sz = current_tokens.size(1)
            # mask = get_attention_mask(decoder_inputs=current_tokens, pad_idx=pad_idx, seq_len=sz, device=device)
            attn_mask = get_causal_mask(sz, device=device)
            padding_mask = get_padding_mask(current_tokens, pad_idx, device=device)

            # 4. Get predictions (Q)
            # We only care about the very last token predicted
            print('Current tokens:', current_tokens)
            decoder_output = model.forward_text(current_tokens, img_features, attn_mask=attn_mask, pad_mask=padding_mask)
            token_logits = model.linear(decoder_output)  # shape: [1, T, V]
            print('Token logits shape: ', token_logits.shape)
            next_token_logits = token_logits[:, -1, :]
            # next_token_logits = model.linear(logits[:, -1, :])
            # print('Token Logits shape:', token_logits.shape)
            preds = next_token_logits.argmax(dim=-1).item()
            print('Predicted token IDs:', preds)
            decoded_preds = tokenizer.decode([preds])
            print(f'Predicted: {decoded_preds}')
            next_token = preds

            # 5. Greedy selection: take the most likely token
            # next_token = torch.argmax(next_token_logits, dim=-1).item()
            
            # print 10 most likely tokens at each step
            topk = 10
            topk_probs, topk_indices = torch.topk(torch.softmax(next_token_logits, dim=-1), k=topk, dim=-1)
            print(f'Step {i+1}:')
            for rank in range(topk):
                token_id = topk_indices[0, rank].item()
                token_str = tokenizer.decode([token_id])
                prob = topk_probs[0, rank].item()
                print(f'  Rank {rank+1}: Token ID {token_id} ("{token_str}") with probability {prob:.4f}')
            
            generated_ids.append(next_token)

            # 6. Stop if [EOS] is reached
            if next_token == eos_idx:
                break

            # 7. Append and continue
            # debug: append random token
            next_token = random.randint(0, 1002)
            current_tokens = torch.tensor([[next_token]], device=device) #torch.cat([current_tokens, torch.tensor([[next_token]], device=device)], dim=1)

    # 8. Convert IDs back to words
    caption = tokenizer.decode(generated_ids)
    return caption

In [None]:
batch = next(iter(test_dataloader))
image = batch["pixel_values"][0]

generate_caption(transformer, image.unsqueeze(0), tokenizer, max_length=L, device=device)