In [21]:
import timm
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Load a pretrained Swin Transformer model
model_name = "swin_base_patch4_window7_224"  # Change model name for other variants
model = timm.create_model(model_name, pretrained=True)

model.set_input_size((1152,896))

# Set the model to evaluation mode
model.eval()



SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (layers): Sequential(
    (0): SwinTransformerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path1): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU(approximate='none')
            (

In [25]:
print(model.default_cfg['mean'])
print(model.default_cfg['std'])

# Print the model architecture (optional)
#print(model)

# Example: Prepare an image for inference
# Define input size and transformations
input_size = model.default_cfg['input_size'][-1]  # Model's expected input size
transform = transforms.Compose([
    transforms.Normalize(
        mean=model.default_cfg['mean'],
        std=model.default_cfg['std']
    )
])

# Load an image
image_path = "/home/alalbiol/Data/mamo/DDSM_png_16bit_1152x896/cancers/cancer_01/case0001/C_0001_1.RIGHT_CC.png"  # Path to your image
image = np.array(Image.open(image_path)).astype(np.float32)



print("Image max: ", image.max())
print("Image min: ", image.min())

image = torch.from_numpy(image)
image = image.unsqueeze(0)  # Add batch dimension
image = np.repeat(image, 3, axis=0)  # Repeat the image 3 times to create 3 channels
print(image.shape)



(0.485, 0.456, 0.406)
(0.229, 0.224, 0.225)
Image max:  65535.0
Image min:  0.0
torch.Size([3, 1152, 896])


In [23]:


# Preprocess the image
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

print("input_tensor shape: ", input_tensor.shape)

# Run inference
with torch.no_grad():
    output = model(input_tensor)

# Get predicted class
predicted_class = torch.argmax(output, dim=1).item()
print(f"Predicted class: {predicted_class}")


input_tensor shape:  torch.Size([1, 3, 1152, 896])
Predicted class: 111
