In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd
from dataset import StoneDataset
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import os

# 1. 配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "./best_model.pth"
batch_size = 256
num_classes = 3

# 2. 定义模型
def build_model():
    weights = EfficientNet_B0_Weights.DEFAULT
    model = efficientnet_b0(weights=weights)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    return model

# 3. 加载模型
model = build_model()
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print(f"使用 {torch.cuda.device_count()} 张GPU进行 TTA 推理")

# 4. 定义 TTA 变换列表
tta_transforms = [
    transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4593, 0.4543, 0.4495), (0.1704, 0.1726, 0.1809))
    ]),
    transforms.Compose([
        transforms.RandomHorizontalFlip(p=1.0),#水平翻转
        transforms.ToTensor(),
        transforms.Normalize((0.4593, 0.4543, 0.4495), (0.1704, 0.1726, 0.1809))
    ]),
    transforms.Compose([
        transforms.RandomRotation(15),#角度旋转
        transforms.ToTensor(),
        transforms.Normalize((0.4593, 0.4543, 0.4495), (0.1704, 0.1726, 0.1809))
    ]),
    transforms.Compose([
        transforms.ColorJitter(brightness=0.3, contrast=0.3),#光照抖动
        transforms.ToTensor(),
        transforms.Normalize((0.4593, 0.4543, 0.4495), (0.1704, 0.1726, 0.1809))
    ]),
    transforms.Compose([
        transforms.RandomAffine(0, translate=(0.05, 0.05)),#随机平移
        transforms.ToTensor(),
        transforms.Normalize((0.4593, 0.4543, 0.4495), (0.1704, 0.1726, 0.1809))
    ])
]

# 5. 做 TTA 推理
from collections import defaultdict
import numpy as np

all_preds = defaultdict(list)

for idx, t in enumerate(tta_transforms):
    print(f"🔁 TTA Pass {idx + 1}/{len(tta_transforms)}")
    test_dataset = StoneDataset(root="./dataset_processed", split="test", transforms=t)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

    with torch.no_grad():
        for images, filenames in test_loader:
            images = images.to(device)
            outputs = model(images)  # logits
            probs = torch.softmax(outputs, dim=1).cpu().numpy()

            for fname, prob in zip(filenames, probs):
                all_preds[fname].append(prob)

# 6. 平均所有预测并生成提交
final_preds = {}
for fname, prob_list in all_preds.items():
    avg_prob = np.mean(prob_list, axis=0)
    pred_label = int(np.argmax(avg_prob))
    final_preds[fname] = pred_label

# 7. 保存为 submission.csv
submission = pd.DataFrame(list(final_preds.items()), columns=["id", "label"])
submission = submission.sort_values(by="id")
submission.to_csv("submission_tta.csv", index=False)
print("✅ 已保存 TTA 推理版本 submission_tta.csv，可提交 Kaggle")

  model.load_state_dict(torch.load(model_path, map_location=device))


🔁 TTA Pass 1/5




🔁 TTA Pass 2/5




🔁 TTA Pass 3/5




🔁 TTA Pass 4/5




🔁 TTA Pass 5/5




✅ 已保存 TTA 推理版本 submission_tta.csv，可提交 Kaggle
