In [None]:
import os
import tensorflow as tf
from sklearn.model_selection import KFold
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
import random
from torchvision import datasets, transforms
from sklearn.model_selection import KFold
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import functional as F
import timm
import time
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import cv2

The code blocks are organized in the following sequence: 

Data Preprocessing → Model and Function Definition → Model Training → Model Testing → Viewing Test Results. 

This code uses EfficientNet_b0 as pretrained models and includes implementations of custom feature extractors, Gabor filters, and Grad-CAM visualization modules. The implementation of different functionalities can be seen in the code comments.


In [None]:
# Data Preprocessing
file_dir = r'your-dataset-path-directory'  # e.g., '/home/user/dataset' 
train_dir = os.path.join(file_dir, 'your-trainsubset-name')
test_dir = os.path.join(file_dir, 'your-testsubset-name')
"""
The dataset is stored in the following format:

file_dir/
├── train_dir/
│    ├── Normal/
│    │     ├── image1.jpg
│    │     ├── image2.jpg
│    │     └── ...
│    ├── Pneumonia-Bacterial/
│    │     ├── image1.jpg
│    │     ├── image2.jpg
│    │     └── ...
│    └── Pneumonia-Viral/
│            ├── image1.jpg
│            ├── image2.jpg
│            └── ...
│
└── test_dir/
    ├── Normal/
    │     ├── image1.jpg
    │     └── ...
    ├── Pneumonia-Bacterial/
    │     ├── image1.jpg
    │     ├── image2.jpg
    │     └── ...
    └── Pneumonia-Viral/
           ├── image1.jpg
           ├── image2.jpg
           └── ...           
"""

img_size = 512
batch_size = 16

transform_train = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Model and Function Definition(①Default, ②CNN Block, ③gabor filters+ gabor_conv, ④gabor filters + CNN Block)
# each model will be commented, and you can choose one of them to run
# make sure to comment out other models once you choose one to run

# model definition
# ① Default EfficientNet-B0
def create_model():
    class CustomEfficientNet(nn.Module):
        def __init__(self, num_classes=3):
            super(CustomEfficientNet, self).__init__()
            self.efficientnet = timm.create_model('efficientnet_b0', pretrained=True)
            in_features = self.efficientnet.classifier.in_features
            self.efficientnet.classifier = nn.Sequential(
                nn.Linear(in_features, in_features // 2),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(in_features // 2, num_classes)
            )
        def forward(self, x):
            x = self.efficientnet(x)  
            return x
    model = CustomEfficientNet(num_classes=3)
    return model.to(device)

# ② CNN Block
# def create_model():
#     class CustomEfficientNet(nn.Module):
#         def __init__(self, num_classes=3):
#             super(CustomEfficientNet, self).__init__()
#             self.feature_extractor = nn.Sequential(
#                 nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
#                 nn.BatchNorm2d(64),
#                 nn.ReLU(inplace=True),
#                 nn.MaxPool2d(kernel_size=2, stride=2),
#                 nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
#                 nn.BatchNorm2d(128),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(128, 3, kernel_size=1),
#                 nn.AdaptiveAvgPool2d((224, 224))  
#             )
#             self.efficientnet = timm.create_model('efficientnet_b0', pretrained=True)
#             in_features = self.efficientnet.classifier.in_features
#             self.efficientnet.classifier = nn.Sequential(
#                 nn.Linear(in_features, in_features // 2),
#                 nn.ReLU(),
#                 nn.Dropout(0.5),
#                 nn.Linear(in_features // 2, num_classes)
#             )
#         def forward(self, x):
#             x = self.feature_extractor(x)
#             x = self.efficientnet(x)  
#             return x
#     model = CustomEfficientNet(num_classes=3)
#     return model.to(device)

# ③ gabor filters + gabor_conv
# def create_model():
#     class CustomEfficientNet(nn.Module):
#         def __init__(self, num_classes=3):
#             super(CustomEfficientNet, self).__init__()
#             self.gabor_conv = nn.Conv2d(8, 1, kernel_size=3, padding=1) 
#             self.efficientnet = timm.create_model('efficientnet_b0', pretrained=True)
#             self.efficientnet.conv_stem = nn.Conv2d(
#                 4, 32, kernel_size=3, stride=2, padding=1, bias=False
#             )
#             in_features = self.efficientnet.classifier.in_features
#             self.efficientnet.classifier = nn.Sequential(
#                 nn.Linear(in_features, in_features // 2),
#                 nn.ReLU(),
#                 nn.Dropout(0.5),
#                 nn.Linear(in_features // 2, num_classes)
#             )
#         def forward(self, x, gabor_features):
#             gabor_features = self.gabor_conv(gabor_features)  
#             x = torch.cat((x, gabor_features), dim=1) 
#             x = self.efficientnet(x)
#             return x
#     model = CustomEfficientNet(num_classes=3)
#     return model.to(device)

# ④ gabor filters + CNN Block
# def create_model():
#     class CustomEfficientNet(nn.Module):
#         def __init__(self, num_classes=3):
#             super(CustomEfficientNet, self).__init__()
#              # 自定义特征提取模块
#             self.feature_extractor = nn.Sequential(
#                 nn.Conv2d(9, 64, kernel_size=3, stride=2, padding=1),  
#                 nn.BatchNorm2d(64),
#                 nn.ReLU(inplace=True),
#                 nn.MaxPool2d(kernel_size=2, stride=2), 
#                 nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  
#                 nn.BatchNorm2d(128),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(128, 3, kernel_size=1), 
#                 nn.AdaptiveAvgPool2d((224, 224))  
#             )
#             self.efficientnet = timm.create_model('efficientnet_b0', pretrained=True)
#             in_features = self.efficientnet.classifier.in_features
#             self.efficientnet.classifier = nn.Sequential(
#                 nn.Linear(in_features, in_features // 2),
#                 nn.ReLU(),
#                 nn.Dropout(0.5),
#                 nn.Linear(in_features // 2, num_classes)
#             )
#         def forward(self, x):
#             x = self.feature_extractor(x)
#             x = self.efficientnet(x)
#             return x
#     model = CustomEfficientNet(num_classes=3)
#     return model.to(device)

# 8 Gabor filters definition
def gabor_filter():
    frequency_values = [0.250, 0.177, 0.125, 0.088, 0.062, 0.044, 0.031, 0.022] 
    sigma_values = [4 * np.sqrt(2 * np.pi), 8 * np.pi, 8 * np.sqrt(2 * np.pi), 16 * np.pi,
                16 * np.sqrt(2 * np.pi), 32 * np.pi, 32 * np.sqrt(2 * np.pi), 64 * np.pi] 
    gabor_filters = []
    i = 0
    for f in frequency_values:
        wavelength = 1 / f
        gabor_kernel = cv2.getGaborKernel((21, 21), 0.15*sigma_values[i], np.pi/2, wavelength, 1/np.sqrt(2), 0, ktype=cv2.CV_32F)
        gabor_filters.append(gabor_kernel)
        i += 1
    return gabor_filters

def gabor_filter_batch(inputs, gabor_filters, batch_size=16):
    """
    Apply Gabor filters to a batch of images.

    Parameters:

    inputs (torch.Tensor): Input image tensor with shape [batch_size, 3, H, W].
    gabor_filters (list): A list containing 8 Gabor kernels.
    batch_size (int): Number of images in the batch.

    Returns:
    
    torch.Tensor: Gabor feature tensor with shape [batch_size, 8, H, W].
    or
    torch.Tensor: Gabor feature tensor with shape [batch_size, 9, H, W] (concatenated with the original gray image).
    """
    batch_size = batch_size
    gabor_features = []  

    for i in range(batch_size):
        img = inputs[i].permute(1, 2, 0).cpu().numpy() 
        img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)  
        gabor_responses = []
        for kernel in gabor_filters:
            response = cv2.filter2D(img_gray, cv2.CV_32F, kernel)  
            gabor_responses.append(response)
        gabor_responses = np.stack(gabor_responses, axis=0)  # [8, H, W] default 8 channels
        # If you want to concatenate the original gray image with Gabor responses, uncomment the following line:
        # gabor_responses = np.concatenate((np.expand_dims(img_gray, axis=0), gabor_responses), axis=0)  # [9, H, W]
        gabor_features.append(gabor_responses)

    gabor_features = np.stack(gabor_features, axis=0) 
    gabor_features = torch.tensor(gabor_features, dtype=torch.float32) 
    return gabor_features

# Grad-CAM visualization function definition
def grad_cam(model, img_tensor, target_layer):
    model.eval()
    img_tensor = img_tensor.to(device)
    features = None
    gradients = None

    def forward_hook(module, input, output):
        nonlocal features
        features = output

    def backward_hook(module, grad_in, grad_out):
        nonlocal gradients
        gradients = grad_out[0]

    handle_forward = target_layer.register_forward_hook(forward_hook)
    handle_backward = target_layer.register_backward_hook(backward_hook)

    output = model(img_tensor)
    pred_class = output.argmax(dim=1)
    loss = output[0, pred_class]
    model.zero_grad()
    loss.backward()

    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    features = features[0]
    for i in range(features.shape[0]):
        features[i, :, :] *= pooled_gradients[i]

    heatmap = features.detach().cpu().numpy().mean(axis=0)
    heatmap = np.maximum(heatmap, 0)  
    max_value = np.max(heatmap)
    if max_value > 0:
        heatmap /= max_value
    else:
        heatmap = np.zeros_like(heatmap)  

    handle_forward.remove()
    handle_backward.remove()

    return heatmap

def plot_grad_cam(img_path, model, target_layer):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (img_size, img_size))
    img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))  
    img_tensor = transform_val(img).unsqueeze(0)

    heatmap = grad_cam(model, img_tensor, target_layer)
    heatmap = cv2.resize(heatmap, img.size)  
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    if isinstance(img, Image.Image):  
        img = np.array(img)  # 转换为 NumPy 数组
    superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.title("Grad-CAM")
    plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()

In [None]:
# Model Training (different models' training codes are not the same)
# make sure to uncomment the model you want to train and comment out the others

# If you want to save the model, uncomment the following line:

# model_path_dir = f"your-model-path-directory"  # e.g., '/home/user/model'
# current_time = time.strftime("%Y%m%d_%H%M%S")
# model_name = f"resnet18_time{current_time}.pth"
# model_path = os.path.join(model_path_dir, model_name)

model = create_model()
train_dataset = datasets.ImageFolder(train_dir)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
folds = list(kf.split(range(len(train_dataset))))
criterion = nn.CrossEntropyLoss()
epochs = 4
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for fold, (train_idx, val_idx) in enumerate(folds):
    print(f"  Fold {fold+1}/5---------------------------------------------------------------------------")
    train_dataset = datasets.ImageFolder(train_dir)
    train_subset = Subset(train_dataset, train_idx)
    val_subset = Subset(train_dataset, val_idx)
    train_subset.dataset.transform = transform_train
    val_subset.dataset.transform = transform_val
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, drop_last=True)
    for epoch in range(epochs):  
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        # ① Default ResNet18 & ② CNN Block 
        for inputs, labels in train_loader:  
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

        # ③ Gabor filters + gabor_conv
        # for inputs, labels in train_loader:
        #     gabor_features = gabor_filter_batch(inputs, gabor_filter()) # Here gabor features tensor with shape [batch_size, 8, H, W]  
        #     inputs, gabor_features, labels = inputs.to(device).float(), gabor_features.to(device).float(), labels.to(device)  
        #     optimizer.zero_grad()
        #     outputs = model(inputs, gabor_features)
        
        # ④ Gabor filters + CNN Block
        # for inputs, labels in train_loader:
        #     gabor_features = gabor_filter_batch(inputs, gabor_filter()) # Here gabor features tensor with shape [batch_size, 9, H, W]  
        #     inputs, labels = gabor_features.to(device).float(), labels.to(device) 
        #     optimizer.zero_grad()
        #     outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        train_acc = train_correct / train_total

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            # ① Default ResNet18 & ② CNN Block 
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device) 
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

            # # ③ Gabor filters + gabor_conv
            # for inputs, labels in val_loader:
            #     gabor_features = gabor_filter_batch(inputs, gabor_filter()) # Here gabor features tensor with shape [batch_size, 8, H, W]  
            #     inputs, gabor_features, labels = inputs.to(device).float(), gabor_features.to(device).float(), labels.to(device)
            #     outputs = model(inputs, gabor_features)
            #     loss = criterion(outputs, labels)
            #     val_loss += loss.item()
                
            # ④ Gabor filters + CNN Block
            # for inputs, labels in val_loader:   
            #     gabor_features = gabor_filter_batch(inputs, gabor_filter()) # Here gabor features tensor with shape [batch_size, 9, H, W]  
            #     inputs, labels = gabor_features.to(device).float(), labels.to(device)  
            #     outputs = model(inputs)
            #     loss = criterion(outputs, labels)
            #     val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
            val_acc = val_correct / val_total
        print(f"   Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, "
            f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_correct/val_total:.4f}")

    torch.cuda.empty_cache()
    
# If you want to save the model, uncomment the following line:
# torch.save(model.state_dict(), model_path)

In [None]:
# Model Testing  (different models' testing codes are not the same)
# make sure to uncomment the model you want to test and comment out the others

# If you want to test the saved model, uncomment the following line:
# make sure the model you want to test is the same as the one you trained
# you can uncomment corresponding model definition in "Model and Function Definition" part and run it again
# model_path_dir = f"your-model-path-directory"  # e.g., '/home/user/model'
# model_name = f"your-model-name.pth" 
# model_path = os.path.join(model_path_dir, model_name)
# model = create_model()
# model.load_state_dict(torch.load(model_path))

file_dir = r'your-dataset-path-directory'  # e.g., '/home/user/dataset'
test_dir = os.path.join(file_dir, 'your-testsubset-name')
test_dataset = datasets.ImageFolder(test_dir, transform=transform_val)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

correct = 0
total = 0

all_labels = []
all_predictions = []

misclassified_images = []
misclassified_labels = []
misclassified_predictions = []

trueclassified_images = []
trueclassified_labels = []
trueclassified_predictions = []

model.eval()
with torch.no_grad():
    # ① Default ResNet18 & ② CNN Block
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)

    # ③ Gabor filters + gabor_conv
    # for inputs, labels in test_loader:
    #     gabor_features = gabor_filter_batch(inputs, gabor_filter())
    #     inputs, gabor_features, labels = inputs.to(device).float(), gabor_features.to(device).float(), labels.to(device)
    #     outputs = model(inputs, gabor_features)

    # ④ Gabor filters + CNN Block
    # for inputs, labels in test_loader:
    #     gabor_features = gabor_filter_batch(inputs, gabor_filter())
    #     inputs, labels = gabor_features.to(device).float(), labels.to(device)
    #     outputs = model(inputs)
        
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item() 
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())
        for i in range(len(labels)):
            if labels[i] == predicted[i]:
                misclassified_images.append(inputs[i].cpu())
                misclassified_labels.append(labels[i].cpu().item())
                misclassified_predictions.append(predicted[i].cpu().item())
            else:
                trueclassified_images.append(inputs[i].cpu())
                trueclassified_labels.append(labels[i].cpu().item())
                trueclassified_predictions.append(predicted[i].cpu().item())
test_acc = 100. * correct / total
print(f"Test Accuracy: {test_acc:.2f}%")

In [None]:
# Viewing Test Results-confusion matrix and recall

cm = confusion_matrix(all_labels, all_predictions)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=test_dataset.classes)
tp = cm[1, 1]+ cm[2, 2]
fn = cm[1, 0]+ cm[2, 0]
recall = tp / (tp + fn)*100 if (tp + fn) > 0 else 0
print(f"Recall: {recall:.2f}%")
plt.figure(figsize=(8, 8))
disp.plot(cmap=plt.cm.Blues, values_format='d')
plt.title("Confusion Matrix")
plt.show()

In [None]:
# Viewing Test Results-Grad-CAM visualization for misclassified images or trueclassified images

combined_heatmap = np.zeros((img_size, img_size))
target_layer = model.resnet.layer4[1].conv2
for img_tensor in trueclassified_images: # choose misclassified_images or trueclassified_images
    heatmap = grad_cam(model, img_tensor.unsqueeze(0).to(device), target_layer) 
    heatmap_resized = cv2.resize(heatmap, (img_size, img_size))
    combined_heatmap += heatmap_resized
combined_heatmap /= len(trueclassified_images) # choose misclassified_images or trueclassified_images
combined_heatmap = np.uint8(255 * (combined_heatmap / np.max(combined_heatmap))) 
combined_heatmap = cv2.applyColorMap(combined_heatmap, cv2.COLORMAP_JET)
img = np.ones((img_size, img_size, 3), dtype=np.uint8) * 255
superimposed_img = cv2.addWeighted((img * 255).astype(np.uint8), 0.6, combined_heatmap, 0.4, 0)
plt.figure(figsize=(10, 10))
plt.title("Combined Grad-CAM Heatmap")
plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()

In [None]:
# Viewing Test Results-Grad-CAM visualization for random misclassified&trueclassified images

# left is input image, right is Grad-CAM result
max_num = 10 # choose the most number of images to show at once
class_names = test_dataset.classes
num_images_to_show = min(max_num, len(misclassified_images))  
plt.figure(figsize=(15, 30))
for i in range(num_images_to_show):
    j = i + max_num*1  # adjust "*num" to show more images
    img = misclassified_images[j]
    img = img.permute(1, 2, 0).numpy()  
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img = np.clip(img, 0, 1) 
    heatmap = grad_cam(model, misclassified_images[j].unsqueeze(0).to(device), target_layer)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = cv2.addWeighted((img * 255).astype(np.uint8), 0.6, heatmap, 0.4, 0)

    plt.subplot(num_images_to_show, 2, 2 * i + 1)
    plt.imshow(img)
    plt.title(f"{j}:True: {class_names[misclassified_labels[j]]}, Pred: {class_names[misclassified_predictions[j]]}")
    plt.axis('off')

    plt.subplot(num_images_to_show, 2, 2 * i + 2)
    plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
    plt.title("Grad-CAM")
    plt.axis('off')

plt.tight_layout()
plt.show()