# Gemma 3N Plant Dataset Fine-tuning

<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


In [None]:
# GPU Selection - Set which GPU to use
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Use GPU 3 instead of GPU 0
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Use GPU 2
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"  # Use GPU 3
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"  # Use GPUs 1 and 2

print(f"CUDA_VISIBLE_DEVICES set to: {os.environ.get('CUDA_VISIBLE_DEVICES', 'default')}")


### Plant Dataset Fine-tuning Configuration

**Key Features:**
- 🌿 **Plant Identification**: Fine-tune on plant dataset for species identification
- 🎯 **Configurable Dataset Size**: Set `MAX_IMAGES_PER_SPECIES` to control training data size
- 🖼️ **Vision Fine-tuning**: Enabled vision layers for multimodal learning
- 📊 **Balanced Species**: Automatically balances across different plant species
- 🔄 **Adaptive Training**: Training steps adjust automatically based on dataset size

**Configuration Parameters:**
- `DATASET_PATH`: Path to your plant dataset
- `MAX_IMAGES_PER_SPECIES`: Number of images per species (default: 200)
- `IDENTIFICATION_PROMPT`: The prompt used for plant identification

### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !uv pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
%%capture
!uv pip install --no-deps --upgrade timm # Only for Gemma 3N

In [None]:
!uv pip show transformers triton torch xformers timm unsloth unsloth-zoo

In [None]:
# Verify the plant dataset structure
import os
import glob
# dataset_base_path = "/content/drive/MyDrive/plants"
dataset_base_path = '../../data/plants/train/'

print("Checking plant dataset structure...")
if os.path.exists(dataset_base_path):
    print(f"✅ Found plant folder at: {dataset_base_path}")
    
    # Check for species folders
    species_folders = [d for d in os.listdir(dataset_base_path) if os.path.isdir(os.path.join(dataset_base_path, d))]
    
    if species_folders:
        print(f"✅ Found {len(species_folders)} species folders:")
        total_images = 0
        for species in sorted(species_folders):
            species_path = os.path.join(dataset_base_path, species)
            image_files = []
            for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                image_files.extend(glob.glob(os.path.join(species_path, ext)))
            image_count = len(image_files)
            total_images += image_count
            print(f"   - {species}: {image_count} images")
        
        print(f"✅ Total images found: {total_images}")
    else:
        print("❌ No species folders found")

else:
    print(f"❌ Plant dataset not found at: {dataset_base_path}")
    print("Please ensure your plant dataset is available")
    print("Expected structure:")
    print("plant_data/")
    print("├── Dandelion/")
    print("│   ├── image1.jpg")
    print("│   └── image2.jpg")
    print("├── Chickweed/")
    print("│   ├── image1.jpg")
    print("│   └── image2.jpg")
    print("└── ...")


### Unsloth

`FastModel` supports loading nearly any model now! This includes Vision and Text models!

In [None]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer_inference = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E2B-it",
    dtype = None, # None for auto detection
    max_seq_length = 2048, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = True,
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

In [None]:
# Fix PyTorch Dynamo recompile limits for Unsloth + Gemma 3N
import torch._dynamo
torch._dynamo.config.cache_size_limit = 1000  # Increase from default 64
torch._dynamo.config.suppress_errors = True   # Don't fail on compilation errors

# Set up environment for better PyTorch compilation
import os
os.environ['TORCH_LOGS'] = 'recompiles'  # Monitor recompilations
os.environ['TORCHDYNAMO_VERBOSE'] = '0'   # Reduce verbose output

# Gemma 3N can process Text, Vision and Audio!

Let's first experience how Gemma 3N can handle multimodal inputs. We use Gemma 3N's recommended settings of `temperature = 1.0, top_p = 0.95, top_k = 64`

In [None]:
from transformers import TextStreamer
# Helper function for inference
def do_gemma_3n_inference(messages, max_new_tokens = 128):
    _ = model.generate(
        **tokenizer_inference.apply_chat_template(
            messages,
            add_generation_prompt = True, # Must add for generation
            tokenize = True,
            return_dict = True,
            return_tensors = "pt",
        ).to("cuda"),
        max_new_tokens = max_new_tokens,
        temperature = 1.0, top_p = 0.95, top_k = 64,
        streamer = TextStreamer(tokenizer_inference, skip_prompt = True),
    )

# Gemma 3N can see images!

<img src="https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg" alt="Alt text" height="256">

In [None]:
from PIL import Image
from matplotlib import pyplot as plt

sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"
# Try to find a plant image for demonstration
import glob
import os

plant_image = None
if os.path.exists(dataset_base_path):
    for species_folder in os.listdir(dataset_base_path):
        species_path = os.path.join(dataset_base_path, species_folder)
        if os.path.isdir(species_path):
            for ext in ['*.jpg', '*.jpeg', '*.png']:
                images = glob.glob(os.path.join(species_path, ext))
                if images:
                    plant_image = images[0]
                    print(f"Found plant image: {plant_image}")
                    break
            if plant_image:
                break

image_link = plant_image if plant_image else sloth_link
image = Image.open(image_link)

print("Image link:", image_link)
plt.imshow(image)

messages = [{
    "role" : "user",
    "content": [
        { "type": "image", "image" : image },
        { "type": "text",  "text" : "What type of plant is this? Can you identify the species?" }
    ]
}]
# You might have to wait 1 minute for Unsloth's auto compiler
do_gemma_3n_inference(messages, max_new_tokens = 32)

# Let's finetune Gemma 3N!

You can finetune the vision and text parts for now through selection - the audio part can also be finetuned - we're working to make it selectable as well!

We now add LoRA adapters so we only need to update a small amount of parameters!

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = True,  # Turn ON for vision tasks!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # Should leave on always!

    r = 8,           # Larger = higher accuracy, but might overfit
    lora_alpha = 8,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

<a name="Data"></a>
### Data Prep
We now use the `Gemma-3` format for conversation style finetunes with the plant dataset. We'll create vision-text pairs for plant identification. Each sample will contain an image and the corresponding question/answer about plant species identification.

Gemma-3 renders multi turn conversations like below:
```
<bos><start_of_turn>user
<image>
What type of plant is this?<end_of_turn>
<start_of_turn>model  
This is an ...<end_of_turn>
```

We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3, phi4, qwen2.5, gemma3` and more.

In [None]:
# Create a copy of the tokenizer for training (with chat template)
# but keep the original tokenizer_inference for multimodal inference
import copy

from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    copy.deepcopy(tokenizer_inference),  # Use a copy to preserve the original
    chat_template = "gemma-3",
)

# Keep the original tokenizer_inference for proper multimodal inference
print("Created training tokenizer with chat template")
print("Keeping original tokenizer_inference for multimodal inference")

### Plant Dataset Configuration

**Important**: Before running this notebook, first split your dataset using:
```bash
python split_plant_dataset.py --source data/plants --output data/plants
```

This creates `data/plants/train/` and `data/plants/test/` directories with a proper holdout test set.

Configure the number of images to use for training per species:

In [None]:
# Configuration for plant dataset
DATASET_PATH = "../../data/plants/train/"  # Path to the training split (after running split_plant_dataset.py)
MAX_IMAGES_PER_SPECIES = 2000  # Number of images to use per species
INCLUDE_DETAILED_RESPONSES = False  # Include detailed feature descriptions

# Train/Validation Split Configuration
VALIDATION_SPLIT = 0.0  # 20% for validation, 80% for training
USE_VALIDATION = False   # Set to False to use all data for training

# Define the identification prompts
IDENTIFICATION_PROMPTS = [
    "What type of plant is this? Please respond concisely.",
    "Can you identify this plant species? Please respond concisely.", 
    "What species does this plant belong to? Please respond concisely.",
    "Please identify this plant. Please respond concisely.",
    "What kind of plant am I looking at? Please respond concisely.",
    "Help me identify this plant. Please respond concisely."
]

# Plant species descriptions (from prepare_plant_dataset.py)
PLANT_DESCRIPTIONS = {
    "Alfalfa": {
        "description": "A perennial legume cultivated globally as a primary forage crop. It is highly valued for its high protein content and ability to improve soil by fixing nitrogen.",
        "features": ["Compound leaves with three leaflets", "Clusters of small purple flowers", "Deep taproot system", "Grows up to 1 meter tall"],
        "habitat": "Cultivated fields, grasslands, and pastures.",
        "uses": "Forage for livestock (hay and silage); cover crop; soil improvement."
    },
    "Asparagus": {
        "description": "A popular perennial vegetable known for its tender, edible spears that emerge in the spring. Once established, plants can be productive for many years.",
        "features": ["Edible spears in spring", "Feathery, fern-like foliage", "Small bell-shaped, greenish-white flowers", "Red berries on female plants"],
        "habitat": "Cultivated in gardens and farms with well-drained soil.",
        "uses": "Culinary vegetable; ornamental plant."
    },
    "Broadleaf Plantain": {
        "description": "An extremely common and resilient perennial herb found in disturbed areas worldwide. It is well-known in traditional medicine as a soothing poultice.",
        "features": ["Rosette of broad, oval leaves", "Prominent parallel leaf veins", "Tall spike of inconspicuous green flowers", "Grows low to the ground"],
        "habitat": "Lawns, footpaths, roadsides, and compacted or disturbed soil.",
        "uses": "Traditional medicine (for stings, burns, and cuts); edible young leaves."
    },
    "Cattail": {
        "description": "A tall wetland plant easily identified by its unique brown, sausage-shaped flower spike. It is a vital resource for wildlife and has numerous survival uses.",
        "features": ["Distinctive brown cylindrical flower head", "Long, flat, blade-like leaves", "Grows in dense colonies in water", "Sturdy, tall stalk"],
        "habitat": "Marshes, ponds, ditches, and shallow freshwater edges.",
        "uses": "Edible shoots and rhizomes; weaving material; fire tinder; wildlife habitat."
    },
    "Chicory": {
        "description": "A hardy perennial with vibrant blue flowers, often seen along roadsides. Its root is famously roasted and used as a coffee substitute or additive.",
        "features": ["Bright, sky-blue daisy-like flowers", "Tough, grooved, and branching stem", "Toothed basal leaves (similar to a dandelion)", "Flowers often close in the afternoon"],
        "habitat": "Roadsides, pastures, and disturbed, sunny ground.",
        "uses": "Roasted root as a coffee substitute; edible leaves (radicchio is a variety); forage."
    },
    "Coneflower": {
        "description": "A popular North American prairie native (genus Echinacea) widely grown as an ornamental flower and as a major commercial herbal supplement.",
        "features": ["Drooping purple or pink petals", "Spiny, cone-shaped center", "Hairy stems and leaves", "Daisy-like appearance"],
        "habitat": "Native to prairies and open woodlands; widely cultivated in gardens.",
        "uses": "Popular herbal supplement (Echinacea); ornamental garden plant; vital nectar source for pollinators."
    },
    "Dandelion": {
        "description": "A ubiquitous perennial herb, often seen as a weed but also a nutritious food and a critical early-season food source for pollinators.",
        "features": ["Bright yellow composite flower", "Deeply toothed, basal leaves", "Hollow stem with milky sap", "Puffy white seed head"],
        "habitat": "Lawns, fields, roadsides, and disturbed ground worldwide.",
        "uses": "Edible leaves, flowers, and roots; traditional medicine; important for pollinators."
    },
    "Elderberry": {
        "description": "A deciduous shrub known for its large clusters of white flowers and dark purple berries. Both are widely used in culinary and medicinal preparations.",
        "features": ["Large, flat-topped clusters of creamy-white flowers", "Drooping clusters of small, dark purple-black berries", "Compound leaves with 5-9 leaflets", "Shrub or small tree form"],
        "habitat": "Woodlands, hedgerows, stream banks, and disturbed areas.",
        "uses": "Berries for syrups, jams, and wine; flowers for cordials; popular cold remedy; wildlife food."
    },
    "Japanese Knotweed": {
        "description": "A large, highly aggressive herbaceous perennial considered one of the world's most destructive invasive species. Its strong rhizomes can damage infrastructure.",
        "features": ["Hollow, bamboo-like stems with reddish speckles", "Large, spade-shaped leaves", "Plumes of small, creamy-white flowers", "Forms dense, impenetrable thickets"],
        "habitat": "Riverbanks, roadsides, gardens, and waste areas; thrives in disturbed soil.",
        "uses": "Classified as a harmful invasive pest; can damage foundations and pavement."
    },
    "Kudzu": {
        "description": "A notoriously invasive perennial vine from Asia that grows with extreme speed, blanketing trees, buildings, and landscapes in the southeastern United States.",
        "features": ["Extremely rapid vine growth", "Large compound leaves with three broad leaflets", "Purple, fragrant flowers in late summer", "Completely smothers other vegetation"],
        "habitat": "Forests, fields, and roadsides; climbs over any available structure.",
        "uses": "Considered a destructive invasive pest; sometimes used for erosion control or livestock fodder."
    },
    "Lambs Quarters": {
        "description": "A common annual weed that is also a highly nutritious wild edible, closely related to spinach and quinoa. New growth often has a powdery white coating.",
        "features": ["Diamond-shaped or triangular leaves", "White, mealy powder on new leaves and underside", "Inconspicuous green flower clusters", "Erect, branching stem"],
        "habitat": "Gardens, farms, and disturbed, nutrient-rich soil.",
        "uses": "Nutritious edible green (cooked); animal fodder."
    },
    "Mullein": {
        "description": "A distinctive biennial plant with large, fuzzy leaves that form a rosette in the first year and a tall flower spike in the second. It has a long history of medicinal use.",
        "features": ["Large, very soft, fuzzy silver-green leaves", "Tall, thick flower stalk (up to 2 meters)", "Dense spike of yellow, five-petaled flowers", "Forms a basal rosette in its first year"],
        "habitat": "Disturbed soil, pastures, roadsides, and sunny, open fields.",
        "uses": "Traditional medicine (especially for respiratory ailments); leaves as tinder."
    },
    "Red Clover": {
        "description": "A common legume grown agriculturally for forage and soil health. Its globe-shaped, pinkish-purple flower heads are a familiar sight in meadows and lawns.",
        "features": ["Three-leaflet leaves, often with a pale chevron mark", "Round, reddish-purple or pink flower head", "Low-growing, spreading habit", "Hairy stems and leaves"],
        "habitat": "Meadows, lawns, fields, and roadsides.",
        "uses": "Forage for livestock; nitrogen-fixing cover crop; traditional medicine; edible flowers."
    },
    "Sunflower": {
        "description": "A tall annual plant famous for its large flower head that tracks the sun. It is a major global agricultural crop for its seeds and oil.",
        "features": ["Very large, daisy-like flower head", "Bright yellow outer petals (ray florets)", "Tall, thick, hairy stem", "Large, rough, heart-shaped leaves"],
        "habitat": "Cultivated fields and gardens; native to prairies and dry areas.",
        "uses": "Production of sunflower oil and edible seeds; ornamental plant; bird feed."
    },
    "Tea Plant": {
        "description": "An evergreen shrub (Camellia sinensis) whose leaves and buds are harvested and processed to produce tea, one of the world's most consumed beverages.",
        "features": ["Glossy, dark green, serrated leaves", "White, fragrant flowers with yellow stamens", "Typically pruned to a waist-high shrub for cultivation", "Leathery leaf texture"],
        "habitat": "Cultivated in tropical and subtropical regions with high rainfall and acidic soil.",
        "uses": "Production of all types of tea (black, green, white, oolong); ornamental."
    },
    "Wild Grape Vine": {
        "description": "A climbing woody vine and the ancestor of most cultivated grapes. It uses tendrils to scale trees and produces small, tart fruit.",
        "features": ["Woody, climbing vine", "Grasping tendrils opposite the leaves", "Large, lobed, heart-shaped leaves", "Clusters of small, dark, tart berries"],
        "habitat": "Forests, riverbanks, fencerows, and woodland edges.",
        "uses": "Wildlife food source; edible fruit (for jams and jellies); leaves used in cooking."
    }
}

print(f"Configured to use {MAX_IMAGES_PER_SPECIES} images per species")
print(f"Include detailed responses: {INCLUDE_DETAILED_RESPONSES}")
print(f"Train/Validation split: {int((1-VALIDATION_SPLIT)*100)}% train, {int(VALIDATION_SPLIT*100)}% validation")
print(f"Use validation: {USE_VALIDATION}")
print(f"Available species: {list(PLANT_DESCRIPTIONS.keys())}")

Load the plant dataset with the custom dataset loader:

In [None]:
import os
import random
from pathlib import Path
from datasets import Dataset
from PIL import Image
import pandas as pd

class PlantDatasetLoader:
    """Custom plant dataset loader for species identification."""

    def __init__(self, dataset_path: str):
        self.dataset_path = Path(dataset_path)

    def load_dataset(self, max_per_species: int = 200, validation_split: float = 0.2, seed: int = 42, split: str = "train") -> list:
        """Load dataset with specified number of images per species and train/test split."""
        random.seed(seed)
        
        all_sample_data = []
        
        # Process each species directory
        for species_dir in self.dataset_path.iterdir():
            if not species_dir.is_dir():
                continue
                
            species_name = species_dir.name
            print(f"Processing species: {species_name}")
            
            # Get all image files
            image_files = []
            for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                image_files.extend(species_dir.glob(ext))
            
            print(f"  Found {len(image_files)} images")
            
            # Shuffle files for this species
            random.shuffle(image_files)
            
            # Limit samples per species if specified
            if max_per_species and len(image_files) > max_per_species:
                image_files = image_files[:max_per_species]
                print(f"  Limited to {max_per_species} samples")
            
            # Split this species's images into train/val
            if validation_split > 0:
                split_idx = int(len(image_files) * (1 - validation_split))
                train_files = image_files[:split_idx]
                val_files = image_files[split_idx:]
                
                # Choose which split to use
                selected_files = train_files if split == "train" else val_files
                print(f"  Using {len(selected_files)} images for {split} split")
            else:
                selected_files = image_files
            
            # Process selected images
            for image_file in selected_files:
                try:
                    image = Image.open(image_file).convert("RGB")
                    all_sample_data.append({
                        'image_path': str(image_file),
                        'image': image,
                        'species': species_name,
                        'image_name': image_file.stem
                    })
                except Exception as e:
                    print(f"  Error loading {image_file}: {e}")
                    continue
        
        return all_sample_data

# Load the plant dataset
if os.path.exists(DATASET_PATH):
    loader = PlantDatasetLoader(DATASET_PATH)
    
    # Load training data
    dataset_samples = loader.load_dataset(
        max_per_species=MAX_IMAGES_PER_SPECIES, 
        validation_split=VALIDATION_SPLIT if USE_VALIDATION else 0,
        split="train"
    )
    print(f"Loaded {len(dataset_samples)} training samples from plant dataset")

    # Show distribution
    from collections import Counter
    species_counts = Counter([sample['species'] for sample in dataset_samples])
    print(f"Training species distribution: {dict(species_counts)}")
    
    # Optionally load validation data for evaluation
    if USE_VALIDATION:
        val_dataset_samples = loader.load_dataset(
            max_per_species=MAX_IMAGES_PER_SPECIES,
            validation_split=VALIDATION_SPLIT,
            split="val"
        )
        val_species_counts = Counter([sample['species'] for sample in val_dataset_samples])
        print(f"Loaded {len(val_dataset_samples)} validation samples")
        print(f"Validation species distribution: {dict(val_species_counts)}")
    else:
        val_dataset_samples = []
        
else:
    print(f"Dataset path not found: {DATASET_PATH}")
    print("Please update DATASET_PATH to point to your plant dataset location")
    dataset_samples = []
    val_dataset_samples = []

Let's see how a sample from our plant dataset looks like!

In [None]:
if dataset_samples:
    print("Sample from plant dataset:")
    sample = dataset_samples[0]
    print(f"Image path: {sample['image_path']}")
    print(f"Species: {sample['species']}")
    print(f"Image name: {sample['image_name']}")

    # Display the image
    from PIL import Image
    import matplotlib.pyplot as plt

    img = Image.open(sample['image_path'])
    plt.figure(figsize=(8, 6))
    plt.imshow(img)
    plt.title(f"Sample Image - Species: {sample['species']}")
    plt.axis('off')
    plt.show()
else:
    print("No dataset samples loaded. Please check the DATASET_PATH.")

## Format Datast + Dataset Cache

In [None]:
import hashlib
import json
from pathlib import Path
from datasets import Dataset
import pandas as pd

class DatasetCache:
    """Cache system for processed datasets to avoid expensive HF Dataset operations."""
    
    def __init__(self, cache_dir: str = "./dataset_cache"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
    
    def _get_config_hash(self, config_dict: dict) -> str:
        """Generate hash from configuration to ensure cache validity."""
        # Sort keys to ensure consistent hashing
        config_str = json.dumps(config_dict, sort_keys=True)
        return hashlib.md5(config_str.encode()).hexdigest()
    
    def _get_cache_path(self, config_hash: str) -> Path:
        """Get cache directory path for given configuration."""
        return self.cache_dir / f"dataset_{config_hash}"
    
    def _get_metadata_path(self, config_hash: str) -> Path:
        """Get metadata file path for given configuration."""
        return self.cache_dir / f"dataset_{config_hash}_metadata.json"
    
    def get_cache_config(self, dataset_path, max_images_per_species, validation_split, 
                        include_detailed_responses, split="train"):
        """Create configuration dict for cache validation."""
        return {
            "dataset_path": str(dataset_path),
            "max_images_per_species": max_images_per_species,
            "validation_split": validation_split,
            "include_detailed_responses": include_detailed_responses,
            "split": split,
            "plant_descriptions": str(PLANT_DESCRIPTIONS),
            "identification_prompts": str(IDENTIFICATION_PROMPTS)
        }
    
    def load_dataset(self, config_dict: dict) -> tuple:
        """Load cached dataset if available and valid."""
        config_hash = self._get_config_hash(config_dict)
        cache_path = self._get_cache_path(config_hash)
        metadata_path = self._get_metadata_path(config_hash)
        
        if cache_path.exists() and metadata_path.exists():
            print(f"📦 Loading processed dataset from cache: {cache_path}")
            try:
                # Load metadata
                with open(metadata_path, 'r') as f:
                    metadata = json.load(f)
                
                # Load dataset using HF datasets
                dataset = Dataset.load_from_disk(str(cache_path))
                dataset_samples = metadata.get("dataset_samples", [])
                
                print(f"✅ Loaded cached dataset with {len(dataset)} samples")
                return dataset, dataset_samples
            except Exception as e:
                print(f"⚠️  Cache load failed: {e}")
                return None, None
        
        return None, None
    
    def save_dataset(self, config_dict: dict, dataset, dataset_samples: list = None):
        """Save processed dataset to cache."""
        config_hash = self._get_config_hash(config_dict)
        cache_path = self._get_cache_path(config_hash)
        metadata_path = self._get_metadata_path(config_hash)
        
        print(f"💾 Saving processed dataset to cache: {cache_path}")
        try:
            # Save dataset using HF datasets
            dataset.save_to_disk(str(cache_path))
            
            # Save metadata (without dataset_samples which can be large)
            metadata = {
                "config": config_dict,
                "timestamp": str(pd.Timestamp.now()),
                "dataset_size": len(dataset)
            }
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)
            
            print(f"✅ Dataset cache saved successfully")
        except Exception as e:
            print(f"❌ Dataset cache save failed: {e}")
    
    def clear_cache(self):
        """Clear all cached datasets."""
        import shutil
        for cache_dir in self.cache_dir.glob("dataset_*"):
            if cache_dir.is_dir():
                shutil.rmtree(cache_dir)
        for metadata_file in self.cache_dir.glob("dataset_*_metadata.json"):
            metadata_file.unlink()
        print("🗑️  All dataset caches cleared")
    
    def list_caches(self):
        """List all available caches with their configurations."""
        metadata_files = list(self.cache_dir.glob("dataset_*_metadata.json"))
        if not metadata_files:
            print("No dataset caches found")
            return
        
        print(f"Found {len(metadata_files)} cached dataset sets:")
        for metadata_file in metadata_files:
            try:
                with open(metadata_file, 'r') as f:
                    metadata = json.load(f)
                config = metadata.get("config", {})
                timestamp = metadata.get("timestamp", "Unknown")
                dataset_size = metadata.get("dataset_size", "Unknown")
                
                print(f"  📁 {metadata_file.stem}")
                print(f"     Dataset size: {dataset_size} samples")
                print(f"     Max images per species: {config.get('max_images_per_species', 'N/A')}")
                print(f"     Validation split: {config.get('validation_split', 'N/A')}")
                print(f"     Include detailed: {config.get('include_detailed_responses', 'N/A')}")
                print(f"     Split: {config.get('split', 'N/A')}")
                print(f"     Timestamp: {timestamp}")
                print()
            except Exception as e:
                print(f"  ❌ Error reading {metadata_file.name}: {e}")

# Initialize dataset cache system
dataset_cache = DatasetCache()
print("🚀 Dataset cache system initialized (caches final processed datasets)")
print("💡 This will cache the expensive Dataset.from_dict() and dataset.map() operations!")

Now we'll convert our plant dataset samples into the conversation format for `Gemma-3`. Each sample will be formatted as a multimodal conversation with an image and question/answer about plant species identification. We remove the `<bos>` token using removeprefix(`'<bos>'`) since we're finetuning.

In [None]:
def format_plant_conversations(dataset_samples):
    """Convert plant samples to conversation format."""
    conversations = []

    for sample in dataset_samples:
        species = sample['species']
        image = sample['image']
        
        # Get species information
        species_info = PLANT_DESCRIPTIONS.get(species, {})
        description = species_info.get("description", f"A plant from the {species} species.")
        
        # Basic identification conversation
        prompt = random.choice(IDENTIFICATION_PROMPTS)
        basic_response = f"{species}."
        
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image", "image": image}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": basic_response}
                ]
            }
        ]
        conversations.append(conversation)
        
        # Add detailed feature conversation if enabled
        if INCLUDE_DETAILED_RESPONSES and species_info:
            features_text = ", ".join(species_info.get("features", []))
            habitat = species_info.get("habitat", "Various environments")
            uses = species_info.get("uses", "Various uses")
            
            detailed_response = f"This is {species}. Key identifying features include: {features_text}. Habitat: {habitat}. Uses: {uses}."
            
            detailed_conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Can you describe the key features of this plant?"},
                        {"type": "image", "image": image}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": detailed_response}
                    ]
                }
            ]
            conversations.append(detailed_conversation)
            
            # Usage conversation for edible/useful plants
            if "edible" in uses.lower() or "medicinal" in uses.lower():
                usage_conversation = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": "What can this plant be used for?"},
                            {"type": "image", "image": image}
                        ]
                    },
                    {
                        "role": "assistant",
                        "content": [
                            {"type": "text", "text": f"This is {species}. {uses}. Always ensure proper identification before consuming wild plants."}
                        ]
                    }
                ]
                conversations.append(usage_conversation)

    return conversations

def formatting_prompts_func(examples):
    """Format conversations for training."""
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False).removeprefix('<bos>') for convo in convos]
    return {"text": texts}

In [None]:
# 🚀 CACHED DATASET CREATION - Use this instead of the expensive operations!
# This version caches the final processed dataset to avoid re-running Dataset.from_dict() and dataset.map()

if dataset_samples:
    # Create cache configuration
    cache_config = dataset_cache.get_cache_config(
        dataset_path=DATASET_PATH,
        max_images_per_species=MAX_IMAGES_PER_SPECIES,
        validation_split=VALIDATION_SPLIT if USE_VALIDATION else 0,
        include_detailed_responses=INCLUDE_DETAILED_RESPONSES,
        split="train"
    )
    
    # Try to load from cache first
    cached_dataset, cached_samples = dataset_cache.load_dataset(cache_config)
    
    if cached_dataset is not None:
        print("✅ Using cached dataset - skipping expensive Dataset.from_dict() and dataset.map()!")
        dataset = cached_dataset
        if cached_samples:
            dataset_samples = cached_samples
    else:
        print("⏳ Cache miss - creating dataset from scratch (this will take time)...")
        import time
        start_time = time.time()
        
        # Format conversations (this is relatively fast)
        print("📝 Formatting conversations...")
        conversations = format_plant_conversations(dataset_samples)
        
        # Create HF Dataset (this is slow!)
        print("📦 Creating HF Dataset from conversations...")
        dataset = Dataset.from_dict({"conversations": conversations})
        
        # Apply chat template (this is VERY slow!)
        print("🔄 Applying chat template to dataset...")
        dataset = dataset.map(formatting_prompts_func, batched=True)
        
        end_time = time.time()
        duration = end_time - start_time
        print(f"✅ Dataset processing completed in {duration:.2f} seconds")
        
        # Save to cache for next time
        dataset_cache.save_dataset(cache_config, dataset, dataset_samples)

    print(f"✅ Final dataset ready with {len(dataset)} samples")
    print(f"Dataset columns: {dataset.column_names}")
else:
    print("❌ No dataset samples to process. Please check the DATASET_PATH.")

In [None]:
# 🛠️ CACHE MANAGEMENT UTILITIES
# Use these functions to manage your dataset cache

def show_cache_status():
    """Show current cache status and available caches."""
    print("📊 Dataset Cache Management System Status")
    print("=" * 50)
    dataset_cache.list_caches()

def clear_all_caches():
    """Clear all cached datasets."""
    dataset_cache.clear_cache()
    print("✅ All dataset caches cleared!")

def get_cache_for_validation():
    """Get cached dataset for validation if available."""
    if not USE_VALIDATION:
        print("❌ Validation is disabled. Set USE_VALIDATION = True to use validation caching.")
        return None, None
    
    # Create cache configuration for validation
    val_cache_config = dataset_cache.get_cache_config(
        dataset_path=DATASET_PATH,
        max_images_per_species=MAX_IMAGES_PER_SPECIES,
        validation_split=VALIDATION_SPLIT,
        include_detailed_responses=INCLUDE_DETAILED_RESPONSES,
        split="val"  # This is for validation split
    )
    
    # Try to load validation dataset from cache
    val_dataset, val_samples = dataset_cache.load_dataset(val_cache_config)
    
    if val_dataset is not None:
        print("✅ Found cached validation dataset!")
        return val_dataset, val_samples
    else:
        print("⏳ No validation cache found. You'll need to create validation dataset.")
        return None, None

# Show current cache status
show_cache_status()

print("\n🔧 Available cache management functions:")
print("  • show_cache_status() - Show all available dataset caches")  
print("  • clear_all_caches() - Delete all cached datasets")
print("  • get_cache_for_validation() - Get cached validation dataset")
print("\n💡 Benefits of dataset caching:")
print("  • Avoids expensive Dataset.from_dict() operation")
print("  • Skips time-consuming dataset.map() chat template application")
print("  • Cache automatically validates based on your configuration")
print("  • If you change settings, a new cache will be created automatically")
print("\n⚡ Performance improvement: ~10-100x faster on subsequent runs!")

Let's see how the chat template formatted our mushroom conversation! Notice there is no `<bos>` token as the processor tokenizer will be adding one.

In [None]:
if dataset_samples and len(dataset) > 0:
    print("Formatted conversation example:")
    print(dataset[0]["text"])
else:
    print("No formatted dataset available.")

<a name="Train"></a>
### Train the model
Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.

In [None]:
from trl import SFTTrainer, SFTConfig

# Prepare validation dataset if available
eval_dataset = None
if USE_VALIDATION and val_dataset_samples:
    # Convert validation samples to conversation format
    val_conversations = format_plant_conversations(val_dataset_samples)
    val_dataset = Dataset.from_dict({"conversations": val_conversations})
    val_dataset = val_dataset.map(formatting_prompts_func, batched=True)
    eval_dataset = val_dataset
    print(f"Prepared validation dataset with {len(eval_dataset)} samples")

# Configure training parameters based on dataset size
if dataset_samples:
    dataset_size = len(dataset)
    # Adjust training steps based on dataset size
    if dataset_size < 100:
        max_steps = 20
    elif dataset_size < 500:
        max_steps = 40
    else:
        max_steps = 100

    print(f"Training dataset size: {dataset_size}")
    if eval_dataset:
        print(f"Validation dataset size: {len(eval_dataset)}")
    print(f"Training steps: {max_steps}")

    trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = dataset,
        eval_dataset = eval_dataset, # Now includes validation data!
        args = SFTConfig(
            dataset_text_field = "text",
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 4, # Use GA to mimic batch size!
            warmup_steps = 5,
            # num_train_epochs = 1, # Set this for 1 full training run.
            max_steps = max_steps,
            learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
            logging_steps = 1,
            optim = "adamw_8bit",
            weight_decay = 0.01,
            lr_scheduler_type = "linear",
            seed = 3407,
            report_to = "none", # Use this for WandB etc

            eval_steps = 10,
        ),
    )
else:
    print("No dataset loaded. Cannot create trainer.")

We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!

In [None]:
if dataset_samples and 'trainer' in locals():
    from unsloth.chat_templates import train_on_responses_only
    trainer = train_on_responses_only(
        trainer,
        instruction_part = "<start_of_turn>user\n",
        response_part = "<start_of_turn>model\n",
    )
    print("Trainer configured to train only on model responses.")
else:
    print("Trainer not available for response masking configuration.")

Let's verify masking the instruction part is done! Let's print the 100th row again.  Notice how the sample only has a single `<bos>` as expected!

In [None]:
if dataset_samples and 'trainer' in locals() and len(trainer.train_dataset) > 0:
    print("Training dataset sample (input_ids):")
    print(tokenizer.decode(trainer.train_dataset[0]["input_ids"]))
else:
    print("Training dataset not available for inspection.")

Now let's print the masked out example - you should see only the answer is present:

In [None]:
if dataset_samples and 'trainer' in locals() and len(trainer.train_dataset) > 0:
    print("Training dataset sample (labels - masked):")
    print(tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[0]["labels"]]).replace(tokenizer.pad_token, " "))
else:
    print("Training dataset not available for label inspection.")

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

# Let's train the model!

To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

In [None]:
if dataset_samples and 'trainer' in locals():
    trainer_stats = trainer.train()
    print("Training completed successfully!")
else:
    print("Cannot start training: No dataset loaded or trainer not initialized.")

In [None]:
# @title Show final memory and time stats
if 'trainer_stats' in locals():
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    used_percentage = round(used_memory / max_memory * 100, 3)
    lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
    print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
    print(
        f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
    )
    print(f"Peak reserved memory = {used_memory} GB.")
    print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
    print(f"Peak reserved memory % of max memory = {used_percentage} %.")
    print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
else:
    print("Training was not executed, so no stats to display.")

<a name="Inference"></a>
### Inference
Let's run the model via Unsloth native inference! According to the `Gemma-3` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`

In [None]:
# Test the model on plant identification
if dataset_samples:
    # Use a sample from our dataset for testing
    test_sample = dataset_samples[0]
    test_image_path = test_sample['image_path']
    test_image = Image.open(test_image_path)

    print(f"Testing with image: {test_image_path}")

    messages = [{
        "role": "user",
        "content": [
            {"type": "text", "text": "What type of plant is this?"},
            {"type": "image", "image": test_image}
        ]
    }]

    # Important: Use the original tokenizer_inference for multimodal inference
    # The get_chat_template() function may break vision capabilities
    inputs = tokenizer_inference.apply_chat_template(
        messages,
        add_generation_prompt = True, # Must add for generation
        return_tensors = "pt",
        tokenize = True,
        return_dict = True,
    ).to("cuda")

    outputs = model.generate(
        **inputs,
        max_new_tokens = 64, # Longer response for plant descriptions
        # Recommended Gemma-3 settings!
        temperature = 1.0, top_p = 0.95, top_k = 64,
    )

    print(f"Testing on image: {test_sample['image_name']}")
    print(f"Ground truth species: {test_sample['species']}")
    print(f"Model response: {tokenizer_inference.batch_decode(outputs)[0]}")
else:
    print("No dataset samples available for testing.")

 You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!

<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
model.save_pretrained("gemma-3n-e2b-it-plant-8bit_lora")  # Local saving
tokenizer.save_pretrained("gemma-3n-e2b-it-plant-8bit_lora")
# model.push_to_hub("wrongryan/gemma-3n-it-plant-4bit", token = "XXX") # Online saving
# tokenizer.push_to_hub("wrongryan/gemma-3n-it-plant-4bit", token = "XXX") # Online saving

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [None]:
# if False:
#     from unsloth import FastModel
#     model, tokenizer = FastModel.from_pretrained(
#         model_name = "lora_model", # YOUR MODEL YOU USED FOR TRAINING
#         max_seq_length = 2048,
#         load_in_4bit = True,
#     )

# messages = [{
#     "role": "user",
#     "content": [{"type" : "text", "text" : "What is Gemma-3N?",}]
# }]
# inputs = tokenizer.apply_chat_template(
#     messages,
#     add_generation_prompt = True, # Must add for generation
#     return_tensors = "pt",
#     tokenize = True,
#     return_dict = True,
# ).to("cuda")

# from transformers import TextStreamer
# _ = model.generate(
#     **inputs,
#     max_new_tokens = 128, # Increase for longer outputs!
#     # Recommended Gemma-3 settings!
#     temperature = 1.0, top_p = 0.95, top_k = 64,
#     streamer = TextStreamer(tokenizer, skip_prompt = True),
# )

### Saving to float16 for VLLM

We also support saving to `float16` directly for deployment! We save it in the folder `gemma-3N-finetune`. Set `if False` to `if True` to let it run!

In [None]:
if True: # Change to True to save finetune!
    model.save_pretrained_merged("gemma-3n-e2b-it-plant-8bit_lora", tokenizer)

If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!

In [None]:
if True: # Change to True to upload finetune
    model.push_to_hub_merged(
        "wrongryan/gemma-3n-e2b-it-plant-8bit-lora", tokenizer,
        token = "XXX"
    )

### Test Set Evaluation (Valid Species Only)

Let's evaluate the model's performance on a proper test set with comprehensive metrics and analysis, filtering to only include valid species from our training data.


In [None]:
# Get valid species from training data
def get_valid_species_from_training():
    return list(PLANT_DESCRIPTIONS.keys())

# Load test set and filter for valid species only
TEST_DATASET_PATH = "../../data/plants/test/"  # Path to the holdout test set

print("Loading test set...")
if os.path.exists(TEST_DATASET_PATH):
    test_loader = PlantDatasetLoader(TEST_DATASET_PATH)
    all_test_samples = test_loader.load_dataset(
        max_per_species=None,  # Use all available test images
        validation_split=0,    # No split needed, use all as test
        split="train"         # Using "train" parameter but loading from test directory
    )
    print(f"Loaded {len(all_test_samples)} total test samples")
    
    # Get valid species from training
    valid_species = get_valid_species_from_training()
    print(f"Valid species from training: {len(valid_species)} species")
    print(f"Valid species: {valid_species[:10]}{'...' if len(valid_species) > 10 else ''}")
    
    # Filter test samples to only include valid species
    test_samples = [sample for sample in all_test_samples if sample['species'] in valid_species]
    
    print(f"Filtered to {len(test_samples)} test samples with valid species")
    print(f"Excluded {len(all_test_samples) - len(test_samples)} samples with invalid species")
    
    # Show test set distribution for valid species
    from collections import Counter
    test_species_counts = Counter([sample['species'] for sample in test_samples])
    invalid_species = Counter([sample['species'] for sample in all_test_samples if sample['species'] not in valid_species])
    
    print(f"\\nValid species in test set: {len(test_species_counts)} species")
    print(f"Test set species distribution (top 10): {dict(list(test_species_counts.most_common(10)))}")
    
    if invalid_species:
        print(f"\\nExcluded species: {len(invalid_species)} species")
        print(f"Excluded species (top 5): {dict(list(invalid_species.most_common(5)))}")
        
else:
    print(f"Test dataset path not found: {TEST_DATASET_PATH}")
    print("Please ensure you have a test directory with the holdout test set")
    test_samples = []
    valid_species = []


In [None]:
import re
from tqdm import tqdm
import torch

def extract_species_from_response(response_text, valid_species):
    """Extract the predicted species from model response text."""
    # Convert response to lowercase for matching
    response_lower = response_text.lower()
    
    # Create normalized species names for matching
    normalized_species = {}
    for species in valid_species:
        # Store original -> normalized mapping
        normalized_species[species.lower().replace(' ', '').replace('-', '')] = species
        normalized_species[species.lower()] = species
    
    # Common patterns to look for species mentions
    patterns = [
        # Direct patterns like "This is a Dandelion plant"
        r'(?:this|it|these)\\s+(?:is|are|appears?|looks?|seems?)\\s+(?:a|an)?\\s*([\\w\\s]+?)\\s+(?:plant|species|flower|weed)',
        # Patterns like "Dandelion species" or "Dandelion plant"
        r'([\\w\\s]+?)\\s+(?:species|plant|flower|weed)',
        # Patterns like "species Dandelion"
        r'species\\s+([\\w\\s]+)',
        # Patterns with **bold** markdown
        r'\\*\\*([\\w\\s]+?)\\*\\*',
        # Direct species names at start of sentences
        r'(?:^|\\.\\s+)([\\w\\s]+?)\\s+(?:is|are|species)',
        # Common plant-specific patterns
        r'(?:plant|flower|weed|species).*?([\\w\\s]+)',
        # Just capture potential species names (2-3 words)
        r'\\b([A-Z][a-z]+(?:\\s+[A-Z][a-z]+){0,2})\\b',
    ]
    
    # First try exact species name matching
    for species in valid_species:
        if species.lower() in response_lower:
            return species
    
    # Then try pattern matching
    for pattern in patterns:
        matches = re.findall(pattern, response_lower, re.IGNORECASE)
        for match in matches:
            match = match.strip()
            if not match:
                continue
                
            # Try direct match
            if match.lower() in normalized_species:
                return normalized_species[match.lower()]
            
            # Try normalized match (remove spaces/hyphens)
            normalized_match = match.lower().replace(' ', '').replace('-', '')
            if normalized_match in normalized_species:
                return normalized_species[normalized_match]
            
            # Try partial matching for compound names
            for species in valid_species:
                species_words = species.lower().split()
                match_words = match.lower().split()
                
                # If all words in match are in species name
                if len(match_words) >= 2 and all(word in species_words for word in match_words):
                    return species
                    
                # If match contains the key part of species name
                if len(species_words) >= 2 and any(word in match.lower() for word in species_words):
                    if len([word for word in species_words if word in match.lower()]) >= len(species_words) // 2:
                        return species
    
    return "Unknown"

def run_test_evaluation_valid_species(test_samples, model, tokenizer_inference, valid_species, max_samples=None):
    """Run comprehensive evaluation on test samples with valid species only."""
    if not test_samples:
        print("No test samples available for evaluation.")
        return [], [], []
    
    # Limit test samples if specified
    if max_samples and len(test_samples) > max_samples:
        test_samples = random.sample(test_samples, max_samples)
        print(f"Limited test set to {max_samples} samples for faster evaluation")
    
    predictions = []
    ground_truths = []
    responses = []
    
    print(f"Running inference on {len(test_samples)} test samples for {len(valid_species)} valid species...")
    
    # Run inference on each test sample
    for i, sample in enumerate(tqdm(test_samples, desc="Testing Plants")):
        try:
            # Prepare the message - using plant-specific prompt
            messages = [{
                "role": "user",
                "content": [
                    {"type": "text", "text": "What type of plant is this? Please respond concisely."},
                    {"type": "image", "image": sample['image']}
                ]
            }]
            
            # Prepare inputs
            inputs = tokenizer_inference.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt",
                tokenize=True,
                return_dict=True,
            ).to("cuda")
            
            # Generate response
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=50,  # Keep responses concise
                    temperature=1.0, 
                    top_p=0.95, 
                    top_k=64,
                    do_sample=True,
                    pad_token_id=tokenizer_inference.eos_token_id
                )
            
            # Decode response
            full_response = tokenizer_inference.batch_decode(outputs, skip_special_tokens=True)[0]
            
            # Extract just the model's response (after the user message)
            response_parts = full_response.split("<start_of_turn>model")
            if len(response_parts) > 1:
                model_response = response_parts[-1].replace("<end_of_turn>", "").strip()
            else:
                model_response = full_response
            
            # Extract species prediction
            predicted_species = extract_species_from_response(model_response, valid_species)
            
            # Store results
            predictions.append(predicted_species)
            ground_truths.append(sample['species'])
            responses.append(model_response)
            
            # Show progress for first few samples
            if i < 5:
                print(f"\\nSample {i+1}:")
                print(f"  Ground truth: {sample['species']}")
                print(f"  Model response: {model_response[:100]}...")
                print(f"  Predicted species: {predicted_species}")
        
        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            predictions.append("Error")
            ground_truths.append(sample['species'])
            responses.append(f"Error: {e}")
    
    return predictions, ground_truths, responses

# Run the evaluation on valid species only
if test_samples and valid_species:
    predictions, ground_truths, responses = run_test_evaluation_valid_species(
        test_samples, model, tokenizer_inference, valid_species, max_samples=2000
    )
    print(f"\\nCompleted evaluation on {len(predictions)} samples for valid species")
    
    # Show prediction accuracy for valid vs invalid species
    valid_predictions = [pred for pred, truth in zip(predictions, ground_truths) if pred in valid_species]
    invalid_predictions = [pred for pred, truth in zip(predictions, ground_truths) if pred not in valid_species and pred != "Error"]
    
    print(f"Predictions mapping to valid species: {len(valid_predictions)}/{len(predictions)} ({100*len(valid_predictions)/len(predictions):.1f}%)")
    print(f"Predictions mapping to invalid species: {len(invalid_predictions)}")
    print(f"Error predictions: {predictions.count('Error')}")
    
else:
    predictions, ground_truths, responses = [], [], []
    print("Skipping evaluation - no test samples or valid species loaded")


In [None]:
# Calculate comprehensive classification metrics for valid species only
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
import numpy as np

def calculate_classification_metrics_valid_species(predictions, ground_truths, valid_species):
    """Calculate comprehensive classification metrics for valid species only."""
    if not predictions or not ground_truths:
        print("No predictions or ground truths available for metrics calculation.")
        return {}
    
    # Filter to include only:
    # 1. Valid predictions (not "Error" or "Unknown")
    # 2. Ground truth species that are in valid_species 
    # 3. Predictions that map to valid species
    valid_indices = []
    for i, (pred, truth) in enumerate(zip(predictions, ground_truths)):
        if (pred != "Error" and pred != "Unknown" and 
            truth in valid_species and pred in valid_species):
            valid_indices.append(i)
    
    valid_predictions = [predictions[i] for i in valid_indices]
    valid_ground_truths = [ground_truths[i] for i in valid_indices]
    
    if not valid_predictions:
        print("No valid predictions available for metrics calculation.")
        return {}
    
    print(f"Calculating metrics for {len(valid_predictions)} valid predictions from {len(valid_species)} species")
    
    # Overall accuracy (only for valid species)
    accuracy = accuracy_score(valid_ground_truths, valid_predictions)
    
    # Get unique labels that appear in both predictions and ground truth
    all_labels = sorted(list(set(valid_ground_truths + valid_predictions)))
    
    # Per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        valid_ground_truths, valid_predictions, labels=all_labels, average=None, zero_division=0
    )
    
    # Macro and weighted averages
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        valid_ground_truths, valid_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        valid_ground_truths, valid_predictions, average='weighted', zero_division=0
    )
    
    # Confusion matrix
    cm = confusion_matrix(valid_ground_truths, valid_predictions, labels=all_labels)
    
    # Calculate coverage metrics
    predicted_species = set([pred for pred in predictions if pred in valid_species])
    ground_truth_species = set([truth for truth in ground_truths if truth in valid_species])
    
    species_coverage = len(predicted_species) / len(valid_species) if valid_species else 0
    species_recall = len(predicted_species.intersection(ground_truth_species)) / len(ground_truth_species) if ground_truth_species else 0
    
    metrics = {
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'per_class_precision': dict(zip(all_labels, precision)),
        'per_class_recall': dict(zip(all_labels, recall)),
        'per_class_f1': dict(zip(all_labels, f1)),
        'per_class_support': dict(zip(all_labels, support)),
        'confusion_matrix': cm,
        'labels': all_labels,
        'valid_predictions': valid_predictions,
        'valid_ground_truths': valid_ground_truths,
        'total_samples': len(predictions),
        'valid_samples': len(valid_predictions),
        'error_samples': predictions.count("Error"),
        'unknown_samples': predictions.count("Unknown"),
        'invalid_species_predictions': len([p for p in predictions if p not in valid_species and p not in ["Error", "Unknown"]]),
        'species_coverage': species_coverage,
        'species_recall': species_recall,
        'total_valid_species': len(valid_species),
        'predicted_species_count': len(predicted_species),
        'ground_truth_species_count': len(ground_truth_species)
    }
    
    return metrics

# Calculate metrics for valid species only
if predictions and ground_truths and valid_species:
    metrics = calculate_classification_metrics_valid_species(predictions, ground_truths, valid_species)
    
    if metrics:
        print("\\n" + "="*60)
        print("VALID SPECIES CLASSIFICATION METRICS")
        print("="*60)
        
        # Overall statistics
        print(f"Total test samples: {metrics['total_samples']}")
        print(f"Valid species predictions: {metrics['valid_samples']}")
        print(f"Error samples: {metrics['error_samples']}")
        print(f"Unknown species samples: {metrics['unknown_samples']}")
        print(f"Invalid species predictions: {metrics['invalid_species_predictions']}")
        
        # Species coverage analysis
        print(f"\\nSpecies Coverage Analysis:")
        print(f"Total valid species in training: {metrics['total_valid_species']}")
        print(f"Species represented in test set: {metrics['ground_truth_species_count']}")
        print(f"Species predicted by model: {metrics['predicted_species_count']}")
        print(f"Species coverage: {metrics['species_coverage']:.3f} ({metrics['predicted_species_count']}/{metrics['total_valid_species']})")
        print(f"Species recall: {metrics['species_recall']:.3f}")
        
        # Classification performance
        print(f"\\nClassification Performance (Valid Species Only):")
        print(f"Overall Accuracy: {metrics['accuracy']:.3f}")
        print(f"Macro-averaged Precision: {metrics['precision_macro']:.3f}")
        print(f"Macro-averaged Recall: {metrics['recall_macro']:.3f}")
        print(f"Macro-averaged F1-score: {metrics['f1_macro']:.3f}")
        print(f"Weighted-averaged Precision: {metrics['precision_weighted']:.3f}")
        print(f"Weighted-averaged Recall: {metrics['recall_weighted']:.3f}")
        print(f"Weighted-averaged F1-score: {metrics['f1_weighted']:.3f}")
        
        # Top performing species
        print(f"\\nTop Performing Species (F1-Score):")
        print("-"*50)
        species_f1 = [(species, f1) for species, f1 in metrics['per_class_f1'].items()]
        species_f1.sort(key=lambda x: x[1], reverse=True)
        for species, f1 in species_f1[:10]:
            precision = metrics['per_class_precision'][species]
            recall = metrics['per_class_recall'][species]
            support = metrics['per_class_support'][species]
            print(f"{species:25} | F1: {f1:.3f} | P: {precision:.3f} | R: {recall:.3f} | Support: {support}")
        
        # Detailed classification report
        print(f"\\nDetailed Classification Report (Valid Species):")
        print("-"*60)
        print(classification_report(metrics['valid_ground_truths'], metrics['valid_predictions']))
        
else:
    print("No predictions available for metrics calculation.")
    metrics = {}


In [None]:
# Create comprehensive visualizations and analysis
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap

def create_evaluation_visualizations(metrics):
    """Create comprehensive evaluation visualizations."""
    if not metrics:
        print("No metrics available for visualization.")
        return
    
    # Set up the plotting style
    plt.style.use('default')
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Confusion Matrix
    ax1 = plt.subplot(2, 3, 1)
    cm = metrics['confusion_matrix']
    labels = metrics['labels']
    
    # Create a custom colormap
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=labels, yticklabels=labels,
                cbar_kws={'label': 'Number of Samples'})
    plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
    plt.xlabel('Predicted Species', fontsize=12)
    plt.ylabel('True Species', fontsize=12)
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    
    # 2. Per-Class Accuracy Bar Chart
    ax2 = plt.subplot(2, 3, 2)
    per_class_acc = []
    class_names = []
    for label in labels:
        if label in metrics['per_class_recall']:
            per_class_acc.append(metrics['per_class_recall'][label])
            class_names.append(label)
    
    bars = plt.bar(range(len(class_names)), per_class_acc, 
                   color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22'])
    plt.title('Per-Species Recall (Sensitivity)', fontsize=14, fontweight='bold')
    plt.xlabel('Species', fontsize=12)
    plt.ylabel('Recall', fontsize=12)
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.ylim(0, 1.1)
    
    # Add value labels on bars
    for i, (bar, acc) in enumerate(zip(bars, per_class_acc)):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')
    
    # 3. Precision vs Recall Scatter Plot
    ax3 = plt.subplot(2, 3, 3)
    precisions = [metrics['per_class_precision'][label] for label in labels if label in metrics['per_class_precision']]
    recalls = [metrics['per_class_recall'][label] for label in labels if label in metrics['per_class_recall']]
    
    scatter = plt.scatter(recalls, precisions, s=100, alpha=0.7, 
                         c=range(len(labels)), cmap='tab10')
    plt.title('Precision vs Recall by Species', fontsize=14, fontweight='bold')
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.xlim(0, 1.1)
    plt.ylim(0, 1.1)
    
    # Add diagonal line (perfect precision = recall)
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect P=R')
    
    # Add species labels
    for i, label in enumerate(class_names):
        if i < len(recalls) and i < len(precisions):
            plt.annotate(label, (recalls[i], precisions[i]), 
                        xytext=(5, 5), textcoords='offset points', fontsize=9)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 4. F1-Score Bar Chart
    ax4 = plt.subplot(2, 3, 4)
    f1_scores = [metrics['per_class_f1'][label] for label in labels if label in metrics['per_class_f1']]
    
    bars = plt.bar(range(len(class_names)), f1_scores, 
                   color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22'])
    plt.title('Per-Species F1-Score', fontsize=14, fontweight='bold')
    plt.xlabel('Species', fontsize=12)
    plt.ylabel('F1-Score', fontsize=12)
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.ylim(0, 1.1)
    
    # Add value labels on bars
    for i, (bar, f1) in enumerate(zip(bars, f1_scores)):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{f1:.3f}', ha='center', va='bottom', fontweight='bold')
    
    # 5. Sample Support Distribution
    ax5 = plt.subplot(2, 3, 5)
    supports = [metrics['per_class_support'][label] for label in labels if label in metrics['per_class_support']]
    
    bars = plt.bar(range(len(class_names)), supports, 
                   color=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#ff99cc', '#c2c2f0', '#ffb3e6', '#c4e17f', '#ffff99'])
    plt.title('Test Set Sample Distribution', fontsize=14, fontweight='bold')
    plt.xlabel('Species', fontsize=12)
    plt.ylabel('Number of Test Samples', fontsize=12)
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    
    # Add value labels on bars
    for i, (bar, support) in enumerate(zip(bars, supports)):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                f'{support}', ha='center', va='bottom', fontweight='bold')
    
    # 6. Overall Metrics Summary
    ax6 = plt.subplot(2, 3, 6)
    ax6.axis('off')
    
    # Create a summary text
    summary_text = f"""
    OVERALL PERFORMANCE SUMMARY
    
    Total Test Samples: {metrics['total_samples']}
    Valid Predictions: {metrics['valid_samples']}
    Error Rate: {metrics['error_samples']}/{metrics['total_samples']} ({100*metrics['error_samples']/metrics['total_samples']:.1f}%)
    
    Overall Accuracy: {metrics['accuracy']:.3f}
    
    Macro Averages:
    • Precision: {metrics['precision_macro']:.3f}
    • Recall: {metrics['recall_macro']:.3f}
    • F1-Score: {metrics['f1_macro']:.3f}
    
    Weighted Averages:
    • Precision: {metrics['precision_weighted']:.3f}
    • Recall: {metrics['recall_weighted']:.3f}
    • F1-Score: {metrics['f1_weighted']:.3f}
    """
    
    ax6.text(0.1, 0.9, summary_text, transform=ax6.transAxes, fontsize=11,
             verticalalignment='top', bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))
    
    plt.tight_layout()
    plt.show()
    
    # Additional Analysis: Top Confusion Pairs
    print("\n" + "="*50)
    print("CONFUSION ANALYSIS")
    print("="*50)
    
    # Find most confused pairs
    confusion_pairs = []
    for i, true_label in enumerate(labels):
        for j, pred_label in enumerate(labels):
            if i != j and cm[i, j] > 0:
                confusion_pairs.append((true_label, pred_label, cm[i, j]))
    
    # Sort by confusion count
    confusion_pairs.sort(key=lambda x: x[2], reverse=True)
    
    print("Most Common Misclassifications:")
    for true_label, pred_label, count in confusion_pairs[:10]:
        print(f"  {true_label} → {pred_label}: {count} samples")
    
    if not confusion_pairs:
        print("  No misclassifications found! Perfect performance!")

# Create visualizations
if metrics:
    create_evaluation_visualizations(metrics)
else:
    print("No metrics available for visualization creation.")


In [None]:
# Final evaluation summary and completion status
if test_samples and predictions and ground_truths and valid_species:
    print("\\n" + "="*70)
    print("PLANT SPECIES TEST SET EVALUATION COMPLETED!")
    print("="*70)
    print("✅ Loaded test set and filtered for valid species")
    print("✅ Ran inference on test samples") 
    print("✅ Calculated classification metrics for valid species only")
    print("✅ Generated comprehensive visualizations and analysis")
    print("✅ Analyzed species coverage and performance")
    
    print(f"\\nKey Results:")
    print(f"- Evaluated {len(test_samples)} test samples")
    print(f"- Valid species predictions: {metrics['valid_samples']} ({100*metrics['valid_samples']/len(predictions):.1f}%)")
    print(f"- Overall accuracy on valid species: {metrics['accuracy']:.3f}")
    print(f"- Species coverage: {metrics['predicted_species_count']}/{metrics['total_valid_species']} ({metrics['species_coverage']:.1%})")
    
    # Show some example correct and incorrect predictions
    correct_examples = [(pred, truth) for pred, truth in zip(predictions, ground_truths) if pred == truth and pred in valid_species]
    incorrect_examples = [(pred, truth) for pred, truth in zip(predictions, ground_truths) if pred != truth and pred in valid_species and truth in valid_species]
    
    if correct_examples:
        print(f"\\nExample correct predictions: {correct_examples[:3]}")
    if incorrect_examples:
        print(f"Example incorrect predictions: {incorrect_examples[:3]}")
        
else:
    print("Cannot complete evaluation - missing test data, predictions, or valid species.")


In [None]:
# Show example predictions for qualitative analysis
def show_prediction_examples(test_samples, predictions, ground_truths, responses, num_examples=8):
    """Show example predictions with images for qualitative analysis."""
    if not test_samples or not predictions:
        print("No test samples or predictions available for examples.")
        return
    
    # Select examples: some correct, some incorrect
    correct_indices = [i for i, (pred, true) in enumerate(zip(predictions, ground_truths)) 
                      if pred == true and pred != "Error"]
    incorrect_indices = [i for i, (pred, true) in enumerate(zip(predictions, ground_truths)) 
                        if pred != true and pred != "Error"]
    
    # Select examples to show
    selected_indices = []
    
    # Add some correct examples
    if correct_indices:
        selected_indices.extend(random.sample(correct_indices, min(4, len(correct_indices))))
    
    # Add some incorrect examples  
    if incorrect_indices:
        selected_indices.extend(random.sample(incorrect_indices, min(4, len(incorrect_indices))))
    
    # If we don't have enough, fill with any available
    if len(selected_indices) < num_examples:
        remaining_indices = [i for i in range(len(predictions)) if i not in selected_indices and predictions[i] != "Error"]
        selected_indices.extend(random.sample(remaining_indices, min(num_examples - len(selected_indices), len(remaining_indices))))
    
    if not selected_indices:
        print("No valid examples to display.")
        return
    
    # Create the visualization
    cols = 4
    rows = (len(selected_indices) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(20, 5*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    for idx, example_idx in enumerate(selected_indices):
        row = idx // cols
        col = idx % cols
        ax = axes[row, col] if rows > 1 else axes[col]
        
        # Get example data
        sample = test_samples[example_idx]
        pred = predictions[example_idx]
        true = ground_truths[example_idx]
        response = responses[example_idx]
        
        # Display image
        ax.imshow(sample['image'])
        ax.axis('off')
        
        # Create title with prediction info
        is_correct = pred == true
        status = "✓ CORRECT" if is_correct else "✗ INCORRECT"
        color = 'green' if is_correct else 'red'
        
        title = f"{status}\nTrue: {true}\nPred: {pred}"
        ax.set_title(title, fontsize=10, fontweight='bold', color=color, pad=10)
        
        # Add response as text below
        response_text = response[:60] + "..." if len(response) > 60 else response
        ax.text(0.5, -0.1, f"Response: {response_text}", 
                transform=ax.transAxes, ha='center', va='top', 
                fontsize=8, wrap=True, style='italic')
    
    # Hide any unused subplots
    for idx in range(len(selected_indices), rows * cols):
        row = idx // cols
        col = idx % cols
        ax = axes[row, col] if rows > 1 else axes[col]
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed results
    print("\n" + "="*50)
    print("EXAMPLE PREDICTIONS ANALYSIS")
    print("="*50)
    
    correct_count = len(correct_indices)
    incorrect_count = len(incorrect_indices)
    total_valid = correct_count + incorrect_count
    
    print(f"Showing {len(selected_indices)} example predictions:")
    print(f"  • Correct predictions available: {correct_count}/{total_valid}")
    print(f"  • Incorrect predictions available: {incorrect_count}/{total_valid}")
    
    if incorrect_indices:
        print(f"\nSample of incorrect predictions:")
        for i, idx in enumerate(incorrect_indices[:5]):
            pred = predictions[idx]
            true = ground_truths[idx]
            print(f"  {i+1}. True: {true} → Predicted: {pred}")

# Show example predictions
if test_samples and predictions and ground_truths:
    show_prediction_examples(test_samples, predictions, ground_truths, responses)
    
    # Update todo status
    print("\n🎉 Test set evaluation completed!")
    print("✅ Loaded test set")
    print("✅ Ran inference on test samples") 
    print("✅ Calculated classification metrics")
    print("✅ Generated comprehensive visualizations")
    print("✅ Analyzed prediction examples")
else:
    print("Cannot show examples - missing test data or predictions.")


### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!

In [None]:
if True: # Change to True to save to GGUF
    model.save_pretrained_gguf(
        "gemma-3n-it-plant-4bit_gguf",
        quantization_type = "Q8_0", # For now only Q8_0, BF16, F16 supported
    )

Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!

In [None]:
if True: # Change to True to upload GGUF
    model.push_to_hub_gguf(
        "gemma-3n-it-plant-4bit",
        quantization_type = "Q8_0", # Only Q8_0, BF16, F16 supported
        repo_id = "wrongryan/gemma-3n-it-plant-4bit-gguf",
        token = "XXX",
    )