In [2]:
import torch
import config as cfg
from model.CPTR_upd import CPTR
from tokenizer.tokenizer import TokenizerHF, ByteLevelBPE
from torchvista import draw_graph

AttributeError: module 'torch.utils._pytree' has no attribute 'register_pytree_node'

In [None]:
config = cfg.import_config("config.json")

In [None]:
special_tokens = [
    cfg.SpecialTokens.PAD,
    cfg.SpecialTokens.BOS,
    cfg.SpecialTokens.EOS
]

if config["TOKENIZER_TYPE"] == cfg.TokenizerType.HF:
    tokenizer = TokenizerHF()
else:
    tokenizer = ByteLevelBPE(special_tokens=special_tokens)
    tokenizer.load(
        folder=config["TOKENIZER_DATA_PATH"],
        filename_prefix=config["TOKENIZER_FILENAME_PREFIX"]
    )

pad_idx = tokenizer.get_padding_token_id()
vocab_size = tokenizer.get_vocab_size()


In [None]:
model = CPTR(
    num_patches=config["NUM_PATCHES"],
    encoder_arch=config["ENCODER_ARCH"],
    encoding_strategy=config["VIT_ENCODING_STRATEGY"],
    use_embedding_projection=config["USE_PROJECTION_LAYER"],
    img_emb_use_conv=config["USE_CONV_IMG_EMBEDDING"],
    img_emb_dim=config["IMG_EMBEDDING_DIM"],
    patch_size=config["PATCH_SIZE"],
    text_emb_dim=config["TEXT_EMBEDDING_DIM"],
    d_model=config["EMBEDDING_DIM"],
    max_text_seq_len=config["MAX_TEXT_SEQUENCE_LENGTH"],
    vocab_size=vocab_size,
    pad_idx=pad_idx,
    channels=config["NUM_INPUT_CHANNELS"],
    num_encoder_blocks=config["ENCODER_NUM_BLOCKS"],
    num_encoder_heads=config["ENCODER_NUM_HEADS"],
    encoder_hidden_dim=config["ENCODER_HIDDEN_DIM"],
    encoder_dropout_prob=config["ENCODER_DROPOUT_PROB"],
    num_decoder_blocks=config["DECODER_NUM_BLOCKS"],
    num_decoder_heads=config["DECODER_NUM_HEADS"],
    decoder_hidden_dim=config["DECODER_HIDDEN_DIM"],
    decoder_dropout_prob=config["DECODER_DROPOUT_PROB"],
    bias=config["USE_BIAS"],
    use_weight_tying=config["USE_WEIGHT_TYING"],
    sublayer_dropout=config["SUBLAYER_DROPOUT"],
    verbose=False
)

model.eval()

In [None]:
dummy_img = torch.randn(
    1,
    3,
    config["IMG_HEIGHT"],
    config["IMG_WIDTH"]
)

dummy_txt = torch.randint(
    0,
    vocab_size,
    (1, min(10, config["MAX_TEXT_SEQUENCE_LENGTH"]))
)

In [None]:
draw_graph(
    model,
    input_data=(dummy_img, dummy_txt),
    device="cpu",
    graph_name="CPTR_architecture"
)