In [1]:
import torch
import timm
import numpy as np
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from PIL import Image

# ================================
# 1️⃣ LOAD DATASET
# ================================
dataset_name = "Team-SknAI/SknAI_300_v3_11Labels"
datasets = load_dataset(dataset_name)
datasets = datasets["train"].train_test_split(test_size=0.2, stratify_by_column="label")

# ================================
# 2️⃣ DEFINE TRANSFORMATIONS
# ================================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ================================
# 3️⃣ CUSTOM DATASET CLASS
# ================================
class SkinDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        try:
            img = self.dataset[idx]["image"]

            # Convert image to PIL (Fix corrupt images)
            img = Image.fromarray(np.array(img).astype("uint8"))

            # Ensure image is RGB (Fix grayscale issues)
            img = img.convert("RGB")

            img = self.transform(img)
            label = torch.tensor(self.dataset[idx]["label"], dtype=torch.long)
            return img, label
        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            return torch.zeros(3, 224, 224), torch.tensor(0)  # Return black image if error

# ================================
# 4️⃣ CREATE DATA LOADERS (Fixed)
# ================================
train_dataset = SkinDataset(datasets["train"], transform)
val_dataset = SkinDataset(datasets["test"], transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)  # No multiprocessing
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)  # No multiprocessing

# ================================
# 5️⃣ DEFINE MODEL
# ================================
class DenseNet121Model(nn.Module):
    def __init__(self, num_classes):
        super(DenseNet121Model, self).__init__()
        self.model = timm.create_model("densenet121", pretrained=True, num_classes=num_classes)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.model(x)

num_classes = len(datasets["train"].features["label"].names)
model = DenseNet121Model(num_classes)

In [None]:
# ================================
# 6️⃣ TRAINING SETUP (Updated)
# ================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)  # Cosine annealing for better LR decay

# ================================
# 7️⃣ TRAINING LOOP (Updated)
# ================================
EPOCHS = 15 
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {running_loss:.4f}, Train Accuracy: {train_acc:.2f}%")


# ================================
# 8️⃣ EVALUATION
# ================================
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU

        outputs = model(inputs)
        correct += (outputs.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")





Epoch 1/15 - Loss: 91.9136, Train Accuracy: 65.72%
Epoch 2/15 - Loss: 40.6836, Train Accuracy: 84.47%
Epoch 3/15 - Loss: 27.0344, Train Accuracy: 89.51%
Epoch 4/15 - Loss: 17.0373, Train Accuracy: 93.52%
Epoch 5/15 - Loss: 9.9543, Train Accuracy: 96.48%
Epoch 6/15 - Loss: 5.6283, Train Accuracy: 98.14%
Epoch 7/15 - Loss: 3.5899, Train Accuracy: 98.90%
Epoch 8/15 - Loss: 2.5495, Train Accuracy: 99.20%
Epoch 9/15 - Loss: 1.8241, Train Accuracy: 99.36%
Epoch 10/15 - Loss: 1.6887, Train Accuracy: 99.55%
Epoch 11/15 - Loss: 1.3513, Train Accuracy: 99.66%
Epoch 12/15 - Loss: 1.3306, Train Accuracy: 99.55%
Epoch 13/15 - Loss: 1.4210, Train Accuracy: 99.32%
Epoch 14/15 - Loss: 1.7984, Train Accuracy: 99.39%
Epoch 15/15 - Loss: 2.9855, Train Accuracy: 98.86%
Test Accuracy: 90.15%
