In [20]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import time
import pandas as pd


In [21]:

# 参数配置
data_dir = "./datasets"
img_size = 224
batch_size = 32
epochs = 15
device = "cuda" if torch.cuda.is_available() else "cpu"
save_dir = "resnet_exp"
os.makedirs(save_dir, exist_ok=True)

# 数据加载
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder(os.path.join(data_dir, "Train"), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, "Test"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

num_classes = len(train_dataset.classes)


In [22]:

# 模型 1：基础 ResNet18
model1 = models.resnet18(weights=None)
model1.fc = nn.Linear(model1.fc.in_features, num_classes)
model1 = model1.to(device)


# 模型 2：改良版 ResNet18（CBAM 注意力）
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.fc(self.avg_pool(x)) + self.fc(self.max_pool(x))) * x

class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, 7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat)) * x

class ResNetCBAM(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        base = models.resnet18(weights=None)
        self.stem = nn.Sequential(
            base.conv1, base.bn1, base.relu, base.maxpool
        )
        self.layer1 = base.layer1
        self.ca1 = ChannelAttention(64)
        self.sa1 = SpatialAttention()
        self.layer2 = base.layer2
        self.ca2 = ChannelAttention(128)
        self.sa2 = SpatialAttention()
        self.layer3 = base.layer3
        self.layer4 = base.layer4
        self.pool = base.avgpool
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.ca1(x)
        x = self.sa1(x)
        x = self.layer1(x)
        x = self.ca2(x)
        x = self.sa2(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

model2 = ResNetCBAM(num_classes).to(device)



In [23]:
# 训练函数

def train(model, model_name):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_acc = 0
    records = []

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        total_loss, correct, total = 0, 0, 0
        for images, labels in tqdm(train_loader, desc=f"[Train {model_name}] Epoch {epoch+1}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total
        train_loss = total_loss / total

        model.eval()
        correct, top5_correct, total = 0, 0, 0
        val_loss = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, labels).item() * images.size(0)
                total += labels.size(0)
                correct += (outputs.argmax(1) == labels).sum().item()
                top5_correct += sum([1 if label in pred.topk(5)[1] else 0 
                                     for pred, label in zip(outputs, labels)])

        val_acc = correct / total
        val_acc_top5 = top5_correct / total
        val_loss /= total

        # Save best
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(save_dir, f"best_{model_name}.pt"))

        # Logging
        elapsed = time.time() - start_time
        records.append({
            "epoch": epoch+1,
            "time": round(elapsed, 2),
            "train_loss": round(train_loss, 4),
            "train_acc": round(train_acc, 4),
            "val_loss": round(val_loss, 4),
            "val_acc_top1": round(val_acc, 4),
            "val_acc_top5": round(val_acc_top5, 4)
        })
        scheduler.step()

    df = pd.DataFrame(records)
    df.to_excel(os.path.join(save_dir, f"history_{model_name}.xlsx"), index=False)
    print(f"{model_name} Done. Best Top-1 Accuracy: {best_acc:.4f}")


In [24]:

train(model1, "resnet18_basic")


[Train resnet18_basic] Epoch 1: 100%|████████████████████████████████████████████████| 244/244 [06:52<00:00,  1.69s/it]


resnet18_basic Done. Best Top-1 Accuracy: 0.2542


In [None]:

train(model2, "resnet18_cbam")
