In [1]:
from medclip import MedCLIPModel, MedCLIPProcessor
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from medclip import MedCLIPModel, MedCLIPVisionModelViT, MedCLIPVisionModel

model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)

state_dict = torch.load("medclip-vit-pretrained/pytorch_model.bin")
state_dict.pop("text_model.model.embeddings.position_ids")

model.load_state_dict(state_dict)
model

MedCLIPModel(
  (vision_model): MedCLIPVisionModelViT(
    (model): SwinModel(
      (embeddings): SwinEmbeddings(
        (patch_embeddings): SwinPatchEmbeddings(
          (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): SwinEncoder(
        (layers): ModuleList(
          (0): SwinStage(
            (blocks): ModuleList(
              (0): SwinLayer(
                (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                (attention): SwinAttention(
                  (self): SwinSelfAttention(
                    (query): Linear(in_features=96, out_features=96, bias=True)
                    (key): Linear(in_features=96, out_features=96, bias=True)
                    (value): Linear(in_features=96, out_features=96, bias=True)
                    (dropout): Dropout(p=0.0, inplace=False)


In [3]:
import torch
import torch.nn as nn

class KneeMRIClassifier(nn.Module):
    def __init__(self, vit_encoder, hidden_dim=512):
        super().__init__()
        self.backbone = vit_encoder
        feature_dim = self.backbone.projection_head.out_features

        self.abnormal_head = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        self.acl_head = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        self.meniscus_head = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, pixel_values):

        features = self.backbone(pixel_values).pooler_output
        
        abnormal_pred = self.abnormal_head(features)
        acl_pred = self.acl_head(features)
        meniscus_pred = self.meniscus_head(features)
        
        return {
            'abnormal': abnormal_pred,
            'acl': acl_pred,
            'meniscus': meniscus_pred
        }

In [4]:
classifier = KneeMRIClassifier(model.vision_model)
classifier

KneeMRIClassifier(
  (backbone): MedCLIPVisionModelViT(
    (model): SwinModel(
      (embeddings): SwinEmbeddings(
        (patch_embeddings): SwinPatchEmbeddings(
          (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): SwinEncoder(
        (layers): ModuleList(
          (0): SwinStage(
            (blocks): ModuleList(
              (0): SwinLayer(
                (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                (attention): SwinAttention(
                  (self): SwinSelfAttention(
                    (query): Linear(in_features=96, out_features=96, bias=True)
                    (key): Linear(in_features=96, out_features=96, bias=True)
                    (value): Linear(in_features=96, out_features=96, bias=True)
                    (dropout): Dropout(p=0.0, inplace=False)

In [5]:
device = torch.device("cuda")
classifier.to(device)

KneeMRIClassifier(
  (backbone): MedCLIPVisionModelViT(
    (model): SwinModel(
      (embeddings): SwinEmbeddings(
        (patch_embeddings): SwinPatchEmbeddings(
          (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): SwinEncoder(
        (layers): ModuleList(
          (0): SwinStage(
            (blocks): ModuleList(
              (0): SwinLayer(
                (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                (attention): SwinAttention(
                  (self): SwinSelfAttention(
                    (query): Linear(in_features=96, out_features=96, bias=True)
                    (key): Linear(in_features=96, out_features=96, bias=True)
                    (value): Linear(in_features=96, out_features=96, bias=True)
                    (dropout): Dropout(p=0.0, inplace=False)