In [None]:
import os
from glob import glob
import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
from sklearn.model_selection import train_test_split

# selects training device
# if a GPU is not available, run on the CPU (but GPU is much faster for CNNs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# load metadata
metadata = pd.read_csv("HAM10000/HAM10000_metadata")

# converts the label, the diagnosis (dx), into numeric labels (0â€“6)
labels = metadata["dx"].unique()
label_map = {name: i for i, name in enumerate(labels)}
metadata["label"] = metadata["dx"].map(label_map)
print("Label map:", label_map)

# map image ids to image files
paths_part1 = glob("HAM10000/HAM10000_part1/*.jpg")
paths_part2 = glob("HAM10000/HAM10000_part2/*.jpg")
all_paths = paths_part1 + paths_part2

id_to_path = {os.path.basename(p)[:-4]: p for p in all_paths}
print("Total images found:", len(id_to_path))

# initialize lists to store image paths and their labels
image_paths = []
image_labels = []

for _, row in metadata.iterrows():
    img_id = row["image_id"]
    if img_id in id_to_path:
        image_paths.append(id_to_path[img_id])
        image_labels.append(row["label"])

print("Total usable images:", len(image_paths))

# train/validation split
# once we tune hyperparameters we can include a test set
train_paths, val_paths, train_labels, val_labels = train_test_split(
    image_paths,
    image_labels,
    test_size=0.2,
    stratify=image_labels, # makes class proportions consistent for training and validation
    random_state=42
)

print("Train size:", len(train_paths), "Val size:", len(val_paths))

# define image transforms for training images
train_transform = transforms.Compose([
    # make every image the same size
    transforms.Resize((224, 224)),

    # helps model learn symmetry if lesion is flipped
    transforms.RandomHorizontalFlip(),

    # account for images being taken in different lighting conditions
    transforms.ColorJitter(brightness=0.2, contrast=0.2),

    # CAN INSERT MORE TRANSFORMS HERE IF WE WANT TO TRY THEM

    # convert PIL image into PyTorch tensor and scale pixel values from 0-255 to 0-1
    transforms.ToTensor(),

    # standardize each channel (match what EfficientNet was trained on)
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# define image transforms for training images
val_transform = transforms.Compose([
    # make every image the same size
    transforms.Resize((224, 224)),

    # convert PIL image into PyTorch tensor and scale pixel values from 0-255 to 0-1
    transforms.ToTensor(),

    # standardize each channel (match what EfficientNet was trained on)
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# helper function that loads one image at a time and applies transformation
def load_image_tensor(path, transform):
    img = Image.open(path).convert("RGB")
    img = transform(img)
    return img

# build EfficientNet-B2 Model
weights = EfficientNet_B2_Weights.DEFAULT
model = efficientnet_b2(weights=weights)

num_classes = len(label_map) # we have 7 classes
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
model = model.to(device)

# define loss and optimizer

# cross entropy loss is used for CNNs
criterion = nn.CrossEntropyLoss()

# commonly used optimizer for EfficientNet over SGD
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

#################
# TRAINING
#################

# define batch size - if training is too slow, increase it
batch_size = 32

# more epochs means more learning, but it takes longer... can overfit if too many
epochs = 10

# keep track of best validation accuracy over all epochs
best_val_acc = 0.0

for epoch in range(epochs):
    # model is now training
    model.train()

    # accum loss over all batches for current epoch
    running_loss = 0.0

    # shuffle training data indices for current epoch
    indices = np.random.permutation(len(train_paths))

    # loop over batches
    for start in range(0, len(indices), batch_size):
        batch_idx = indices[start:start + batch_size]
        batch_imgs = []
        batch_labs = []

        # Load batch into memory
        for idx in batch_idx:
            # get file path for image
            path = train_paths[idx]
            # get numeric label (0-6) for image
            lab = train_labels[idx]
            # open image, apply transform, and return tensor (C x H x W)
            img_tensor = load_image_tensor(path, train_transform)

            # collect image tensors and image labels
            batch_imgs.append(img_tensor)
            batch_labs.append(lab)

        # turn batched tensors into single tensor (B x C x H x W) on GPU (if using GPU)
        images = torch.stack(batch_imgs).to(device)
        # create 1D tensor with int class labels (0-6)
        labels = torch.tensor(batch_labs, dtype=torch.long).to(device)

        ##### TRAINING STEP FOR EACH BATCH ####

        # clear old gradients from previous batches
        optimizer.zero_grad()
        # send batch through CNN
        outputs = model(images)
        # calculate loss
        loss = criterion(outputs, labels)
        # compute gradients
        loss.backward()
        # update model weights using gradients... this is where learning occurs
        optimizer.step()

        # track total loss for this epoch
        running_loss += loss.item()

    #################
    # VALIDATION
    #################

    # model is now in the mode for validation/testing
    model.eval()
    # count how many predictions are correct
    correct = 0
    # total num of predictions
    total = 0

    # no gradient tracking bc we are not training
    with torch.no_grad():
        for start in range(0, len(val_paths), batch_size):
            batch_paths = val_paths[start:start + batch_size]
            batch_labs = val_labels[start:start + batch_size]

            batch_imgs = []
            for path in batch_paths:
                img_tensor = load_image_tensor(path, val_transform)
                batch_imgs.append(img_tensor)

            images = torch.stack(batch_imgs).to(device)
            labels = torch.tensor(batch_labs, dtype=torch.long).to(device)

            outputs = model(images)

            # returns values, indices
            _, preds = torch.max(outputs, 1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    # compute validation accuracy
    val_acc = correct / total
    print("Epoch %d: loss=%.3f, val_acc=%.2f%%" % (epoch + 1, running_loss, val_acc * 100))

    # save best model so we can load it in more easily later
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # can do: model.load_state_dict(torch.load("best_effnet_b2_nodataset.pth"))
        torch.save(model.state_dict(), "best_effnet_b2_nodataset.pth")

print("Training complete. Best validation accuracy:", best_val_acc * 100)
     