In [9]:
import torch  
from torchvision.transforms import v2
import transformers
import datasets
from torch.utils.data import Dataset, DataLoader

mbatch_size = 10

tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24')
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True)
transforms = v2.Compose(
    [
        v2.PILToTensor(),
        v2.Grayscale(num_output_channels=3),
        v2.Resize(size=model.config.encoder.image_size, antialias=True),
        v2.CenterCrop(size=[model.config.encoder.image_size]*2),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std),
    ]
)

dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']

def transform_batch(batch):
    batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
    batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)  
    return batch

dataset = dataset.with_transform(transform_batch)
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
batch = next(iter(dataloader))

output_ids = model.generate(
    pixel_values=batch['images'],
    max_length=512,
    num_beams=4,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
)
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)

TypeError: MultiUniFormerWithProjectionHead.forward() got an unexpected keyword argument 'output_attentions'

In [10]:
model

CXRRGModel(
  (encoder): MultiUniFormerWithProjectionHead(
    (uniformer): UniFormer(
      (patch_embed1): PatchEmbed(
        (proj): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      )
      (patch_embed2): PatchEmbed(
        (proj): Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2))
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (patch_embed3): PatchEmbed(
        (proj): Conv2d(128, 320, kernel_size=(2, 2), stride=(2, 2))
        (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      )
      (patch_embed4): PatchEmbed(
        (proj): Conv2d(320, 512, kernel_size=(2, 2), stride=(2, 2))
        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks1): ModuleList(
        (0): CBlock(
          (pos_embed): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)


In [13]:
model.encoder.uniformer.forward()

<bound method UniFormer.forward of UniFormer(
  (patch_embed1): PatchEmbed(
    (proj): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (patch_embed2): PatchEmbed(
    (proj): Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (patch_embed3): PatchEmbed(
    (proj): Conv2d(128, 320, kernel_size=(2, 2), stride=(2, 2))
    (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  )
  (patch_embed4): PatchEmbed(
    (proj): Conv2d(320, 512, kernel_size=(2, 2), stride=(2, 2))
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks1): ModuleList(
    (0): CBlock(
      (pos_embed): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 

In [16]:
fake_image_batch = torch.randn(1, 3, 384, 384)  # (batch, channels, time, height, width)

In [18]:
output = model.encoder.uniformer(fake_image_batch)

In [20]:
output.shape

torch.Size([1, 512, 12, 12])