In [None]:
import os
import requests
import zipfile
from tqdm import tqdm
import json

def download_file(url, destination):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte

    with open(destination, 'wb') as file, tqdm(
            desc=os.path.basename(destination),
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
        for data in response.iter_content(block_size):
            size = file.write(data)
            bar.update(size)

def download_coco_validation_set(root_dir='./coco_data'):
    os.makedirs(root_dir, exist_ok=True)

    # Check if already downloaded
    if os.path.exists(os.path.join(root_dir, 'val2017')) and \
       os.path.exists(os.path.join(root_dir, 'annotations')):
        print("COCO validation set already exists!")
        return os.path.join(root_dir, 'val2017'), os.path.join(root_dir, 'annotations/instances_val2017.json')

    
    val_images_url = "http://images.cocodataset.org/zips/val2017.zip"
    annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"

  
    print("Downloading COCO validation images (5000 images)...")
    val_zip_path = os.path.join(root_dir, "val2017.zip")
    download_file(val_images_url, val_zip_path)

    
    print("ddownloading COCO annotations...")
    ann_zip_path = os.path.join(root_dir, "annotations_trainval2017.zip")
    download_file(annotations_url, ann_zip_path)

    # Extract validation images
    print("extracting validation images...")
    with zipfile.ZipFile(val_zip_path, 'r') as zip_ref:
        zip_ref.extractall(root_dir)

    
    print("extracting annotations...")
    with zipfile.ZipFile(ann_zip_path, 'r') as zip_ref:
        zip_ref.extractall(root_dir)

    # Clean up zip files
    os.remove(val_zip_path)
    os.remove(ann_zip_path)

    print("COCO validation set downloaded successfully!")
    return os.path.join(root_dir, 'val2017'), os.path.join(root_dir, 'annotations/instances_val2017.json')

def verify_coco_dataset(images_dir, annotations_file):
    # Check if images directory exists and has files
    if not os.path.exists(images_dir):
        print(f"Error: Images directory {images_dir} does not exist")
        return False

    image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg')]
    print(f"Found {len(image_files)} images in {images_dir}")

    # Check if annotations file exists and is valid JSON
    if not os.path.exists(annotations_file):
        print(f"Error: Annotations file {annotations_file} does not exist")
        return False

    try:
        with open(annotations_file, 'r') as f:
            annotations = json.load(f)
            num_images = len(annotations['images'])
            num_annotations = len(annotations['annotations'])
            print(f"annotations file contains information for {num_images} images and {num_annotations} annotations")
    except Exception as e:
        print(f"error reading annotations file: {e}")
        return False

    print("dataset verification completed successfully!")
    return True

images_dir, annotations_file = download_coco_validation_set()

# Verify the dataset
verify_coco_dataset(images_dir, annotations_file)

print("\nDataset information:")
print(f"- Images directory: {images_dir}")
print(f"- Annotations file: {annotations_file}")
print("\nThe COCO validation set (5000 images) with annotations is ready for evaluation!")

Downloading COCO validation images (5000 images)...


val2017.zip: 100%|██████████| 778M/778M [00:18<00:00, 44.0MiB/s]


Downloading COCO annotations...


annotations_trainval2017.zip: 100%|██████████| 241M/241M [00:05<00:00, 49.2MiB/s]


Extracting validation images...
Extracting annotations...
COCO validation set downloaded successfully!
Found 5000 images in ./coco_data/val2017
Annotations file contains information for 5000 images and 36781 annotations
Dataset verification completed successfully!

Dataset information:
- Images directory: ./coco_data/val2017
- Annotations file: ./coco_data/annotations/instances_val2017.json

The COCO validation set (5000 images) with annotations is ready for evaluation!


In [None]:
import torch
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from pycocotools.coco import COCO
import os
from PIL import Image

class COCOEvalDataset(Dataset):
    
    def __init__(self, root, annFile,transform=None):
        self.root = root
        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.imgs.keys()))

        # Get categories
        self.categories = self.coco.loadCats(self.coco.getCatIds())
        self.categories.sort(key=lambda x: x['id'])
        self.category_ids = [category['id'] for category in self.categories]
        self.category_names = [category['name'] for category in self.categories]
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        # Load image
        img_id = self.ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info['file_name'])
        img = Image.open(img_path).convert('RGB')

        # Apply transformations that match the pretrained models
        transform = T.Compose([
            T.ToTensor(),
            # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(img)

        # Get annotations
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        boxes = []
        labels = []
        areas = []
        iscrowd = []

        for ann in anns:
            x, y, w, h = ann['bbox']
            # Convert to [x1, y1, x2, y2] format
            boxes.append([x, y, (x + w), (y + h)])

            # Map COCO category_id to model index
            # Handle case where category might not be in our mapping
            cat_id = ann['category_id']
            if cat_id in self.category_ids:
                labels.append(cat_id)
            else:
                # Skip annotations with unknown categories
                continue

            areas.append(ann['area'])
            iscrowd.append(ann['iscrowd'])

        # Convert to tensors
        if boxes:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            areas = torch.as_tensor(areas, dtype=torch.float32)
            iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        else:
            # Handle images with no annotations
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            areas = torch.zeros((0,), dtype=torch.float32)
            iscrowd = torch.zeros((0,), dtype=torch.int64)

        # Create target dict
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([img_id]),
            'area': areas,
            'iscrowd': iscrowd
        }

        return image, target

def create_coco_loader(images_dir, annotations_file, batch_size=1, num_workers=4):
    # Define transforms
    transform = T.Compose([
        # T.Resize((300, 300)),  # Resize to 300x300
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create dataset
    dataset = COCOEvalDataset(
        root=images_dir,
        annFile=annotations_file,
        transform=transform
    )

    # Create data loader


    print(f"Created data loader with {len(dataset)} images for evaluation")
    return dataset

images_dir = "./coco_data/val2017"
annotations_file = "./coco_data/annotations/instances_val2017.json"

# Create data loader
loader, dataset = create_coco_loader(images_dir, annotations_file)

# Print some dataset statistics
print(f"Number of images: {len(dataset)}")
print(f"Number of categories: {len(dataset.categories)}")

# Display some category names
print("\nSome COCO categories:")
for i, name in enumerate(dataset.category_names):
    print(f"  {i+1}. {name}")

print("\nDataset is ready for evaluation!")

loading annotations into memory...
Done (t=0.73s)
creating index...
index created!
Created data loader with 5000 images for evaluation
Number of images: 5000
Number of categories: 80

Some COCO categories:
  1. person
  2. bicycle
  3. car
  4. motorcycle
  5. airplane
  6. bus
  7. train
  8. truck
  9. boat
  10. traffic light
  11. fire hydrant
  12. stop sign
  13. parking meter
  14. bench
  15. bird
  16. cat
  17. dog
  18. horse
  19. sheep
  20. cow
  21. elephant
  22. bear
  23. zebra
  24. giraffe
  25. backpack
  26. umbrella
  27. handbag
  28. tie
  29. suitcase
  30. frisbee
  31. skis
  32. snowboard
  33. sports ball
  34. kite
  35. baseball bat
  36. baseball glove
  37. skateboard
  38. surfboard
  39. tennis racket
  40. bottle
  41. wine glass
  42. cup
  43. fork
  44. knife
  45. spoon
  46. bowl
  47. banana
  48. apple
  49. sandwich
  50. orange
  51. broccoli
  52. carrot
  53. hot dog
  54. pizza
  55. donut
  56. cake
  57. chair
  58. couch
  59. potted 



In [None]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights,FasterRCNN_MobileNet_V3_Large_FPN_Weights,retinanet_resnet50_fpn_v2,RetinaNet_ResNet50_FPN_V2_Weights


import torch
from tqdm import tqdm
import numpy as np
import json
from collections import defaultdict
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import json
from PIL import Image
from torchvision.ops import nms

def calculate_iou(box1, box2):
    """
    Calculate IoU between two bounding boxes
    Box format: [x1, y1, x2, y2]
    """
    # Calculate intersection area
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    # Calculate union area
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union_area = box1_area + box2_area - intersection_area

    return intersection_area / union_area if union_area > 0 else 0.0

def convert_to_xywh(box):
    return [box[0], box[1], box[2] - box[0], box[3] - box[1]]

def convert_to_xyxy(box):
    return [box[0], box[1], box[0] + box[2], box[1] + box[3]]

def compute_ap(precision, recall):
    """Compute Average Precision using 11-point interpolation"""
    ap = 0.0
    for t in np.arange(0.0, 1.1, 0.1):
        if np.sum(recall >= t) == 0:
            p = 0
        else:
            p = np.max(precision[recall >= t])
        ap += p / 11.0
    return ap

def evaluate_detections(predictions, ground_truth, iou_threshold=0.5, max_dets=100):
    # Group predictions and ground truth by image_id and category_id
    pred_by_img = defaultdict(list)
    gt_by_img = defaultdict(lambda: defaultdict(list))

    # Get all category ids
    category_ids = set()

    # Process predictions
    for pred in predictions:
        img_id = pred['image_id']
        cat_id = pred['category_id']
        pred_by_img[img_id].append(pred)
        category_ids.add(cat_id)

    # Process ground truth
    for gt in ground_truth:
        img_id = gt['image_id']
        cat_id = gt['category_id']
        gt_by_img[img_id][cat_id].append(gt)
        category_ids.add(cat_id)

    # Metrics storage
    metrics = {
        'precision': {},
        'recall': {},
        'ap': {},
        'f1_score': {}
    }

    # Calculate metrics for each category
    for cat_id in category_ids:
        true_positives = []
        false_positives = []
        scores = []
        num_gt = 0

        # Count total ground truth for this category
        for img_id in gt_by_img:
            num_gt += len(gt_by_img[img_id].get(cat_id, []))

        # Process each image
        for img_id in pred_by_img:
            # Get predictions for this image
            img_preds = [p for p in pred_by_img[img_id] if p['category_id'] == cat_id]

            # Sort predictions by score in descending order
            img_preds = sorted(img_preds, key=lambda x: x['score'], reverse=True)

            # Limit number of detections
            img_preds = img_preds[:max_dets]

            # Get ground truth for this image and category
            img_gts = gt_by_img[img_id].get(cat_id, [])

            # Mark each ground truth as matched or not
            matched_gt = [False] * len(img_gts)

            # Check each prediction
            for pred in img_preds:
                pred_bbox = convert_to_xyxy(pred['bbox'])
                pred_score = pred['score']

                best_iou = 0
                best_gt_idx = -1

                # Find best matching ground truth
                for gt_idx, gt in enumerate(img_gts):
                    if matched_gt[gt_idx]:
                        continue

                    gt_bbox = convert_to_xyxy(gt['bbox'])
                    iou = calculate_iou(pred_bbox, gt_bbox)

                    if iou > best_iou:
                        best_iou = iou
                        best_gt_idx = gt_idx

                # Check if we have a match
                if best_iou >= iou_threshold and best_gt_idx >= 0:
                    true_positives.append(1)
                    false_positives.append(0)
                    matched_gt[best_gt_idx] = True
                else:
                    true_positives.append(0)
                    false_positives.append(1)

                scores.append(pred_score)

        # Sort by score
        inds = np.argsort(scores)[::-1]
        true_positives = np.array(true_positives)[inds]
        false_positives = np.array(false_positives)[inds]

        # Compute cumulative sum
        tp_cumsum = np.cumsum(true_positives)
        fp_cumsum = np.cumsum(false_positives)

        # Compute precision and recall
        precision = tp_cumsum / (tp_cumsum + fp_cumsum + np.finfo(float).eps)
        recall = tp_cumsum / (num_gt + np.finfo(float).eps)

        # Compute AP
        ap = compute_ap(precision, recall)

        # Store metrics
        metrics['precision'][cat_id] = precision
        metrics['recall'][cat_id] = recall
        metrics['ap'][cat_id] = ap

        # Compute F1 score (optional)
        if len(precision) > 0 and len(recall) > 0:
            f1 = 2 * precision * recall / (precision + recall + np.finfo(float).eps)
            metrics['f1_score'][cat_id] = np.max(f1)
        else:
            metrics['f1_score'][cat_id] = 0.0

    # Compute mAP
    metrics['mAP'] = np.mean([metrics['ap'][cat_id] for cat_id in metrics['ap']])

    return metrics

def main():
    # Initialize model

    # model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1)
    # weights= FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1
    # weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
    # model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.5)
    model=retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
    weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1
    model.eval()


    # Initialize preprocessing
    preprocess = weights.transforms()

    # Load COCO dataset
    images_dir = "./coco_data/val2017"
    annotations_file = "./coco_data/annotations/instances_val2017.json"
    dataset = create_coco_loader(images_dir, annotations_file)

    # Run inference
    results = []
    ground_truth = []
    img,target=dataset[0]
    visualize_detections(model,dataset)

    print("Running inference...")
    with torch.no_grad():
        for i, (img, target) in enumerate(tqdm(dataset)):
            if i >= 5000:  # Limit to first 100 images for testing
                break

            image_id = target['image_id'].item()

            # Add ground truth for this image
            for obj_idx in range(len(target['boxes'])):
                gt = {
                    'image_id': image_id,
                    'category_id': target['labels'][obj_idx].item(),
                    'bbox': convert_to_xywh(target['boxes'][obj_idx].tolist()),
                    'iscrowd': 0
                }
                ground_truth.append(gt)

            # Run model
            batch = [preprocess(img.to('cuda'))]
            prediction = model(batch)[0]

            # Extract predictions
            boxes = prediction['boxes'].cpu().numpy()
            scores = prediction['scores'].cpu().numpy()
            labels = prediction['labels'].cpu().numpy()

            # Convert predictions to COCO format
            for box, score, label in zip(boxes, scores, labels):
                if score >= 0.05:  # Minimum score threshold
                    # print(score)
                    x1, y1, x2, y2 = box.tolist()
                    coco_box = [x1, y1, x2 - x1, y2 - y1]  # [x, y, width, height]

                    category_id = label.item()

                    results.append({
                        'image_id': image_id,
                        'category_id': category_id,
                        'bbox': coco_box,
                        'score': float(score)
                    })


    # Save results to file
    with open('eval_results.json', 'w') as f:
        json.dump(results, f)

    # Run custom evaluation
    print("Running custom evaluation...")
    metrics = evaluate_detections(results, ground_truth)

    # Print results
    print("\nEvaluation Results:")
    print(f"mAP @ IoU={0.5}: {metrics['mAP']:.4f}")

    # Print AP for each category
    print("\nAP by category:")
    for cat_id in sorted(metrics['ap'].keys()):
        print(f"Category {cat_id}: {metrics['ap'][cat_id]:.4f}")

    # Optional: Plot precision-recall curves
    try:
        import matplotlib.pyplot as plt

        # Plot precision-recall curve for first 5 categories
        plt.figure(figsize=(10, 8))

        for i, cat_id in enumerate(sorted(list(metrics['precision'].keys()))[:5]):
            if len(metrics['precision'][cat_id]) > 0 and len(metrics['recall'][cat_id]) > 0:
                plt.plot(
                    metrics['recall'][cat_id],
                    metrics['precision'][cat_id],
                    label=f'Category {cat_id} (AP: {metrics["ap"][cat_id]:.4f})'
                )

        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curves')
        plt.legend()
        plt.grid()
        plt.savefig('precision_recall_curves.png')
        plt.close()
        print("Saved precision-recall curves to 'precision_recall_curves.png'")
    except ImportError:
        print("Matplotlib not installed. Skipping precision-recall curve plotting.")



def visualize_detections(model, dataset, n_examples=5, confidence_threshold=0.5):
    model.eval()
    device = next(model.parameters()).device

    indices = np.random.choice(len(dataset), n_examples, replace=False)

    plt.figure(figsize=(15, n_examples * 5))

    for i, idx in enumerate(indices):
        img, target = dataset[idx]

        # Create a copy of the image for drawing
        img_np = img.permute(1, 2, 0).cpu().numpy()
        # Denormalize
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)

        # Make prediction
        with torch.no_grad():
            prediction = model([img.to(device)])[0]

        # Plot the image
        plt.subplot(n_examples, 1, i + 1)
        plt.imshow(img_np)


        for box, label in zip(target['boxes'], target['labels']):
            x, y, w, h = box[0].item(), box[1].item(), box[2].item() - box[0].item(), box[3].item() - box[1].item()
            rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='g', facecolor='none')
            plt.gca().add_patch(rect)
            category_id = label.item()
            category_name = next(cat['name'] for cat in dataset.categories if cat['id'] == category_id)
            plt.text(x, y-5, category_name, color='g', fontsize=10, backgroundcolor='w')


        for box, label, score in zip(prediction['boxes'], prediction['labels'], prediction['scores']):
            if score >= confidence_threshold:
                x, y, w, h = box[0].item(), box[1].item(), box[2].item() - box[0].item(), box[3].item() - box[1].item()
                rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='r', facecolor='none')
                plt.gca().add_patch(rect)


                idx_to_cat = {v: k for k, v in dataset.category_id_to_idx.items()}
                category_id = label.item()
                category_name = next((cat['name'] for cat in dataset.categories if cat['id'] == category_id), "unknown")

                plt.text(x, y+h+15, f"{category_name}: {score:.2f}", color='r', fontsize=10, backgroundcolor='w')

        plt.axis('off')
        plt.title(f"Image {target['image_id'].item()}")

    plt.tight_layout()
    plt.savefig('coco_detection_examples.png')
    plt.close()
    print(f"Saved detection visualizations to coco_detection_examples.png")

torch.set_default_device('cuda')
main()

Downloading: "https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth" to /root/.cache/torch/hub/checkpoints/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth
100%|██████████| 146M/146M [00:00<00:00, 170MB/s]


loading annotations into memory...
Done (t=0.75s)
creating index...
index created!
Created data loader with 5000 images for evaluation




Saved detection visualizations to coco_detection_examples.png
Running inference...


100%|██████████| 5000/5000 [10:37<00:00,  7.84it/s]


Running custom evaluation...

Evaluation Results:
mAP @ IoU=0.5: 0.5978

AP by category:
Category 1: 0.7571
Category 2: 0.5649
Category 3: 0.6503
Category 4: 0.6962
Category 5: 0.8572
Category 6: 0.8100
Category 7: 0.8342
Category 8: 0.5587
Category 9: 0.5150
Category 10: 0.5399
Category 11: 0.8495
Category 13: 0.7079
Category 14: 0.6213
Category 15: 0.3814
Category 16: 0.5174
Category 17: 0.8812
Category 18: 0.8078
Category 19: 0.7882
Category 20: 0.7620
Category 21: 0.7566
Category 22: 0.8267
Category 23: 0.8949
Category 24: 0.8615
Category 25: 0.8691
Category 27: 0.3226
Category 28: 0.6113
Category 31: 0.3025
Category 32: 0.5225
Category 33: 0.5877
Category 34: 0.8174
Category 35: 0.4750
Category 36: 0.5419
Category 37: 0.6090
Category 38: 0.5920
Category 39: 0.5602
Category 40: 0.6384
Category 41: 0.7810
Category 42: 0.5864
Category 43: 0.7733
Category 44: 0.5520
Category 46: 0.5700
Category 47: 0.5830
Category 48: 0.5067
Category 49: 0.3002
Category 50: 0.3033
Category 51: 0.5643


In [None]:
import os
import sys
import argparse
import time
import json
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from google.colab import drive
import cv2
import requests
import zipfile
import glob


IN_COLAB = 'google.colab' in sys.modules
if not IN_COLAB:
    print("This script is designed to run in Google Colab")
    sys.exit(1)

print("Mounting Google Drive...")
drive.mount('/content/drive', force_remount=True)


!pip install -q pycocotools
!pip install -q tqdm


!git clone https://github.com/AlexeyAB/darknet


!git clone https://github.com/WongKinYiu/PyTorch_YOLOv4.git

!wget https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights -P /content/PyTorch_YOLOv4/weights/


DOWNLOAD_FULL_COCO = False  

if DOWNLOAD_FULL_COCO:

    !wget http://images.cocodataset.org/zips/val2017.zip -P /content/
    !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P /content/

    !unzip -q /content/val2017.zip -d /content/coco/
    !unzip -q /content/annotations_trainval2017.zip -d /content/coco/
else:
    !mkdir -p /content/coco/val2017
    !mkdir -p /content/coco/annotations
    
   
    for i in tqdm(range(1, 101)):
        # COCO val2017 images start from 000000000139.jpg
        img_id = 139 + i
        img_url = f"http://images.cocodataset.org/val2017/{img_id:012d}.jpg"
        img_path = f"/content/coco/val2017/{img_id:012d}.jpg"
        try:
            response = requests.get(img_url)
            if response.status_code == 200:
                with open(img_path, 'wb') as f:
                    f.write(response.content)
        except Exception as e:
            print(f"Failed to download image {img_id}: {e}")
    
    # Download annotations
    
    !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P /content/
    !unzip -q /content/annotations_trainval2017.zip -d /content/coco/

# Change to the PyTorch_YOLOv4 directory
os.chdir('/content/PyTorch_YOLOv4')

# Import YOLOv4 modules
sys.path.append('/content/PyTorch_YOLOv4')
from models.models import Darknet
from utils.datasets import LoadImagesAndLabels
from utils.utils import non_max_suppression, xywh2xyxy, box_iou, ap_per_class, compute_ap

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

def coco80_to_coco91_class():
    # Converts 80-index (YOLO) to 91-index (COCO)
    return [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33,
            34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
            62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]

class CocoDataset(torch.utils.data.Dataset):
   
    def __init__(self, img_dir, ann_file, img_size=416, transform=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.transform = transform
        self.img_size = img_size
        
       
        ids_with_ann = []
        for img_id in self.ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            if len(ann_ids) > 0:
                ids_with_ann.append(img_id)
        self.ids = ids_with_ann
        
        
        self.class_names = [cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())]
        
    def __len__(self):
        return len(self.ids)
        
    def __getitem__(self, index):
        
        img_id = self.ids[index]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        
        
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        
        height, width = img.shape[:2]
        
        
        img_resized = cv2.resize(img, (self.img_size, self.img_size))
        
        
        img_resized = img_resized.transpose(2, 0, 1) / 255.0
        img_resized = np.ascontiguousarray(img_resized, dtype=np.float32)
        
        
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        
        
        boxes = []
        labels = []
        
        for ann in anns:
            if ann['area'] > 0 and ann['iscrowd'] == 0:
                
                x, y, w, h = ann['bbox']
               
                x_center = (x + w/2) / width
                y_center = (y + h/2) / height
                w = w / width
                h = h / height
                
               
                cat_id = ann['category_id']
                cls_id = self.coco.getCatIds().index(cat_id)
                
                boxes.append([x_center, y_center, w, h])
                labels.append(cls_id)
        
        if len(boxes) == 0:
            
            boxes = np.zeros((0, 4), dtype=np.float32)
            labels = np.zeros(0, dtype=np.int64)
        else:
            boxes = np.array(boxes, dtype=np.float32)
            labels = np.array(labels, dtype=np.int64)
            
        
        targets = np.hstack((labels.reshape(-1, 1), boxes))
        
        return torch.from_numpy(img_resized), torch.from_numpy(targets), img_path, (height, width)

def evaluate(model, dataloader, device, conf_thres=0.001, iou_thres=0.65):
    model.eval()
    
    coco91class = coco80_to_coco91_class()
    
    # Initialize COCO json dictionaries
    jdict = []
    
    for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader)):
        imgs = imgs.to(device).float()  # already normalized 0-1
        targets = targets.to(device)
        
        # Disable gradients
        with torch.no_grad():
            # Run model
            pred = model(imgs)  # inference
            
            # Run NMS
            pred = non_max_suppression(pred, conf_thres=conf_thres, iou_thres=iou_thres)
        
        # Process detections
        for i, det in enumerate(pred):  # per image
            if det is None or len(det) == 0:
                continue
                
            # Rescale boxes from img_size to original size
            det[:, :4] = scale_coords(imgs.shape[2:], det[:, :4], shapes[i]).round()
            
            # Get image id
            img_path = paths[i]
            image_id = int(os.path.basename(img_path).split('.')[0])
            
            # Append to COCO JSON dictionary
            for *xyxy, conf, cls in det:
                x1, y1, x2, y2 = [float(x) for x in xyxy]
                # Convert to COCO format: [x,y,w,h]
                w = x2 - x1
                h = y2 - y1
                
                jdict.append({
                    'image_id': image_id,
                    'category_id': coco91class[int(cls)],  # Convert to COCO category id
                    'bbox': [x1, y1, w, h],
                    'score': float(conf)
                })
    
    # Save JSON
    results_file = '/content/drive/MyDrive/results.json' if os.path.exists('/content/drive/MyDrive') else '/content/results.json'
    with open(results_file, 'w') as f:
        json.dump(jdict, f)
    
    print(f"Results saved to {results_file}")
    
    # Load COCO dataset
    coco_gt = COCO('/content/coco/annotations/instances_val2017.json')
    coco_dt = coco_gt.loadRes(results_file)
    
    # Run COCO evaluation
    coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
    if not DOWNLOAD_FULL_COCO:
        # If using subset, set the images to evaluate on
        img_ids = [int(os.path.basename(p).split('.')[0]) for p in glob.glob('/content/coco/val2017/*.jpg')]
        coco_eval.params.imgIds = img_ids
    
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    
    # Return mAP
    return coco_eval.stats

def scale_coords(img_size, coords, img_shape):

    gain = min(img_size[0] / img_shape[0], img_size[1] / img_shape[1])  # gain = old / new
    pad = (img_size[1] - img_shape[1] * gain) / 2, (img_size[0] - img_shape[0] * gain) / 2  # wh padding
    
    coords[:, [0, 2]] -= pad[0]  # x padding
    coords[:, [1, 3]] -= pad[1]  # y padding
    coords[:, :4] /= gain
    
    # Clip bounding xyxy bounding boxes to image shape (height, width)
    coords[:, 0].clamp_(0, img_shape[1])  # x1
    coords[:, 1].clamp_(0, img_shape[0])  # y1
    coords[:, 2].clamp_(0, img_shape[1])  # x2
    coords[:, 3].clamp_(0, img_shape[0])  # y2
    
    return coords

def load_model(weights_path, cfg_path, img_size=416, device='cuda'):

    # Initialize model
    model = Darknet(cfg_path, img_size)
    
    # Load weights
    if weights_path.endswith('.weights'):  # Darknet format
        model.load_state_dict(weights_path)
    else:  # PyTorch format
        checkpoint = torch.load(weights_path, map_location=device)
        if 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint)
    
    model.to(device)
    model.eval()
    
    return model

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Model configuration
    img_size = 416  # inference size (pixels)
    conf_thres = 0.001  # object confidence threshold
    iou_thres = 0.65  # IOU threshold for NMS
    batch_size = 16  # batch size
    
    # Set paths
    cfg_path = '/content/PyTorch_YOLOv4/cfg/yolov4.cfg'
    weights_path = '/content/PyTorch_YOLOv4/weights/yolov4.weights'
    
    # Load model
    print("Loading YOLOv4 model...")
    model = load_model(weights_path, cfg_path, img_size, device)
    
    # Create dataset and dataloader
    print("Creating COCO dataset...")
    dataset = CocoDataset(
        img_dir='/content/coco/val2017',
        ann_file='/content/coco/annotations/instances_val2017.json',
        img_size=img_size
    )
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=lambda x: list(zip(*x))  # custom collate function
    )
    
    
    print(f"\nEvaluating YOLOv4 on {'subset of 100' if not DOWNLOAD_FULL_COCO else 'full'} COCO 2017 validation set...")
    print(f"Settings: Image size={img_size}, Confidence threshold={conf_thres}, IoU threshold={iou_thres}")
    
    start_time = time.time()
    stats = evaluate(model, dataloader, device, conf_thres, iou_thres)
    elapsed = time.time() - start_time
    
    print(f"\nEvaluation completed in {elapsed:.2f} seconds")
    print("COCO mAP metrics:")
    print(f"mAP@0.5:0.95 = {stats[0]:.5f}")
    print(f"mAP@0.5 = {stats[1]:.5f}")
    print(f"mAP@0.75 = {stats[2]:.5f}")
    
    
    print("\nVisualizing some example detections...")
    visualize_examples(model, dataset, device, conf_thres=0.5, n=3)
    
    return stats

def visualize_examples(model, dataset, device, conf_thres=0.5, iou_thres=0.45, n=3):

    import matplotlib.patches as patches
    from matplotlib.colors import to_rgba
    
    # Select random images
    indices = np.random.choice(len(dataset), n, replace=False)
    
    for idx in indices:
        img_tensor, targets, img_path, (height, width) = dataset[idx]
        
        # Get original image for display
        orig_img = cv2.imread(img_path)
        orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
        
        # Get predictions
        with torch.no_grad():
            img_batch = img_tensor.unsqueeze(0).to(device)
            pred = model(img_batch)
            pred = non_max_suppression(pred, conf_thres, iou_thres)[0]  # Predictions for this image
        
        
        if pred is not None and len(pred) > 0:
            pred[:, :4] = scale_coords(img_tensor.shape[1:], pred[:, :4], (height, width)).round()
        
        
        fig, ax = plt.subplots(1, figsize=(12, 9))
        ax.imshow(orig_img)
        
       
        for t in targets:
            label, x, y, w, h = t.tolist()
            
            # Convert normalized xywh to pixel xyxy
            x1 = (x - w/2) * width
            y1 = (y - h/2) * height
            x2 = (x + w/2) * width
            y2 = (y + h/2) * height
            
            # Create rectangle patch
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='lime', facecolor='none')
            ax.add_patch(rect)
            
            # Add class label
            cls_name = dataset.class_names[int(label)]
            plt.text(x1, y1-5, f'{cls_name}', color='lime', fontsize=11, bbox=dict(facecolor='black', alpha=0.5))
        
        # Draw predicted boxes
        if pred is not None:
            for *xyxy, conf, cls in pred:
                x1, y1, x2, y2 = [x.item() for x in xyxy]
                
                # Create rectangle patch
                rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='red', facecolor='none')
                ax.add_patch(rect)
                
                # Add class label with confidence
                cls_name = dataset.class_names[int(cls)]
                plt.text(x1, y1-5, f'{cls_name} {conf:.2f}', color='red', fontsize=11, bbox=dict(facecolor='black', alpha=0.5))
        
        # Set title and display
        ax.set_title(f'Image: {os.path.basename(img_path)}')
        ax.axis('off')
        plt.tight_layout()
        plt.show()


main()