In [2]:
!pip install torch torchvision timm




# EVALUATIONS

In [1]:
from timm.layers.helpers import to_2tuple
import timm
import torch.nn as nn

class ConvStem(nn.Module):
  """Custom Patch Embed Layer.

  Adapted from https://github.com/Xiyue-Wang/TransPath/blob/main/ctran.py#L6-L44
  """

  def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, **kwargs):
    super().__init__()

    # Check input constraints
    assert patch_size == 4, "Patch size must be 4"
    assert embed_dim % 8 == 0, "Embedding dimension must be a multiple of 8"

    img_size = to_2tuple(img_size)
    patch_size = to_2tuple(patch_size)

    self.img_size = img_size
    self.patch_size = patch_size
    self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
    self.num_patches = self.grid_size[0] * self.grid_size[1]

    # Create stem network
    stem = []
    input_dim, output_dim = 3, embed_dim // 8
    for l in range(2):
      stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
      stem.append(nn.BatchNorm2d(output_dim))
      stem.append(nn.ReLU(inplace=True))
      input_dim = output_dim
      output_dim *= 2
    stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
    self.proj = nn.Sequential(*stem)

    # Apply normalization layer (if provided)
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

  def forward(self, x):
    B, C, H, W = x.shape

    # Check input image size
    assert H == self.img_size[0] and W == self.img_size[1], \
        f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

    x = self.proj(x)
    x = x.permute(0, 2, 3, 1)  # BCHW -> BHWC
    x = self.norm(x)
    return x


In [None]:
from urllib.request import urlopen
from PIL import Image
import timm
import torch

# Charger l'image
img_path = "/kaggle/input/segmentation-of-nuclei-in-cryosectioned-he-images/tissue images/Human_LymphNodes_01.tif"
img = Image.open(img_path).convert("RGB")  # Convertir en RGB pour être compatible avec le modèle

# Charger le modèle préentraîné depuis Hugging Face Hub via timm
model = timm.create_model(
    model_name="hf-hub:1aurent/swin_tiny_patch4_window7_224.CTransPath",
    pretrained=True,
).eval()

# Obtenir les configurations spécifiques au modèle pour les transformations
data_config = timm.data.resolve_data_config({}, model=model)
transforms = timm.data.create_transform(**data_config, is_training=False)

# Appliquer les transformations sur l'image
data = transforms(img).unsqueeze(0)  # Ajouter la dimension batch

# Vérifier si un GPU est disponible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data = data.to(device)

# Effectuer l'inférence
with torch.no_grad():
    output = model(data)

# Afficher la sortie
print("Output shape:", output.shape)
print("Output:", output)


In [15]:
import torch
import timm
from timm.models.swin_transformer import SwinTransformer
from torch import nn
from torchvision import transforms
from PIL import Image

# Define the ConvStem to match the checkpoint's architecture
class ConvStem(nn.Module):
    def __init__(self, in_channels=3, embed_dim=96):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim // 2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.proj(x)

# Create a custom Swin Transformer that uses the ConvStem
class CustomSwinTransformer(SwinTransformer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.patch_embed.proj = ConvStem(in_channels=3, embed_dim=self.embed_dim)

# Load the model with the adjusted architecture
model = CustomSwinTransformer(
    patch_size=4,
    window_size=7,
    embed_dim=96,
    depths=(2, 2, 6, 2),
    num_heads=(3, 6, 12, 24),
    num_classes=1000
)

# Load the pretrained weights
checkpoint_path = "/kaggle/input/ctranspath-models/checkpoint.pth"  # Update with the actual path
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)  # Load with non-strict mode to handle differences

# Move the model to GPU and set to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Define transforms
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

# Load and preprocess the image
img_path = "/kaggle/input/segmentation-of-nuclei-in-cryosectioned-he-images/tissue images/Human_LymphNodes_01.tif"
img = Image.open(img_path).convert("RGB")
data = transforms(img).unsqueeze(0).to(device)  # Add batch dimension and move to device

# Perform inference
with torch.no_grad():
    output = model(data)  # Output is (batch_size, num_features)
    print("Model output:", output)


  checkpoint = torch.load(checkpoint_path, map_location="cpu")


Model output: tensor([[ 8.1148e-01,  1.8935e-01, -9.0224e-02,  2.5861e-01,  3.0596e-01,
          8.2276e-03,  9.0792e-02,  6.6537e-01, -4.3929e-01,  4.7570e-01,
         -3.9093e-01, -4.5301e-01, -4.3061e-01, -2.6897e-01,  1.2209e-01,
         -9.3461e-02,  2.3282e-02, -5.1481e-01,  7.7189e-01,  2.6584e-02,
         -2.7158e-01,  4.8113e-01,  1.1812e+00,  1.3154e-01,  5.9637e-01,
          2.2954e-01,  1.3701e-01, -1.1506e+00, -2.3190e-01, -2.2726e-01,
          9.2835e-02, -1.0608e-01, -4.6960e-01,  1.2167e+00,  5.3204e-01,
         -2.2739e-01,  6.7872e-01,  4.4365e-01, -4.0916e-01, -6.2200e-02,
          8.5042e-01, -4.2307e-01,  1.0337e+00,  7.4152e-02, -5.5259e-01,
         -1.5084e-01, -4.5673e-01,  1.6890e-01, -6.2315e-01,  3.0302e-01,
         -4.0741e-01,  9.9701e-01, -3.5055e-01,  6.4751e-01,  2.1747e-02,
         -1.0295e-01, -4.4445e-01,  3.9062e-01, -5.9998e-01,  5.9981e-03,
         -8.1216e-01,  1.1108e-01, -4.3611e-01,  3.6392e-01, -6.4813e-01,
          1.9192e-01,  5