In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import resnet18
from torch.nn import MSELoss
from torch.optim import Adam
from PIL import Image
from tqdm import tqdm

# 数据集路径
train_img_dir = r"/kaggle/input/pet-age-raw/dataset5/updated_trainset"
train_annotations = r"/kaggle/input/pet-age-raw/dataset5/annotations/updated_train.txt"
val_img_dir = r"/kaggle/input/pet-age-raw/dataset5/updated_valset"
val_annotations = r"/kaggle/input/pet-age-raw/dataset5/annotations/updated_val.txt"
test_img_dir = r"/kaggle/input/pet-age-raw/dataset5/updated_testset"
output_file = r"/kaggle/working/pred_result.txt"

# 数据加载类
class PetDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = []
        with open(annotations_file, "r") as f:
            for line in f:
                img_name, age = line.strip().split("\t")
                self.img_labels.append((img_name, int(age)))
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, age = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, age

# 模型定义
class AgePredictor(torch.nn.Module):
    def __init__(self):
        super(AgePredictor, self).__init__()
        self.model = resnet18(pretrained=True)
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, 1)  # 单输出：年龄值

    def forward(self, x):
        return self.model(x)

# 训练函数
def train_model(model, train_loader, val_loader, device, epochs=10, lr=1e-4):
    criterion = MSELoss()
    optimizer = Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        progress_bar = tqdm(total=len(train_loader) + len(val_loader), desc=f"Epoch {epoch+1}/{epochs}")

        # 训练阶段
        for images, ages in train_loader:
            images, ages = images.to(device), ages.to(device, dtype=torch.float32).unsqueeze(1)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, ages)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            progress_bar.update(1)  # 更新训练进度

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}")

        # 验证阶段
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, ages in val_loader:
                images, ages = images.to(device), ages.to(device, dtype=torch.float32).unsqueeze(1)
                outputs = model(images)
                loss = criterion(outputs, ages)
                val_loss += loss.item()
                progress_bar.update(1)  # 更新验证进度

        progress_bar.close()
        print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {val_loss/len(val_loader):.4f}")

# 推理函数
def infer_model(model, test_img_dir, output_file, device, transform):
    model.eval()
    results = []
    progress_bar = tqdm(total=len(os.listdir(test_img_dir)), desc="Inference Progress")

    for img_name in os.listdir(test_img_dir):
        img_path = os.path.join(test_img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        image = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            age_prediction = model(image).item()
        results.append(f"{img_name}\t{int(round(age_prediction))}")
        progress_bar.update(1)  # 更新推理进度

    progress_bar.close()

    with open(output_file, "w") as f:
        f.write("\n".join(results))
    print(f"Predictions saved to {output_file}")

if __name__ == "__main__":
    # 配置：直接运行训练和推理
    run_training = True  # 设置为 True 运行训练
    run_inference = True  # 设置为 True 运行推理

    # 设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 数据增强
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 模型初始化
    model = AgePredictor().to(device)

    if run_training:
        # 数据加载
        train_dataset = PetDataset(train_annotations, train_img_dir, transform)
        val_dataset = PetDataset(val_annotations, val_img_dir, transform)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=32)

        # 训练
        train_model(model, train_loader, val_loader, device, epochs=10, lr=1e-4)

        # 保存模型
        torch.save(model.state_dict(), "age_predictor.pth")
        print("Model saved as age_predictor.pth")

    if run_inference:
        # 加载模型
        model.load_state_dict(torch.load("age_predictor.pth", map_location=device))

        # 推理
        infer_model(model, test_img_dir, output_file, device, transform)
