In [1]:
import torch
import timm
import numpy as np
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from PIL import Image

# ================================
# 1️⃣ LOAD DATASET
# ================================
dataset_name = "Team-SknAI/SknAI_300_v3_11Labels"
datasets = load_dataset(dataset_name)
datasets = datasets["train"].train_test_split(test_size=0.2, stratify_by_column="label")

# ================================
# 2️⃣ DEFINE TRANSFORMATIONS
# ================================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ================================
# 3️⃣ CUSTOM DATASET CLASS
# ================================
class SkinDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        try:
            img = self.dataset[idx]["image"]

            # Convert image to PIL (Fix corrupt images)
            img = Image.fromarray(np.array(img).astype("uint8"))

            # Ensure image is RGB (Fix grayscale issues)
            img = img.convert("RGB")

            img = self.transform(img)
            label = torch.tensor(self.dataset[idx]["label"], dtype=torch.long)
            return img, label
        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            return torch.zeros(3, 224, 224), torch.tensor(0)  # Return black image if error

# ================================
# 4️⃣ CREATE DATA LOADERS (Fixed)
# ================================
train_dataset = SkinDataset(datasets["train"], transform)
val_dataset = SkinDataset(datasets["test"], transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)  # No multiprocessing
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)  # No multiprocessing

# ================================
# 5️⃣ DEFINE MODEL
# ================================
class DenseNet121Model(nn.Module):
    def __init__(self, num_classes):
        super(DenseNet121Model, self).__init__()
        self.model = timm.create_model("densenet121", pretrained=True, num_classes=num_classes)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.model(x)

num_classes = len(datasets["train"].features["label"].names)
model = DenseNet121Model(num_classes)

# ================================
# 6️⃣ TRAINING SETUP
# ================================
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# ================================
# 7️⃣ TRAINING LOOP (Fixed)
# ================================
EPOCHS = 10

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    for inputs, labels in train_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {running_loss:.4f}, Train Accuracy: {train_acc:.2f}%")

# ================================
# 8️⃣ EVALUATION
# ================================
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)
        correct += (outputs.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Epoch 1/10 - Loss: 245.5991, Train Accuracy: 50.57%
Epoch 2/10 - Loss: 163.5084, Train Accuracy: 66.59%
Epoch 3/10 - Loss: 123.0944, Train Accuracy: 74.96%
Epoch 4/10 - Loss: 105.0996, Train Accuracy: 79.09%
Epoch 5/10 - Loss: 96.0255, Train Accuracy: 80.00%
Epoch 6/10 - Loss: 53.3952, Train Accuracy: 89.96%
Epoch 7/10 - Loss: 36.0152, Train Accuracy: 93.14%
Epoch 8/10 - Loss: 27.3937, Train Accuracy: 94.28%
Epoch 9/10 - Loss: 29.2353, Train Accuracy: 94.09%
Epoch 10/10 - Loss: 25.8235, Train Accuracy: 94.66%




Test Accuracy: 85.76%
