In [15]:
import torch
import os, requests, torch, transformers, warnings
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from io import BytesIO
from torchvision.utils import make_grid
from monai import transforms as monai_transforms
import numpy as np

In [16]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [17]:
ckpt_name = 'aehrc/cxrmate-single-tf'

encoder_decoder = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True).to(device)
encoder_decoder.eval()
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(ckpt_name)
image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name)

test_transforms = transforms.Compose(
    [
        transforms.Resize(size=image_processor.size['shortest_edge']),
        transforms.CenterCrop(size=[
            image_processor.size['shortest_edge'],
            image_processor.size['shortest_edge'],
        ]
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=image_processor.image_mean,
            std=image_processor.image_std,
        ),
    ]
)



Some weights of SingleCXREncoderDecoderModel were not initialized from the model checkpoint at aehrc/cxrmate-single-tf and are newly initialized: ['decoder.bert.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ConvNextImageProcessor instead.


In [18]:
pil_image = Image.open('CXR1_1_IM-0001-3001.png')
pil_image = pil_image.convert('RGB')
image = test_transforms(pil_image)
image = torch.stack([image], dim=0)
image.min(), image.max(), image.shape, pil_image.getextrema()

(tensor(-2.0665),
 tensor(2.1694),
 torch.Size([1, 3, 384, 384]),
 ((1, 235), (1, 235), (1, 235)))

In [19]:
outputs = encoder_decoder.generate(
    pixel_values=image.to(device),
    special_token_ids=[tokenizer.sep_token_id],
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256,
    num_beams=4,
)
outputs.sequences



tensor([[  1, 966, 306, 120, 237,  23, 139, 356, 148, 386, 349, 150, 237,  23,
         139, 230, 705, 120, 237,  23, 661, 150, 303,  23, 198, 183, 171, 214,
         211, 120, 269,  23, 213, 150, 163, 271, 542, 666,  23,   3, 159, 271,
         397, 578,  23,   2]], device='mps:0')

In [20]:
# Findings and impression sections (exclude previous impression section):
findings, impression = encoder_decoder.split_and_decode_sections(
    outputs.sequences,
    [tokenizer.sep_token_id, tokenizer.eos_token_id],
    tokenizer,
)
for i, j in zip(findings, impression):
    print(f'Findings: {i}\nImpression: {j}\n')

Findings: Heart size is normal. The mediastinal and hilar contours are normal. The pulmonary vasculature is normal. Lungs are clear. No pleural effusion or pneumothorax is seen. There are no acute osseous abnormalities.
Impression: No acute cardiopulmonary abnormality.



In [27]:

# image, tags = monai_transforms.LoadImage(image_only=False)('CXR1_1_IM-0001-3001.png')

import pydicom         

dataset = pydicom.dcmread('1_IM-0001-3001.dcm')
# dataset.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian  # or whatever is the correct transfer syntax for the file
# dataset = dcmread('e084de3b-be89b11e-20fe3f9f-9c8d8dfe-4cfd202c.dcm', force=True)

image = transforms.ToTensor()(dataset.pixel_array.astype(np.float32))

quantisation_error = True

post_dicom_transforms = transforms.Compose(
    [
        transforms.Resize(size=image_processor.size['shortest_edge'], antialias=True),
        transforms.CenterCrop(size=[
            image_processor.size['shortest_edge'],
            image_processor.size['shortest_edge'],
        ]
        ),
        transforms.Normalize(
            mean=image_processor.image_mean,
            std=image_processor.image_std,
        ),
    ]
)

min_intensity, max_intensity = image.min(), image.max()

# See 8.1.1 for more details: https://dicom.nema.org/medical/dicom/current/output/html/part05.html#sect_8.1.1
if ((0x28,0x106) in dataset) and ((0x28,0x107) in dataset):
    min_intensity, max_intensity = dataset[0x28,0x106].value, dataset[0x28,0x107].value

image = (image - min_intensity) / (max_intensity - min_intensity)
# image = image.permute([2, 0, 1])

if quantisation_error:
    image = (255*image).to(torch.uint8).to(torch.float32)/255

if image.shape[0] == 1:
    image = image.repeat([3, 1, 1])
image = post_dicom_transforms(image)
image = torch.stack([image], dim=0)
image.shape

torch.Size([1, 3, 384, 384])

In [28]:
outputs = encoder_decoder.generate(
    pixel_values=image.to(device),
    special_token_ids=[tokenizer.sep_token_id],
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256,
    num_beams=4,
)
outputs.sequences

tensor([[   1,  668,  148,  369,  546,  132,  115,  250,  854,   23,  213,  120,
          163,  322,  284,   21,  171,   21,  214,  211,   23,  139,  449,  278,
          120,  237,   23, 1427,  542,  569,  150,  721,   23,  198,  801,  541,
          780,  115,  182,  615,  120,  269,   23,    3,  159,  271,  916,  389,
           23,    2]], device='mps:0')

In [29]:
# Findings and impression sections (exclude previous impression section):
findings, impression = encoder_decoder.split_and_decode_sections(
    outputs.sequences,
    [tokenizer.sep_token_id, tokenizer.eos_token_id],
    tokenizer,
)
for i, j in zip(findings, impression):
    print(f'Findings: {i}\nImpression: {j}\n')

Findings: PA and lateral views of the chest provided. There is no focal consolidation, effusion, or pneumothorax. The cardiomediastinal silhouette is normal. Imaged osseous structures are intact. No free air below the right hemidiaphragm is seen.
Impression: No acute intrathoracic process.

