In [63]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from transformers import SwinForImageClassification, SwinConfig
from sklearn.model_selection import train_test_split

# Define dataset paths
HAM10000_METADATA = "D:\\Melanoma-Skin-Cancer-Detection-using-Swin-Transformer-main\\data\\HAM10000_metadata.csv"
IMAGE_DIR = "D:\\Melanoma-Skin-Cancer-Detection-using-Swin-Transformer-main\\data\\HAM10000_images"

# Load and process metadata
df = pd.read_csv(HAM10000_METADATA)
label_mapping = {'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3, 'mel': 4, 'nv': 5, 'vasc': 6}
df["label"] = df["dx"].map(label_mapping)

# Split dataset into 80% train, 20% test
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df["label"])

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Custom Dataset Class
class SkinCancerDataset(Dataset):
    def __init__(self, img_dir, metadata, transform=None):
        self.img_dir = img_dir
        self.metadata = metadata
        self.transform = transform
        self.valid_images = [(os.path.join(img_dir, f"{row['image_id']}.jpg"), row['label'])
                             for _, row in metadata.iterrows() if os.path.exists(os.path.join(img_dir, f"{row['image_id']}.jpg"))]
    def __len__(self):
        return len(self.valid_images)
    def __getitem__(self, idx):
        img_path, label = self.valid_images[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

# Create dataset & dataloaders
train_dataset = SkinCancerDataset(IMAGE_DIR, train_df, transform=transform)
test_dataset = SkinCancerDataset(IMAGE_DIR, test_df, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load Swin Transformer Model
config = SwinConfig.from_pretrained("microsoft/swin-tiny-patch4-window7-224", num_labels=7)
model = SwinForImageClassification(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Training Function
def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

# Train the model
train_model(model, train_loader, criterion, optimizer, epochs=10)

# Save the trained model
torch.save(model.state_dict(), "swin_transformer_skin_cancer.pth")

# Prediction Function
def predict(image_path, model):
    model.eval()
    if not os.path.exists(image_path):
        print("Image not found!")
        return None 
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image).logits
        prediction = torch.argmax(output, dim=1).item()
    return prediction

# Example Prediction
example_image = os.path.join(IMAGE_DIR, "ISIC_0027413.jpg")
predicted_class = predict(example_image, model)
if predicted_class is not None:
    print(f"Predicted Class: {predicted_class}")

Epoch 1, Loss: 1.0645
Epoch 2, Loss: 0.9074
Epoch 3, Loss: 0.8171
Epoch 4, Loss: 0.7728
Epoch 5, Loss: 0.7338
Epoch 6, Loss: 0.6919
Epoch 7, Loss: 0.6623
Epoch 8, Loss: 0.6482
Epoch 9, Loss: 0.6342
Epoch 10, Loss: 0.6293
Predicted Class: 5
