In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import cv2
import matplotlib.pyplot as plt
import random
from sklearn.model_selection import KFold
from PIL import Image
from IPython import display
import time
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

The code blocks are organized in the following sequence: 

Data Preprocessing → Model and Function Definition → Model Training → Model Testing → Viewing Test Results. 

This code uses DenseNet201 as pretrained models and includes implementation of Grad-CAM visualization modules.


In [None]:
# Data Preprocessing
file_dir = r'your-dataset-path-directory'  # e.g., '/home/user/dataset' 
train_dir = os.path.join(file_dir, 'your-trainsubset-name')
test_dir = os.path.join(file_dir, 'your-testsubset-name')
"""
The dataset is stored in the following format:

file_dir/
├── train_dir/
│    ├── Normal/
│    │     ├── image1.jpg
│    │     ├── image2.jpg
│    │     └── ...
│    ├── Pneumonia-Bacterial/
│    │     ├── image1.jpg
│    │     ├── image2.jpg
│    │     └── ...
│    └── Pneumonia-Viral/
│            ├── image1.jpg
│            ├── image2.jpg
│            └── ...
│
└── test_dir/
    ├── Normal/
    │     ├── image1.jpg
    │     └── ...
    ├── Pneumonia-Bacterial/
    │     ├── image1.jpg
    │     ├── image2.jpg
    │     └── ...
    └── Pneumonia-Viral/
           ├── image1.jpg
           ├── image2.jpg
           └── ...           
"""

img_size = 512
batch_size = 4

transform_train = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Model and Function Definition

# model definition
def create_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = models.densenet201(pretrained=True)
    num_features = model.classifier.in_features
    model.classifier = nn.Sequential(
        nn.Linear(num_features, num_features // 2),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(num_features // 2, 3) 
    )
    return model.to(device)

# Grad-CAM visualization function definition
def grad_cam(model, img_tensor, target_layer):
    model.eval()
    img_tensor = img_tensor.to(device)
    features = None
    gradients = None

    def forward_hook(module, input, output):
        nonlocal features
        features = output

    def backward_hook(module, grad_in, grad_out):
        nonlocal gradients
        gradients = grad_out[0]

    handle_forward = target_layer.register_forward_hook(forward_hook)
    handle_backward = target_layer.register_backward_hook(backward_hook)

    output = model(img_tensor)
    pred_class = output.argmax(dim=1)
    loss = output[0, pred_class]
    model.zero_grad()
    loss.backward()

    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    features = features[0]
    for i in range(features.shape[0]):
        features[i, :, :] *= pooled_gradients[i]

    heatmap = features.detach().cpu().numpy().mean(axis=0)
    heatmap = np.maximum(heatmap, 0)  
    max_value = np.max(heatmap)
    if max_value > 0:
        heatmap /= max_value
    else:
        heatmap = np.zeros_like(heatmap)  

    handle_forward.remove()
    handle_backward.remove()

    return heatmap

def plot_grad_cam(img_path, model, target_layer):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (img_size, img_size))
    img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))  
    img_tensor = transform_val(img).unsqueeze(0)

    heatmap = grad_cam(model, img_tensor, target_layer)
    heatmap = cv2.resize(heatmap, img.size)  
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    if isinstance(img, Image.Image):  
        img = np.array(img)  # 转换为 NumPy 数组
    superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.title("Grad-CAM")
    plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()

In [None]:
# Model Training

# If you want to save the model, uncomment the following line:

# model_path_dir = f"your-model-path-directory"  # e.g., '/home/user/model'
# current_time = time.strftime("%Y%m%d_%H%M%S")
# model_name = f"resnet18_time{current_time}.pth"
# model_path = os.path.join(model_path_dir, model_name)

model = create_model()
train_dataset = datasets.ImageFolder(train_dir)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
folds = list(kf.split(range(len(train_dataset))))
criterion = nn.CrossEntropyLoss()
epochs = 4
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for fold, (train_idx, val_idx) in enumerate(folds):
    print(f"  Fold {fold+1}/5---------------------------------------------------------------------------")
    train_dataset = datasets.ImageFolder(train_dir)
    train_subset = Subset(train_dataset, train_idx)
    val_subset = Subset(train_dataset, val_idx)
    train_subset.dataset.transform = transform_train
    val_subset.dataset.transform = transform_val
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, drop_last=True)
    for epoch in range(epochs):  
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for inputs, labels in train_loader:  
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        train_acc = train_correct / train_total

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device) 
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
            val_acc = val_correct / val_total
        print(f"   Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, "
            f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_correct/val_total:.4f}")

    torch.cuda.empty_cache()
    
# If you want to save the model, uncomment the following line:
# torch.save(model.state_dict(), model_path)

In [None]:
# Model Testing  (different models' testing codes are not the same)
# make sure to uncomment the model you want to test and comment out the others

# If you want to test the saved model, uncomment the following line:
# make sure the model you want to test is the same as the one you trained
# you can uncomment corresponding model definition in "Model and Function Definition" part and run it again
# model_path_dir = f"your-model-path-directory"  # e.g., '/home/user/model'
# model_name = f"your-model-name.pth" 
# model_path = os.path.join(model_path_dir, model_name)
# model = create_model()
# model.load_state_dict(torch.load(model_path))

file_dir = r'your-dataset-path-directory'  # e.g., '/home/user/dataset'
test_dir = os.path.join(file_dir, 'your-testsubset-name')
test_dataset = datasets.ImageFolder(test_dir, transform=transform_val)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

correct = 0
total = 0

all_labels = []
all_predictions = []

misclassified_images = []
misclassified_labels = []
misclassified_predictions = []

trueclassified_images = []
trueclassified_labels = []
trueclassified_predictions = []

model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)  
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item() 
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())
        for i in range(len(labels)):
            if labels[i] == predicted[i]:
                misclassified_images.append(inputs[i].cpu())
                misclassified_labels.append(labels[i].cpu().item())
                misclassified_predictions.append(predicted[i].cpu().item())
            else:
                trueclassified_images.append(inputs[i].cpu())
                trueclassified_labels.append(labels[i].cpu().item())
                trueclassified_predictions.append(predicted[i].cpu().item())
test_acc = 100. * correct / total
print(f"Test Accuracy: {test_acc:.2f}%")


In [None]:
# Viewing Test Results-confusion matrix and recall

cm = confusion_matrix(all_labels, all_predictions)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=test_dataset.classes)
tp = cm[1, 1]+ cm[2, 2]
fn = cm[1, 0]+ cm[2, 0]
recall = tp / (tp + fn)*100 if (tp + fn) > 0 else 0
print(f"Recall: {recall:.2f}%")
plt.figure(figsize=(8, 8))
disp.plot(cmap=plt.cm.Blues, values_format='d')
plt.title("Confusion Matrix")
plt.show()

In [None]:
# Viewing Test Results-Grad-CAM visualization for misclassified images or trueclassified images

combined_heatmap = np.zeros((img_size, img_size))
target_layer = model.resnet.layer4[1].conv2
for img_tensor in trueclassified_images: # choose misclassified_images or trueclassified_images
    heatmap = grad_cam(model, img_tensor.unsqueeze(0).to(device), target_layer) 
    heatmap_resized = cv2.resize(heatmap, (img_size, img_size))
    combined_heatmap += heatmap_resized
combined_heatmap /= len(trueclassified_images) # choose misclassified_images or trueclassified_images
combined_heatmap = np.uint8(255 * (combined_heatmap / np.max(combined_heatmap))) 
combined_heatmap = cv2.applyColorMap(combined_heatmap, cv2.COLORMAP_JET)
img = np.ones((img_size, img_size, 3), dtype=np.uint8) * 255
superimposed_img = cv2.addWeighted((img * 255).astype(np.uint8), 0.6, combined_heatmap, 0.4, 0)
plt.figure(figsize=(10, 10))
plt.title("Combined Grad-CAM Heatmap")
plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()

In [None]:
# Viewing Test Results-Grad-CAM visualization for random misclassified&trueclassified images

# left is input image, right is Grad-CAM result
max_num = 10 # choose the most number of images to show at once
class_names = test_dataset.classes
num_images_to_show = min(max_num, len(misclassified_images))  
plt.figure(figsize=(15, 30))
for i in range(num_images_to_show):
    j = i + max_num*1  # adjust "*num" to show more images
    img = misclassified_images[j]
    img = img.permute(1, 2, 0).numpy()  
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img = np.clip(img, 0, 1) 
    heatmap = grad_cam(model, misclassified_images[j].unsqueeze(0).to(device), target_layer)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = cv2.addWeighted((img * 255).astype(np.uint8), 0.6, heatmap, 0.4, 0)

    plt.subplot(num_images_to_show, 2, 2 * i + 1)
    plt.imshow(img)
    plt.title(f"{j}:True: {class_names[misclassified_labels[j]]}, Pred: {class_names[misclassified_predictions[j]]}")
    plt.axis('off')

    plt.subplot(num_images_to_show, 2, 2 * i + 2)
    plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
    plt.title("Grad-CAM")
    plt.axis('off')

plt.tight_layout()
plt.show()