In [100]:
import torch
import torchvista

import importlib

import sys
sys.path.append("..")
import config as cfg

from model.CPTR_upd import CPTR
from tokenizer.tokenizer import TokenizerHF, ByteLevelBPE

from dataset.loader import DatasetLoader

In [101]:
importlib.reload(cfg)

<module 'config' from '/home/nad/studies/Transformer-Image-Captioning-IIW/visualization/../config.py'>

In [102]:
device = "cpu"

model_folder = cfg.CONFIG_ROOT / "results/config_20260123-020022"
config = cfg.import_config(model_folder / 'config.json')
model_path = model_folder / 'cptr_model.pth'

In [103]:
batch_size_train = config["BATCH_SIZE_TRAIN"]
batch_size_test = config["BATCH_SIZE_TEST"]

H = config["IMG_HEIGHT"]
W = config["IMG_WIDTH"]
P = config["PATCH_SIZE"]
D_IMG = config["IMG_EMBEDDING_DIM"]

# The data will get truncated/padded to this length AFTER tokenization
L = config["MAX_TEXT_SEQUENCE_LENGTH"]
D_TEXT = config["TEXT_EMBEDDING_DIM"]
DROPOUT_DEC = config["DECODER_DROPOUT_PROB"]
RANDOM_SEED = config["RANDOM_SEED"]

In [104]:
special_tokens = [
    cfg.SpecialTokens.PAD,
    cfg.SpecialTokens.BOS,
    cfg.SpecialTokens.EOS
]

if config["TOKENIZER_TYPE"] == cfg.TokenizerType.HF:
    tokenizer = TokenizerHF()
else:
    tokenizer = ByteLevelBPE(special_tokens=special_tokens)
    tokenizer.load(
        folder=cfg.TOKENIZER_DATA_PATH,
        filename_prefix=config["TOKENIZER_FILENAME_PREFIX"]
    )

pad_idx = tokenizer.get_padding_token_id()
vocab_size = tokenizer.get_vocab_size()


In [105]:
data_loader = DatasetLoader(dataset_type=config["DATASET"],
                            img_height=H,
                            img_width=W,
                            batch_size_train=batch_size_train, 
                            batch_size_test=batch_size_test,
                            split_ratio=config["SPLIT_RATIO"],
                            shuffle_test=True,
                            seed=RANDOM_SEED)
data_loader.load_data()

test_dataloader = data_loader.get_test_dataloader()

batch = next(iter(test_dataloader))
img_tensor = batch['pixel_values'][0].unsqueeze(0).to(device)

bos_token=tokenizer.get_vocab()[cfg.SpecialTokens.BOS.value]

tokens = torch.tensor(data=[[bos_token]], requires_grad=False).to(device)
attn_mask = torch.triu(torch.ones((1, 1), device=device, requires_grad=False), diagonal=1).bool()


Loading Flickr30k dataset...
DatasetDict({
    test: Dataset({
        features: ['image', 'description'],
        num_rows: 31014
    })
})


In [106]:
model = CPTR(
    num_patches=config["NUM_PATCHES"],
    encoder_arch=config["ENCODER_ARCH"],
    encoding_strategy=config["VIT_ENCODING_STRATEGY"],
    use_embedding_projection=config["USE_PROJECTION_LAYER"],
    img_emb_use_conv=config["USE_CONV_IMG_EMBEDDING"],
    img_emb_dim=config["IMG_EMBEDDING_DIM"],
    patch_size=config["PATCH_SIZE"],
    text_emb_dim=config["TEXT_EMBEDDING_DIM"],
    d_model=config["EMBEDDING_DIM"],
    max_text_seq_len=config["MAX_TEXT_SEQUENCE_LENGTH"],
    vocab_size=vocab_size,
    pad_idx=pad_idx,
    channels=config["NUM_INPUT_CHANNELS"],
    num_encoder_blocks=config["ENCODER_NUM_BLOCKS"],
    num_encoder_heads=config["ENCODER_NUM_HEADS"],
    encoder_hidden_dim=config["ENCODER_HIDDEN_DIM"],
    encoder_dropout_prob=config["ENCODER_DROPOUT_PROB"],
    num_decoder_blocks=config["DECODER_NUM_BLOCKS"],
    num_decoder_heads=config["DECODER_NUM_HEADS"],
    decoder_hidden_dim=config["DECODER_HIDDEN_DIM"],
    decoder_dropout_prob=config["DECODER_DROPOUT_PROB"],
    bias=config["USE_BIAS"],
    use_weight_tying=config["USE_WEIGHT_TYING"],
    sublayer_dropout=config["SUBLAYER_DROPOUT"],
    verbose=False
)

model.eval()

Initialized CNN + CPTR Encoder


CPTR(
  (encoder): CNN_CPTREncoder(
    (patcher): CNNEncoder(
      (backbone): Sequential(
        (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (4): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_

In [107]:
torchvista.trace_model(
    model=model,
    inputs=(img_tensor, tokens, attn_mask),
    export_format='html',
    export_path='architecture_graphs/{}.html'.format(config['ENCODER_ARCH'])
)