In [5]:
import os
import json
import random
import shutil
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image

from uuid import uuid4
from tqdm import tqdm
from typing import Dict, List, Optional
from datetime import datetime

# Dataset builder

In [None]:
# Define paths
METADATA_SERENGETI_PATH = "/data/luiz/dataset/serengeti/SnapshotSerengeti_S1-11_v2.1.json"
DATA_SERENGETI_PATH = "/data/luiz/dataset/serengeti_images/"
DATA_SERENGETI_SSD_PATH = "/ssd/luiz/dataset/serengeti_images/"
SPECIES_CSV_SERENGETI_PATH = "/data/luiz/dataset/partitions/species-classifier/serengeti"
ANIMAL_CSV_SERENGETI_PATH = "/data/luiz/dataset/partitions/species-classifier/serengeti"

METADATA_SERENGETI_PATH = "C:\\Users\\fabio\\Documents\\Workspace\\my-repos\\WildMatch\\data\\serengeti\\metadata.json"
DATA_SERENGETI_PATH = "C:\\Users\\fabio\\Documents\\Workspace\\my-repos\\WildMatch\\"

In [25]:
def load_json(json_file_path):
    print(f"Loading JSON file from {json_file_path}")
    with open(json_file_path, 'r') as f:
        return json.load(f)

In [11]:
data = load_json(METADATA_SERENGETI_PATH)

Loading JSON file from C:\Users\fabio\Documents\Workspace\my-repos\WildMatch\data\serengeti\metadata.json


In [19]:
def merge_annotations_and_images(data, image_base_path):
    """
    Merge annotations with images data and filter out images that don't exist.
    
    Args:
        data: Dictionary containing 'annotations' and 'images' lists
        image_base_path: Base path where images are stored
    
    Returns:
        List of dictionaries with merged data (only for existing images)
    """
    # Create a mapping from image_id to image metadata
    images_dict = {img['id']: img for img in data['images']}
    
    # Create a mapping from image_id to annotation
    annotations_dict = {ann['image_id']: ann for ann in data['annotations']}
    
    merged_data = []
    missing_images = []
    
    print(f"Total images in metadata: {len(data['images'])}")
    print(f"Total annotations: {len(data['annotations'])}")
    print(f"\nChecking which images exist on disk...")
    
    for img in tqdm(data['images']):
        image_id = img['id']
        file_name = img['file_name']
        
        # Construct full image path
        image_path = os.path.join(image_base_path, file_name)
        
        # Check if image exists
        if os.path.exists(image_path):
            # Get corresponding annotation if it exists
            annotation = annotations_dict.get(image_id, None)
            
            # Merge image and annotation data
            merged_item = {
                **img,  # Include all image metadata
                'full_path': image_path,
            }
            
            # Add annotation data if available
            if annotation:
                merged_item.update({
                    'annotation_id': annotation['id'],
                    'category_id': annotation['category_id'],
                    'seq_id': annotation['seq_id'],
                    'season': annotation['season'],
                    'subject_id': annotation['subject_id'],
                    'count': annotation['count'],
                    'standing': annotation['standing'],
                    'resting': annotation['resting'],
                    'moving': annotation['moving'],
                    'interacting': annotation['interacting'],
                    'young_present': annotation['young_present'],
                })
            else:
                merged_item['annotation_id'] = None
                merged_item['category_id'] = None
            
            merged_data.append(merged_item)
        else:
            missing_images.append({
                'image_id': image_id,
                'file_name': file_name,
                'expected_path': image_path
            })
    
    print(f"\n✓ Found {len(merged_data)} existing images")
    print(f"✗ Missing {len(missing_images)} images")
    
    return merged_data, missing_images

In [30]:
# Merge data and filter existing images
merged_data, missing_images = merge_annotations_and_images(data, DATA_SERENGETI_PATH)

# Convert to DataFrame for easier analysis
df = pd.DataFrame(merged_data)

Total images in metadata: 10
Total annotations: 10

Checking which images exist on disk...


100%|██████████| 10/10 [00:00<00:00, 13486.51it/s]


✓ Found 10 existing images
✗ Missing 0 images





In [29]:
df

Unnamed: 0,id,file_name,frame_num,seq_id,width,height,corrupt,location,seq_num_frames,datetime,...,annotation_id,category_id,season,subject_id,count,standing,resting,moving,interacting,young_present
0,S1/B04/B04_R1/S1_B04_R1_PICT0001,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#1,2048,1536,False,B04,1,2010-07-18 16:26:14,...,10e6bdcc-91f8-11e9-83b3-000d3a198845,1,S1,ASG0002kjh,1.0,0.62,0.06,0.0,0.5,0.0
1,S1/B04/B04_R1/S1_B04_R1_PICT0002,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#2,2048,1536,False,B04,1,2010-07-18 16:26:30,...,10e6bdcd-91f8-11e9-ab27-000d3a198845,1,S1,ASG0002kji,2.0,0.1,0.62,0.0,0.33,0.0
2,S1/B04/B04_R1/S1_B04_R1_PICT0003,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#3,2048,1536,False,B04,1,2010-07-20 06:14:06,...,10e6bdce-91f8-11e9-aa12-000d3a198845,0,S1,ASG0002kjj,,,,,,
3,S1/B04/B04_R1/S1_B04_R1_PICT0004,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#4,2048,1536,False,B04,1,2010-07-22 08:56:06,...,10e6bdcf-91f8-11e9-88e6-000d3a198845,0,S1,ASG0002kjk,,,,,,
4,S1/B04/B04_R1/S1_B04_R1_PICT0005,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#5,2048,1536,False,B04,1,2010-07-24 01:16:28,...,10e6bdd0-91f8-11e9-ad77-000d3a198845,0,S1,ASG0002kjl,,,,,,
5,S1/B04/B04_R1/S1_B04_R1_PICT0006,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#6,2048,1536,False,B04,1,2010-07-24 08:20:10,...,10e6bdd1-91f8-11e9-804b-000d3a198845,2,S1,ASG0002kjm,1.0,0.0,0.0,1.0,0.0,0.0
6,S1/B04/B04_R1/S1_B04_R1_PICT0007,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#7,2048,1536,False,B04,1,2010-07-24 10:14:32,...,10e6bdd2-91f8-11e9-90c3-000d3a198845,3,S1,ASG0002kjn,1.0,0.83,0.0,0.17,0.0,0.0
7,S1/B04/B04_R1/S1_B04_R1_PICT0008,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#8,2048,1536,False,B04,1,2010-07-30 03:59:42,...,10e6bdd3-91f8-11e9-b232-000d3a198845,4,S1,ASG0002kjo,1.0,0.81,0.0,0.19,0.0,0.0
8,S1/B04/B04_R1/S1_B04_R1_PICT0009,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#9,2048,1536,False,B04,1,2010-07-30 05:20:22,...,10e6bdd4-91f8-11e9-b80d-000d3a198845,5,S1,ASG0002kjp,1.0,0.94,0.0,0.06,0.0,0.0
9,S1/B04/B04_R1/S1_B04_R1_PICT0010,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#10,2048,1536,False,B04,1,2010-07-30 05:24:50,...,10e6bdd5-91f8-11e9-98c5-000d3a198845,0,S1,ASG0002kjq,,,,,,


# Data Requirements for WildMatch Implementation

Based on the paper (Section 2-11), we need to add:

## 1. Species Names
Map `category_id` to human-readable species names from `data["categories"]`

## 2. Taxonomic Hierarchy
For hierarchical prediction (Section 8):
- Class → Order → Family → Genus → Species

## 3. Knowledge Base (KB)
For each species, create Wikipedia-based visual descriptions:
- Extract Wikipedia articles
- Generate "Visually Relevant Summaries" using GPT-4
- Format: `SPECIES: VISUAL_DESCRIPTION`

## 4. Image Captions (during inference)
Generated by fine-tuned LLaVA-7B:
- Multiple captions per image (N=5 for self-consistency)
- Post-processing: color filtering + knowledge augmentation

## 5. Sequence Grouping
Group images by `seq_id` for sequence-level predictions (Section 11)

In [32]:
# Step 1: Add species names to the dataframe
# Create a mapping from category_id to species name
category_map = {cat['id']: cat['name'] for cat in data['categories']}

# Add species_name column
df['species_name'] = df['category_id'].map(category_map)

print("Species distribution:")
print(df['species_name'].value_counts())
print(f"\nTotal unique species: {df['species_name'].nunique()}")
print(f"Samples with unknown category: {df['species_name'].isna().sum()}")

Species distribution:
species_name
empty            4
human            2
gazellegrants    1
reedbuck         1
dikdik           1
zebra            1
Name: count, dtype: int64

Total unique species: 6
Samples with unknown category: 0


In [35]:
df

Unnamed: 0,id,file_name,frame_num,seq_id,width,height,corrupt,location,seq_num_frames,datetime,...,category_id,season,subject_id,count,standing,resting,moving,interacting,young_present,species_name
0,S1/B04/B04_R1/S1_B04_R1_PICT0001,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#1,2048,1536,False,B04,1,2010-07-18 16:26:14,...,1,S1,ASG0002kjh,1.0,0.62,0.06,0.0,0.5,0.0,human
1,S1/B04/B04_R1/S1_B04_R1_PICT0002,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#2,2048,1536,False,B04,1,2010-07-18 16:26:30,...,1,S1,ASG0002kji,2.0,0.1,0.62,0.0,0.33,0.0,human
2,S1/B04/B04_R1/S1_B04_R1_PICT0003,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#3,2048,1536,False,B04,1,2010-07-20 06:14:06,...,0,S1,ASG0002kjj,,,,,,,empty
3,S1/B04/B04_R1/S1_B04_R1_PICT0004,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#4,2048,1536,False,B04,1,2010-07-22 08:56:06,...,0,S1,ASG0002kjk,,,,,,,empty
4,S1/B04/B04_R1/S1_B04_R1_PICT0005,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#5,2048,1536,False,B04,1,2010-07-24 01:16:28,...,0,S1,ASG0002kjl,,,,,,,empty
5,S1/B04/B04_R1/S1_B04_R1_PICT0006,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#6,2048,1536,False,B04,1,2010-07-24 08:20:10,...,2,S1,ASG0002kjm,1.0,0.0,0.0,1.0,0.0,0.0,gazellegrants
6,S1/B04/B04_R1/S1_B04_R1_PICT0007,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#7,2048,1536,False,B04,1,2010-07-24 10:14:32,...,3,S1,ASG0002kjn,1.0,0.83,0.0,0.17,0.0,0.0,reedbuck
7,S1/B04/B04_R1/S1_B04_R1_PICT0008,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#8,2048,1536,False,B04,1,2010-07-30 03:59:42,...,4,S1,ASG0002kjo,1.0,0.81,0.0,0.19,0.0,0.0,dikdik
8,S1/B04/B04_R1/S1_B04_R1_PICT0009,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#9,2048,1536,False,B04,1,2010-07-30 05:20:22,...,5,S1,ASG0002kjp,1.0,0.94,0.0,0.06,0.0,0.0,zebra
9,S1/B04/B04_R1/S1_B04_R1_PICT0010,data\serengeti\images\puma.jpeg,1,SER_S1#B04#1#10,2048,1536,False,B04,1,2010-07-30 05:24:50,...,0,S1,ASG0002kjq,,,,,,,empty


In [36]:
# Step 2: Analyze sequence-level data (needed for Section 11 of paper)
# Group by sequence to see how many frames per sequence
sequence_stats = df.groupby('seq_id').agg({
    'id': 'count',  # Number of frames
    'species_name': 'first',  # Species (should be same across sequence)
    'datetime': ['min', 'max']  # Time range
}).round(2)

sequence_stats.columns = ['num_frames', 'species', 'start_time', 'end_time']
sequence_stats = sequence_stats.reset_index()

print("Sequence-level statistics:")
print(f"Total sequences: {len(sequence_stats)}")
print(f"\nFrames per sequence distribution:")
print(sequence_stats['num_frames'].describe())
print(f"\nSequences with multiple frames: {(sequence_stats['num_frames'] > 1).sum()}")


Sequence-level statistics:
Total sequences: 10

Frames per sequence distribution:
count    10.0
mean      1.0
std       0.0
min       1.0
25%       1.0
50%       1.0
75%       1.0
max       1.0
Name: num_frames, dtype: float64

Sequences with multiple frames: 0


In [37]:
sequence_stats

Unnamed: 0,seq_id,num_frames,species,start_time,end_time
0,SER_S1#B04#1#1,1,human,2010-07-18 16:26:14,2010-07-18 16:26:14
1,SER_S1#B04#1#10,1,empty,2010-07-30 05:24:50,2010-07-30 05:24:50
2,SER_S1#B04#1#2,1,human,2010-07-18 16:26:30,2010-07-18 16:26:30
3,SER_S1#B04#1#3,1,empty,2010-07-20 06:14:06,2010-07-20 06:14:06
4,SER_S1#B04#1#4,1,empty,2010-07-22 08:56:06,2010-07-22 08:56:06
5,SER_S1#B04#1#5,1,empty,2010-07-24 01:16:28,2010-07-24 01:16:28
6,SER_S1#B04#1#6,1,gazellegrants,2010-07-24 08:20:10,2010-07-24 08:20:10
7,SER_S1#B04#1#7,1,reedbuck,2010-07-24 10:14:32,2010-07-24 10:14:32
8,SER_S1#B04#1#8,1,dikdik,2010-07-30 03:59:42,2010-07-30 03:59:42
9,SER_S1#B04#1#9,1,zebra,2010-07-30 05:20:22,2010-07-30 05:20:22


# Task 1: Build Knowledge Base from Wikipedia

We'll use Wikipedia API to fetch articles and GPT-4 to generate Visually Relevant Summaries (VRS)

In [43]:
import wikipediaapi
from openai import OpenAI
from dotenv import load_dotenv

# Load OpenAI API key
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# Initialize Wikipedia API
wiki_wiki = wikipediaapi.Wikipedia(
    user_agent='WildMatch/1.0 (fabio@example.com)',
    language='en'
)

In [45]:
def fetch_wikipedia_article(species_name):
    """
    Fetch Wikipedia article for a species.
    Extract summary and relevant sections about appearance.
    
    Args:
        species_name: Name of the species (e.g., 'zebra', 'elephant')
    
    Returns:
        Dictionary with summary and appearance sections
    """
    # Try to get the page
    page = wiki_wiki.page(species_name)
    
    if not page.exists():
        # Try with capital first letter
        page = wiki_wiki.page(species_name.capitalize())
    
    if not page.exists():
        print(f"⚠ Wikipedia page not found for: {species_name}")
        return None
    
    # Get the summary
    summary = page.summary
    
    # Extract relevant sections
    appearance_text = []
    relevant_keywords = ['description', 'characteristics', 'appearance', 'anatomy', 'morphology', 'physical']
    
    def extract_sections(sections, level=0):
        """Recursively extract sections with relevant keywords"""
        for section in sections:
            # Check if section title contains relevant keywords
            if any(keyword in section.title.lower() for keyword in relevant_keywords):
                appearance_text.append(f"\n## {section.title}\n{section.text}")
            
            # Recursively check subsections
            if hasattr(section, 'sections') and section.sections:
                extract_sections(section.sections, level + 1)
    
    # Extract from all sections
    if hasattr(page, 'sections') and page.sections:
        extract_sections(page.sections)
    
    return {
        'species': species_name,
        'page_title': page.title,
        'summary': summary,
        'appearance_sections': '\n'.join(appearance_text) if appearance_text else summary,
        'url': page.fullurl
    }

# Test with one species
test_species = 'zebra'
wiki_data = fetch_wikipedia_article(test_species)
if wiki_data:
    print(f"✓ Found page: {wiki_data['page_title']}")
    print(f"Summary length: {len(wiki_data['summary'])} chars")
    print(f"Appearance sections length: {len(wiki_data['appearance_sections'])} chars")
    print(f"\nFirst 500 chars of appearance text:\n{wiki_data['appearance_sections'][:500]}...")

✓ Found page: Zebra
Summary length: 2116 chars
Appearance sections length: 1112 chars

First 500 chars of appearance text:

## Characteristics
As with all wild equines, zebras have barrel-chested bodies with tufted tails, elongated faces and long necks with long, erect manes. Their thin legs are each supported by a spade-shaped toe covered in a hard hoof. Their dentition is adapted for grazing; they have large incisors that clip grass blades and rough molars and premolars well suited for grinding. Males have spade-shaped canines, which can be used as weapons in fighting. The eyes of zebras are at the sides and far u...


In [50]:
def generate_visually_relevant_summary(wiki_text, species_name, model="gpt-4o-mini"):
    """
    Use GPT-4 to generate Visually Relevant Summary (VRS) from Wikipedia text.
    This follows the prompt from Appendix A of the WildMatch paper.
    
    Args:
        wiki_text: Wikipedia article text
        species_name: Name of the species
        model: OpenAI model to use
    
    Returns:
        Visually relevant description string
    """
    system_msg = (
        "You are an AI assistant specialized in biology and providing accurate and "
        "detailed descriptions of animal species."
    )
    
    user_msg = f"""You are given the description of an animal species. Provide a very detailed description of the appearance of the species and describe each body part of the animal in detail. Only include details that can be directly visible in a photograph of the animal. Only include information related to the appearance of the animal and nothing else. Make sure to only include information that is present in the species description and is certainly true for the given species. Do not include any information related to the sound or smell of the animal. Do not include any numerical information related to measurements in the text in units: m, cm, in, inches, ft, feet, km/h, kg, lb, lbs. Remove any special characters such as unicode tags from the text. Return the answer as a single paragraph.

Species description: {wiki_text}

Answer:"""
    
    try:
        completion = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_msg},
                {"role": "user", "content": user_msg},
            ],
            temperature=0.0,
        )
        
        vrs = completion.choices[0].message.content.strip()
        return vrs
    
    except Exception as e:
        print(f"❌ Error generating VRS for {species_name}: {e}")
        return None

# Test VRS generation with zebra
if wiki_data:
    print(f"Generating Visually Relevant Summary for {test_species}...")
    vrs = generate_visually_relevant_summary(
        wiki_data['appearance_sections'], 
        test_species
    )
    print(vrs)

Generating Visually Relevant Summary for zebra...
Zebras possess a distinctive appearance characterized by their barrel-chested bodies and elongated faces. Their long necks are adorned with long, erect manes that stand upright, contributing to their striking silhouette. The body is supported by thin legs, each ending in a spade-shaped toe that is encased in a hard hoof, providing stability and support. The tail is tufted, adding to their unique profile. Zebras have large, prominent eyes positioned on the sides of their heads, allowing for a wide field of vision, which is particularly advantageous when grazing in tall grass. Their moderately long, erect ears are highly mobile, enabling them to detect sounds from various directions. The front legs are notably longer than the back legs, a feature that distinguishes them from other equines. On their front legs, zebras exhibit chestnut callosities, which are absent in horses. The skull of a zebra is relatively small with a straight dorsal o

In [51]:
def build_knowledge_base(species_list, output_path="data/knowledge_base.json", skip_existing=True):
    """
    Build the complete knowledge base for all species.
    
    Args:
        species_list: List of species names
        output_path: Path to save the knowledge base
        skip_existing: If True, skip species that already exist in the KB
    
    Returns:
        Dictionary mapping species names to visual descriptions
    """
    # Load existing KB if it exists
    knowledge_base = {}
    if skip_existing and os.path.exists(output_path):
        print(f"Loading existing knowledge base from {output_path}")
        with open(output_path, 'r', encoding='utf-8') as f:
            knowledge_base = json.load(f)
        print(f"Found {len(knowledge_base)} existing species")
    
    # Filter out species that already exist
    species_to_process = [s for s in species_list if s not in knowledge_base or not knowledge_base.get(s)]
    
    print(f"\nProcessing {len(species_to_process)} species (out of {len(species_list)} total)")
    print("This may take a while due to API rate limits...\n")
    
    for i, species in enumerate(tqdm(species_to_process, desc="Building KB")):
        # Skip empty species
        if not species or species == 'empty':
            continue
        
        try:
            # Fetch Wikipedia article
            wiki_data = fetch_wikipedia_article(species)
            
            if not wiki_data:
                knowledge_base[species] = None
                continue
            
            # Generate VRS
            vrs = generate_visually_relevant_summary(
                wiki_data['appearance_sections'],
                species
            )
            
            # Store in KB
            knowledge_base[species] = {
                'description': vrs,
                'wikipedia_title': wiki_data['page_title'],
                'wikipedia_url': wiki_data['url'],
                'raw_summary': wiki_data['summary'][:500] + '...',  # Store truncated version
            }
            
            # Save periodically (every 5 species)
            if (i + 1) % 5 == 0:
                with open(output_path, 'w', encoding='utf-8') as f:
                    json.dump(knowledge_base, f, indent=2, ensure_ascii=False)
        
        except Exception as e:
            print(f"\n❌ Error processing {species}: {e}")
            knowledge_base[species] = None
    
    # Final save
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(knowledge_base, f, indent=2, ensure_ascii=False)
    
    print(f"\n✓ Knowledge base saved to {output_path}")
    print(f"✓ Successfully processed {sum(1 for v in knowledge_base.values() if v)} species")
    print(f"✗ Failed to process {sum(1 for v in knowledge_base.values() if not v)} species")
    
    return knowledge_base

# Get unique species from your dataset
unique_species = df['species_name'].dropna().unique().tolist()
print(f"Found {len(unique_species)} unique species in dataset:")
print(unique_species, "..." if len(unique_species) > 10 else "")

Found 6 unique species in dataset:
['human', 'empty', 'gazellegrants', 'reedbuck', 'dikdik', 'zebra'] 


In [52]:
# Build the knowledge base for all species
# Uncomment the line below to start building the KB
# kb = build_knowledge_base(unique_species, output_path="../data/knowledge_base.json")

# For now, let's just build for a few test species
test_species_list = ['zebra', 'elephant', 'giraffe', 'lion', 'leopard']
print(f"Building KB for {len(test_species_list)} test species...")
kb_test = build_knowledge_base(test_species_list, output_path="../data/knowledge_base_test.json")

Building KB for 5 test species...

Processing 5 species (out of 5 total)
This may take a while due to API rate limits...



Building KB: 100%|██████████| 5/5 [00:41<00:00,  8.36s/it]


✓ Knowledge base saved to ../data/knowledge_base_test.json
✓ Successfully processed 5 species
✗ Failed to process 0 species





# Task 2: Collect Taxonomic Hierarchy

We'll use an API to get taxonomic information (Class → Order → Family → Genus → Species)

In [67]:
taxonomies = load_json("../data/taxonomy.json")
taxonomies

Loading JSON file from ../data/taxonomy.json


[{'species_name': 'zebra',
  'scientific_name': 'Equus quagga',
  'kingdom': 'Animalia',
  'phylum': 'Chordata',
  'class': 'Mammalia',
  'order': 'Perissodactyla',
  'family': 'Equidae',
  'genus': 'Equus',
  'species': 'quagga',
  'rank': 'SPECIES'},
 {'species_name': 'lion',
  'scientific_name': 'Panthera leo',
  'kingdom': 'Animalia',
  'phylum': 'Chordata',
  'class': 'Mammalia',
  'order': 'Carnivora',
  'family': 'Felidae',
  'genus': 'Panthera',
  'species': 'leo',
  'rank': 'SPECIES'},
 {'species_name': 'panthera',
  'scientific_name': 'Panthera leo',
  'kingdom': 'Animalia',
  'phylum': 'Chordata',
  'class': 'Mammalia',
  'order': 'Carnivora',
  'family': 'Felidae',
  'genus': 'Panthera',
  'species': 'leo',
  'rank': 'SPECIES'}]

# Task 3: Visual Description Generation (VLM Captioning)

For this task, you have several options:
1. **Use OpenAI's GPT-4 Vision API** (easier, but costs per image)
2. **Fine-tune LLaVA-7B** (complex, requires GPU training as in the paper)
3. **Use other VLM APIs** (Google Gemini, Anthropic Claude with vision, etc.)

Let's implement option 1 (GPT-4 Vision) as a starting point:

In [None]:
import base64

def encode_image_to_base64(image_path):
    """Encode image to base64 for OpenAI API."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def generate_visual_description(image_path, model="gpt-4o-mini", num_samples=1):
    """
    Generate detailed visual description of an animal in an image.
    Uses prompts similar to those in the WildMatch paper (Section 5.2).
    
    Args:
        image_path: Path to the image
        model: OpenAI model to use (gpt-4o, gpt-4o-mini)
        num_samples: Number of caption samples to generate
    
    Returns:
        List of caption strings
    """
    # Instruction prompts from the paper (Section 5.2)
    instructions = [
        "Give a very detailed visual description of the animal in the photo.",
        "Describe in detail the visible body parts of the animal in the photo.",
        "What are the visual characteristics of the animal in the photo?",
        "Describe the appearance of the animal in the photo.",
        "What are the identifying characteristics of the animal visible in the photo?",
        "How would you describe the animal in the photo?",
        "What does the animal in the photo look like?"
    ]
    
    captions = []
    
    for i in range(num_samples):
        # Randomly select an instruction
        instruction = random.choice(instructions)
        
        try:
            # For gpt-4o models, we can use vision
            if 'gpt-4o' in model:
                # Encode image
                base64_image = encode_image_to_base64(image_path)
                
                response = client.chat.completions.create(
                    model=model,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": instruction},
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/jpeg;base64,{base64_image}"
                                    }
                                }
                            ]
                        }
                    ],
                    max_tokens=500,
                    temperature=0.7 if num_samples > 1 else 0.0,
                )
                
                caption = response.choices[0].message.content.strip()
                captions.append(caption)
            else:
                print(f"⚠ Model {model} doesn't support vision. Use gpt-4o or gpt-4o-mini")
                return None
                
        except Exception as e:
            print(f"❌ Error generating caption: {e}")
            captions.append(None)
    
    return captions

# Test with one image
if len(df) > 0:
    test_image_path = df.iloc[0]['full_path']
    print(f"Testing caption generation on: {test_image_path}")
    print(f"Ground truth species: {df.iloc[0]['species_name']}\n")
    
    # Generate 1 caption for testing
    # Uncomment to run (will use OpenAI API credits):
    # captions = generate_visual_description(test_image_path, model="gpt-4o-mini", num_samples=1)
    # if captions and captions[0]:
    #     print(f"Generated caption:\n{captions[0]}")

# Task 4: LLM Matching Pipeline

This implements the core WildMatch prediction pipeline (Section 7 of the paper)

In [None]:
def match_caption_to_species(caption, knowledge_base, model="gpt-4o-mini"):
    """
    Use LLM to match a visual description to the most likely species.
    Follows the prompt from Appendix G of the WildMatch paper.
    
    Args:
        caption: Visual description of the animal
        knowledge_base: Dict mapping species to descriptions
        model: OpenAI model to use
    
    Returns:
        Predicted species name
    """
    # Format knowledge base
    kb_text_lines = []
    species_names = []
    
    for species, data in knowledge_base.items():
        if data and data.get('description'):
            species_names.append(species)
            kb_text_lines.append(f"{species}: {data['description']}")
    
    kb_text = "\n\n".join(kb_text_lines)
    species_list_str = ", ".join(species_names)
    
    system_msg = (
        "You are an AI expert in biology specialized in animal species identification."
    )
    
    user_msg = f"""You are given a knowledge base of animal species and their visual appearance:

{kb_text}

Now you are given the following detailed description of an animal seen in a photograph:

\"\"\"{caption}\"\"\"

Task:
- Choose the single most likely species from this list: [{species_list_str}]
- Answer with exactly one of these names, nothing else.
"""
    
    try:
        completion = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_msg},
                {"role": "user", "content": user_msg},
            ],
            temperature=0.0,
        )
        
        answer = completion.choices[0].message.content.strip()
        
        # Post-process to ensure it's one of the species names
        if answer not in species_names:
            # Try fuzzy matching
            answer_lower = answer.lower()
            for sp in species_names:
                if sp.lower() in answer_lower or answer_lower in sp.lower():
                    return sp
            # Default to first species if no match
            print(f"⚠ Warning: LLM returned '{answer}' which is not in species list")
            return species_names[0] if species_names else None
        
        return answer
    
    except Exception as e:
        print(f"❌ Error in LLM matching: {e}")
        return None

print("✓ LLM matching function defined")

In [None]:
def wildmatch_predict(image_path, knowledge_base, n_captions=5, vlm_model="gpt-4o-mini", llm_model="gpt-4o-mini"):
    """
    Complete WildMatch prediction pipeline with self-consistency.
    Implements Section 7 of the paper.
    
    Args:
        image_path: Path to the image
        knowledge_base: Dict mapping species to descriptions
        n_captions: Number of captions to generate for self-consistency
        vlm_model: Model for visual description generation
        llm_model: Model for species matching
    
    Returns:
        Dict with prediction, confidence, and intermediate results
    """
    # 1. Generate N visual descriptions
    print(f"Generating {n_captions} visual descriptions...")
    captions = generate_visual_description(image_path, model=vlm_model, num_samples=n_captions)
    
    if not captions or all(c is None for c in captions):
        return {
            'prediction': None,
            'confidence': 0.0,
            'error': 'Failed to generate captions'
        }
    
    # 2. Match each caption to species
    print(f"Matching captions to species...")
    species_predictions = []
    for i, caption in enumerate(captions):
        if caption:
            pred = match_caption_to_species(caption, knowledge_base, model=llm_model)
            species_predictions.append(pred)
            print(f"  Caption {i+1}: {pred}")
    
    # 3. Majority vote (self-consistency)
    from collections import Counter
    vote_counts = Counter(species_predictions)
    best_species, best_count = vote_counts.most_common(1)[0]
    confidence = best_count / n_captions
    
    result = {
        'prediction': best_species,
        'confidence': confidence,
        'captions': captions,
        'species_votes': species_predictions,
        'vote_counts': dict(vote_counts),
    }
    
    print(f"\n✓ Prediction: {best_species} (confidence: {confidence:.2%})")
    print(f"  Vote distribution: {dict(vote_counts)}")
    
    return result

print("✓ WildMatch prediction pipeline defined")