In [1]:
import os, torch, transformers
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from io import BytesIO
from torchvision.utils import make_grid

In [2]:
ckpt_name = 'aehrc/medicap'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder_decoder = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True).to(device)
encoder_decoder.eval()
image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name)

test_transforms = transforms.Compose(
    [
        transforms.Resize(size=image_processor.size['shortest_edge']),
        transforms.CenterCrop(size=[
            image_processor.size['shortest_edge'],
            image_processor.size['shortest_edge'],
        ]
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=image_processor.image_mean,
            std=image_processor.image_std,
        ),
    ]
)

tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(ckpt_name)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [3]:
images = [
    '../Thorax-231030-2310300175_Series_1001_0000-64.jpg'
]

for i, _ in enumerate(images):
    images[i] = Image.open(images[i])
    images[i] = images[i].convert('RGB')
    images[i] = test_transforms(images[i])

images = torch.stack(images, dim=0)
images.shape

torch.Size([1, 3, 384, 384])

In [4]:
outputs = encoder_decoder.generate(
    pixel_values=images.to(device),
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256,
    num_beams=4,
)
outputs.sequences

tensor([[50257, 45170,  1395,    12,  2433,  4478,   257,  1364,  3339,  1523,
           914,  4241,    13, 50256]])

In [10]:
captions = [tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs.sequences]
print(captions)

['Chest X-ray showing a left pleural effusion.']
