In [1]:
import torch
from PIL import Image
from torchvision import transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [3]:
test_transform = transforms.Compose([
    transforms.Resize((518, 518)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

In [4]:
import torch.nn as nn

In [5]:
class MultiTaskModel(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.backbone = timm.create_model("vit_base_patch14_dinov2", pretrained=False)
        embed_dim = self.backbone.embed_dim
        self.backbone.head = nn.Identity()
        self.class_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 1024),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        self.reg_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 768),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(768, 256),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        feat = self.backbone.forward_features(x)
        cls = self.class_head(feat)
        reg = self.reg_head(feat)
        return cls, reg

In [6]:
import timm

In [7]:
model = MultiTaskModel().to(device)
model.load_state_dict(torch.load("best_thickness_model.pth", map_location=device))
model.eval()

print("Model loaded successfully")

thickness_mean = 0.01
thickness_std  = 0.9999997947120274

Model loaded successfully


In [8]:
def predict_image(img_path):
    img = Image.open(img_path).convert("RGB")
    img_t = test_transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        cls_out, reg_out = model(img_t)

    cls_logits = cls_out.mean(dim=1)
    pred_class_idx = cls_logits.argmax(dim=1).item()

    folder = "benign" if pred_class_idx == 0 else "malignant"

    reg_mean = reg_out.mean(dim=1).item()
    thickness = reg_mean * thickness_std + thickness_mean

    return folder, thickness



In [9]:
img_path = "sample_test.jpg"
folder, thickness = predict_image(img_path)

In [10]:
print("Predicted Class:", folder)
print("Predicted Thickness:", thickness)

Predicted Class: malignant
Predicted Thickness: 1.0285076347330275
