This script loads the HAM10000 skin lesion dataset and performs fine-tuning on DinoV2 ViT-B/14 model.
It compares the performance between the pre-trained model without fine-tuning and with fine-tuning.

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import timm
# from torchvision.models import dinov2_vitb14, DINOv2_ViTB14_Weights
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [3]:
# Define dataset paths
DATA_DIR = "/home/zack/11785/project/data/HAM10000_dataset"
METADATA_FILE = os.path.join(DATA_DIR, "HAM10000_metadata.csv")

# Create HAM10000 Dataset
class HAM10000Dataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        """
        Args:
            dataframe: Pandas dataframe containing image_id and dx (diagnosis)
            img_dir: Directory with all the images
            transform: Optional transform to apply to the images
        """
        self.dataframe = dataframe
        self.img_dir = img_dir
        self.transform = transform
        
        # Create a mapping from diagnosis to class index
        self.diagnosis_mapping = {
            'akiec': 0,  # Actinic Keratosis
            'bcc': 1,    # Basal Cell Carcinoma
            'bkl': 2,    # Benign Keratosis
            'df': 3,     # Dermatofibroma
            'mel': 4,    # Melanoma
            'nv': 5,     # Melanocytic Nevus
            'vasc': 6    # Vascular Lesion
        }
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        img_id = self.dataframe.iloc[idx]['image_id']
        img_path = os.path.join(self.img_dir, img_id + '.jpg')
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = self.diagnosis_mapping[self.dataframe.iloc[idx]['dx']]
        return image, label

# Data loading and preprocessing
def load_data():
    # Read metadata
    metadata = pd.read_csv(METADATA_FILE)
    print(f"Dataset size: {len(metadata)}")
    
    # Check class distribution
    print("Class distribution:")
    print(metadata['dx'].value_counts())
    
    # Train/test split
    train_df, test_df = train_test_split(metadata, test_size=0.2, stratify=metadata['dx'], random_state=42)
    
    # Define transformations
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = HAM10000Dataset(train_df, DATA_DIR, train_transform)
    test_dataset = HAM10000Dataset(test_df, DATA_DIR, test_transform)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    return train_loader, test_loader

In [None]:
# Create a classifier on top of DinoV2 features
class DinoClassifier(nn.Module):
    def __init__(self, backbone, num_classes=7, fine_tune=True):
        super(DinoClassifier, self).__init__()
        self.backbone = backbone
        
        # Freeze backbone if not fine-tuning
        if not fine_tune:
            for param in self.backbone.parameters():
                param.requires_grad = False
        else:
            # Only fine-tune the last 2 blocks of the backbone
            for name, param in self.backbone.named_parameters():
                if 'blocks.10' in name or 'blocks.11' in name or 'norm' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        # Get embedding dimension from the backbone
        if hasattr(self.backbone, 'embed_dim'):
            embedding_dim = self.backbone.embed_dim
        else:
            embedding_dim = self.backbone.num_features
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # Extract features using the backbone
        features = self.backbone(x)
        
        # Pass features through the classifier
        logits = self.classifier(features)
        
        return logits

# Training function
def train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs=10):
    best_accuracy = 0.0
    train_losses = []
    test_accuracies = []
    scaler = torch.cuda.amp.GradScaler()  
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()        
            scaler.step(optimizer)                
            scaler.update()                       
            
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        
        model.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                # 使用混合精度进行推理
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.numpy())
        
        accuracy = accuracy_score(all_labels, all_preds)
        test_accuracies.append(accuracy)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Test Accuracy: {accuracy:.4f}")
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), 'best_model.pth')
    
    return train_losses, test_accuracies, best_accuracy

# Evaluation function
def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            
            with torch.cuda.amp.autocast():
                outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {accuracy:.4f}")
    
    # Generate classification report
    class_names = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
    report = classification_report(all_labels, all_preds, target_names=class_names)
    print("Classification Report:")
    print(report)
    
    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    fmt = 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    return accuracy, report

In [6]:
train_loader, test_loader = load_data()

print("Loading DinoV2 ViT-B/14 model...")
try:
    backbone =  torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
    
    print("Successfully loaded DinoV2 from torch")
except:
    backbone = timm.create_model('vit_base_patch14_dinov2', pretrained=True)
    print("Successfully loaded DinoV2 from timm")

# Model without fine-tuning (frozen backbone)
frozen_model = DinoClassifier(backbone, fine_tune=False).to(device)

# Model with fine-tuning
finetuned_model = DinoClassifier(backbone, fine_tune=True).to(device)

# Define loss function and optimizers
criterion = nn.CrossEntropyLoss()
frozen_optimizer = optim.AdamW(frozen_model.parameters(), lr=1e-3)
finetuned_optimizer = optim.AdamW([
    {'params': [p for n, p in finetuned_model.named_parameters() if 'backbone' in n and p.requires_grad], 'lr': 5e-5},
    {'params': [p for n, p in finetuned_model.named_parameters() if 'backbone' not in n], 'lr': 1e-3}
])



Dataset size: 10015
Class distribution:
dx
nv       6705
mel      1113
bkl      1099
bcc       514
akiec     327
vasc      142
df        115
Name: count, dtype: int64
Loading DinoV2 ViT-B/14 model...


Using cache found in /home/zack/.cache/torch/hub/facebookresearch_dinov2_main


Successfully loaded DinoV2 from torch


In [7]:
# Train and evaluate frozen model
print("\n--- Training model with frozen backbone ---")
frozen_losses, frozen_accuracies, frozen_best = train_model(
    frozen_model, train_loader, test_loader, criterion, frozen_optimizer, num_epochs=10
)

# Load best frozen model and evaluate
frozen_model.load_state_dict(torch.load('best_model.pth'))
print("\n--- Final evaluation of model with frozen backbone ---")
frozen_accuracy, frozen_report = evaluate_model(frozen_model, test_loader)

# Rename the best model file to avoid overwriting
os.rename('best_model.pth', 'frozen_best_model.pth')


--- Training model with frozen backbone ---


Epoch 1/10: 100%|██████████| 251/251 [01:40<00:00,  2.49it/s]


Epoch 1/10, Loss: 1.0187, Test Accuracy: 0.7064


Epoch 2/10: 100%|██████████| 251/251 [01:45<00:00,  2.37it/s]


Epoch 2/10, Loss: 0.8224, Test Accuracy: 0.7059


Epoch 3/10: 100%|██████████| 251/251 [01:42<00:00,  2.46it/s]


Epoch 3/10, Loss: 0.7659, Test Accuracy: 0.7299


Epoch 4/10: 100%|██████████| 251/251 [01:42<00:00,  2.45it/s]


Epoch 4/10, Loss: 0.7362, Test Accuracy: 0.7364


Epoch 5/10: 100%|██████████| 251/251 [01:42<00:00,  2.45it/s]


Epoch 5/10, Loss: 0.6931, Test Accuracy: 0.7384


Epoch 6/10: 100%|██████████| 251/251 [01:42<00:00,  2.46it/s]


Epoch 6/10, Loss: 0.6654, Test Accuracy: 0.7649


Epoch 7/10: 100%|██████████| 251/251 [01:42<00:00,  2.44it/s]


Epoch 7/10, Loss: 0.6476, Test Accuracy: 0.7703


Epoch 8/10: 100%|██████████| 251/251 [01:42<00:00,  2.44it/s]


Epoch 8/10, Loss: 0.5999, Test Accuracy: 0.7843


Epoch 9/10: 100%|██████████| 251/251 [01:43<00:00,  2.43it/s]


Epoch 9/10, Loss: 0.5821, Test Accuracy: 0.7329


Epoch 10/10: 100%|██████████| 251/251 [01:39<00:00,  2.51it/s]


Epoch 10/10, Loss: 0.5719, Test Accuracy: 0.7908

--- Final evaluation of model with frozen backbone ---


Evaluating: 100%|██████████| 63/63 [00:11<00:00,  5.53it/s]


Test Accuracy: 0.7908
Classification Report:
              precision    recall  f1-score   support

       akiec       0.41      0.37      0.39        65
         bcc       0.61      0.71      0.66       103
         bkl       0.55      0.58      0.57       220
          df       0.67      0.09      0.15        23
         mel       0.49      0.52      0.51       223
          nv       0.92      0.91      0.91      1341
        vasc       0.88      0.79      0.83        28

    accuracy                           0.79      2003
   macro avg       0.65      0.57      0.57      2003
weighted avg       0.79      0.79      0.79      2003



In [11]:
# Train and evaluate fine-tuned model
print("\n--- Training model with fine-tuned backbone ---")
finetuned_losses, finetuned_accuracies, finetuned_best = train_model(
    finetuned_model, train_loader, test_loader, criterion, finetuned_optimizer, num_epochs=10
)



--- Training model with fine-tuned backbone ---


Epoch 1/10: 100%|██████████| 251/251 [01:43<00:00,  2.42it/s]


Epoch 1/10, Loss: 0.5200, Test Accuracy: 0.7823


Epoch 2/10: 100%|██████████| 251/251 [01:43<00:00,  2.42it/s]


Epoch 2/10, Loss: 0.5071, Test Accuracy: 0.8018


Epoch 3/10: 100%|██████████| 251/251 [01:45<00:00,  2.37it/s]


Epoch 3/10, Loss: 0.4895, Test Accuracy: 0.8128


Epoch 4/10: 100%|██████████| 251/251 [01:46<00:00,  2.35it/s]


Epoch 4/10, Loss: 0.4782, Test Accuracy: 0.8043


Epoch 5/10: 100%|██████████| 251/251 [01:47<00:00,  2.34it/s]


Epoch 5/10, Loss: 0.4770, Test Accuracy: 0.8123


Epoch 6/10: 100%|██████████| 251/251 [01:44<00:00,  2.40it/s]


Epoch 6/10, Loss: 0.4742, Test Accuracy: 0.8093


Epoch 7/10: 100%|██████████| 251/251 [01:46<00:00,  2.36it/s]


Epoch 7/10, Loss: 0.4711, Test Accuracy: 0.8173


Epoch 8/10: 100%|██████████| 251/251 [01:45<00:00,  2.39it/s]


Epoch 8/10, Loss: 0.4606, Test Accuracy: 0.8153


Epoch 9/10: 100%|██████████| 251/251 [01:42<00:00,  2.45it/s]


Epoch 9/10, Loss: 0.4586, Test Accuracy: 0.8173


Epoch 10/10: 100%|██████████| 251/251 [01:42<00:00,  2.46it/s]


Epoch 10/10, Loss: 0.4616, Test Accuracy: 0.8223


In [12]:
# Load best fine-tuned model and evaluate
finetuned_model.load_state_dict(torch.load('best_model.pth'))
print("\n--- Final evaluation of model with fine-tuned backbone ---")
finetuned_accuracy, finetuned_report = evaluate_model(finetuned_model, test_loader)

# Rename the fine-tuned model file for clarity
os.rename('best_model.pth', 'finetuned_best_model.pth')


--- Final evaluation of model with fine-tuned backbone ---


Evaluating: 100%|██████████| 63/63 [00:11<00:00,  5.53it/s]


Test Accuracy: 0.8223
Classification Report:
              precision    recall  f1-score   support

       akiec       0.63      0.42      0.50        65
         bcc       0.64      0.79      0.70       103
         bkl       0.65      0.64      0.64       220
          df       0.50      0.57      0.53        23
         mel       0.58      0.46      0.52       223
          nv       0.91      0.94      0.92      1341
        vasc       0.91      0.75      0.82        28

    accuracy                           0.82      2003
   macro avg       0.69      0.65      0.66      2003
weighted avg       0.82      0.82      0.82      2003



In [14]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, len(frozen_losses) + 1), frozen_losses, label='Frozen')
plt.plot(range(1, len(finetuned_losses) + 1), finetuned_losses, label='Fine-tuned')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, len(frozen_accuracies) + 1), frozen_accuracies, label='Frozen')
plt.plot(range(1, len(finetuned_accuracies) + 1), finetuned_accuracies, label='Fine-tuned')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.savefig('training_curves.png')
plt.close()

# Compare results
print("\n--- Comparison ---")
print(f"Frozen model best accuracy: {frozen_best:.4f}")
print(f"Fine-tuned model best accuracy: {finetuned_best:.4f}")
print(f"Improvement: {(finetuned_best - frozen_best) * 100:.2f}%")


--- Comparison ---
Frozen model best accuracy: 0.7908
Fine-tuned model best accuracy: 0.8223
Improvement: 3.15%
