In [1]:
#!pip install torch torchvision timm numpy matplotlib seaborn datasets pillow scikit-learn pandas


In [1]:
import torch
import timm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd
from timm.data.mixup import Mixup
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau

# ================================
# 1️⃣ LOAD DATASET
# ================================
dataset_name = "Team-SknAI/SknAI_300_v3_11Labels"
datasets = load_dataset(dataset_name)
datasets = datasets["train"].train_test_split(test_size=0.2, stratify_by_column="label")

In [2]:
# ================================
# 2️⃣ DEFINE TRANSFORMATIONS
# ================================
transform = transforms.Compose([
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), shear=10),
    transforms.GaussianBlur(kernel_size=3),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ================================
# 3️⃣ CUSTOM DATASET CLASS
# ================================
class SkinDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img = self.dataset[idx]["image"]
        img = Image.fromarray(np.array(img).astype("uint8")).convert("RGB")
        img = self.transform(img)
        label = torch.tensor(self.dataset[idx]["label"], dtype=torch.long)
        return img, label

# ================================
# 4️⃣ CREATE DATA LOADERS
# ================================
train_dataset = SkinDataset(datasets["train"], transform)
val_dataset = SkinDataset(datasets["test"], transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)


In [None]:
# ================================
# 5️⃣ DEFINE MODEL
# ================================
class DenseNet121Model(nn.Module):
    def __init__(self, num_classes):
        super(DenseNet121Model, self).__init__()
        self.model = timm.create_model("densenet121", pretrained=True, num_classes=num_classes)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.model(x)

num_classes = len(datasets["train"].features["label"].names)
model = DenseNet121Model(num_classes)

# ================================
# 6️⃣ TRAINING SETUP
# ================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=5e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)
early_stopping = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

mixup_fn = Mixup(mixup_alpha=0.1, cutmix_alpha=0.5, num_classes=num_classes)



In [5]:
# ================================
# 7️⃣ TRAINING LOOP WITH EARLY STOPPING
# ================================
EPOCHS = 10
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
best_acc = 0

for epoch in range(EPOCHS):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels = mixup_fn(inputs, labels)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
        total += labels.size(0)
    
    scheduler.step(running_loss)
    early_stopping.step(running_loss)
    train_losses.append(running_loss / len(train_loader))
    train_accuracies.append(100 * correct / total)
    
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {train_losses[-1]:.4f}, Train Accuracy: {train_accuracies[-1]:.2f}%")



Epoch 1/10 - Loss: 1.5650, Train Accuracy: 71.44%
Epoch 2/10 - Loss: 1.4928, Train Accuracy: 76.74%
Epoch 3/10 - Loss: 1.3728, Train Accuracy: 82.84%
Epoch 4/10 - Loss: 1.3055, Train Accuracy: 84.17%
Epoch 5/10 - Loss: 1.4659, Train Accuracy: 77.88%
Epoch 6/10 - Loss: 1.3912, Train Accuracy: 80.53%
Epoch 7/10 - Loss: 1.3198, Train Accuracy: 84.58%
Epoch 8/10 - Loss: 1.3824, Train Accuracy: 80.15%
Epoch 9/10 - Loss: 1.3023, Train Accuracy: 88.94%
Epoch 10/10 - Loss: 1.1594, Train Accuracy: 93.11%


In [None]:
    
if train_accuracies[-1] > best_acc:
    best_acc = train_accuracies[-1]
    torch.save(model.state_dict(), "best_model.pth")

In [6]:
# ================================
# 8️⃣ EVALUATION & TEST-TIME AUGMENTATION (TTA)
# ================================
def tta_predict(model, images, n_augments=5):
    model.eval()
    outputs = torch.stack([model(images) for _ in range(n_augments)]).mean(0)
    return outputs.argmax(dim=1)

model.eval()
correct, total = 0, 0
y_true, y_pred = [], []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        preds = tta_predict(model, inputs)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

val_accuracy = 100 * correct / total
val_accuracies.append(val_accuracy)
print(f"Test Accuracy: {val_accuracy:.2f}%")

Test Accuracy: 88.48%


In [7]:
print(classification_report(y_true, y_pred, target_names=datasets["train"].features["label"].names))


                                           precision    recall  f1-score   support

                                     Acne       0.88      0.83      0.85        60
                                 Alopecia       0.97      0.97      0.97        60
                                   Eczema       0.94      0.80      0.86        60
      Fungal Infection (Nail or ringworm)       0.94      0.98      0.96        60
                                   Herpes       0.70      0.77      0.73        60
                              Normal Skin       0.98      0.98      0.98        60
                                Psoriasis       0.73      0.73      0.73        60
                                  Rosacea       0.87      0.92      0.89        60
Viral Infection (Chicken-pox or shingles)       0.88      0.88      0.88        60
                                 Vitiligo       0.97      0.98      0.98        60
                                    Warts       0.90      0.88      0.89        60

  