<a href="https://colab.research.google.com/github/EricBaidoo/GhanaSegNet/blob/main/GhanaSegNet_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🎯 Enhanced GhanaSegNet - 30% mIoU Training Notebook

**Objective:** Train enhanced GhanaSegNet architecture targeting 30% mIoU performance

**Enhanced Features:**
- 12-head transformer (384 dimensions)
- 384-channel ASPP module
- Progressive training (256→320→384px)
- Multi-scale supervision
- Advanced loss functions (Dice + Boundary + Focal + CE)

**Models:** UNet, DeepLabV3+, SegFormer-B0, **Enhanced GhanaSegNet**

In [None]:
# Mount Google Drive (if your data is stored there)
from google.colab import drive
drive.mount('/content/drive')

# Check GPU availability
import torch
print(f"🚀 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🔥 GPU: {torch.cuda.get_device_name(0)}")
    print(f"📊 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("❌ No GPU detected - switch to GPU runtime!")

In [None]:
# Clone your GitHub repo
!git clone https://github.com/EricBaidoo/GhanaSegNet.git
%cd GhanaSegNet

# Check if we have the expected files
!ls -la
print("\n✅ Repository cloned successfully!")

## 📁 Dataset Connection Instructions

**Before running the next cell:**

1. **Locate your data folder in Google Drive** - Find where you uploaded your `data` folder
2. **Check the path** - Note the exact path (e.g., `MyDrive/data` or `MyDrive/GhanaSegNet/data`)
3. **Update the copy command** - Modify the path in the next cell to match your Drive structure
4. **Run the cell** - The dataset will be copied to your Colab workspace

**Expected folder structure after copying:**
```
data/
  train/
    images/
    masks/
  val/
    images/
    masks/
  test/ (optional)
    images/
    masks/
```

In [None]:
# Download and extract data from Google Drive
# MODIFY THIS PATH to match your Google Drive structure
!cp -r "/content/drive/MyDrive/data" .

print("✅ Dataset copied from Google Drive!")

In [None]:
# Verify dataset is copied successfully
print("🔍 Checking dataset structure...")
!ls -la data/
print("\n📊 Dataset statistics:")
!echo "Train images:" && ls data/train/images/ | wc -l
!echo "Train masks:" && ls data/train/masks/ | wc -l
!echo "Val images:" && ls data/val/images/ | wc -l 2>/dev/null || echo "No val images found"
!echo "Val masks:" && ls data/val/masks/ | wc -l 2>/dev/null || echo "No val masks found"

In [None]:
# 🛠️ Setup and Dependencies
# Install all required dependencies
print("🔧 Installing PyTorch and dependencies...")
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install opencv-python pillow tqdm matplotlib seaborn
!pip install efficientnet-pytorch  # Required for GhanaSegNet backbone
!pip install segmentation-models-pytorch  # For DeepLabV3+ and other models

import torch
import os
print(f"\n✅ CUDA available: {torch.cuda.is_available()}")
print(f"✅ GPU device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")
print(f"✅ PyTorch version: {torch.__version__}")

# Verify EfficientNet installation
try:
    from efficientnet_pytorch import EfficientNet
    print("✅ EfficientNet-PyTorch installed successfully")
except ImportError:
    print("⚠️ EfficientNet-PyTorch not found - installing...")
    !pip install efficientnet-pytorch
    from efficientnet_pytorch import EfficientNet
    print("✅ EfficientNet-PyTorch installed successfully")

print("\n🎯 All dependencies installed! Ready for enhanced training!")

In [None]:
# 🔍 Verify Environment Setup
os.chdir('/content/GhanaSegNet')
print(f"📁 Current directory: {os.getcwd()}")
print(f"📂 Files: {os.listdir('.')[:10]}")  # Show first 10 files

# Verify all required modules can be imported
try:
    from models.ghanasegnet import GhanaSegNet
    from utils.losses import CombinedLoss
    from data.dataset_loader import GhanaFoodDataset
    print("\n✅ All GhanaSegNet modules imported successfully!")
    
    # Quick model test
    model = GhanaSegNet(num_classes=6)
    params = sum(p.numel() for p in model.parameters())
    print(f"✅ Enhanced GhanaSegNet: {params:,} parameters (16.6M - enhanced capacity)")
    print(f"✅ Architecture: 12-head transformer, 384-channel ASPP, multi-scale supervision")
    
except Exception as e:
    print(f"❌ Import error: {e}")
    print("Please check that all files are properly uploaded!")

In [None]:
# 🩺 Quick Training Environment Diagnostic
print("🩺 ENHANCED TRAINING DIAGNOSTIC CHECK")
print("="*45)

# Check Python path and imports
import sys
print(f"🐍 Python executable: {sys.executable}")
print(f"🐍 Python version: {sys.version}")

# Check dataset availability
import os
print(f"\n📁 Current working directory: {os.getcwd()}")
print(f"📊 Data directory exists: {os.path.exists('data')}")

if os.path.exists('data'):
    print(f"📊 Data contents: {os.listdir('data')}")
    if os.path.exists('data/train'):
        train_images = len(os.listdir('data/train/images')) if os.path.exists('data/train/images') else 0
        train_masks = len(os.listdir('data/train/masks')) if os.path.exists('data/train/masks') else 0
        print(f"🖼️ Training images: {train_images}")
        print(f"🎭 Training masks: {train_masks}")

# Quick model instantiation test
print(f"\n🤖 Testing model creation...")
try:
    from models.ghanasegnet import GhanaSegNet
    model = GhanaSegNet(num_classes=6)
    
    # Test forward pass
    import torch
    if torch.cuda.is_available():
        model = model.cuda()
        test_input = torch.randn(1, 3, 256, 256).cuda()
    else:
        test_input = torch.randn(1, 3, 256, 256)
    
    with torch.no_grad():
        output = model(test_input)
    
    print(f"✅ Model forward pass successful!")
    print(f"✅ Output shape: {output[0].shape if isinstance(output, tuple) else output.shape}")
    
except Exception as e:
    print(f"❌ Model test failed: {e}")
    import traceback
    print(f"📋 Full traceback:\n{traceback.format_exc()}")

# Check training script
print(f"\n📜 Training script check:")
script_exists = os.path.exists('scripts/train_baselines.py')
print(f"✅ Script exists: {script_exists}")

if script_exists:
    # Test import of training functions
    try:
        sys.path.append('.')
        from scripts.train_baselines import train_model, get_model_and_criterion
        print(f"✅ Training functions importable")
        
        # Test model creation through training script
        model, criterion = get_model_and_criterion('ghanasegnet', num_classes=6)
        print(f"✅ Model creation via training script successful")
        
    except Exception as e:
        print(f"❌ Training script import failed: {e}")
        import traceback
        print(f"📋 Traceback:\n{traceback.format_exc()}")

print("="*45)
print("🎯 Diagnostic complete! Check above for any issues before training.")

In [None]:
# 🛠️ Quick Fixes for Common Training Issues
import os
import sys

print("🛠️ APPLYING QUICK FIXES")
print("="*30)

# Fix 1: Ensure Python path includes current directory
if '.' not in sys.path:
    sys.path.insert(0, '.')
    print("✅ Added current directory to Python path")

if '/content/GhanaSegNet' not in sys.path:
    sys.path.insert(0, '/content/GhanaSegNet')
    print("✅ Added GhanaSegNet to Python path")

# Fix 2: Verify and create checkpoints directory
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('checkpoints/ghanasegnet', exist_ok=True)
print("✅ Ensured checkpoints directories exist")

# Fix 3: Check GPU memory and suggest batch size
import torch
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"🔍 GPU Memory: {gpu_memory:.1f} GB")
    
    if gpu_memory < 8:
        print("⚠️ Limited GPU memory detected. Consider using smaller batch sizes.")
        print("💡 Recommended: --batch-size 2 or --batch-size 1")
    else:
        print("✅ Sufficient GPU memory for training")

# Fix 4: Quick dataset validation
if os.path.exists('data/train/images') and os.path.exists('data/train/masks'):
    train_imgs = len([f for f in os.listdir('data/train/images') if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    train_masks = len([f for f in os.listdir('data/train/masks') if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    
    print(f"📊 Dataset check: {train_imgs} images, {train_masks} masks")
    
    if train_imgs == 0 or train_masks == 0:
        print("❌ No training data found!")
        print("💡 Please check your data path in the previous cell")
    elif train_imgs != train_masks:
        print("⚠️ Mismatch between images and masks count")
    else:
        print("✅ Dataset appears valid")
else:
    print("❌ Training data directories not found")
    print("💡 Make sure to run the dataset copy cell first")

print("="*30)
print("🎯 Quick fixes applied! Ready for training.")

In [None]:
# 🚀 Enhanced GhanaSegNet Training (30% mIoU Target) - WITH ERROR DEBUGGING
import subprocess
import sys

print("🎯 STARTING ENHANCED GHANASEGNET TRAINING")
print("="*50)
print("🔥 TARGET: 30% mIoU")
print("⚡ FEATURES: Progressive training, 12-head transformer, enhanced loss")
print("📈 STRATEGY: 256px→320px→384px resolution scaling")
print("="*50)

# First, let's check if the training script exists and is accessible
import os
print(f"\n📁 Current directory: {os.getcwd()}")
print(f"📂 Training script exists: {os.path.exists('scripts/train_baselines.py')}")

if not os.path.exists('scripts/train_baselines.py'):
    print("❌ Training script not found! Let's check available files:")
    print("Scripts directory:", os.listdir('scripts/') if os.path.exists('scripts/') else "Scripts folder not found")
    print("Root directory:", os.listdir('.')[:10])
else:
    print("✅ Training script found!")

# Test with simpler command first to debug
print("\n🔧 Running training with detailed error capture...")
try:
    # Run with more verbose output and error capture
    result = subprocess.run(
        ['python', 'scripts/train_baselines.py', '--model', 'ghanasegnet', '--epochs', '5'], 
        capture_output=True, 
        text=True, 
        timeout=3600  # 1 hour timeout
    )
    
    print("📤 STDOUT:")
    print(result.stdout)
    
    if result.stderr:
        print("\n📥 STDERR:")
        print(result.stderr)
    
    print(f"\n📊 Return code: {result.returncode}")
    
    if result.returncode == 0:
        print("\n🏆 ENHANCED GHANASEGNET TRAINING COMPLETED!")
        print("📊 Check results in checkpoints/ghanasegnet/ folder")
    else:
        print(f"\n❌ ERROR: Enhanced GhanaSegNet training failed with return code {result.returncode}!")
        print("🔍 Common issues to check:")
        print("   1. Dataset path correct?")
        print("   2. Sufficient GPU memory?")
        print("   3. All dependencies installed?")
        print("   4. Data format compatible?")
        
except subprocess.TimeoutExpired:
    print("⏰ Training timed out after 1 hour")
except Exception as e:
    print(f"💥 Unexpected error: {e}")
    print(f"Error type: {type(e).__name__}")

print("\n" + "="*50)

In [None]:
# 🔧 Fallback: Simple GhanaSegNet Training (if enhanced fails)
import subprocess

print("🔧 FALLBACK TRAINING OPTION")
print("="*40)
print("This cell provides a simpler training approach if the enhanced training fails")
print("Run this ONLY if the enhanced training above failed")
print("="*40)

# Simpler training with fewer epochs for debugging
fallback_training = False  # Change to True if you want to run this

if fallback_training:
    print("🚀 Starting simplified GhanaSegNet training...")
    
    try:
        # Try with minimal parameters first
        result = subprocess.run([
            'python', 'scripts/train_baselines.py', 
            '--model', 'ghanasegnet',
            '--epochs', '3',  # Very short for debugging
            '--batch-size', '4'  # Small batch size
        ], capture_output=True, text=True, timeout=1800)  # 30 min timeout
        
        print("📤 Training output:")
        print(result.stdout)
        
        if result.stderr:
            print("\n🔍 Error details:")
            print(result.stderr)
        
        if result.returncode == 0:
            print("✅ Fallback training completed successfully!")
        else:
            print(f"❌ Fallback training also failed (code: {result.returncode})")
            
    except Exception as e:
        print(f"💥 Fallback training error: {e}")
        
else:
    print("ℹ️ Fallback training not activated. Set fallback_training=True to run.")
    print("💡 Try the enhanced training first - this is just a backup option.")

In [None]:
# 📊 Load and Display Enhanced GhanaSegNet Results
import json
import matplotlib.pyplot as plt
import numpy as np

try:
    # Load enhanced GhanaSegNet results
    with open('checkpoints/ghanasegnet/training_results.json', 'r') as f:
        results = json.load(f)
    
    print("🎯 ENHANCED GHANASEGNET RESULTS")
    print("="*40)
    print(f"🏆 Best Validation mIoU: {results['best_val_iou']:.4f} ({results['best_val_iou']*100:.2f}%)")
    print(f"📈 Final Validation mIoU: {results['final_val_iou']:.4f} ({results['final_val_iou']*100:.2f}%)")
    print(f"🎯 30% Target Achieved: {'✅ YES!' if results.get('target_achieved', False) else '❌ Not quite'}")
    
    # Plot training progress
    if 'training_history' in results:
        history = results['training_history']
        epochs = list(range(1, len(history) + 1))
        train_ious = [h['train_iou'] for h in history]
        val_ious = [h['val_iou'] for h in history]
        
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.plot(epochs, [iou*100 for iou in train_ious], 'b-', label='Training mIoU', linewidth=2)
        plt.plot(epochs, [iou*100 for iou in val_ious], 'r-', label='Validation mIoU', linewidth=2)
        plt.axhline(y=30, color='g', linestyle='--', alpha=0.7, label='30% Target')
        plt.xlabel('Epoch')
        plt.ylabel('mIoU (%)')
        plt.title('Enhanced GhanaSegNet Training Progress')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Add progressive training phase markers
        plt.axvline(x=15, color='orange', linestyle=':', alpha=0.5, label='256→320px')
        plt.axvline(x=30, color='orange', linestyle=':', alpha=0.5, label='320→384px')
        
        plt.subplot(1, 2, 2)
        train_losses = [h['train_loss'] for h in history]
        val_losses = [h['val_loss'] for h in history]
        plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
        plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Enhanced GhanaSegNet Loss Progress')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    print("\n📈 Training completed with enhanced architecture!")
    
except FileNotFoundError:
    print("❌ Results file not found. Please run training first.")
except Exception as e:
    print(f"❌ Error loading results: {e}")

In [None]:
# 🔄 Optional: Train Baseline Models for Comparison
import subprocess

baseline_models = ['unet', 'deeplabv3plus', 'segformer']
print("🔄 TRAINING BASELINE MODELS FOR COMPARISON")
print("(This step is optional - you can skip if you only want GhanaSegNet results)")

train_baselines = input("Train baseline models? (y/n): ")

if train_baselines.lower() == 'y':
    for model in baseline_models:
        print(f"\n🏃‍♂️ Training {model.upper()}...")
        result = subprocess.run(['python', 'scripts/train_baselines.py', '--model', model, '--epochs', '15'])
        
        if result.returncode == 0:
            print(f"✅ {model.upper()} training completed!")
        else:
            print(f"❌ {model.upper()} training failed!")

    print("\n🏆 All baseline training completed!")
else:
    print("⏭️ Skipping baseline model training")

In [None]:
# 📊 Final Results Comparison (if baselines were trained)
import json
import matplotlib.pyplot as plt
import os

models = ['ghanasegnet', 'unet', 'deeplabv3plus', 'segformer']
model_names = ['Enhanced GhanaSegNet', 'UNet', 'DeepLabV3+', 'SegFormer-B0']
results_data = {}

print("📊 FINAL RESULTS SUMMARY")
print("="*50)

for i, model in enumerate(models):
    results_file = f'checkpoints/{model}/training_results.json'
    if os.path.exists(results_file):
        try:
            with open(results_file, 'r') as f:
                data = json.load(f)
            results_data[model] = {
                'name': model_names[i],
                'best_iou': data['best_val_iou'],
                'final_iou': data['final_val_iou']
            }
            print(f"✅ {model_names[i]:20}: {data['best_val_iou']*100:.2f}% mIoU")
        except:
            print(f"⚠️ {model_names[i]:20}: Results file corrupted")
    else:
        print(f"❌ {model_names[i]:20}: Not trained")

# Create comparison chart if we have multiple results
if len(results_data) > 1:
    plt.figure(figsize=(12, 6))
    
    names = [data['name'] for data in results_data.values()]
    ious = [data['best_iou']*100 for data in results_data.values()]
    
    colors = ['#ff6b6b', '#4ecdc4', '#45b7d1', '#96ceb4']
    bars = plt.bar(names, ious, color=colors[:len(names)])
    
    # Add 30% target line
    plt.axhline(y=30, color='red', linestyle='--', alpha=0.7, linewidth=2, label='30% Target')
    
    # Add value labels on bars
    for bar, iou in zip(bars, ious):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                f'{iou:.2f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.ylabel('Best Validation mIoU (%)')
    plt.title('Model Performance Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3, axis='y')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Find best performing model
    best_model = max(results_data.items(), key=lambda x: x[1]['best_iou'])
    print(f"\n🏆 BEST PERFORMER: {best_model[1]['name']} with {best_model[1]['best_iou']*100:.2f}% mIoU")
    
    if 'ghanasegnet' in results_data:
        ghanasegnet_iou = results_data['ghanasegnet']['best_iou'] * 100
        if ghanasegnet_iou >= 30.0:
            print(f"🎯 SUCCESS! Enhanced GhanaSegNet achieved the 30% mIoU target!")
        else:
            print(f"📈 Enhanced GhanaSegNet achieved {ghanasegnet_iou:.2f}% - Close to 30% target!")

print("\n🎉 Training and evaluation completed!")