In [4]:
import io
import os
import sys
import time
import datetime
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import random

plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ 当前设备：{device}")

✅ 当前设备：cuda


## Model

### BirdCNN - 基础卷积神经网络模型

In [6]:
class BirdCNN(nn.Module):
    """
    鸟类识别CNN模型
    适配224x224输入尺寸，支持自定义类别数量
    """
    
    def __init__(self, num_classes=10):
        super(BirdCNN, self).__init__()

        # ===== 卷积层堆叠 =====
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(16)

        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 64, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.Conv2d(64, 128, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(128)

        # Dropout
        self.dropout = nn.Dropout2d(p=0.25)

        # ===== 全连接层 =====
        self.fc1 = nn.Linear(128 * 7 * 7, 512)
        self.fc1_bn = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.fc2_bn = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        # --- 卷积块1 ---
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), 2)
        # --- 卷积块2 ---
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        # --- 卷积块3 ---
        x = self.dropout(F.max_pool2d(F.relu(self.bn3(self.conv3(x))), 2))
        # --- 卷积块4 ---
        x = F.max_pool2d(F.relu(self.bn4(self.conv4(x))), 2)
        x = F.max_pool2d(x, 2)  # -> [batch, 128, 7, 7]

        # --- 展平 + 全连接 ---
        x = x.view(-1, 128 * 7 * 7)
        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = F.dropout(x, p=0.5)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

### BirdCNN_Optimal — 针对中小规模数据集优化模型

In [5]:
class BirdCNN_Optimal(nn.Module):
    """
    针对3000左右张图片的优化模型 - 避免过拟合同时增强特征提取
    """

    def __init__(self, num_classes=10):
        super(BirdCNN_Optimal, self).__init__()

        # ===== 卷积模块 =====
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv3_extra = nn.Conv2d(128, 128, 3, 1, 1)
        self.bn3_extra = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 256, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(256)

        # Dropout
        self.dropout_conv = nn.Dropout2d(p=0.3)
        self.dropout_fc = nn.Dropout(p=0.5)

        # ===== 全连接层 =====
        self.fc1 = nn.Linear(256 * 7 * 7, 512)
        self.fc1_bn = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.fc2_bn = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), 2)
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn3_extra(self.conv3_extra(x)))
        x = self.dropout_conv(x)
        x = F.max_pool2d(x, 2)

        x = F.max_pool2d(F.relu(self.bn4(self.conv4(x))), 4)
        x = x.view(-1, 256 * 7 * 7)

        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = self.dropout_fc(x)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

### BirdResNet — 残差结构CNN模型

In [44]:
class ResidualBlock(nn.Module):
    """
    残差块 - 保持输入输出尺寸相同
    """

    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        # 主路径
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 捷径路径（如果需要调整尺寸）
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.dropout = nn.Dropout2d(p=0.2)

    def forward(self, x):
        # 主路径
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.dropout(out)

        # 捷径路径
        shortcut = self.shortcut(x)

        # 残差连接
        out += shortcut
        out = F.relu(out)

        return out


class BirdResNet(nn.Module):
    """
    鸟类识别残差网络 - 结合残差连接与中等规模数据集优化
    名称含义: Bird (鸟类) + ResNet (残差网络)
    """

    def __init__(self, num_classes=10):
        super(BirdResNet, self).__init__()

        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        # 第一组卷积层（包含残差块）
        self.res_block1 = ResidualBlock(32, 64, stride=1)

        # 第二组卷积层
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)

        # 第二组卷积层（包含残差块）
        self.res_block2 = ResidualBlock(128, 128, stride=1)

        # 第三组卷积层
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(256)

        # 第三组卷积层（包含残差块）
        self.res_block3 = ResidualBlock(256, 256, stride=1)

        # 第四组卷积层（深层特征）
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(512)

        # 更强的正则化
        self.dropout_conv = nn.Dropout2d(p=0.3)
        self.dropout_fc = nn.Dropout(p=0.5)

        # 全连接层（适度减小规模）
        self.fc1 = nn.Linear(512 * 7 * 7, 512)
        self.fc1_bn = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.fc2_bn = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        # 初始卷积块
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)  # 112x112

        # 第一组卷积 + 残差块
        x = self.res_block1(x)  # 残差连接
        x = F.max_pool2d(x, 2)  # 56x56

        # 第二组卷积 + 残差块
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.res_block2(x)  # 残差连接
        x = self.dropout_conv(x)
        x = F.max_pool2d(x, 2)  # 28x28

        # 第三组卷积 + 残差块
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.res_block3(x)  # 残差连接
        x = F.max_pool2d(x, 2)  # 14x14

        # 第四组卷积
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2)  # 7x7

        # 展平
        x = x.view(-1, 512 * 7 * 7)

        # 全连接层
        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = self.dropout_fc(x)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = self.fc3(x)

        return F.log_softmax(x, dim=1)

In [19]:
class TrainingLogger:
    """训练日志记录器"""
    def __init__(self, log_path):
        self.log_path = log_path
        self.console = sys.stdout
        self.log_file = open(log_path, 'w', encoding='utf-8')

    def write(self, message):
        self.console.write(message)
        self.log_file.write(message)
        self.log_file.flush()

    def flush(self):
        self.console.flush()
        self.log_file.flush()

    def close(self):
        self.log_file.close()


def setup_logging(log_path):
    """启用日志记录"""
    logger = TrainingLogger(log_path)
    sys.stdout = logger
    return logger

In [20]:
def pil_loader_with_error_handling(path):
    """带错误处理的图片加载"""
    try:
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    except Exception as e:
        print(f"⚠️ 无法加载图片: {path}, 错误: {e}")
        return Image.new('RGB', (224, 224), color='white')


def check_and_clean_dataset(data_path):
    """检查并清理损坏图片"""
    print("正在检查数据集完整性...")
    corrupted = []
    for mode in ['train', 'test']:
        dir_path = os.path.join(data_path, mode)
        for cls in os.listdir(dir_path):
            cls_path = os.path.join(dir_path, cls)
            for f in os.listdir(cls_path):
                if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                    path = os.path.join(cls_path, f)
                    try:
                        Image.open(path).verify()
                    except:
                        corrupted.append(path)
                        os.rename(path, path + ".corrupted")
    if corrupted:
        print(f"共发现 {len(corrupted)} 张损坏图片，已标记为 .corrupted")
    else:
        print("数据集检查通过 ✅")

In [21]:
def create_data_loaders(data_path, batch_size_train=32, batch_size_test=64):
    """构建DataLoader"""
    check_and_clean_dataset(data_path)

    transform_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        torchvision.transforms.RandomHorizontalFlip(0.5),
        torchvision.transforms.RandomRotation(10),
        torchvision.transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
    ])
    transform_test = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
    ])

    train_ds = torchvision.datasets.ImageFolder(
        os.path.join(data_path, 'train'), transform=transform_train, loader=pil_loader_with_error_handling)
    test_ds = torchvision.datasets.ImageFolder(
        os.path.join(data_path, 'test'), transform=transform_test, loader=pil_loader_with_error_handling)

    print(f"📊 数据集统计: {len(train_ds)} 训练样本, {len(test_ds)} 测试样本, 类别数: {len(train_ds.classes)}")

    train_loader = torch.utils.data.DataLoader(train_ds, batch_size_train, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size_test, shuffle=False)
    return train_loader, test_loader, len(train_ds.classes), train_ds.classes

In [22]:
def train(model, device, loader, optimizer, epoch):
    model.train()
    loss_total, correct, total = 0, 0, 0
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for data, target in pbar:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        loss_total += loss.item()
        pred = output.argmax(1)
        correct += pred.eq(target).sum().item()
        total += data.size(0)
        pbar.set_postfix(loss=loss.item(), acc=f"{100*correct/total:.2f}%")

    return loss_total / len(loader), 100 * correct / total


def evaluate(model, device, loader):
    model.eval()
    correct, loss_total = 0, 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss_total += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(1)
            correct += pred.eq(target).sum().item()
    return loss_total / len(loader.dataset), 100 * correct / len(loader.dataset)

In [41]:
def pil_loader(path):
    try:
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    except Exception as e:
        print(f"⚠️ 图片读取失败 {path}: {e}")
        return Image.new('RGB', (224,224), color='white')

def get_loaders(data_dir, batch_size=64):
    transform_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
    ])
    transform_test = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
    ])

    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=transform_train, loader=pil_loader)
    test_dataset = torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=transform_test, loader=pil_loader)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, len(train_dataset.classes), train_dataset.classes

In [42]:
def train_one_epoch(model, device, loader, optimizer):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for data, target in tqdm(loader, desc="训练中"):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = output.argmax(1)
        correct += pred.eq(target).sum().item()
        total += len(data)
    return total_loss/len(loader), 100 * correct / total


def evaluate(model, device, loader):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(1)
            correct += pred.eq(target).sum().item()
            total += len(data)
    return total_loss/len(loader.dataset), 100 * correct / total

In [43]:
data_path = r".\Dataset"
train_loader, test_loader, num_classes, class_names = get_loaders(data_path)
print(f"✅ 数据集加载完成，共 {num_classes} 类")

✅ 数据集加载完成，共 10 类
