<a href="https://colab.research.google.com/github/EricBaidoo/GhanaSegNet/blob/main/notebooks/Enhanced_GhanaSegNet_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Enhanced GhanaSegNet - 30% mIoU Target

**Objective:** Train the Enhanced GhanaSegNet architecture to achieve 30% mIoU

**Model:** Enhanced GhanaSegNet (FPN + Advanced ASPP + Cross-Attention)
- **Parameters:** 10.5M
- **Architecture:** EfficientNet-B0 + FPN + Enhanced ASPP + Cross-Attention Transformer
- **Target Performance:** 30% mIoU

In [None]:
# Mount Google Drive and check GPU
from google.colab import drive
drive.mount('/content/drive')

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️  No GPU detected - switch to GPU runtime for training!")

In [None]:
# Clone repository
!git clone https://github.com/EricBaidoo/GhanaSegNet.git
%cd GhanaSegNet

# Verify repository structure
print("Repository contents:")
!ls -la

## Dataset Setup

**Before running the next cell:**

1. **Upload your dataset to Google Drive** in this structure:
   ```
   MyDrive/
     dataset/
       train/
         images/
         masks/
       val/
         images/
         masks/
   ```

2. **Update the path below** if your dataset is in a different location

In [None]:
# Copy dataset from Google Drive
DATASET_PATH = "/content/drive/MyDrive/dataset"
LOCAL_DATA_PATH = "/content/GhanaSegNet/data"

print(f"Copying dataset from: {DATASET_PATH}")
!cp -r "{DATASET_PATH}" data

# Verify dataset structure
print("\n📊 Dataset verification:")
!echo "Train images:" && ls data/train/images/ | wc -l
!echo "Train masks:" && ls data/train/masks/ | wc -l
!echo "Val images:" && ls data/val/images/ | wc -l
!echo "Val masks:" && ls data/val/masks/ | wc -l

In [None]:
# Install dependencies
print("🔧 Installing dependencies...")
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install efficientnet-pytorch opencv-python pillow tqdm

# Verify installations
import torch
from efficientnet_pytorch import EfficientNet
print(f"✅ PyTorch: {torch.__version__}")
print(f"✅ CUDA: {torch.cuda.is_available()}")
print("✅ EfficientNet installed")

In [None]:
# Verify Enhanced GhanaSegNet can be imported
import os
os.chdir('/content/GhanaSegNet')

try:
    from models.ghanasegnet import GhanaSegNet
    from utils.losses import CombinedLoss
    from utils.metrics import compute_iou
    
    # Test model creation
    model = GhanaSegNet(num_classes=6)
    total_params = sum(p.numel() for p in model.parameters())
    
    print("✅ Enhanced GhanaSegNet imported successfully")
    print(f"✅ Model parameters: {total_params:,} (Target: ~10.5M)")
    print("✅ Enhanced loss function ready")
    print("✅ All systems ready for 30% mIoU training!")
    
except ImportError as e:
    print(f"❌ Import error: {e}")

In [None]:
# Setup auto-save to Google Drive
from google.colab import drive
import shutil

# Create results directory
RESULTS_DIR = '/content/drive/MyDrive/Enhanced_GhanaSegNet_Results'
os.makedirs(RESULTS_DIR, exist_ok=True)

def save_results():
    """Save training results to Google Drive"""
    if os.path.exists('checkpoints/ghanasegnet'):
        shutil.copytree('checkpoints/ghanasegnet', f'{RESULTS_DIR}/checkpoints', dirs_exist_ok=True)
        print(f"✅ Results saved to: {RESULTS_DIR}")
    else:
        print("❌ No results to save")

print(f"📁 Auto-save configured to: {RESULTS_DIR}")

In [None]:
# Train Enhanced GhanaSegNet - Quick Test (15 epochs)
from scripts.train_baselines import train_model

# Training configuration for 30% mIoU target
config = {
    'epochs': 15,
    'batch_size': 8,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'num_classes': 6,
    'custom_seed': 789,  # Enhanced GhanaSegNet seed
    'benchmark_mode': True,
    'dataset_path': 'data',
    'device': 'cuda',
    'timestamp': '2025-10-12',
    'note': 'Enhanced GhanaSegNet - 30% mIoU Target'
}

print("🚀 Starting Enhanced GhanaSegNet Training...")
print("🎯 Target: 30% mIoU")
print(f"📋 Config: {config}")
print("=" * 60)

try:
    result = train_model('ghanasegnet', config)
    
    # Display results
    best_iou = result['best_iou']
    print("=" * 60)
    print("🎯 TRAINING COMPLETED!")
    print(f"📊 Best IoU: {best_iou:.4f} ({best_iou*100:.2f}%)")
    
    if best_iou >= 0.30:
        print("🏆 30% mIoU TARGET ACHIEVED! 🎉")
    elif best_iou >= 0.25:
        print(f"📈 Strong Progress! {(best_iou*100):.1f}% (Target: 30%)")
    else:
        print(f"📊 Current: {(best_iou*100):.1f}% (Target: 30%) - Consider full training")
    
    # Auto-save results
    save_results()
    
except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

## Full Training (80+ epochs)

If the quick test shows promising results, run full training by changing `'epochs': 80` in the config above.

**Expected timeline:**
- 15 epochs: ~10-15 minutes (quick validation)
- 80 epochs: ~45-60 minutes (full training)

In [None]:
# Load and analyze results
import json
import os

if os.path.exists('checkpoints/ghanasegnet/training_history.json'):
    with open('checkpoints/ghanasegnet/training_history.json', 'r') as f:
        history = json.load(f)
    
    print("📈 Training History:")
    print(f"📊 Final IoU: {history[-1]['val_iou']:.4f}")
    print(f"📊 Final Accuracy: {history[-1]['val_accuracy']:.4f}")
    print(f"📊 Best Epoch: {max(history, key=lambda x: x['val_iou'])['epoch']}")
    
    # Plot training curves if possible
    try:
        import matplotlib.pyplot as plt
        
        epochs = [h['epoch'] for h in history]
        val_iou = [h['val_iou'] for h in history]
        train_loss = [h['train_loss'] for h in history]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        ax1.plot(epochs, val_iou, 'b-', label='Validation IoU')
        ax1.axhline(y=0.30, color='r', linestyle='--', label='30% Target')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('IoU')
        ax1.set_title('Enhanced GhanaSegNet - IoU Progress')
        ax1.legend()
        ax1.grid(True)
        
        ax2.plot(epochs, train_loss, 'g-', label='Training Loss')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.set_title('Training Loss')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.show()
        
    except ImportError:
        print("📊 Install matplotlib for training curves: !pip install matplotlib")
        
else:
    print("❌ No training history found")

## Results Summary

**Enhanced GhanaSegNet Architecture:**
- **Backbone:** EfficientNet-B0 (pretrained)
- **Decoder:** FPN-style multi-scale fusion
- **ASPP:** Advanced with 4 dilation rates [2,4,8,16]
- **Attention:** Cross-attention transformer (8+4 heads)
- **Loss:** Multi-scale supervision + Dice + Focal + Boundary
- **Parameters:** ~10.5M
- **Target:** 30% mIoU

**Key Innovations:**
1. Multi-scale feature pyramid network
2. Cross-scale attention mechanism
3. Enhanced ASPP with depth-wise convolutions
4. Multi-scale auxiliary supervision
5. Class-balanced loss for food segmentation