# 🌿 Mangrove Classifier Training Notebook

This notebook provides a complete walkthrough for training a deep learning model to classify satellite/ground images as **Mangrove** or **Non-Mangrove**.

## 📋 Table of Contents
1. [Setup & Dependencies](#setup)
2. [Data Exploration](#data-exploration)
3. [Model Training](#training)
4. [Model Evaluation](#evaluation)
5. [Testing Predictions](#testing)

In [None]:
# Import required libraries
import sys
import os
sys.path.append('../src')

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Import our custom modules
import config
from utils import get_data_loaders, count_images, prepare_data_structure
from train import create_model, train_model
from predict import MangroveClassifier

print("✅ All libraries imported successfully!")
print(f"🖥️  PyTorch version: {torch.__version__}")
print(f"🎯 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")

## 📊 Data Exploration

Let's first check our data structure and see what we're working with.

In [None]:
# Check current data structure
print("🔍 Current Data Structure:")
print(f"📁 Data directory: {config.DATA_DIR}")
print(f"📁 Models directory: {config.MODEL_DIR}")

# Count images in each category
mangrove_count, non_mangrove_count = count_images(config.DATA_DIR)
print(f"\n📊 Data Summary:")
print(f"   🌿 Mangrove images: {mangrove_count}")
print(f"   🏞️  Non-mangrove images: {non_mangrove_count}")
print(f"   📈 Total images: {mangrove_count + non_mangrove_count}")

if mangrove_count == 0 and non_mangrove_count == 0:
    print("\n⚠️  No training data found!")
    print("Please add images to the following folders:")
    print("   - data/mangrove/ (for mangrove images)")
    print("   - data/non-mangrove/ (for non-mangrove images)")
    print("\n💡 You can download sample datasets from:")
    print("   - Kaggle mangrove datasets")
    print("   - Satellite image databases")
    print("   - Environmental monitoring websites")
else:
    print(f"\n✅ Data looks good! Ready for training.")
    
    # Check if balanced
    if abs(mangrove_count - non_mangrove_count) > 0.3 * max(mangrove_count, non_mangrove_count):
        print("⚠️  Dataset appears imbalanced. Consider data augmentation.")
    else:
        print("✅ Dataset appears reasonably balanced.")

In [None]:
# Sample image visualization (if data exists)
def show_sample_images():
    mangrove_dir = os.path.join(config.DATA_DIR, "mangrove")
    non_mangrove_dir = os.path.join(config.DATA_DIR, "non-mangrove")
    
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    fig.suptitle('Sample Images from Dataset', fontsize=16)
    
    # Show mangrove samples
    if os.path.exists(mangrove_dir):
        mangrove_files = [f for f in os.listdir(mangrove_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:4]
        for i, file in enumerate(mangrove_files):
            if i < 4:
                img_path = os.path.join(mangrove_dir, file)
                img = Image.open(img_path)
                axes[0, i].imshow(img)
                axes[0, i].set_title(f'Mangrove {i+1}')
                axes[0, i].axis('off')
    
    # Show non-mangrove samples
    if os.path.exists(non_mangrove_dir):
        non_mangrove_files = [f for f in os.listdir(non_mangrove_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:4]
        for i, file in enumerate(non_mangrove_files):
            if i < 4:
                img_path = os.path.join(non_mangrove_dir, file)
                img = Image.open(img_path)
                axes[1, i].imshow(img)
                axes[1, i].set_title(f'Non-Mangrove {i+1}')
                axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Only show samples if data exists
if mangrove_count > 0 or non_mangrove_count > 0:
    show_sample_images()
else:
    print("📷 No images to display. Please add training data first.")

## 🚀 Model Training

Now let's train our mangrove classification model using transfer learning with ResNet50.

In [None]:
# Display training configuration
print("⚙️  Training Configuration:")
print(f"   🏗️  Model: ResNet50 (Transfer Learning)")
print(f"   📊 Batch Size: {config.BATCH_SIZE}")
print(f"   🔄 Epochs: {config.EPOCHS}")
print(f"   📈 Learning Rate: {config.LEARNING_RATE}")
print(f"   🖼️  Image Size: {config.IMG_SIZE}")
print(f"   🎯 Classes: {config.CLASS_NAMES}")
print(f"   🔢 Number of Classes: {config.NUM_CLASSES}")

# Check if we can proceed with training
if mangrove_count == 0 and non_mangrove_count == 0:
    print("\n❌ Cannot proceed with training - no data available!")
    print("Please add training images before running the next cell.")
else:
    print(f"\n✅ Ready to train with {mangrove_count + non_mangrove_count} total images!")
    print("Run the next cell to start training.")

In [None]:
# Train the model
# Note: This cell will only run if training data is available

if mangrove_count > 0 or non_mangrove_count > 0:
    print("🌿 Starting training...")
    try:
        # Train the model using our training script
        trained_model = train_model()
        print("🎉 Training completed successfully!")
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        print("Common issues:")
        print("- Insufficient training data")
        print("- CUDA out of memory (try reducing batch size)")
        print("- Corrupted image files")
        
else:
    print("⏸️  Skipping training - no data available")
    print("Please add images to data/mangrove/ and data/non-mangrove/ folders")

## 🧪 Model Testing & Evaluation

Let's test our trained model and see how well it performs.

In [None]:
# Test the trained model
model_path = os.path.join(config.MODEL_DIR, "mangrove_model.pth")

if os.path.exists(model_path):
    print("🎯 Testing the trained model...")
    
    try:
        # Initialize the classifier
        classifier = MangroveClassifier(model_path)
        print("✅ Model loaded successfully!")
        
        # Test on sample images from the data folders
        test_images = []
        
        # Get sample images for testing
        mangrove_dir = os.path.join(config.DATA_DIR, "mangrove")
        non_mangrove_dir = os.path.join(config.DATA_DIR, "non-mangrove")
        
        if os.path.exists(mangrove_dir):
            mangrove_files = [os.path.join(mangrove_dir, f) for f in os.listdir(mangrove_dir) 
                             if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3]
            test_images.extend(mangrove_files)
        
        if os.path.exists(non_mangrove_dir):
            non_mangrove_files = [os.path.join(non_mangrove_dir, f) for f in os.listdir(non_mangrove_dir) 
                                 if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3]
            test_images.extend(non_mangrove_files)
        
        if test_images:
            print(f"🖼️  Testing on {len(test_images)} sample images...")
            predictions = classifier.predict_batch(test_images)
            
            # Display results
            for pred in predictions:
                if 'error' not in pred:
                    print(f"📸 {pred['image']}: {pred['prediction']} (confidence: {pred['confidence']:.2%})")
                else:
                    print(f"❌ {pred['image']}: Error - {pred['error']}")
        else:
            print("📷 No test images available in data folders")
            
    except Exception as e:
        print(f"❌ Error testing model: {e}")
        
else:
    print("❌ No trained model found!")
    print(f"Expected model at: {model_path}")
    print("Please train the model first by running the training cell above.")

## 🎯 Next Steps & Usage

### Using the Trained Model

Once training is complete, you can use the model in several ways:

#### 1. **Command Line Prediction**
```bash
cd src
python predict.py path/to/your/image.jpg
```

#### 2. **Python Script Integration**
```python
from predict import MangroveClassifier

classifier = MangroveClassifier()
prediction, confidence = classifier.predict("image.jpg", return_confidence=True)
print(f"Prediction: {prediction} (confidence: {confidence:.2%})")
```

#### 3. **Web Application Integration**
The model can be integrated into your web application using the server components.

### 📁 Expected Data Structure
For training, organize your data like this:
```
data/
├── mangrove/          # Put mangrove images here
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
└── non-mangrove/      # Put non-mangrove images here
    ├── image1.jpg
    ├── image2.jpg
    └── ...
```

### 💡 Tips for Better Results
1. **More Data**: Collect at least 100-500 images per class
2. **Balanced Dataset**: Keep similar numbers of mangrove/non-mangrove images
3. **Quality Images**: Use clear, high-resolution satellite or ground images
4. **Variety**: Include different seasons, lighting conditions, and angles
5. **Preprocessing**: Ensure images are properly labeled and organized