In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
from torchvision.models import efficientnet_b0, mobilenet_v3_small, MobileNet_V3_Small_Weights
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

# 数据增强
class PathoAugment:
    def __call__(self, img):
        # 添加具体增强操作
        return img

# 模型定义
class CropDiseaseNet(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()
        # 主干网络
        self.effnet = efficientnet_b0(weights='DEFAULT')  # 使用DEFAULT更通用
        self.mobilenet = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
        
        # 特征融合（修正特征维度）
        effnet_features = 1280
        mobilenet_features = 576
        self.fusion = nn.Sequential(
            nn.Linear(effnet_features + mobilenet_features, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(inplace=True),
            nn.Dropout(0.3)
        )
        
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x1 = self.effnet.features(x)
        x1 = nn.AdaptiveAvgPool2d(1)(x1).flatten(1)
        
        x2 = self.mobilenet.features(x)
        x2 = nn.AdaptiveAvgPool2d(1)(x2).flatten(1)
        
        fused = torch.cat([x1, x2], dim=1)
        return self.classifier(self.fusion(fused))

# 自定义数据集类
class CUBDataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.root = root
        self.transform = transform
        self.split = split
        
        # 读取图像路径和标签
        self.image_paths = []
        self.labels = []
        
        with open(os.path.join(root, 'images.txt')) as f:
            for line in f:
                img_id, img_path = line.strip().split()
                self.image_paths.append(os.path.join(root, 'images', img_path))
        
        with open(os.path.join(root, 'image_class_labels.txt')) as f:
            for line in f:
                img_id, label = line.strip().split()
                self.labels.append(int(label) - 1)  # 标签从0开始
        
        # 读取训练/验证分割
        self.train_indices = []
        self.val_indices = []
        
        with open(os.path.join(root, 'train_test_split.txt')) as f:
            for line in f:
                img_id, is_train = line.strip().split()
                if int(is_train) == 1:
                    self.train_indices.append(int(img_id) - 1)
                else:
                    self.val_indices.append(int(img_id) - 1)
        
        if self.split == 'train':
            self.indices = self.train_indices
        elif self.split == 'val':
            self.indices = self.val_indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        idx = self.indices[idx]
        path = self.image_paths[idx]
        label = self.labels[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

# 训练函数
def train():
    # 初始化
    model = CropDiseaseNet().cuda()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
    scaler = amp.GradScaler()
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # 数据加载（添加transform）
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])
    train_set = CUBDataset(root='E:\Datasets\CUB_200_2011', split='train', transform=transform)
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, 
                            num_workers=2, pin_memory=True)

    val_set = CUBDataset(root='E:\Datasets\CUB_200_2011', split='val', transform=transform)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=False, 
                            num_workers=2, pin_memory=True)

    # 训练循环
    for epoch in range(30):
        model.train()
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            
            with amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            
            if (i+1) % 4 == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            if i % 10 == 0:
                mem = torch.cuda.memory_allocated()/1e9
                print(f'Epoch {epoch} | Batch {i} | Loss: {loss.item():.3f} | Mem: {mem:.2f}GB')
        
        scheduler.step()

        # 验证
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f'Epoch {epoch} | Val Accuracy: {100 * correct / total:.2f}%')

if __name__ == "__main__":
    train()

  scaler = amp.GradScaler()
