In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity

2025-09-30 04:38:07.267177: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759207087.595925      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759207087.688235      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# ---------------------------
# Configuration
# ---------------------------
CONFIG = {
    'num_reference_images': 50, 
    'embedding_method': 'mean',
    'augment_reference': True,
    'augment_query': True, 
    'image_preprocessing': True,
}

In [3]:
# ---------------------------
# Step 1. Setup paths and explore dataset
# ---------------------------
dataset_path = Path("/kaggle/input/whoi-plankton-2014/2014")

# List species folders
species_folders = [p for p in dataset_path.iterdir() if p.is_dir()]
print(f"Number of species: {len(species_folders)}")
print(f"Example species: {[f.name for f in species_folders[:5]]}")

Number of species: 94
Example species: ['DactFragCerataul', 'Rhizosolenia', 'Chaetoceros', 'bead', 'G_delicatula_external_parasite']


In [4]:
# ---------------------------
# Step 2. Image Preprocessing Functions
# ---------------------------
def preprocess_plankton_image(pil_img: Image.Image, target_size=224) -> Image.Image:
    """
    Preprocess plankton images for better CLIP performance
    - Resize while maintaining aspect ratio
    - Add padding to make square
    - Enhance contrast
    """
    # Convert to RGB
    img = pil_img.convert("RGB")
    
    # Get original size
    orig_w, orig_h = img.size
    
    # Calculate scaling to fit within target_size while maintaining aspect ratio
    scale = min(target_size / orig_w, target_size / orig_h)
    new_w, new_h = int(orig_w * scale), int(orig_h * scale)
    
    # Resize image
    img = img.resize((new_w, new_h), Image.LANCZOS)
    
    # Create white background (plankton images often have white/light backgrounds)
    background = Image.new('RGB', (target_size, target_size), (255, 255, 255))
    
    # Paste resized image in center
    offset = ((target_size - new_w) // 2, (target_size - new_h) // 2)
    background.paste(img, offset)
    
    return background


def apply_augmentation(pil_img: Image.Image) -> list:
    """
    Apply augmentations to generate multiple views of the same image
    Returns list of augmented images
    """
    augmented = [pil_img]  # Original
    
    # Horizontal flip
    augmented.append(pil_img.transpose(Image.FLIP_LEFT_RIGHT))
    
    # Vertical flip
    augmented.append(pil_img.transpose(Image.FLIP_TOP_BOTTOM))
    
    # Slight rotations
    augmented.append(pil_img.rotate(90, expand=True))
    augmented.append(pil_img.rotate(180, expand=True))
    augmented.append(pil_img.rotate(270, expand=True))
    
    return augmented

In [5]:
# ---------------------------
# Step 3. CLIP setup
# ---------------------------
model_path = "/kaggle/input/openaiclip-vit-base-patch32"

print("Loading CLIP model...")
clip_model = CLIPModel.from_pretrained(model_path)
clip_processor = CLIPProcessor.from_pretrained(model_path)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = clip_model.to(DEVICE)
print(f"Using device: {DEVICE}")

@torch.no_grad()
def image_embedding(pil_img: Image.Image) -> np.ndarray:
    """Extract CLIP embedding from PIL image"""
    inputs = clip_processor(images=pil_img, return_tensors="pt").to(DEVICE)
    feats = clip_model.get_image_features(**inputs).cpu().numpy()
    feats = feats / np.linalg.norm(feats, axis=1, keepdims=True)
    return feats[0]

@torch.no_grad()
def batch_image_embeddings(pil_images: list) -> np.ndarray:
    """Extract CLIP embeddings for multiple images at once (faster)"""
    inputs = clip_processor(images=pil_images, return_tensors="pt").to(DEVICE)
    feats = clip_model.get_image_features(**inputs).cpu().numpy()
    feats = feats / np.linalg.norm(feats, axis=1, keepdims=True)
    return feats

Loading CLIP model...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Using device: cpu


In [6]:
# ---------------------------
# Step 4. Build IMPROVED reference dataset
# ---------------------------
print(f"Using {CONFIG['num_reference_images']} images per species")
print(f"Augmentation: {CONFIG['augment_reference']}")
print(f"Preprocessing: {CONFIG['image_preprocessing']}")

species_embeddings = {}
records = []

for species_folder in tqdm(sorted(dataset_path.iterdir()), desc="Processing species"):
    if not species_folder.is_dir():
        continue
    
    species_name = species_folder.name
    png_files = list(species_folder.glob("*.png"))
    
    if not png_files:
        print(f"Warning: No PNG files found in {species_name}")
        continue
    
    # Select multiple reference images
    num_refs = min(CONFIG['num_reference_images'], len(png_files))
    
    # Distribute selection across the dataset (not just first N)
    if num_refs >= len(png_files):
        # Use all images if we have fewer than requested
        selected_files = png_files
    else:
        # Sample evenly distributed images
        step = max(1, len(png_files) // num_refs)
        selected_files = png_files[::step][:num_refs]
    
    all_embeddings = []
    
    try:
        for img_path in selected_files:
            img = Image.open(img_path).convert("RGB")
            
            # Apply preprocessing
            if CONFIG['image_preprocessing']:
                img = preprocess_plankton_image(img)
            
            # Apply augmentation if enabled
            if CONFIG['augment_reference']:
                augmented_images = apply_augmentation(img)
                # Get embeddings for all augmented versions
                embs = batch_image_embeddings(augmented_images)
                all_embeddings.extend(embs)
            else:
                emb = image_embedding(img)
                all_embeddings.append(emb)
        
        # Aggregate embeddings 
        all_embeddings = np.array(all_embeddings)
        
        if CONFIG['embedding_method'] == 'mean':
            final_embedding = np.mean(all_embeddings, axis=0)
        else:  # median
            final_embedding = np.median(all_embeddings, axis=0)
        
        # Renormalize
        final_embedding = final_embedding / np.linalg.norm(final_embedding)
        
        species_embeddings[species_name] = final_embedding
        
        records.append({
            "species": species_name,
            "num_reference_images": num_refs,
            "num_embeddings_used": len(all_embeddings),
            "total_images_available": len(png_files),
            "embedding": final_embedding
        })
        
    except Exception as e:
        print(f"Error processing {species_name}: {e}")
        continue

df_ref = pd.DataFrame(records)
print(f"\nReference dataset built: {df_ref.shape}")
print(f"Species with embeddings: {len(species_embeddings)}")
print("\nEmbeddings statistics:")
print(df_ref[['species', 'num_reference_images', 'num_embeddings_used', 'total_images_available']].head(10))

Using 50 images per species
Augmentation: True
Preprocessing: True


Processing species: 100%|██████████| 94/94 [27:26<00:00, 17.51s/it]


Reference dataset built: (94, 5)
Species with embeddings: 94

Embeddings statistics:
                          species  num_reference_images  num_embeddings_used  \
0                        Akashiwo                     1                    6   
1                  Amphidinium_sp                    50                  300   
2                Asterionellopsis                    50                  300   
3                     Cerataulina                    50                  300   
4          Cerataulina_flagellate                     5                   30   
5                        Ceratium                     6                   36   
6                     Chaetoceros                    50                  300   
7             Chaetoceros_didymus                    11                   66   
8  Chaetoceros_didymus_flagellate                     1                    6   
9          Chaetoceros_flagellate                     4                   24   

   total_images_available  
0    




In [7]:
# ---------------------------
# Step 5. Classify test images with improvements
# ---------------------------

print("\n" + "="*50)
print("Testing classification")
print("="*50)

test_species_folder = species_folders[0]
test_images = list(test_species_folder.glob("*.png"))[:10]

print(f"\nTest species: {test_species_folder.name}")
print(f"Found {len(test_images)} test images")

species_list = list(species_embeddings.keys())
ref_matrix = np.stack([species_embeddings[s] for s in species_list])

results = []

for img_path in tqdm(test_images, desc="Classifying images"):
    try:
        img = Image.open(img_path).convert("RGB")
        
        # Apply preprocessing
        if CONFIG['image_preprocessing']:
            img = preprocess_plankton_image(img)
        
        # Test-time augmentation 
        if CONFIG['augment_query']:
            augmented = apply_augmentation(img)
            embs = batch_image_embeddings(augmented)
            q_emb = np.mean(embs, axis=0)
            q_emb = q_emb / np.linalg.norm(q_emb)
        else:
            q_emb = image_embedding(img)
        
        # Compute similarities
        similarities = ref_matrix @ q_emb
        
        # Get predictions
        best_idx = np.argmax(similarities)
        best_species = species_list[best_idx]
        best_score = similarities[best_idx]
        
        # Top 5 predictions
        top5_indices = np.argsort(similarities)[-5:][::-1]
        top5_species = [species_list[i] for i in top5_indices]
        top5_scores = [similarities[i] for i in top5_indices]
        
        results.append({
            "query_image": img_path.name,
            "true_species": test_species_folder.name,
            "predicted_species": best_species,
            "confidence": float(best_score),
            "is_correct": best_species == test_species_folder.name,
            "top5_predictions": top5_species,
            "top5_scores": [float(s) for s in top5_scores]
        })
        
    except Exception as e:
        print(f"Error processing {img_path.name}: {e}")
        continue

df_results = pd.DataFrame(results)

print("\n" + "="*50)
print("Classification Results")
print("="*50)
print(df_results[['query_image', 'true_species', 'predicted_species', 'confidence', 'is_correct']])

if len(df_results) > 0:
    accuracy = df_results['is_correct'].mean()
    avg_confidence = df_results['confidence'].mean()
    
    print(f"\n{'='*50}")
    print(f"METRICS")
    print(f"{'='*50}")
    print(f"Accuracy: {accuracy:.2%} ({df_results['is_correct'].sum()}/{len(df_results)})")
    print(f"Average confidence: {avg_confidence:.4f}")
    print(f"Correct predictions avg confidence: {df_results[df_results['is_correct']]['confidence'].mean():.4f}")
    print(f"Wrong predictions avg confidence: {df_results[~df_results['is_correct']]['confidence'].mean():.4f}")
    
    # Top-5 accuracy
    top5_correct = sum([row['true_species'] in row['top5_predictions'] for _, row in df_results.iterrows()])
    top5_accuracy = top5_correct / len(df_results)
    print(f"Top-5 Accuracy: {top5_accuracy:.2%}")


Testing classification

Test species: DactFragCerataul
Found 10 test images


Classifying images: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]


Classification Results
                       query_image      true_species predicted_species  \
0  IFCB5_2014_248_004113_05455.png  DactFragCerataul      Rhizosolenia   
1  IFCB5_2014_328_152205_07231.png  DactFragCerataul  DactFragCerataul   
2  IFCB5_2014_315_144452_00017.png  DactFragCerataul       Skeletonema   
3  IFCB5_2014_315_135823_06074.png  DactFragCerataul  DactFragCerataul   
4  IFCB5_2014_315_144452_01507.png  DactFragCerataul       Skeletonema   
5  IFCB5_2014_259_120213_05242.png  DactFragCerataul      Rhizosolenia   
6  IFCB5_2014_002_210221_01583.png  DactFragCerataul      Rhizosolenia   
7  IFCB5_2014_315_142115_01290.png  DactFragCerataul      Rhizosolenia   
8  IFCB5_2014_248_010423_01317.png  DactFragCerataul  DactFragCerataul   
9  IFCB5_2014_315_135823_07315.png  DactFragCerataul       Skeletonema   

   confidence  is_correct  
0    0.974654       False  
1    0.972635        True  
2    0.957793       False  
3    0.953981        True  
4    0.984811       F




In [8]:
# ---------------------------
# Step 6. Save results
# ---------------------------
df_ref_save = df_ref.drop(columns=['embedding'])
df_ref_save.to_csv("reference_dataset_improved.csv", index=False)

df_results_save = df_results.drop(columns=['top5_predictions', 'top5_scores'], errors='ignore')
df_results_save.to_csv("classification_results_improved.csv", index=False)

embeddings_array = np.stack([species_embeddings[s] for s in species_list])
np.save("species_embeddings_improved.npy", embeddings_array)
np.save("species_names.npy", np.array(species_list))

In [9]:
# ---------------------------
# Step 7. Analyze Confusion Patterns
# ---------------------------
print("\n" + "="*50)
print("CONFUSION ANALYSIS")
print("="*50)

# Show what the model confused with what
if len(df_results) > 0:
    wrong_preds = df_results[~df_results['is_correct']]
    if len(wrong_preds) > 0:
        print("\nMisclassifications:")
        for idx, row in wrong_preds.iterrows():
            print(f"  {row['query_image'][:30]:<30} → Predicted: {row['predicted_species']:<30} (should be: {row['true_species']})")
            print(f"    Top-5: {', '.join(row['top5_predictions'][:5])}")
        
        # Check if true label appears in top-5
        print(f"\n{len(wrong_preds)} misclassifications, but true label in top-5: {sum([row['true_species'] in row['top5_predictions'] for _, row in wrong_preds.iterrows()])} times")


CONFUSION ANALYSIS

Misclassifications:
  IFCB5_2014_248_004113_05455.pn → Predicted: Rhizosolenia                   (should be: DactFragCerataul)
    Top-5: Rhizosolenia, Pseudonitzschia, G_delicatula_parasite, DactFragCerataul, G_delicatula_external_parasite
  IFCB5_2014_315_144452_00017.pn → Predicted: Skeletonema                    (should be: DactFragCerataul)
    Top-5: Skeletonema, DactFragCerataul, Cerataulina, Leptocylindrus, Amphidinium_sp
  IFCB5_2014_315_144452_01507.pn → Predicted: Skeletonema                    (should be: DactFragCerataul)
    Top-5: Skeletonema, Cerataulina, mix_elongated, Leptocylindrus, spore
  IFCB5_2014_259_120213_05242.pn → Predicted: Rhizosolenia                   (should be: DactFragCerataul)
    Top-5: Rhizosolenia, DactFragCerataul, Pseudonitzschia, G_delicatula_parasite, G_delicatula_external_parasite
  IFCB5_2014_002_210221_01583.pn → Predicted: Rhizosolenia                   (should be: DactFragCerataul)
    Top-5: Rhizosolenia, Pseudonitzs

In [10]:
# ---------------------------
# Step 8. Test on Multiple Species
# ---------------------------
print("\n" + "="*50)
print("TESTING ON MULTIPLE SPECIES")
print("="*50)

# Test on first 5 species to get better accuracy estimate
all_test_results = []

for test_idx, species_folder in enumerate(species_folders[:5]):
    test_images = list(species_folder.glob("*.png"))[:10]
    
    if not test_images:
        continue
    
    print(f"\nTesting {species_folder.name} ({len(test_images)} images)...")
    
    for img_path in test_images:
        try:
            img = Image.open(img_path).convert("RGB")
            
            if CONFIG['image_preprocessing']:
                img = preprocess_plankton_image(img)
            
            if CONFIG['augment_query']:
                augmented = apply_augmentation(img)
                embs = batch_image_embeddings(augmented)
                q_emb = np.mean(embs, axis=0)
                q_emb = q_emb / np.linalg.norm(q_emb)
            else:
                q_emb = image_embedding(img)
            
            similarities = ref_matrix @ q_emb
            best_idx = np.argmax(similarities)
            best_species = species_list[best_idx]
            
            top5_indices = np.argsort(similarities)[-5:][::-1]
            top5_species = [species_list[i] for i in top5_indices]
            
            all_test_results.append({
                "true_species": species_folder.name,
                "predicted_species": best_species,
                "is_correct": best_species == species_folder.name,
                "in_top5": species_folder.name in top5_species
            })
            
        except Exception as e:
            continue

if all_test_results:
    df_all = pd.DataFrame(all_test_results)
    overall_acc = df_all['is_correct'].mean()
    overall_top5 = df_all['in_top5'].mean()
    
    print(f"\n{'='*50}")
    print(f"OVERALL PERFORMANCE (5 species × 10 images)")
    print(f"{'='*50}")
    print(f"Top-1 Accuracy: {overall_acc:.2%}")
    print(f"Top-5 Accuracy: {overall_top5:.2%}")
    
    # Per-species breakdown
    print("\nPer-species accuracy:")
    for species in df_all['true_species'].unique():
        species_data = df_all[df_all['true_species'] == species]
        species_acc = species_data['is_correct'].mean()
        print(f"  {species:<30}: {species_acc:.1%} ({species_data['is_correct'].sum()}/{len(species_data)})")

print("\n" + "="*50)
print("Pipeline complete!")


TESTING ON MULTIPLE SPECIES

Testing DactFragCerataul (10 images)...

Testing Rhizosolenia (10 images)...

Testing Chaetoceros (10 images)...

Testing bead (10 images)...

Testing G_delicatula_external_parasite (9 images)...

OVERALL PERFORMANCE (5 species × 10 images)
Top-1 Accuracy: 61.22%
Top-5 Accuracy: 81.63%

Per-species accuracy:
  DactFragCerataul              : 30.0% (3/10)
  Rhizosolenia                  : 90.0% (9/10)
  Chaetoceros                   : 50.0% (5/10)
  bead                          : 90.0% (9/10)
  G_delicatula_external_parasite: 44.4% (4/9)

Pipeline complete!


In [13]:
# ---------------------------
# Step 9. Analyze Relationship Between Sample Size and Accuracy
# ---------------------------
print("\n" + "="*50)
print("SAMPLE SIZE vs ACCURACY ANALYSIS")
print("="*50)

if all_test_results:
    # Merge test results with reference dataset info
    species_performance = []
    
    for species in df_all['true_species'].unique():
        # Get accuracy for this species
        species_data = df_all[df_all['true_species'] == species]
        species_acc = species_data['is_correct'].mean()
        species_top5 = species_data['in_top5'].mean()
        
        # Get reference data for this species
        ref_info = df_ref[df_ref['species'] == species]
        
        if len(ref_info) > 0:
            species_performance.append({
                'species': species,
                'total_images_available': ref_info.iloc[0]['total_images_available'],
                'num_reference_images_used': ref_info.iloc[0]['num_reference_images'],
                'num_embeddings_used': ref_info.iloc[0]['num_embeddings_used'],
                'top1_accuracy': species_acc,
                'top5_accuracy': species_top5,
                'num_test_images': len(species_data)
            })
    
    df_performance = pd.DataFrame(species_performance)
    df_performance = df_performance.sort_values('total_images_available')
    
    print("\nSpecies Performance by Sample Size:")
    print(df_performance.to_string(index=False))
    
    # Calculate correlation
    if len(df_performance) > 1:
        # Handle NaN values in correlation calculation
        if df_performance['top1_accuracy'].std() > 0 and df_performance['total_images_available'].std() > 0:
            correlation = df_performance[['total_images_available', 'top1_accuracy']].corr().iloc[0, 1]
            if not np.isnan(correlation):
                print(f"\nCorrelation between sample size and accuracy: {correlation:.3f}")
            else:
                print("\nCorrelation could not be calculated (insufficient variation in data)")
        else:
            print("\nCorrelation could not be calculated (no variation in accuracy or sample size)")
        
        # Categorize by sample size
        print("\nAccuracy by Sample Size Category:")
        
        df_performance['size_category'] = pd.cut(
            df_performance['total_images_available'],
            bins=[0, 10, 50, 200, float('inf')],
            labels=['Very Low (1-10)', 'Low (11-50)', 'Medium (51-200)', 'High (200+)']
        )
        
        category_stats = df_performance.groupby('size_category', observed=True).agg({
            'top1_accuracy': ['mean', 'std', 'count'],
            'top5_accuracy': ['mean']
        }).round(3)
        
        print(category_stats)
        
        # Find problematic species (low accuracy despite high samples)
        print("\nSpecies with surprisingly low accuracy (>50 images but <50% accuracy):")
        problematic = df_performance[
            (df_performance['total_images_available'] > 50) & 
            (df_performance['top1_accuracy'] < 0.5)
        ]
        if len(problematic) > 0:
            print(problematic[['species', 'total_images_available', 'num_reference_images_used', 'top1_accuracy']].to_string(index=False))
        else:
            print("  None found - good!")
        
        # Find high performers with low samples
        print("\nSpecies with high accuracy despite low samples (<20 images but >60% accuracy):")
        high_performers = df_performance[
            (df_performance['total_images_available'] < 20) & 
            (df_performance['top1_accuracy'] > 0.6)
        ]
        if len(high_performers) > 0:
            print(high_performers[['species', 'total_images_available', 'num_reference_images_used', 'top1_accuracy']].to_string(index=False))
        else:
            print("  None found")


SAMPLE SIZE vs ACCURACY ANALYSIS

Species Performance by Sample Size:
                       species  total_images_available  num_reference_images_used  num_embeddings_used  top1_accuracy  top5_accuracy  num_test_images
G_delicatula_external_parasite                       9                          9                   54       0.444444       0.888889                9
                          bead                      17                         17                  102       0.900000       0.900000               10
              DactFragCerataul                     175                         50                  300       0.300000       0.800000               10
                   Chaetoceros                    1871                         50                  300       0.500000       0.600000               10
                  Rhizosolenia                    2199                         50                  300       0.900000       0.900000               10

Correlation between sample s

  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
