In [None]:
from utils.extract_wound_class import CachedWoundClassifier

classifier = CachedWoundClassifier()
label, probabilities = classifier.predict("careful_this_contain_wound_image/test/Normal/2.jpg")
print(f"Predicted: {label}")

In [None]:
from utils.extract_wound_features import CLIPWoundFeatureExtractor
from utils.extract_wound_class import CachedWoundClassifier

extractor = CLIPWoundFeatureExtractor()
classifier = CachedWoundClassifier()
image_path = 'careful_this_contain_wound_image/Burns/burns (20).jpg'
wound_class, probabilities = classifier.predict(image_path)

features = extractor.extract_features(image_path, wound_class, lang='th')

print(f"\nTop features for wound class: {wound_class}")
for desc, score in features:
    print(f"{desc}: {score:.4f}")

In [None]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights

class WoundClassifier(nn.Module):
    def __init__(self, num_classes=5, dropout=0.4):
        super().__init__()
        base = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
        n_features = base.classifier[1].in_features
        base.classifier = nn.Identity()
        self.backbone = base
        self.shared_head = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.GELU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout)
        )
        self.class_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.GELU(),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
        self.layer_groups = [
            self.backbone.features[0:2],
            self.backbone.features[2:4],
            self.backbone.features[4:6],
            self.backbone.features[6:]
        ]

    def forward(self, x):
        x = self.backbone(x)
        x = self.shared_head(x)
        return self.class_head(x)

model = WoundClassifier(num_classes=5, dropout=0.4)
checkpoint = torch.load("topdown_model_fold4_stage3.pt", map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "topdown_model_fold4_stage3_opset_20.onnx",
    export_params=True,
    opset_version=20,
    do_constant_folding=True,
    input_names=["image"],
    output_names=["logits"],
    dynamic_axes={
        "image": {0: "batch_size"},
        "logits": {0: "batch_size"}
    }
)

print("Exported WoundClassifier to topdown_model_fold4_stage3_opset_20.onnx")