In [1]:
import os
import json
import random
import shutil
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image

from sklearn.utils import resample
from uuid import uuid4
from tqdm import tqdm
from typing import Dict, List, Optional
from datetime import datetime

# 1. Dataset builder

In [2]:
# Define paths
METADATA_SERENGETI_PATH = "/data/luiz/dataset/serengeti/SnapshotSerengeti_S1-11_v2.1.json"
DATA_SERENGETI_PATH = "/data/luiz/dataset/serengeti_images/"

# METADATA_SERENGETI_PATH = "C:\\Users\\fabio\\Documents\\Workspace\\my-repos\\WildMatch\\data\\serengeti\\metadata.json"
# DATA_SERENGETI_PATH = "C:\\Users\\fabio\\Documents\\Workspace\\my-repos\\WildMatch\\"

Utils functions

In [3]:
def load_json(json_file_path):
    print(f"Loading JSON file from {json_file_path}")
    with open(json_file_path, 'r') as f:
        return json.load(f)
    
def merge_annotations_and_images(data, image_base_path):
    """
    Merge annotations with images data and filter out images that don't exist.
    
    Args:
        data: Dictionary containing 'annotations' and 'images' lists
        image_base_path: Base path where images are stored
    
    Returns:
        List of dictionaries with merged data (only for existing images)
    """
    # Create a mapping from image_id to image metadata
    images_dict = {img['id']: img for img in data['images']}
    
    # Create a mapping from image_id to annotation
    annotations_dict = {ann['image_id']: ann for ann in data['annotations']}
    
    merged_data = []
    missing_images = []
    
    print(f"Total images in metadata: {len(data['images'])}")
    print(f"Total annotations: {len(data['annotations'])}")
    print(f"\nChecking which images exist on disk...")
    
    for img in tqdm(data['images']):
        image_id = img['id']
        file_name = img['file_name']
        
        # Construct full image path
        image_path = os.path.join(image_base_path, file_name)
        
        # Check if image exists
        if os.path.exists(image_path):
            # Get corresponding annotation if it exists
            annotation = annotations_dict.get(image_id, None)
            
            # Merge image and annotation data
            merged_item = {
                **img,  # Include all image metadata
                'full_path': image_path,
            }
            
            # Add annotation data if available
            if annotation:
                merged_item.update({
                    'annotation_id': annotation['id'],
                    'category_id': annotation['category_id'],
                    'seq_id': annotation['seq_id'],
                    'season': annotation['season'],
                    'subject_id': annotation['subject_id'],
                    'count': annotation['count'],
                    'standing': annotation['standing'],
                    'resting': annotation['resting'],
                    'moving': annotation['moving'],
                    'interacting': annotation['interacting'],
                    'young_present': annotation['young_present'],
                })
            else:
                merged_item['annotation_id'] = None
                merged_item['category_id'] = None
            
            merged_data.append(merged_item)
        else:
            missing_images.append({
                'image_id': image_id,
                'file_name': file_name,
                'expected_path': image_path
            })
    
    print(f"\n✓ Found {len(merged_data)} existing images")
    print(f"✗ Missing {len(missing_images)} images")
    
    return merged_data, missing_images

# Function to balance the dataset
def balance_dataset(df, category_col, min_samples=100):
    # Count samples per class
    class_counts = df[category_col].value_counts()
    
    # Target size = smallest class count
    min_size = class_counts.min()
    if min_size > min_samples:
        min_size = min_samples

    balanced_list = []

    # Downsample each class to min_size
    for class_value in class_counts.index:
        df_class = df[df[category_col] == class_value]
        df_downsampled = resample(
            df_class,
            replace=False,
            n_samples=min_size,
            random_state=123
        )
        balanced_list.append(df_downsampled)
    
    # Combine all classes
    df_balanced = pd.concat(balanced_list).reset_index(drop=True)
    
    return df_balanced

# Function to copy images to a specified directory
def copy_images_to_directory(df, source_col, dest_dir):
    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)
    
    paths = []
    for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Copying images"):
        src = row[source_col]
        paths.append(os.path.join(dest_dir, os.path.basename(src)))
        if os.path.exists(src):
            shutil.copy(src, dest_dir)

    df[source_col] = paths
    return df


In [4]:
CATEGORIES_TO_REMOVE = [0, 1, 23]

SPECIES_TO_INCLUDE = [
  "elephant",
  "ostrich",
  "zebra",
  "cheetah",
  "hippopotamus",
  "baboon",
  "buffalo",
  "giraffe",
  "warthog",
  "guineafowl",
  "hyenaspotted",
  "impala"
]

CSV_PATH = '../data/serengeti/dataset.csv'

In [None]:
data = load_json(METADATA_SERENGETI_PATH)

# Merge data and filter existing images
merged_data, missing_images = merge_annotations_and_images(data, DATA_SERENGETI_PATH)

# Convert to DataFrame for easier analysis
df_original = pd.DataFrame(merged_data)

# Filter out unwanted categories
df = df_original[~df_original['category_id'].isin(CATEGORIES_TO_REMOVE)]

# Create a mapping from category_id to species name
category_map = {cat['id']: cat['name'] for cat in data['categories']}

# Add species_name column
df['species_name'] = df['category_id'].map(category_map)

print("Species distribution:")
print(df['species_name'].value_counts())
print(f"\nTotal unique species: {df['species_name'].nunique()}")
print(f"Samples with unknown category: {df['species_name'].isna().sum()}")

df = df[df['species_name'].isin(SPECIES_TO_INCLUDE)]
df['species_name'].value_counts()

# Balance the dataset
df_balanced = balance_dataset(df, 'category_id')

# Copy images to the specified directory
df_balanced = copy_images_to_directory(df_balanced, 'full_path', '../data/serengeti/images')

Loading JSON file from /data/luiz/dataset/serengeti/SnapshotSerengeti_S1-11_v2.1.json
Total images in metadata: 7178440
Total annotations: 7261545

Checking which images exist on disk...


100%|██████████| 7178440/7178440 [02:13<00:00, 53875.56it/s] 



✓ Found 3197506 existing images
✗ Missing 3980934 images


In [None]:
# Save the balanced dataset to CSV
df_balanced.to_csv(CSV_PATH, index=False)

Copying images:   0%|          | 0/1200 [00:00<?, ?it/s]

Copying images: 100%|██████████| 1200/1200 [00:00<00:00, 3362.73it/s]
