In [2]:
import os, requests, torch, transformers, warnings
from PIL import Image
from torchvision import transforms
from io import BytesIO


ckpt_name = 'aehrc/mimic-cxr-report-gen-single'

cache_dir = '/scratch/pawsey0864/anicolson/checkpoints'

os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir

encoder_decoder = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True, cache_dir=cache_dir)
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(ckpt_name, cache_dir=cache_dir)
image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name, cache_dir=cache_dir)

test_transforms = transforms.Compose(
    [
        transforms.Resize(size=image_processor.size['shortest_edge']),
        transforms.CenterCrop(size=[
            image_processor.size['shortest_edge'],
            image_processor.size['shortest_edge'],
        ]
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=image_processor.image_mean,
            std=image_processor.image_std,
        ),
    ]
)

  from .autonotebook import tqdm as notebook_tqdm
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


In [3]:
url = 'https://www.stritch.luc.edu/lumen/meded/radio/curriculum/IPM/PCM/86a_labelled.jpg'
response = requests.get(url)
image_a = Image.open(BytesIO(response.content))
image_a = image_a.convert('RGB')

url = 'https://prod-images-static.radiopaedia.org/images/566180/d527ff6fc1482161c9225345c4ab42_big_gallery.jpg'
response = requests.get(url)
image_b = Image.open(BytesIO(response.content))
image_b = image_b.convert('RGB')

image_a = test_transforms(image_a)
image_b = test_transforms(image_b)

images = torch.stack([image_a, image_b], dim=0)
images.shape


torch.Size([2, 3, 384, 384])

In [4]:
outputs = encoder_decoder.generate(
    pixel_values=images,
    special_token_ids=[tokenizer.sep_token_id],
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256,
    num_beams=4,
)

findings, impression = encoder_decoder.split_and_decode_sections(
    outputs.sequences,
    [tokenizer.sep_token_id, tokenizer.eos_token_id],
    tokenizer,
)

In [5]:
for i, j in zip(findings, impression):
    print(f'Findings: {i}\nImpression: {j}\n')

Findings: There is a moderate to large left pleural effusion with adjacent atelectasis. The right lung is clear. No pneumothorax identified. The size and appearance of the cardiomediastinal silhouette is unchanged.
Impression: Moderate to large left pleural effusion with adjacent atelectasis.

Findings: The patient has had prior median sternotomy with aortic valve replacement. Sternotomy wires are intact and aligned. Sequential radiographs show advancement of a feeding tube initially positioned in the mid esophagus, through the gastroesophageal junction, and into the stomach. Moderate cardiomegaly despite the projection is unchanged. Mediastinal contours are stable. There is no pneumothorax. Mild pulmonary edema is unchanged.
Impression: Feeding tube terminates in stomach. Stable mild pulmonary edema. Stable moderate cardiomegaly.



In [7]:
encoder_decoder.device

device(type='cpu')