# 基于ResNet152微调实现胸部X-Ray图像肺炎分类

In [None]:
import os
import time
import kagglehub
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from IPython.display import clear_output

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 一、下载数据集

In [None]:
# enter kaggle username and token
kagglehub.login()

In [None]:
# download dataset to .cache
download_path = kagglehub.dataset_download("jtiptj/chest-xray-pneumoniacovid19tuberculosis")
print(download_path)
# then copy files to data/

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Dataset init.
        :param data_dir: Dataset root directory ('train''test''val')
        :param transform: Data preprocessing
        """
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ['COVID19', 'NORMAL', 'PNEUMONIA', 'TURBERCULOSIS']
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        for cls in self.classes:
            cls_dir = os.path.join(self.data_dir, cls)
            if not os.path.exists(cls_dir):
                continue
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                if img_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                    samples.append((img_path, self.class_to_idx[cls]))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

## 二、数据预处理

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
train_dataset = ChestXrayDataset(
    data_dir='data/chest-xray-pneumoniacovid19tuberculosis/train', transform=transform)
val_dataset = ChestXrayDataset(
    data_dir='data/chest-xray-pneumoniacovid19tuberculosis/val', transform=transform)
test_dataset = ChestXrayDataset(
    data_dir='data/chest-xray-pneumoniacovid19tuberculosis/test', transform=transform)

In [None]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, drop_last=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                         shuffle=False, drop_last=False, num_workers=4)

In [None]:
for images, labels in train_loader:
    print(f"Batch images shape: {images.shape}")  # 打印图像张量形状
    print(f"Batch labels: {labels}")  # 打印标签
    break

## 三、定义微调模型

In [None]:
# fine tuning the resnet34
def get_net(device):
    finetune_net = nn.Sequential()
    finetune_net.features = torchvision.models.resnet152(
        weights=torchvision.models.ResNet152_Weights.DEFAULT,
    )
    finetune_net.output_new = nn.Sequential(
        nn.Linear(1000, 256),
        nn.ReLU(),
        nn.Linear(256, 4),
    )
    finetune_net = finetune_net.to(device)
    
    for param in finetune_net.features.parameters():
        param.requires_grad = False
    return finetune_net

model = get_net(device)

## 四、模型训练

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
def plot_training_history(history, epoch, epochs):
    """
    动态绘制训练和验证的损失与准确率曲线
    :param history: 包含训练和验证指标的字典
    :param epoch: 当前训练的轮次
    :param epochs: 总训练轮次
    """
    clear_output(wait=True)
    plt.figure(figsize=(12, 4))
    
    # 绘制损失曲线
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Epoch {epoch+1}/{epochs}')
    plt.legend()
    
    # 绘制准确率曲线
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'Epoch {epoch+1}/{epochs}')
    plt.legend()
    
    plt.tight_layout()
    # plt.savefig(f'assets/epoch_{epoch+1}.png')
    plt.show()

In [None]:
def train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epochs=10):
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'assets/model/best_model.pth')
        
        plot_training_history(history, epoch, epochs)
        
        print(f'Epoch {epoch+1}/{epochs}')
        print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}\n')
        time.sleep(0.1)
    
    print(f'Best Val Acc: {best_acc:.4f}')
    return history


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    (param for param in model.parameters() if param.requires_grad),  # 仅优化全连接层
    lr=1e-3,
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

history = train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    scheduler,
    device,
    epochs=20
)

## 五、预测与评估

In [None]:
def test_model(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    test_loss = running_loss / total
    test_acc = correct / total
    
    return test_loss, test_acc, all_preds, all_labels

In [None]:
# loading best model
model.load_state_dict(torch.load('assets/model/resnet152_finetune_epoch17.pth', weights_only=True))
model.eval()

# evaluate
test_loss, test_acc, all_preds, all_labels = test_model(model, test_loader, criterion, device)
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_acc:.4f}')

In [None]:
# 反归一化函数
def unnormalize(img, mean, std):
    img = img.clone().cpu().numpy().transpose((1, 2, 0))
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img


# 归一化参数
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# 随机采样num_images个测试样本
num_images = 6
indices = np.random.choice(len(test_dataset), num_images, replace=False)
samples = [test_dataset[i] for i in indices]
images, labels = zip(*samples)
images = torch.stack(images)  # shape: [num_images, C, H, W]
labels = torch.tensor(labels)

# 模型预测
model.eval()
with torch.no_grad():
    outputs = model(images.to(device))
    _, preds = torch.max(outputs, 1)
target_names = ['COVID19', 'NORMAL', 'PNEUMONIA', 'TUBERCULOSIS']

# 绘制预测结果
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()
for i in range(num_images):
    ax = axes[i]
    # 对每张图像进行反归一化处理
    img = unnormalize(images[i], mean, std)
    ax.imshow(img)
    ax.set_title(f"Pred: {target_names[preds[i].item()]}\nTrue: {target_names[labels[i].item()]}")
    ax.tick_params(axis='both', which='both', labelsize=10)
plt.tight_layout()
plt.show()

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

In [None]:
# 计算混淆矩阵
cm = confusion_matrix(all_labels, all_preds)

# 可视化混淆矩阵
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['COVID19', 'NORMAL', 'PNEUMONIA', 'TUBERCULOSIS'], 
            yticklabels=['COVID19', 'NORMAL', 'PNEUMONIA', 'TUBERCULOSIS'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# 打印分类报告
print(classification_report(
    all_labels,
    all_preds,
    target_names=['COVID19', 'NORMAL', 'PNEUMONIA', 'TUBERCULOSIS']
))