In [11]:
import os, requests, torch, transformers, warnings
from PIL import Image
from torchvision import transforms
from io import BytesIO


ckpt_name = 'aehrc/mimic-cxr-report-gen-single'

cache_dir = '/scratch/pawsey0864/anicolson/checkpoints'

os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir

encoder_decoder = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True, cache_dir=cache_dir)
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(ckpt_name, cache_dir=cache_dir)
image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name, cache_dir=cache_dir)

test_transforms = transforms.Compose(
    [
        transforms.Resize(size=image_processor.size['shortest_edge']),
        transforms.CenterCrop(size=[
            image_processor.size['shortest_edge'],
            image_processor.size['shortest_edge'],
        ]
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=image_processor.image_mean,
            std=image_processor.image_std,
        ),
    ]
)

Downloading (…)lve/main/config.json: 100%|██████████| 78.6k/78.6k [00:00<00:00, 35.6MB/s]
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Downloading (…)/modelling_single.py: 100%|██████████| 16.4k/16.4k [00:00<00:00, 47.7MB/s]
Downloading pytorch_model.bin: 100%|██████████| 450M/450M [00:05<00:00, 84.4MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 90.0/90.0 [00:00<00:00, 961kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 433/433 [00:00<00:00, 4.99MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 1.33M/1.33M [00:01<00:00, 1.12MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 290/290 [00:00<00:00, 3.13MB/s]
Downloading (…)rocessor_config.json: 100%|██████████| 410/410 [00:00<00:00, 1.29MB/s]


In [10]:
url = 'https://www.stritch.luc.edu/lumen/meded/radio/curriculum/IPM/PCM/86a_labelled.jpg'
response = requests.get(url)
image_a = Image.open(BytesIO(response.content))
image_a = image_a.convert('RGB')

url = 'https://prod-images-static.radiopaedia.org/images/566180/d527ff6fc1482161c9225345c4ab42_big_gallery.jpg'
response = requests.get(url)
image_b = Image.open(BytesIO(response.content))
image_b = image_b.convert('RGB')

image_a = test_transforms(image_a)
image_b = test_transforms(image_b)

images = torch.stack([image_a, image_b], dim=0)
images.shape


In [9]:
outputs = encoder_decoder.generate(
    pixel_values=images,
    special_token_ids=[tokenizer.sep_token_id],
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256,
    num_beams=4,
)

findings, impression = encoder_decoder.split_and_decode_sections(
    outputs.sequences,
    [tokenizer.sep_token_id, tokenizer.eos_token_id],
    tokenizer,
)

In [4]:
for i, j in zip(findings, impression):
    print(f'Findings: {i}\nImpression: {j}\n')