Last updated: 2025-04-07

Here are some useful fucntion that can help you judge the accuracy of your model(s). This will be iteratively update. 

In [None]:
# Calculate IoU (Intersection over Union)
def calculate_iou(box1, box2):
    # Extract coordinates
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2
    
    # Calculate intersection area
    x_min = max(x1_min, x2_min)
    y_min = max(y1_min, y2_min)
    x_max = min(x1_max, x2_max)
    y_max = min(y1_max, y2_max)
    
    if x_max <= x_min or y_max <= y_min:
        return 0.0
    
    intersection = (x_max - x_min) * (y_max - y_min)
    
    # Calculate union area
    area1 = (x1_max - x1_min) * (y1_max - y1_min)
    area2 = (x2_max - x2_min) * (y2_max - y2_min)
    union = area1 + area2 - intersection
    
    # Calculate IoU
    iou = intersection / union
    
    return iou

# Calculate mAP (mean Average Precision)
def calculate_map(results, iou_threshold=0.5):
    total_ap = 0.0
    total_classes = 0
    
    # Calculate AP for each class
    for class_id in range(1, len(CHINESE_CHARS) + 1):
        ap = calculate_ap_for_class(results, class_id, iou_threshold)
        if ap is not None:
            total_ap += ap
            total_classes += 1
    
    # Calculate mAP
    if total_classes > 0:
        map_score = total_ap / total_classes
        print(f"mAP@{iou_threshold}: {map_score:.4f}")
        return map_score
    else:
        print("No classes detected for mAP calculation")
        return 0.0

# Calculate AP for a specific class
def calculate_ap_for_class(results, class_id, iou_threshold):
    all_detections = []
    all_ground_truths = []
    
    # Collect all detections and ground truths for this class
    for result in results:
        # Get detections for this class
        class_detections = []
        for box, score, label in zip(result['pred_boxes'], result['pred_scores'], result['pred_labels']):
            if label == class_id:
                class_detections.append({'box': box, 'score': score})
        
        # Get ground truths for this class
        class_ground_truths = []
        for box, label in zip(result['gt_boxes'], result['gt_labels']):
            if label == class_id:
                class_ground_truths.append({'box': box})
        
        all_detections.append(class_detections)
        all_ground_truths.append(class_ground_truths)
    
    # If no ground truths, skip this class
    total_gt = sum(len(gt) for gt in all_ground_truths)
    if total_gt == 0:
        return None
    
    # Sort all detections by score
    all_detections_flat = []
    for img_idx, detections in enumerate(all_detections):
        for detection in detections:
            all_detections_flat.append({
                'img_idx': img_idx,
                'box': detection['box'],
                'score': detection['score']
            })
    
    all_detections_flat.sort(key=lambda x: x['score'], reverse=True)
    
    # Calculate precision and recall
    tp = np.zeros(len(all_detections_flat))
    fp = np.zeros(len(all_detections_flat))
    gt_used = [np.zeros(len(gt)) for gt in all_ground_truths]
    
    for i, detection in enumerate(all_detections_flat):
        img_idx = detection['img_idx']
        box = detection['box']
        
        # Check if detection matches any ground truth
        max_iou = -1
        max_idx = -1
        
        for j, gt in enumerate(all_ground_truths[img_idx]):
            if gt_used[img_idx][j]:
                continue
            
            iou = calculate_iou(box, gt['box'])
            if iou > max_iou:
                max_iou = iou
                max_idx = j
        
        # If IoU exceeds threshold, it's a true positive
        if max_iou >= iou_threshold and max_idx >= 0:
            tp[i] = 1
            gt_used[img_idx][max_idx] = 1
        else:
            fp[i] = 1
    
    # Calculate cumulative precision and recall
    cumsum_tp = np.cumsum(tp)
    cumsum_fp = np.cumsum(fp)
    recall = cumsum_tp / total_gt
    precision = cumsum_tp / (cumsum_tp + cumsum_fp + 1e-10)
    
    # Calculate AP using 11-point interpolation
    ap = 0
    for t in np.arange(0, 1.1, 0.1):
        if np.sum(recall >= t) == 0:
            p = 0
        else:
            p = np.max(precision[recall >= t])
        ap += p / 11
    
    return ap

# Visualize predictions
def visualize_predictions(image, gt_boxes, gt_labels, pred_boxes, pred_scores, pred_labels, output_path):
    # Create figure
    fig, ax = plt.subplots(1, figsize=(15, 15))
    ax.imshow(image)
    
    # Draw ground truth boxes in green
    for box, label in zip(gt_boxes, gt_labels):
        x, y, x2, y2 = box
        width = x2 - x
        height = y2 - y
        rect = Rectangle((x, y), width, height, linewidth=2, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
        
        # Add label
        char = IDX_TO_CHAR.get(label, "unknown")
        plt.text(x, y-5, char, color='green', fontsize=12,
                bbox=dict(facecolor='white', alpha=0.7))
    
    # Draw predicted boxes in red
    for box, score, label in zip(pred_boxes, pred_scores, pred_labels):
        x, y, x2, y2 = box
        width = x2 - x
        height = y2 - y
        rect = Rectangle((x, y), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        
        # Add label and score
        char = IDX_TO_CHAR.get(label, "unknown")
        plt.text(x, y2+15, f"{char} ({score:.2f})", color='red', fontsize=12,
                bbox=dict(facecolor='white', alpha=0.7))
    
    # Set title
    plt.title(f"Ground Truth (Green) vs Predictions (Red)")
    
    # Remove axes
    plt.axis('off')
    
    # Save figure
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    
    print(f"Visualization saved to {output_path}")

# Visualize results for a dataset
def visualize_results(results, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    for result in tqdm(results, desc="Visualizing"):
        # Load image
        img_path = os.path.join(IMAGE_DIR, result['filename'])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Visualize
        output_path = os.path.join(output_dir, f"pred_{os.path.splitext(result['filename'])[0]}.png")
        visualize_predictions(
            img, 
            result['gt_boxes'], 
            result['gt_labels'], 
            result['pred_boxes'], 
            result['pred_scores'], 
            result['pred_labels'], 
            output_path
        )