In [None]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image
import numpy as np
import pandas as pd
from pycocotools.coco import COCO
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import warnings
warnings.filterwarnings('ignore')
import random
import os
from sklearn.model_selection import train_test_split

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

def get_category_mapping(coco_json_path):
    """
    从COCO格式的JSON文件中获取category_id到类别名的映射
    """
    coco = COCO(coco_json_path)
    
    # 尝试从categories获取，如果没有则从annotations中提取唯一的category_id
    categories = coco.loadCats(coco.getCatIds())
    category_mapping = {cat['id']: cat['name'] for cat in categories}
    
    return category_mapping

def get_num_classes_from_coco(json_file):
    """从COCO格式的JSON文件中提取类别数量和ID映射"""
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # 收集所有类别ID
    all_category_ids = set()
    for ann in data.get('annotations', []):
        all_category_ids.add(ann['category_id'])
    
    # 创建从原始ID到连续ID的映射（0为背景，1开始为实际类别）
    sorted_ids = sorted(list(all_category_ids))
    id_mapping = {original_id: new_id + 1 for new_id, original_id in enumerate(sorted_ids)}
    
    print(f"检测到的原始类别IDs: {sorted_ids}")
    print(f"类别数量: {len(sorted_ids)}")
    print(f"ID映射: {id_mapping}")
    
    return len(sorted_ids), id_mapping

class FasterRCNNDataset(Dataset):
    def __init__(self, json_file, img_dir, transforms=None, device='cpu'):
        with open(json_file, 'r') as f:
            self.data = json.load(f)
        
        self.img_dir = img_dir
        self.transforms = transforms
        
        # 获取类别数量和ID映射
        self.num_classes, self.id_mapping = get_num_classes_from_coco(json_file)
        
        # 创建图像ID到图像信息的映射
        self.image_info = {img['id']: img for img in self.data['images']}
        
        # 创建图像ID到annotations的映射
        self.image_annotations = {}
        for ann in self.data.get('annotations', []):
            image_id = ann['image_id']
            if image_id not in self.image_annotations:
                self.image_annotations[image_id] = []
            self.image_annotations[image_id].append(ann)
        
        # 只保留有annotations的图像
        self.image_ids = list(self.image_annotations.keys())
        
        print(f"数据集包含 {len(self.image_ids)} 张有标注的图像")
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_info = self.image_info[image_id]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        
        # 获取该图像的所有annotations
        annotations = self.image_annotations[image_id]
        
        boxes = []
        labels = []
        
        for ann in annotations:
            # 获取边界框坐标 [x, y, width, height] -> [x1, y1, x2, y2]
            bbox = ann['bbox']
            x1, y1, w, h = bbox
            x2 = x1 + w
            y2 = y1 + h
            
            # 确保边界框在图像范围内
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(img_info['width'], x2)
            y2 = min(img_info['height'], y2)
            
            # 检查边界框是否有效
            if x2 > x1 and y2 > y1:
                boxes.append([x1, y1, x2, y2])
                # 映射category_id到连续的标签
                category_id = ann['category_id']
                if category_id in self.id_mapping:
                    labels.append(self.id_mapping[category_id])
                else:
                    labels.append(1)  # 默认标签
        
        # 如果没有有效的边界框，创建一个虚拟的背景框
        if len(boxes) == 0:
            boxes.append([0, 0, img_info['width'], img_info['height']])
            labels.append(0)  # 背景类
        
        # 转换为张量
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': str(image_id)
        }
        
        if self.transforms:
            image = self.transforms(image)
        
        return image, target

class TestDataset(Dataset):
    """
    测试数据集（用于预测）
    """
    def __init__(self, csv_file, img_dir, transform=None, device='cpu'):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.device = device
        
        print(f"测试数据集初始化完成，共{len(self.df)}张图像")
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # 获取图像信息
        row = self.df.iloc[idx]
        image_id = row['id']
        image_path = row['file_name']
        # 加载图像
        img_path = os.path.join(self.img_dir, image_path)
        image = Image.open(img_path).convert('RGB')
        
        # 应用变换
        if self.transform:
            image = self.transform(image)
        
        return image, image_id

def create_faster_rcnn_model(num_classes):
    """
    创建Faster R-CNN模型
    """
    # 加载预训练的Faster R-CNN模型
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    
    # 替换分类器头部
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes + 1)  # +1 for background
    
    return model

def collate_fn(batch):
    """
    自定义collate函数，用于处理不同大小的图像和目标
    """
    images, targets = zip(*batch)
    return list(images), list(targets)

def test_collate_fn(batch):
    """
    测试数据的collate函数
    """
    images, image_ids = zip(*batch)
    return list(images), list(image_ids)

def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001):
    """
    训练Faster R-CNN模型
    """
    model.to(device)
    
    # 优化器
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=lr, momentum=0.9, weight_decay=0.0005)
    
    # 学习率调度器
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    
    best_loss = float('inf')
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
        for images, targets in train_pbar:
            # 移动数据到设备
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) if hasattr(v, 'to') else v for k, v in t.items()} for t in targets]
            
            # 前向传播
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            # 反向传播
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            train_loss += losses.item()
            train_batches += 1
            
            # 更新进度条
            train_pbar.set_postfix({
                'Loss': f'{losses.item():.4f}',
                'Avg Loss': f'{train_loss/train_batches:.4f}'
            })
        
        avg_train_loss = train_loss / train_batches
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_batches = 0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
            for images, targets in val_pbar:
                # 移动数据到设备
                images = [img.to(device) for img in images]
                targets = [{k: v.to(device) if hasattr(v, 'to') else v for k, v in t.items()} for t in targets]
                
                # 前向传播
                model.train()
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                model.eval()
                val_loss += losses.item()
                val_batches += 1
                
                # 更新进度条
                val_pbar.set_postfix({
                    'Loss': f'{losses.item():.4f}',
                    'Avg Loss': f'{val_loss/val_batches:.4f}'
                })
        
        avg_val_loss = val_loss / val_batches
        
        print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # 保存最佳模型
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
                'num_classes': model.roi_heads.box_predictor.cls_score.out_features - 1
            }, 'best_faster_rcnn_model.pth')
            print(f"保存最佳模型，验证损失: {best_loss:.4f}")
        
        # 更新学习率
        lr_scheduler.step()
    
    print("\n训练完成！")
    return model

def predict_and_generate_submission(model, csv_file, img_dir, output_file='submission.csv', 
                                  device='cpu', conf_threshold=0.01):
    """
    使用训练好的模型进行预测并生成提交文件
    """
    model.eval()
    
    # 创建测试数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = TestDataset(csv_file, img_dir, transform, device=device)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=test_collate_fn)
    
    results = []
    
    with torch.no_grad():
        for images, image_ids in tqdm(dataloader, desc="预测中"):
            # 移动图像到设备
            images = [img.to(device) for img in images]
            
            # 预测
            predictions = model(images)
            
            for i, (pred, image_id) in enumerate(zip(predictions, image_ids)):
                boxes = pred['boxes'].cpu().numpy()
                scores = pred['scores'].cpu().numpy()
                labels = pred['labels'].cpu().numpy()
                
                # 过滤低置信度的预测
                valid_indices = scores >= conf_threshold
                boxes = boxes[valid_indices]
                scores = scores[valid_indices]
                labels = labels[valid_indices]
                
                # 格式化预测结果
                for box, score, label in zip(boxes, scores, labels):
                    x1, y1, x2, y2 = box
                    width = x2 - x1
                    height = y2 - y1
                    
                    results.append({
                        'image_id': image_id,
                        'category_id': int(label),
                        'bbox': [float(x1), float(y1), float(width), float(height)],
                        'score': float(score)
                    })
    
    # 保存结果
    df_results = pd.DataFrame(results)
    df_results.to_csv(output_file, index=False)
    print(f"预测结果已保存到 {output_file}")
    print(f"总共生成了 {len(results)} 个预测结果")
    
    return results

def predict_from_test_csv_to_json(model, csv_file, img_dir, output_file='submission.json', 
                                  device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
                                  conf_threshold=0.01):
    """
    从test.csv预测并生成JSON格式提交文件
    """
    model.eval()
    
    # 创建数据集（不使用数据增强）
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = TestDataset(csv_file, img_dir, transform, device=device)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=test_collate_fn)
    
    results = []
    
    with torch.no_grad():
        for images, image_ids in tqdm(dataloader, desc="预测中"):
            # 移动图像到设备
            images = [img.to(device) for img in images]
            
            # 预测
            predictions = model(images)
            
            for i, (pred, image_id) in enumerate(zip(predictions, image_ids)):
                boxes = pred['boxes'].cpu().numpy()
                scores = pred['scores'].cpu().numpy()
                labels = pred['labels'].cpu().numpy()
                
                # 过滤低置信度的预测
                valid_indices = scores >= conf_threshold
                boxes = boxes[valid_indices]
                scores = scores[valid_indices]
                labels = labels[valid_indices]
                
                # 格式化预测结果
                for box, score, label in zip(boxes, scores, labels):
                    x1, y1, x2, y2 = box
                    width = x2 - x1
                    height = y2 - y1
                    
                    results.append({
                        'image_id': image_id,
                        'category_id': int(label),
                        'bbox': [float(x1), float(y1), float(width), float(height)],
                        'score': float(score)
                    })
    
    # 保存为JSON格式
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"预测结果已保存到 {output_file}")
    print(f"总共生成了 {len(results)} 个预测结果")
    
    return results

def visualize_predictions(model, json_file, img_dir, num_images=5, device='cpu', conf_threshold=0.03):
    """
    可视化预测结果
    """
    model.eval()
    
    # 创建数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = FasterRCNNDataset(json_file, img_dir, transform, device=device)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
    
    # 反归一化函数
    def denormalize(tensor):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return tensor * std + mean
    
    count = 0
    with torch.no_grad():
        for images, targets in dataloader:
            if count >= num_images:
                break
            
            # 移动到设备
            images = [img.to(device) for img in images]
            
            # 预测
            predictions = model(images)
            
            # 可视化第一张图像
            img = images[0].cpu()
            img = denormalize(img)
            img = torch.clamp(img, 0, 1)
            
            target = targets[0]
            pred = predictions[0]
            
            # 创建图像
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
            
            # 显示原图和真实标注
            ax1.imshow(img.permute(1, 2, 0))
            ax1.set_title('Ground Truth')
            
            # 绘制真实边界框
            gt_boxes = target['boxes'].cpu().numpy()
            gt_labels = target['labels'].cpu().numpy()
            
            for box, label in zip(gt_boxes, gt_labels):
                x1, y1, x2, y2 = box
                rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                       linewidth=2, edgecolor='red', facecolor='none')
                ax1.add_patch(rect)
                ax1.text(x1, y1-5, f'GT: {label}', color='red', fontsize=10, weight='bold')
            
            # 显示预测结果
            ax2.imshow(img.permute(1, 2, 0))
            ax2.set_title('Predictions')
            
            # 绘制预测边界框
            pred_boxes = pred['boxes'].cpu().numpy()
            pred_scores = pred['scores'].cpu().numpy()
            pred_labels = pred['labels'].cpu().numpy()
            
            # 过滤低置信度预测
            valid_indices = pred_scores >= conf_threshold
            pred_boxes = pred_boxes[valid_indices]
            pred_scores = pred_scores[valid_indices]
            pred_labels = pred_labels[valid_indices]
            
            for box, score, label in zip(pred_boxes, pred_scores, pred_labels):
                x1, y1, x2, y2 = box
                rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                       linewidth=2, edgecolor='blue', facecolor='none')
                ax2.add_patch(rect)
                ax2.text(x1, y1-5, f'Pred: {label} ({score:.2f})', 
                        color='blue', fontsize=10, weight='bold')
            
            ax1.axis('off')
            ax2.axis('off')
            plt.tight_layout()
            plt.show()
            
            count += 1

def split_train_data(json_file, train_ratio=0.8, random_seed=42):
    """
    从训练数据中分割出训练集和验证集
    """
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # 设置随机种子确保可重复性
    random.seed(random_seed)
    
    # 获取所有图像
    images = data['images']
    annotations = data['annotations']
    
    # 按图像ID分组标注
    image_to_anns = {}
    for ann in annotations:
        img_id = ann['image_id']
        if img_id not in image_to_anns:
            image_to_anns[img_id] = []
        image_to_anns[img_id].append(ann)
    
    # 分割图像
    train_images, val_images = train_test_split(
        images, train_size=train_ratio, random_state=random_seed
    )
    
    # 获取对应的标注
    train_image_ids = {img['id'] for img in train_images}
    val_image_ids = {img['id'] for img in val_images}
    
    train_annotations = []
    val_annotations = []
    
    for ann in annotations:
        if ann['image_id'] in train_image_ids:
            train_annotations.append(ann)
        elif ann['image_id'] in val_image_ids:
            val_annotations.append(ann)
    
    # 创建训练集数据
    train_data = {
        'info': data['info'],
        'licenses': data.get('licenses', []),
        'categories': data['categories'],
        'images': train_images,
        'annotations': train_annotations
    }
    
    # 创建验证集数据
    val_data = {
        'info': data['info'],
        'licenses': data.get('licenses', []),
        'categories': data['categories'],
        'images': val_images,
        'annotations': val_annotations
    }
    
    return train_data, val_data

def run_training(train_json='train.json', val_json=None, img_dir='images',
                num_epochs=10, batch_size=4, lr=0.001, train_ratio=0.8):
    """
    运行训练流程
    """
    print("开始训练流程...")
    
    # 如果没有提供验证集，从训练集中分割
    if val_json is None or not os.path.exists(val_json):
        print(f"未找到验证集文件，从训练集中分割数据（训练集比例: {train_ratio}）")
        
        # 分割数据
        train_data, val_data = split_train_data(train_json, train_ratio)
        
        # 保存临时文件
        temp_train_json = 'temp_train.json'
        temp_val_json = 'temp_val.json'
        
        with open(temp_train_json, 'w') as f:
            json.dump(train_data, f)
        with open(temp_val_json, 'w') as f:
            json.dump(val_data, f)
        
        train_json_file = temp_train_json
        val_json_file = temp_val_json
        
        print(f"数据分割完成：")
        print(f"  训练集: {len(train_data['images'])} 张图像, {len(train_data['annotations'])} 个标注")
        print(f"  验证集: {len(val_data['images'])} 张图像, {len(val_data['annotations'])} 个标注")
    else:
        train_json_file = train_json
        val_json_file = val_json
    
    # 数据变换
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 创建数据集
    train_dataset = FasterRCNNDataset(train_json_file, img_dir, train_transform, device=device)
    val_dataset = FasterRCNNDataset(val_json_file, img_dir, val_transform, device=device)
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             collate_fn=collate_fn, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                           collate_fn=collate_fn, num_workers=2)
    
    # 创建模型
    num_classes = train_dataset.num_classes
    model = create_faster_rcnn_model(num_classes)
    
    print(f"模型创建完成，类别数: {num_classes}")
    
    # 训练模型
    trained_model = train_model(model, train_loader, val_loader, num_epochs, lr)
    
    return trained_model

def run_prediction(model_path='best_faster_rcnn_model.pth', test_csv='test.csv', 
                  img_dir='images', output_file='submission.json'):
    """
    运行预测流程
    """
    print("开始预测流程...")
    
    # 加载模型
    checkpoint = torch.load(model_path, map_location=device)
    num_classes = checkpoint['num_classes']
    
    model = create_faster_rcnn_model(num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    print(f"模型加载完成，类别数: {num_classes}")
    
    # 进行预测
    results = predict_from_test_csv_to_json(
        model=model,
        csv_file=test_csv,
        img_dir=img_dir,
        output_file=output_file,
        device=device,
        conf_threshold=0.01
    )
    
    return results

def run_visualization(model_path='best_faster_rcnn_model.pth', json_file='train.json', 
                     img_dir='images', conf_threshold=0.9):
    """
    运行可视化
    """
    print("开始可视化...")
    
    # 加载模型
    checkpoint = torch.load(model_path, map_location=device)
    num_classes = checkpoint['num_classes']
    
    model = create_faster_rcnn_model(num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    # 可视化
    visualize_predictions(
        model=model,
        json_file=json_file,
        img_dir=img_dir,
        num_images=5,
        device=device,
        conf_threshold=conf_threshold
    )

In [None]:
# 训练模型
trained_model = run_training(
    train_json='train.json',
    val_json='val.json', 
    img_dir='images',
    num_epochs=1,
    batch_size=1,
    lr=0.001
)

In [None]:
# 进行预测
results = run_prediction(
    model_path='best_faster_rcnn_model.pth',
    test_csv='test.csv',
    img_dir='images',
    output_file='submission.json'
)

In [None]:


def visualize_test_predictions(model_path='best_faster_rcnn_model.pth', test_csv='test.csv', jsonfile='train.json',
                              img_dir='images', num_images=5, device='cpu', conf_threshold=0.3, nms_threshold=0.5):
    """
    可视化test.csv中图片的预测结果
    
    Args:
        model_path: 模型文件路径
        test_csv: 测试CSV文件路径
        img_dir: 图片目录
        num_images: 要显示的图片数量
        device: 设备
        conf_threshold: 置信度阈值
        nms_threshold: NMS阈值
    """
    import torchvision.ops as ops
    
    print(f"开始可视化test.csv中的预测结果...")
    id2name = get_category_mapping(jsonfile)
    print(id2name)
    # 加载模型
    checkpoint = torch.load(model_path, map_location=device)
    num_classes = checkpoint['num_classes']
    
    model = create_faster_rcnn_model(num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"模型加载完成，类别数: {num_classes}")
    
    # 创建测试数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = TestDataset(test_csv, img_dir, transform, device=device)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=test_collate_fn)
    
    # 反归一化函数
    def denormalize(tensor):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return tensor * std + mean
    
    count = 0
    with torch.no_grad():
        for images, image_ids in dataloader:
            if count >= num_images:
                break
            
            # 移动到设备
            images = [img.to(device) for img in images]
            
            # 预测
            predictions = model(images)
            
            # 可视化第一张图像
            img = images[0].cpu()
            img = denormalize(img)
            img = torch.clamp(img, 0, 1)
            
            pred = predictions[0]
            image_id = image_ids[0]
            
            # 获取预测结果
            pred_boxes = pred['boxes'].cpu().numpy()
            pred_scores = pred['scores'].cpu().numpy()
            pred_labels = pred['labels'].cpu().numpy()
            
            # 过滤低置信度预测
            valid_indices = pred_scores >= conf_threshold
            pred_boxes = pred_boxes[valid_indices]
            pred_scores = pred_scores[valid_indices]
            pred_labels = pred_labels[valid_indices]
            
            # 应用NMS
            if len(pred_boxes) > 0:
                # 转换为torch tensor进行NMS
                boxes_tensor = torch.tensor(pred_boxes, dtype=torch.float32)
                scores_tensor = torch.tensor(pred_scores, dtype=torch.float32)
                
                # 应用NMS
                keep_indices = ops.nms(boxes_tensor, scores_tensor, nms_threshold)
                
                # 保留NMS后的结果
                pred_boxes = pred_boxes[keep_indices.numpy()]
                pred_scores = pred_scores[keep_indices.numpy()]
                pred_labels = pred_labels[keep_indices.numpy()]
            
            # 创建图像
            fig, ax = plt.subplots(1, 1, figsize=(12, 8))
            
            # 显示图像
            ax.imshow(img.permute(1, 2, 0))
            ax.set_title(f'Test Image: {image_id} (Found {len(pred_boxes)} objects)', fontsize=14, fontweight='bold')
            
            # 绘制预测边界框
            colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'brown']
            
            for i, (box, score, label) in enumerate(zip(pred_boxes, pred_scores, pred_labels)):
                x1, y1, x2, y2 = box
                color = colors[i % len(colors)]
                
                id2id = {87: 1, 131: 2, 318: 3, 588: 4, 1034: 5}
                
                reversed_dict = {v: id2name[k] for k, v in id2id.items()}
                # 绘制边界框
                rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                       linewidth=3, edgecolor=color, facecolor='none')
                ax.add_patch(rect)
                
                # 添加标签和置信度
                ax.text(x1, y1-10, f'Class {reversed_dict[label]}: {score:.3f}', 
                       color=color, fontsize=12, weight='bold',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
            
            # 如果没有检测到对象
            if len(pred_boxes) == 0:
                ax.text(0.5, 0.5, 'No objects detected\n(try lowering conf_threshold)', 
                       transform=ax.transAxes, ha='center', va='center',
                       fontsize=16, color='red', weight='bold',
                       bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.8))
            
            ax.axis('off')
            plt.tight_layout()
            plt.show()
            
            # 打印详细信息
            print(f"\n图像 {image_id} 的预测结果:")
            print(f"  - 检测到 {len(pred_boxes)} 个对象")
            print(f"  - 置信度阈值: {conf_threshold}")
            print(f"  - NMS阈值: {nms_threshold}")
            if len(pred_boxes) > 0:
                for i, (box, score, label) in enumerate(zip(pred_boxes, pred_scores, pred_labels)):
                    print(f"  - 对象 {i+1}: 类别 {label}, 置信度 {score:.3f}, 边界框 [{box[0]:.1f}, {box[1]:.1f}, {box[2]:.1f}, {box[3]:.1f}]")
            print("-" * 50)
            
            count += 1
    
    print(f"\n可视化完成！共显示了 {count} 张图像的预测结果。")

# 调用函数的示例
def run_test_visualization():
    """
    运行test.csv图片的可视化
    """
    visualize_test_predictions(
        model_path='best_faster_rcnn_model.pth',
        test_csv='test.csv',
        img_dir='images',
        num_images=5,  # 显示5张图片
        device=device,
        conf_threshold=0.75,  # 置信度阈值，可以调整
        nms_threshold=0.7    # NMS阈值，解决重叠框问题
    )



def visualize_test_predictions(model_path='best_faster_rcnn_model.pth', test_csv='test.csv', jsonfile='train.json',
                              img_dir='images', num_images=5, device='cpu', conf_threshold=0.3, nms_threshold=0.5):
    """
    可视化test.csv中图片的预测结果
    
    Args:
        model_path: 模型文件路径
        test_csv: 测试CSV文件路径
        img_dir: 图片目录
        num_images: 要显示的图片数量
        device: 设备
        conf_threshold: 置信度阈值
        nms_threshold: NMS阈值
    """
    import torchvision.ops as ops
    
    print(f"开始可视化test.csv中的预测结果...")
    id2name = get_category_mapping(jsonfile)
    print(id2name)
    # 加载模型
    checkpoint = torch.load(model_path, map_location=device)
    num_classes = checkpoint['num_classes']
    
    model = create_faster_rcnn_model(num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"模型加载完成，类别数: {num_classes}")
    
    # 创建测试数据集
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = TestDataset(test_csv, img_dir, transform, device=device)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=test_collate_fn)
    
    # 反归一化函数
    def denormalize(tensor):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return tensor * std + mean
    
    count = 0
    with torch.no_grad():
        for images, image_ids in dataloader:
            if count >= num_images:
                break
            
            # 移动到设备
            images = [img.to(device) for img in images]
            
            # 预测
            predictions = model(images)
            
            # 可视化第一张图像
            img = images[0].cpu()
            img = denormalize(img)
            img = torch.clamp(img, 0, 1)
            
            pred = predictions[0]
            image_id = image_ids[0]
            
            # 获取预测结果
            pred_boxes = pred['boxes'].cpu().numpy()
            pred_scores = pred['scores'].cpu().numpy()
            pred_labels = pred['labels'].cpu().numpy()
            
            # 过滤低置信度预测
            valid_indices = pred_scores >= conf_threshold
            pred_boxes = pred_boxes[valid_indices]
            pred_scores = pred_scores[valid_indices]
            pred_labels = pred_labels[valid_indices]
            
            # 应用NMS
            if len(pred_boxes) > 0:
                # 转换为torch tensor进行NMS
                boxes_tensor = torch.tensor(pred_boxes, dtype=torch.float32)
                scores_tensor = torch.tensor(pred_scores, dtype=torch.float32)
                
                # 应用NMS
                keep_indices = ops.nms(boxes_tensor, scores_tensor, nms_threshold)
                
                # 保留NMS后的结果
                pred_boxes = pred_boxes[keep_indices.numpy()]
                pred_scores = pred_scores[keep_indices.numpy()]
                pred_labels = pred_labels[keep_indices.numpy()]
            
            # 创建图像
            fig, ax = plt.subplots(1, 1, figsize=(12, 8))
            
            # 显示图像
            ax.imshow(img.permute(1, 2, 0))
            ax.set_title(f'Test Image: {image_id} (Found {len(pred_boxes)} objects)', fontsize=14, fontweight='bold')
            
            # 绘制预测边界框
            colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'brown']
            
            for i, (box, score, label) in enumerate(zip(pred_boxes, pred_scores, pred_labels)):
                x1, y1, x2, y2 = box
                color = colors[i % len(colors)]
                
                id2id = {87: 1, 131: 2, 318: 3, 588: 4, 1034: 5}
                
                reversed_dict = {v: id2name[k] for k, v in id2id.items()}
                # 绘制边界框
                rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                       linewidth=3, edgecolor=color, facecolor='none')
                ax.add_patch(rect)
                
                # 添加标签和置信度
                ax.text(x1, y1-10, f'Class {reversed_dict[label]}: {score:.3f}', 
                       color=color, fontsize=12, weight='bold',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
            
            # 如果没有检测到对象
            if len(pred_boxes) == 0:
                ax.text(0.5, 0.5, 'No objects detected\n(try lowering conf_threshold)', 
                       transform=ax.transAxes, ha='center', va='center',
                       fontsize=16, color='red', weight='bold',
                       bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.8))
            
            ax.axis('off')
            plt.tight_layout()
            plt.show()
            
            # 打印详细信息
            print(f"\n图像 {image_id} 的预测结果:")
            print(f"  - 检测到 {len(pred_boxes)} 个对象")
            print(f"  - 置信度阈值: {conf_threshold}")
            print(f"  - NMS阈值: {nms_threshold}")
            if len(pred_boxes) > 0:
                for i, (box, score, label) in enumerate(zip(pred_boxes, pred_scores, pred_labels)):
                    print(f"  - 对象 {i+1}: 类别 {label}, 置信度 {score:.3f}, 边界框 [{box[0]:.1f}, {box[1]:.1f}, {box[2]:.1f}, {box[3]:.1f}]")
            print("-" * 50)
            
            count += 1
    
    print(f"\n可视化完成！共显示了 {count} 张图像的预测结果。")

# 调用函数的示例
def run_test_visualization():
    """
    运行test.csv图片的可视化
    """
    visualize_test_predictions(
        model_path='best_faster_rcnn_model.pth',
        test_csv='test.csv',
        img_dir='images',
        num_images=5,  # 显示5张图片
        device=device,
        conf_threshold=0.75,  # 置信度阈值，可以调整
        nms_threshold=0.7    # NMS阈值，解决重叠框问题
    )


# 可视化test.csv中的图片预测效果
run_test_visualization()