In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, transforms
from datasets import load_dataset
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from config import CACHE_DIR

In [None]:
dataset = load_dataset('ethz/food101', cache_dir=CACHE_DIR)

# 定义图像转换（适用于 ResNet）
transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),  # 确保所有图像都是 RGB
    transforms.Resize((224, 224)),  # ResNet 需要 224x224 输入
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 预处理数据集
def apply_transform(example):
    example['image'] = transform(example['image'])  # 应用转换
    return example

train_data = dataset['train'].map(apply_transform, batched=False)
test_data = dataset['validation'].map(apply_transform, batched=False)

# 创建 PyTorch DataLoader
train_loader = DataLoader(train_data.with_format("torch"), batch_size=32, shuffle=True)
test_loader = DataLoader(test_data.with_format("torch"), batch_size=32, shuffle=False)

model = models.resnet18(pretrained=True)  # 下载预训练 ResNet18

# 修改最后的全连接层，使其适应 Food-101（101 类）
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 101)

# 移动到 GPU/CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用 Adam 优化器

In [None]:
def train_model(epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        correct, total = 0, 0
        for batch in train_loader:
            inputs, labels = batch['image'].to(device), batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        avg_loss = total_loss / len(train_loader)
        accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

train_model(5)  # 训练 5 轮

def evaluate_model():
    model.eval()
    predictions, truths = [], []
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch['image'].to(device), batch['label'].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().tolist())
            truths.extend(labels.cpu().tolist())

    print(classification_report(truths, predictions))

evaluate_model()

def plot_confusion_matrix():
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch['image'].to(device), batch['label'].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(12, 12))
    sns.heatmap(cm, annot=False, cmap="Blues")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix")
    plt.show()

plot_confusion_matrix()

In [None]:
import os
import torch
from config import OUTPUT_DIR

# 定义保存路径
save_path = os.path.join(OUTPUT_DIR, 'model_weights.pth')

# 检查目录是否存在，不存在则创建
if not os.path.exists(os.path.dirname(save_path)):
    os.makedirs(os.path.dirname(save_path))

# 保存模型参数
torch.save(model.state_dict(), save_path)

print(f"Model parameters have been saved to {save_path}")