# Inference instructions

1. Download the pretrained model checkpoints and their associated files from: [Checkpoints](https://drive.google.com/drive/folders/1VgPAr7kvh7lHdAZ1329yMct4MIAR355f?usp=sharing).

2. Create directory `checkpoints/` and move the downloaded directories inside `checkpoints/`. Your final structure should look like this:

   ```
   checkpoints/
   |-- covctr/
   |   |-- model_snapshot.pth
   |   |-- id2token.json
   |   |-- token2id.json
   |-- mmretinal/
   |   |-- model_snapshot.pth
   |   |-- id2token.json
   |   |-- token2id.json
   |-- pgross/
   |   |-- model_snapshot.pth
   |   |-- id2token.json
   |   |-- token2id.json
   |-- roco/
       |-- model_snapshot.pth
       |-- id2token.json
       |-- token2id.json
   ```

3. Update the `CFG` dataclass provided below with the appropriate values (given in `scripts/`) for the dataset you're running inference on. Example configuration for COVCTR is shown in this notebook.

# Inference code

In [19]:
from dataclasses import dataclass
import torch
from architecture.models import CaptionModel
from utils.tokenizers import CustomTokenizer
from utils.dataloaders import CaptionModelDataLoaders

In [8]:
@dataclass
class CFG:
    dataset_name: str = "covctr"
    df_train_filepath: str = f"data/{dataset_name}/train.csv"
    df_val_filepath: str = f"data/{dataset_name}/val.csv"
    df_test_filepath: str = f"data/{dataset_name}/test.csv"

    token2id_filepath: str = f"checkpoints/{dataset_name}/token2id.json"
    id2token_filepath: str = f"checkpoints/{dataset_name}/id2token.json"
    min_frequency: int = 3

    d_v: int = 2048
    num_heads: int = 8
    num_layers: int = 2
    d_model: int = 512
    d_latent: int = 768
    qk_nope_dim: int = 48
    qk_rope_dim: int = 48
    d_ff: int = 2048
    act_fn: str = "silu"
    attention_dropout: float = 0.12
    dropout: float = 0.1
    num_experts: int = 8
    k: int = 2
    text_seq_len: int = 80
    beam_width: int = 3
    batch_size: int = 1

    model_snapshot_filepath: str = f"checkpoints/{dataset_name}/model_snapshot.pth"

In [16]:
cfg = CFG()
tokenizer = CustomTokenizer(cfg)
test_loader = CaptionModelDataLoaders.get_test_dataloader(cfg, tokenizer)
device = torch.device("cuda")
model = CaptionModel(cfg, tokenizer.get_vocab_size())
snapshot = torch.load(cfg.model_snapshot_filepath, map_location=device)
model.load_state_dict(snapshot["model"])
model.to(device)
model.eval()

sos_id = tokenizer.get_id_by_token("<sos>")
eos_id = tokenizer.get_id_by_token("<eos>")
pad_id = tokenizer.get_id_by_token("<pad>")
generated_report = ""
actual_report = ""

batch_iter = iter(test_loader)

Vocabulary loaded with total size of 229


In [18]:
with torch.no_grad():
    batch = next(batch_iter)
    image = batch["image"].to(device)  # [b, c, h, w]
    label_ids = batch["label_ids"].to(device)  # [b, text_seq_len]
    gen_report_ids = model.beam_search(
        image,
        sos_id,
        eos_id,
        pad_id,
        cfg.beam_width,
        cfg.text_seq_len,
    )  # [b, text_seq_len]
    generated_report = tokenizer.decode_by_ids(gen_report_ids[0].detach().cpu().numpy().tolist())
    actual_report = tokenizer.decode_by_ids(label_ids[0].detach().cpu().numpy().tolist())
        
    print(f"Generated report: {generated_report}")
    print(f"Actual report: {actual_report}")

Generated report: the thorax was symmetrical , the mediastinal heart shadow was centered , no enlarged lymph nodes were seen in the mediastinum , the texture of both lungs was enhanced , and a ground glass shadow was seen under the pleura of the lower lobe of the right lung with blurred margins , the bronchi of the lobe were clear , and no abnormal density shadow was seen in the bilateral thoracic cavities .
Actual report: the thorax was symmetrical , the mediastinal heart shadow was centered , no enlarged lymph nodes were seen in the mediastinum , the texture of both lungs was enhanced , a ground glass shadow was seen in the lower lobe of the right lung with blurred margins , and no abnormal density shadow was seen in the bilateral thoracic cavities .
