# Pipeline 2: Multi-Class Disease Classification

This pipeline creates a dataset for disease classification (healthy excluded).

Goal: Balance all disease classes to ~1000 samples each for better training.

Note: This model should only run on samples classified as 'disease' by Pipeline 1.

## Step 1: Configuration Setup

In [None]:
import os
from pathlib import Path
from dotenv import load_dotenv
import pandas as pd
import numpy as np
import re
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

# Load environment variables
load_dotenv()

# Find project root
PROJECT_ROOT = Path(__file__).parent.parent if '__file__' in globals() else Path.cwd().parent

def make_absolute(path_str):
    """Convert relative path from .env to absolute path."""
    path = Path(path_str)
    return path.resolve() if path.is_absolute() else (PROJECT_ROOT / path).resolve()

# Dataset paths
TRAIN_LABELS_CSV = make_absolute(os.getenv('TRAIN_LABELS_CSV'))
TEST_LABELS_CSV = make_absolute(os.getenv('TEST_LABELS_CSV'))
TRAIN_IMAGES_DIR = make_absolute(os.getenv('TRAIN_IMAGES_DIR'))
TEST_IMAGES_DIR = make_absolute(os.getenv('TEST_IMAGES_DIR'))

# Output paths for DISEASE classification
OUTPUT_DISEASE_BASE_DIR = PROJECT_ROOT / 'dataset' / 'diseases'
OUTPUT_DISEASE_IMAGES_TRAIN = OUTPUT_DISEASE_BASE_DIR / 'images' / 'train'
OUTPUT_DISEASE_IMAGES_VAL = OUTPUT_DISEASE_BASE_DIR / 'images' / 'val'
OUTPUT_DISEASE_LABELS_TRAIN = OUTPUT_DISEASE_BASE_DIR / 'labels' / 'train'
OUTPUT_DISEASE_LABELS_VAL = OUTPUT_DISEASE_BASE_DIR / 'labels' / 'val'

# Plant species
PLANT_SPECIES = [s.strip() for s in os.getenv('PLANT_SPECIES').split(',')]

# Configuration
TARGET_SAMPLES_PER_CLASS = 1000

print("‚úì Configuration loaded!")
print(f"\nProject root: {PROJECT_ROOT}")
print(f"Output: {OUTPUT_DISEASE_BASE_DIR}")
print(f"Target samples per class: {TARGET_SAMPLES_PER_CLASS}")

## Step 2: Load and Clean Data

In [None]:
# Load data
df_train = pd.read_csv(TRAIN_LABELS_CSV)
df_test = pd.read_csv(TEST_LABELS_CSV)

print(f"Loaded: {len(df_train)} train, {len(df_test)} test samples")

# Clean class names
for df in [df_train, df_test]:
    df['class'] = (
        df['class']
        .str.replace(r'(?i)leaf', '', regex=True)
        .str.replace(r'\s+', ' ', regex=True)
        .str.replace(r'_', ' ', regex=True)
        .str.strip()
    )

print("‚úì Class names cleaned")

## Step 3: Extract Features

In [None]:
def extract_species(text):
    for plant in PLANT_SPECIES:
        if re.search(rf"\b{plant}\b", text, flags=re.IGNORECASE):
            return plant
    return None

def extract_disease(text):
    for plant in PLANT_SPECIES:
        text = re.sub(rf"\b{plant}\b", "", text, flags=re.IGNORECASE).strip()
    return text if text else "healthy"

# Extract features
df_train['species'] = df_train['class'].apply(extract_species)
df_train['disease'] = df_train['class'].apply(extract_disease)
df_test['species'] = df_test['class'].apply(extract_species)
df_test['disease'] = df_test['class'].apply(extract_disease)

print("‚úì Features extracted")
print(df_train[['class', 'species', 'disease']].head())

## Step 4: Fix Zero Dimensions

In [None]:
def fix_zero_dimensions(df, image_folder):
    image_folder = Path(image_folder)
    for idx, row in df.iterrows():
        if row['width'] == 0 or row['height'] == 0:
            image_path = image_folder / row['filename']
            if image_path.exists():
                with Image.open(image_path) as img:
                    w, h = img.size
                    df.at[idx, 'width'] = w
                    df.at[idx, 'height'] = h
    return df

df_train = fix_zero_dimensions(df_train, TRAIN_IMAGES_DIR)
df_test = fix_zero_dimensions(df_test, TEST_IMAGES_DIR)
print("‚úì Dimensions fixed")

## Step 5: Verify Files Exist

In [None]:
def verify_files_exist(df, image_folder):
    image_folder = Path(image_folder)
    existing_mask = []
    for _, row in df.iterrows():
        existing_mask.append((image_folder / row['filename']).exists())
    return df[existing_mask].copy()

df_train = verify_files_exist(df_train, TRAIN_IMAGES_DIR)
df_test = verify_files_exist(df_test, TEST_IMAGES_DIR)
print(f"‚úì Verified: {len(df_train)} train, {len(df_test)} test samples")

## Step 6: Filter Out Healthy Samples

In [None]:
# Keep only disease samples (exclude healthy)
df_diseases_only = df_train[df_train['disease'] != 'healthy'].copy()

print(f"Original training samples: {len(df_train)}")
print(f"Disease samples only: {len(df_diseases_only)}")
print(f"Removed healthy samples: {len(df_train) - len(df_diseases_only)}")

## Step 7: Remove Very Rare Diseases

In [None]:
# Remove extremely rare diseases (< 0.1% of dataset)
disease_proportions = df_diseases_only['disease'].value_counts(normalize=True)
rare_threshold = 0.001
rare_diseases = disease_proportions[disease_proportions < rare_threshold].index.tolist()

print(f"Rare diseases (< {rare_threshold*100}%): {rare_diseases}")

df_diseases_clean = df_diseases_only[~df_diseases_only['disease'].isin(rare_diseases)].copy()

print(f"\nAfter removing rare diseases: {len(df_diseases_clean)} samples")

## Step 8: Analyze Disease Distribution

In [None]:
disease_counts = df_diseases_clean['disease'].value_counts()

print("Disease distribution (before balancing):")
for disease, count in disease_counts.items():
    print(f"  {disease}: {count} samples")

# Visualize
plt.figure(figsize=(14, 6))
sns.barplot(x=disease_counts.index, y=disease_counts.values)
plt.axhline(y=TARGET_SAMPLES_PER_CLASS, color='r', linestyle='--', label=f'Target: {TARGET_SAMPLES_PER_CLASS}')
plt.title('Disease Distribution (Before Balancing)')
plt.xlabel('Disease')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45, ha='right')
plt.legend()
plt.tight_layout()
plt.show()

## Step 9: Balance Disease Classes via Duplication

In [None]:
balanced_dfs = []

for disease, group in df_diseases_clean.groupby("disease"):
    n_samples = len(group)
    n_to_add = TARGET_SAMPLES_PER_CLASS - n_samples
    
    if n_to_add > 0:
        print(f"\n{disease}: {n_samples} ‚Üí {TARGET_SAMPLES_PER_CLASS} (adding {n_to_add} duplicates)")
        
        # Keep original samples
        balanced_dfs.append(group)
        
        # Add duplicates with modified filenames
        duplicates_added = 0
        while duplicates_added < n_to_add:
            # Cycle through samples
            idx = duplicates_added % n_samples
            sample = group.iloc[idx:idx+1].copy()
            
            # Modify filename to avoid conflicts
            original_filename = sample['filename'].values[0]
            stem = Path(original_filename).stem
            suffix = Path(original_filename).suffix
            new_filename = f"{stem}_dup{duplicates_added}{suffix}"
            sample['filename'] = new_filename
            
            balanced_dfs.append(sample)
            duplicates_added += 1
    else:
        print(f"\n{disease}: {n_samples} (already >= target, keeping all)")
        # Keep only up to target to avoid over-representation
        balanced_dfs.append(group.iloc[:TARGET_SAMPLES_PER_CLASS])

# Combine all balanced data
df_balanced = pd.concat(balanced_dfs, ignore_index=True)

print(f"\n‚úì Dataset balanced!")
print(f"  Total samples: {len(df_balanced)}")

## Step 10: Visualize Balanced Distribution

In [None]:
balanced_counts = df_balanced['disease'].value_counts()

print("\nFinal disease distribution (after balancing):")
for disease in sorted(balanced_counts.index):
    count = balanced_counts[disease]
    print(f"  {disease}: {count} samples")

# Visualize
plt.figure(figsize=(14, 6))
sns.barplot(x=sorted(balanced_counts.index), y=[balanced_counts[d] for d in sorted(balanced_counts.index)])
plt.axhline(y=TARGET_SAMPLES_PER_CLASS, color='g', linestyle='--', label=f'Target: {TARGET_SAMPLES_PER_CLASS}')
plt.title('Disease Distribution (After Balancing)')
plt.xlabel('Disease')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45, ha='right')
plt.legend()
plt.tight_layout()
plt.show()

## Step 11: Create Disease-to-Index Mapping

In [None]:
# Create disease to index mapping (sorted alphabetically for consistency)
diseases = sorted(df_balanced['disease'].unique())
disease2idx = {disease: i for i, disease in enumerate(diseases)}

print(f"Disease to Index Mapping ({len(disease2idx)} classes):")
for disease, idx in disease2idx.items():
    count = len(df_balanced[df_balanced['disease'] == disease])
    print(f"  {idx}: {disease} ({count} samples)")

## Step 12: Create Output Directories

In [None]:
OUTPUT_DISEASE_IMAGES_TRAIN.mkdir(parents=True, exist_ok=True)
OUTPUT_DISEASE_IMAGES_VAL.mkdir(parents=True, exist_ok=True)
OUTPUT_DISEASE_LABELS_TRAIN.mkdir(parents=True, exist_ok=True)
OUTPUT_DISEASE_LABELS_VAL.mkdir(parents=True, exist_ok=True)

print("‚úì Output directories created")
print(f"  {OUTPUT_DISEASE_BASE_DIR}")

## Step 13: Convert to YOLO Format and Export

In [None]:
def convert_bbox_to_yolo(row):
    x_center = (row['xmin'] + row['xmax']) / 2 / row['width']
    y_center = (row['ymin'] + row['ymax']) / 2 / row['height']
    bbox_width = (row['xmax'] - row['xmin']) / row['width']
    bbox_height = (row['ymax'] - row['ymin']) / row['height']
    return x_center, y_center, bbox_width, bbox_height

def export_to_yolo(df, images_dir, output_images_dir, output_labels_dir, class_mapping):
    exported = 0
    skipped = 0
    
    for filename, group in df.groupby("filename"):
        try:
            # Check if this is a duplicate (has _dup in name)
            if '_dup' in filename:
                # Get original filename
                original_filename = filename.split('_dup')[0] + Path(filename).suffix
                src = Path(images_dir) / original_filename
            else:
                src = Path(images_dir) / filename
            
            if not src.exists():
                skipped += 1
                continue
            
            dst = Path(output_images_dir) / filename
            shutil.copy2(src, dst)
            
            # Create label file
            label_file = Path(output_labels_dir) / (Path(filename).stem + ".txt")
            with open(label_file, "w") as f:
                for _, row in group.iterrows():
                    cls_idx = class_mapping[row['disease']]
                    x_c, y_c, w, h = convert_bbox_to_yolo(row)
                    f.write(f"{cls_idx} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}\n")
            exported += 1
        except Exception as e:
            print(f"Error: {filename} - {e}")
            skipped += 1
    
    return exported, skipped

print("Exporting to YOLO format...")
exported, skipped = export_to_yolo(
    df_balanced, 
    TRAIN_IMAGES_DIR, 
    OUTPUT_DISEASE_IMAGES_TRAIN, 
    OUTPUT_DISEASE_LABELS_TRAIN,
    disease2idx
)

print(f"\n‚úì Export complete!")
print(f"  Exported: {exported} images")
print(f"  Skipped: {skipped} images")

## Step 14: Generate YAML Configuration

In [None]:
import yaml

yaml_content = {
    'path': str(OUTPUT_DISEASE_BASE_DIR.resolve()),
    'train': 'images/train',
    'val': 'images/val',
    'nc': len(disease2idx),
    'names': {idx: disease for disease, idx in disease2idx.items()}
}

yaml_path = OUTPUT_DISEASE_BASE_DIR / 'dataset.yaml'
with open(yaml_path, 'w') as f:
    yaml.dump(yaml_content, f, default_flow_style=False, sort_keys=False)

print("‚úì YAML configuration created")
print(f"\nLocation: {yaml_path}")
print(f"\nClasses ({len(disease2idx)}):")
for idx in sorted(yaml_content['names'].keys()):
    print(f"  {idx}: {yaml_content['names'][idx]}")

## Step 15: Summary

In [None]:
print("=" * 60)
print("DISEASE CLASSIFICATION DATASET READY")
print("=" * 60)
print(f"\nüìä Dataset Statistics:")
print(f"  Total samples: {len(df_balanced)}")
print(f"  Number of disease classes: {len(disease2idx)}")
print(f"  Target per class: {TARGET_SAMPLES_PER_CLASS}")
print(f"\n  Class distribution:")
for disease in sorted(disease2idx.keys()):
    count = len(df_balanced[df_balanced['disease'] == disease])
    print(f"    [{disease2idx[disease]}] {disease}: {count} samples")
print(f"\nüìÅ Location: {OUTPUT_DISEASE_BASE_DIR}")
print(f"üìù Config: {yaml_path}")
print(f"\n‚ö†Ô∏è  Important: This model should only process samples")
print(f"   classified as 'disease' by the binary model (Pipeline 1)")
print(f"\n‚úÖ Ready for YOLO training!")
print("=" * 60)