Resnet-18 deep learning based on three classes of senescence

In [None]:
# -*- coding: utf-8 -*-
import os, glob, random, cv2, numpy as np, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import classification_report, confusion_matrix
import tifffile, matplotlib.pyplot as plt, seaborn as sns
from tqdm import tqdm
from torchvision.models import resnet18, ResNet18_Weights

# ----------------------------- Configuration -----------------------------
data_dirs = {
    "Healthy":   r"E:\0. Research\1. Paper_data\2. Hacat cisplatin ROS\数据集\原始数据\downsample-healthy",
    "Cisplatin": r"E:\0. Research\1. Paper_data\2. Hacat cisplatin ROS\数据集\原始数据\Cisplatin",
    "H2O2":      r"E:\0. Research\1. Paper_data\2. Hacat cisplatin ROS\数据集\原始数据\H2O2"
}
batch_size   = 32
resize_size  = 128
num_epochs   = 150
num_seeds    = 3
learning_rate= 1e-3
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir   = r"E:\2. Senescence\CNN模型训练结果"
os.makedirs(output_dir, exist_ok=True)
print("Device:", device, torch.cuda.get_device_name(0) if device.type=='cuda' else "")

# ----------------------------- Dataset -----------------------------
class PSData(Dataset):
    def __init__(self, data_dirs, resize=128, stats=None):
        self.samples, self.labels = [], []
        self.resize = resize
        self.stats  = stats
        label_map = {k: i for i, k in enumerate(data_dirs)}
        for name, folder in data_dirs.items():
            for p in sorted(glob.glob(os.path.join(folder, "*-P*.tif*"))):
                s = p.replace("-P", "-S")
                if os.path.exists(s):
                    self.samples.append((p, s))
                    self.labels.append(label_map[name])

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        p, s = self.samples[idx]
        p_img = cv2.resize(tifffile.imread(p).astype(np.float32)/65535., (self.resize, self.resize))
        s_img = cv2.resize(tifffile.imread(s).astype(np.float32)/65535., (self.resize, self.resize))
        img   = np.stack([p_img, s_img], 0)
        if self.stats is not None:
            p_mean = self.stats['p_mean'].cpu().numpy()
            p_std  = self.stats['p_std'].cpu().numpy()
            s_mean = self.stats['s_mean'].cpu().numpy()
            s_std  = self.stats['s_std'].cpu().numpy()
            img[0] = (img[0] - p_mean) / p_std
            img[1] = (img[1] - s_mean) / s_std
        return torch.tensor(img, dtype=torch.float32), self.labels[idx]

# ----------------------------- Compute training set statistics -----------------------------
@torch.no_grad()
def compute_stats(train_loader):
    p_sum = p_sq = s_sum = s_sq = cnt = 0
    for img, _ in train_loader:
        p, s = img[:,0], img[:,1]
        p_sum += p.sum(); p_sq += (p**2).sum()
        s_sum += s.sum(); s_sq += (s**2).sum()
        cnt   += p.numel()
    p_mean, s_mean = p_sum/cnt, s_sum/cnt
    p_std  = torch.sqrt(p_sq/cnt - p_mean**2).clamp(min=1e-6)
    s_std  = torch.sqrt(s_sq/cnt - s_mean**2).clamp(min=1e-6)
    return {'p_mean':p_mean,'p_std':p_std,'s_mean':s_mean,'s_std':s_std}

# ----------------------------- Network -----------------------------
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Use 'weights' argument instead of 'pretrained' to avoid warnings
        backbone = resnet18(weights=None)
        backbone.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)
        backbone.fc = nn.Linear(backbone.fc.in_features, num_classes)
        self.net = backbone

    def forward(self, x):
        return self.net(x)

# ----------------------------- Training -----------------------------
def train_one(seed):
    torch.manual_seed(seed); random.seed(seed); np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    full_dataset = PSData(data_dirs, resize_size, stats=None)
    total_size = len(full_dataset)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size

    train_idx, val_idx, test_idx = random_split(
        range(total_size),
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(seed)
    )

    train_dataset_raw = Subset(full_dataset, train_idx)
    train_loader_raw = DataLoader(train_dataset_raw, batch_size, shuffle=False, num_workers=0)
    train_stats = compute_stats(train_loader_raw)
    torch.save(train_stats, os.path.join(output_dir, f'train_stats_seed{seed}.pt'))

    full_dataset_normalized = PSData(data_dirs, resize_size, stats=train_stats)
    train_dataset = Subset(full_dataset_normalized, train_idx)
    val_dataset = Subset(full_dataset_normalized, val_idx)
    test_dataset = Subset(full_dataset_normalized, test_idx)

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = SimpleCNN(len(data_dirs)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scaler = torch.cuda.amp.GradScaler(enabled=device.type=='cuda')

    history = {
        "train_acc": [], "val_acc": [], "test_acc": [],
        "train_loss": [], "val_loss": [], "test_loss": []
    }
    preds_final, labels_final = [], []

    for epoch in tqdm(range(num_epochs), desc=f"Seed {seed}"):
        # Training phase
        model.train()
        train_correct = 0; train_total = 0; train_loss = 0.0
        for img, lab in train_loader:
            img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)
            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=device.type=='cuda'):
                outputs = model(img)
                loss = criterion(outputs, lab)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_correct += (outputs.argmax(1) == lab).sum().item()
            train_total += lab.size(0)
            train_loss += loss.item() * lab.size(0)

        history["train_acc"].append(train_correct / train_total)
        history["train_loss"].append(train_loss / train_total)

        # Validation phase
        model.eval()
        val_correct = 0; val_total = 0; val_loss = 0.0
        with torch.no_grad():
            for img, lab in val_loader:
                img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)
                outputs = model(img)
                loss = criterion(outputs, lab)

                val_correct += (outputs.argmax(1) == lab).sum().item()
                val_total += lab.size(0)
                val_loss += loss.item() * lab.size(0)

        history["val_acc"].append(val_correct / val_total)
        history["val_loss"].append(val_loss / val_total)

        # Testing phase
        test_correct = 0; test_total = 0; test_loss = 0.0
        with torch.no_grad():
            for img, lab in test_loader:
                img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)
                outputs = model(img)
                loss = criterion(outputs, lab)

                test_correct += (outputs.argmax(1) == lab).sum().item()
                test_total += lab.size(0)
                test_loss += loss.item() * lab.size(0)

                if epoch == num_epochs - 1:
                    preds_final.extend(outputs.argmax(1).cpu().numpy())
                    labels_final.extend(lab.cpu().numpy())

        history["test_acc"].append(test_correct / test_total)
        history["test_loss"].append(test_loss / test_total)

    # Output classification report
    print(f"\n=== Classification Report (Seed {seed}) ===")
    print(classification_report(
        labels_final, preds_final,
        target_names=list(data_dirs.keys()),
        digits=4
    ))

    # Plot confusion matrix
    cm = confusion_matrix(labels_final, preds_final)
    plt.figure(figsize=(4, 3))
    sns.heatmap(
        cm, annot=True, fmt='d',
        xticklabels=list(data_dirs.keys()),
        yticklabels=list(data_dirs.keys()),
        cmap='Blues', cbar=False
    )
    plt.title(f"Confusion Matrix (Seed {seed})")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    # Fix: use correct parameter bbox_inches='tight'
    plt.savefig(os.path.join(output_dir, f"cm_seed{seed}.png"), dpi=300, bbox_inches='tight')
    plt.close()

    torch.save(model.state_dict(), os.path.join(output_dir, f"model_seed{seed}.pth"))
    torch.save(history, os.path.join(output_dir, f"history_seed{seed}.pt"))

    return history

# ----------------------------- Main process -----------------------------
histories = []
for seed in range(num_seeds):
    history = train_one(seed)
    histories.append(history)
    torch.cuda.empty_cache()

# ----------------------------- Plot averaged curves -----------------------------
def smooth(y, window=5):
    return [np.mean(y[max(0, i-window+1):i+1]) for i in range(len(y))]

# 1. Accuracy curves
keys_acc = ["train_acc", "val_acc", "test_acc"]
labels_acc = ["Train", "Validation", "Test"]
mean_acc = {k: np.mean([h[k] for h in histories], axis=0) for k in keys_acc}
std_acc = {k: np.std([h[k] for h in histories], axis=0) for k in keys_acc}

plt.figure(figsize=(10, 4))
for k, lab in zip(keys_acc, labels_acc):
    mean_smoothed = smooth(mean_acc[k])
    std = std_acc[k]
    plt.plot(mean_smoothed, label=lab, linewidth=1.5)
    plt.fill_between(
        range(len(mean_smoothed)),
        mean_smoothed - std,
        mean_smoothed + std,
        alpha=0.2
    )
plt.xlabel("Epoch", fontsize=10)
plt.ylabel("Accuracy", fontsize=10)
plt.legend(fontsize=8)
plt.title("Accuracy Curve (Mean ± Std)", fontsize=12)
plt.grid(alpha=0.3)
# Fix: use correct parameter bbox_inches='tight'
plt.savefig(os.path.join(output_dir, "accuracy_curve.png"), dpi=300, bbox_inches='tight')
plt.close()

# 2. Loss curves
keys_loss = ["train_loss", "val_loss", "test_loss"]
labels_loss = ["Train", "Validation", "Test"]
mean_loss = {k: np.mean([h[k] for h in histories], axis=0) for k in keys_loss}
std_loss = {k: np.std([h[k] for h in histories], axis=0) for k in keys_loss}

plt.figure(figsize=(10, 4))
for k, lab in zip(keys_loss, labels_loss):
    mean_smoothed = smooth(mean_loss[k])
    std = std_loss[k]
    plt.plot(mean_smoothed, label=lab, linewidth=1.5)
    plt.fill_between(
        range(len(mean_smoothed)),
        mean_smoothed - std,
        mean_smoothed + std,
        alpha=0.2
    )
plt.xlabel("Epoch", fontsize=10)
plt.ylabel("Cross-Entropy Loss", fontsize=10)
plt.legend(fontsize=8)
plt.title("Loss Curve (Mean ± Std)", fontsize=12)
plt.grid(alpha=0.3)
# Fix: use correct parameter bbox_inches='tight'
plt.savefig(os.path.join(output_dir, "loss_curve.png"), dpi=300, bbox_inches='tight')
plt.close()

# ----------------------------- Final statistics -----------------------------
final_epoch_window = 10
print("\n=== Final Average Performance (All Seeds) ===")
for phase in ["train", "val", "test"]:
    acc_list = []
    loss_list = []
    for history in histories:
        final_acc = np.mean(history[f"{phase}_acc"][-final_epoch_window:])
        final_loss = np.mean(history[f"{phase}_loss"][-final_epoch_window:])
        acc_list.append(final_acc)
        loss_list.append(final_loss)
    print(f"{phase.capitalize()} Accuracy: {np.mean(acc_list):.4f} ± {np.std(acc_list):.4f}")
    print(f"{phase.capitalize()} Loss:     {np.mean(loss_list):.4f} ± {np.std(loss_list):.4f}")

torch.save(histories, os.path.join(output_dir, "all_seeds_histories.pt"))
print(f"\nAll results saved to: {output_dir}")
