<a href="https://colab.research.google.com/github/EricBaidoo/GhanaSegNet/blob/main/notebooks/Enhanced_GhanaSegNet_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🎯 Enhanced GhanaSegNet - Food Segmentation

**Objective:** Train Enhanced GhanaSegNet to achieve 30% mIoU for food segmentation

## 📋 **Project Overview**
- **Model**: Enhanced GhanaSegNet with FPN + Advanced ASPP + Multi-Head Attention
- **Parameters**: ~10.5M
- **Backbone**: EfficientNet-B0
- **Target**: 30% mIoU (improvement over 24.37% baseline)
- **Dataset**: Ghana Food Segmentation Dataset

## 📚 **Notebook Structure**
1. **Setup & Environment** - Dependencies, paths, verification
2. **Data Loading** - Dataset preparation and loaders
3. **Model Architecture** - Enhanced GhanaSegNet implementation
4. **Training Pipeline** - Real training with optimizations
5. **Evaluation & Results** - Performance analysis
6. **Test-Time Augmentation** - Optional performance boost

---

In [None]:
# ========================================
# 1️⃣ SETUP & ENVIRONMENT
# ========================================

# Mount Google Drive if in Colab
import sys
import os

if 'google.colab' in sys.modules:
    print("🔗 Mounting Google Drive...")
    from google.colab import drive
    drive.mount('/content/drive')
    print("✅ Google Drive mounted successfully!")
else:
    print("📍 Running locally")

In [None]:
# Clone repository (if needed)
if 'google.colab' in sys.modules and not os.path.exists('/content/GhanaSegNet'):
    print("📥 Cloning GhanaSegNet repository...")
    !git clone https://github.com/EricBaidoo/GhanaSegNet.git /content/GhanaSegNet
    print("✅ Repository cloned successfully!")
    %cd /content/GhanaSegNet

In [None]:
# 🔧 ENVIRONMENT VERIFICATION & SETUP

import torch
import numpy as np
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

print("🔍 SYSTEM VERIFICATION")
print("="*50)

# Check CUDA availability
if torch.cuda.is_available():
    print(f"✅ CUDA available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    device = torch.device('cuda')
else:
    print("⚠️  CUDA not available - using CPU")
    device = torch.device('cpu')

# Set up paths
if 'google.colab' in sys.modules:
    PROJECT_ROOT = '/content/GhanaSegNet'
    DATA_PATH = '/content/drive/MyDrive/data'
    print("✅ Running in Google Colab")
else:
    PROJECT_ROOT = os.getcwd()
    DATA_PATH = 'data'
    print("📍 Running locally")

# Add project to Python path
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

print(f"📁 Project root: {PROJECT_ROOT}")
print(f"📁 Data path: {DATA_PATH}")

# Verify key files
key_files = [
    'models/ghanasegnet.py',
    'utils/losses.py', 
    'utils/metrics.py',
    'data/dataset_loader.py'
]

missing_files = []
for file_path in key_files:
    full_path = os.path.join(PROJECT_ROOT, file_path)
    if os.path.exists(full_path):
        print(f"✅ {file_path}")
    else:
        print(f"❌ {file_path} - MISSING!")
        missing_files.append(file_path)

# Check dataset
if os.path.exists(DATA_PATH):
    print(f"✅ Dataset directory found")
    if os.path.exists(os.path.join(DATA_PATH, 'train')):
        print(f"✅ Train split available")
    if os.path.exists(os.path.join(DATA_PATH, 'val')):
        print(f"✅ Validation split available")
else:
    print(f"⚠️  Dataset not found at: {DATA_PATH}")

if not missing_files:
    print(f"\n🎉 SETUP COMPLETE - Ready to proceed!")
else:
    print(f"\n⚠️  Some files missing - check repository structure")

print("="*50)

---
# ========================================
# 2️⃣ DATA LOADING & PREPARATION
# ========================================

In [None]:
# 📊 DATASET LOADING

from torch.utils.data import DataLoader
from data.dataset_loader import GhanaFoodDataset

print("📊 Loading Ghana Food Dataset...")

try:
    # Load datasets
    train_dataset = GhanaFoodDataset(DATA_PATH, split='train')
    val_dataset = GhanaFoodDataset(DATA_PATH, split='val')
    
    print(f"✅ Train samples: {len(train_dataset)}")
    print(f"✅ Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)
    
    print(f"✅ Data loaders created successfully")
    
except Exception as e:
    print(f"❌ Dataset loading failed: {e}")
    print("Please check your dataset path and structure")

In [None]:
---
# ========================================
# 3️⃣ MODEL ARCHITECTURE
# ========================================

In [None]:
# 🏗️ ENHANCED GHANASEGNET MODEL

from models.ghanasegnet import EnhancedGhanaSegNet
from utils.losses import CombinedLoss
from utils.metrics import calculate_miou

print("🏗️ Initializing Enhanced GhanaSegNet...")

# Initialize model
model = EnhancedGhanaSegNet(num_classes=6).to(device)
num_params = sum(p.numel() for p in model.parameters())

print(f"✅ Model initialized")
print(f"📊 Parameters: {num_params/1e6:.2f}M")
print(f"🎯 Architecture: EfficientNet-B0 + FPN + Enhanced ASPP + Multi-Head Attention")

# Initialize loss function
criterion = CombinedLoss()
print(f"✅ Combined loss function ready (Dice + Focal + Boundary)")

In [None]:
---
# ========================================
# 4️⃣ TRAINING PIPELINE
# ========================================

In [None]:
# ⚙️ TRAINING CONFIGURATION

import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import time
from tqdm import tqdm

print("⚙️ Setting up training configuration...")

# Training hyperparameters
config = {
    'epochs': 15,
    'learning_rate': 2.5e-4,
    'weight_decay': 1.2e-3,
    'batch_size': 8,
    'patience': 8,
    'save_path': 'best_enhanced_ghanasegnet.pth'
}

# Initialize optimizer and scheduler
optimizer = optim.AdamW(
    model.parameters(), 
    lr=config['learning_rate'], 
    weight_decay=config['weight_decay']
)

scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1)

print(f"✅ Optimizer: AdamW (lr={config['learning_rate']}, wd={config['weight_decay']})")
print(f"✅ Scheduler: Cosine Annealing with Warm Restarts")
print(f"✅ Training for {config['epochs']} epochs")

# Training tracking
best_val_iou = 0.0
training_history = {
    'train_loss': [],
    'val_loss': [], 
    'val_iou': [],
    'learning_rate': []
}

In [None]:
# 🚀 ENHANCED GHANASEGNET TRAINING

print("🚀 STARTING ENHANCED GHANASEGNET TRAINING")
print("="*60)
print(f"🎯 Target: Improve upon baseline (24.37% mIoU)")
print(f"🔥 Model: Enhanced Architecture ({num_params/1e6:.1f}M parameters)")
print("="*60)

# Training loop
print("🔄 Beginning training...")

for epoch in range(config['epochs']):
    start_time = time.time()
    
    # Training phase
    model.train()
    train_loss = 0.0
    train_batches = 0
    
    with tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]} [Train]') as pbar:
        for batch_idx, (images, masks) in enumerate(pbar):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # Handle potential tuple output from model
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{train_loss/train_batches:.4f}'
            })
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        with tqdm(val_loader, desc=f'Epoch {epoch+1}/{config["epochs"]} [Val]') as pbar:
            for images, masks in pbar:
                images, masks = images.to(device), masks.to(device)
                
                outputs = model(images)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                # Collect predictions for mIoU calculation
                preds = torch.argmax(outputs, dim=1)
                all_predictions.append(preds.cpu().numpy())
                all_targets.append(masks.cpu().numpy())
    
    # Calculate metrics
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    val_iou = calculate_miou(all_predictions, all_targets, num_classes=6)
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    # Calculate epoch time
    epoch_time = time.time() - start_time
    
    # Store training history
    training_history['train_loss'].append(train_loss / train_batches)
    training_history['val_loss'].append(val_loss / len(val_loader))
    training_history['val_iou'].append(val_iou)
    training_history['learning_rate'].append(current_lr)
    
    # Display epoch results
    print(f"\n🚀 EPOCH {epoch+1}/{config['epochs']}")
    print(f"📊 Train Loss: {train_loss/train_batches:.4f}")
    print(f"📊 Val Loss: {val_loss/len(val_loader):.4f}")
    print(f"📊 Val IoU: {val_iou:.4f} ({val_iou*100:.2f}%)")
    print(f"⚡️ Learning Rate: {current_lr:.2e}")
    print(f"⏱️  Epoch Time: {epoch_time:.1f}s")
    
    # Check for improvement
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        print(f"🎯 NEW BEST! Saving model...")
        
        # Save best model
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_iou': best_val_iou,
            'training_history': training_history
        }, config['save_path'])
        
        # Check if target achieved
        if val_iou >= 0.30:
            print("🎉 TARGET ACHIEVED! 30% mIoU reached!")
    
    print("-" * 60)

# Training completion
print(f"\n🏁 TRAINING COMPLETED!")
print(f"🏆 Best Validation IoU: {best_val_iou:.4f} ({best_val_iou*100:.2f}%)")

# Performance analysis
baseline_miou = 0.2437
improvement = (best_val_iou - baseline_miou) * 100
relative_improvement = (improvement / (baseline_miou * 100)) * 100

print(f"\n📊 PERFORMANCE ANALYSIS:")
print(f"   Baseline: {baseline_miou*100:.2f}% mIoU")
print(f"   Enhanced: {best_val_iou*100:.2f}% mIoU")
print(f"   Improvement: +{improvement:.2f} percentage points")
print(f"   Relative gain: +{relative_improvement:.1f}%")

if best_val_iou >= 0.30:
    print(f"🎉 EXCELLENT! 30% mIoU target achieved!")
elif best_val_iou >= 0.28:
    print(f"🔥 GREAT! Very close to 30% target!")
elif best_val_iou >= 0.26:
    print(f"✅ SOLID! Good improvement over baseline!")
else:
    print(f"📊 Training completed. Consider TTA boost for higher performance.")

In [None]:
---
# ========================================
# 5️⃣ EVALUATION & RESULTS
# ========================================

In [None]:
# 📊 TRAINING RESULTS VISUALIZATION

import matplotlib.pyplot as plt

print("📊 Visualizing training results...")

# Create training plots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Training & Validation Loss
axes[0, 0].plot(training_history['train_loss'], label='Train Loss', color='blue')
axes[0, 0].plot(training_history['val_loss'], label='Val Loss', color='red')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Validation IoU
axes[0, 1].plot(training_history['val_iou'], label='Val IoU', color='green', linewidth=2)
axes[0, 1].axhline(y=0.30, color='red', linestyle='--', label='30% Target')
axes[0, 1].axhline(y=best_val_iou, color='orange', linestyle='--', label=f'Best: {best_val_iou:.3f}')
axes[0, 1].set_title('Validation IoU Progress')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('IoU')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Learning Rate Schedule
axes[1, 0].plot(training_history['learning_rate'], label='Learning Rate', color='purple')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Performance Comparison
models = ['Baseline', 'Enhanced GhanaSegNet']
performance = [baseline_miou * 100, best_val_iou * 100]
colors = ['lightblue', 'darkblue']

axes[1, 1].bar(models, performance, color=colors)
axes[1, 1].axhline(y=30, color='red', linestyle='--', label='30% Target')
axes[1, 1].set_title('Model Performance Comparison')
axes[1, 1].set_ylabel('mIoU (%)')
axes[1, 1].legend()

# Add value labels on bars
for i, v in enumerate(performance):
    axes[1, 1].text(i, v + 0.5, f'{v:.2f}%', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# Summary statistics
print(f"\n📈 TRAINING SUMMARY:")
print(f"   Total epochs: {len(training_history['val_iou'])}")
print(f"   Best epoch: {np.argmax(training_history['val_iou']) + 1}")
print(f"   Final train loss: {training_history['train_loss'][-1]:.4f}")
print(f"   Final val loss: {training_history['val_loss'][-1]:.4f}")
print(f"   Best val IoU: {best_val_iou:.4f} ({best_val_iou*100:.2f}%)")

In [None]:
---
# ========================================
# 6️⃣ TEST-TIME AUGMENTATION (OPTIONAL)
# ========================================

In [None]:
# 🚀 TEST-TIME AUGMENTATION BOOST
# Run this only if you want to further improve performance

print("🚀 APPLYING TEST-TIME AUGMENTATION (TTA)")
print("="*55)
print("💡 TTA can provide +1-3% mIoU improvement")
print("🔬 Uses multi-scale and flip augmentations")
print("="*55)

import torch.nn.functional as F

# Load the best trained model
print("📥 Loading best trained model...")
try:
    checkpoint = torch.load(config['save_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    base_performance = checkpoint['best_val_iou']
    print(f"✅ Loaded model with {base_performance:.4f} ({base_performance*100:.2f}%) mIoU")
except:
    print("⚠️  Using current model state")
    base_performance = best_val_iou

model.eval()

def tta_predict(model, x):
    """Apply Test-Time Augmentation"""
    predictions = []
    
    with torch.no_grad():
        # Original prediction
        pred = model(x)
        if isinstance(pred, tuple):
            pred = pred[0]
        predictions.append(F.softmax(pred, dim=1))
        
        # Horizontal flip
        x_flip = torch.flip(x, [3])
        pred_flip = model(x_flip)
        if isinstance(pred_flip, tuple):
            pred_flip = pred_flip[0]
        pred_flip = torch.flip(F.softmax(pred_flip, dim=1), [3])
        predictions.append(pred_flip)
        
        # Multi-scale predictions  
        for scale in [0.9, 1.1]:
            h, w = x.shape[2], x.shape[3]
            new_h, new_w = int(h * scale), int(w * scale)
            
            x_scaled = F.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
            pred_scaled = model(x_scaled)
            if isinstance(pred_scaled, tuple):
                pred_scaled = pred_scaled[0]
            pred_scaled = F.interpolate(pred_scaled, size=(h, w), mode='bilinear', align_corners=False)
            predictions.append(F.softmax(pred_scaled, dim=1))
    
    return torch.stack(predictions).mean(dim=0)

# Apply TTA evaluation
print("🔄 Applying TTA to validation set...")
tta_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

all_tta_predictions = []
all_tta_targets = []

for batch_idx, (images, masks) in enumerate(tqdm(tta_loader, desc="TTA Evaluation")):
    images = images.to(device)
    masks = masks.to(device)
    
    # Apply TTA
    tta_preds = tta_predict(model, images)
    pred_masks = torch.argmax(tta_preds, dim=1)
    
    all_tta_predictions.append(pred_masks.cpu().numpy())
    all_tta_targets.append(masks.cpu().numpy())

# Calculate TTA performance
all_tta_predictions = np.concatenate(all_tta_predictions, axis=0)
all_tta_targets = np.concatenate(all_tta_targets, axis=0)
tta_miou = calculate_miou(all_tta_predictions, all_tta_targets, num_classes=6)

# Results
improvement = (tta_miou - base_performance) * 100

print(f"\n🎯 TTA RESULTS:")
print(f"📊 Base Model: {base_performance:.4f} ({base_performance*100:.2f}% mIoU)")
print(f"🚀 With TTA: {tta_miou:.4f} ({tta_miou*100:.2f}% mIoU)")
print(f"📈 Improvement: +{improvement:.2f} percentage points")

if tta_miou >= 0.30:
    print(f"🎉 EXCELLENT! TTA achieved 30% mIoU target!")
elif tta_miou >= 0.29:
    print(f"🔥 OUTSTANDING! Very close to 30% target!")
elif improvement > 1.0:
    print(f"✅ SOLID BOOST! TTA provided meaningful improvement!")
else:
    print(f"📊 TTA applied with modest improvement")

print(f"\n🔬 TTA METHODOLOGY:")
print(f"   • Horizontal flip augmentation")
print(f"   • Multi-scale testing (0.9x, 1.0x, 1.1x)")
print(f"   • Ensemble averaging")
print(f"   • Legitimate evaluation enhancement")

# Final comparison
print(f"\n🏆 FINAL PERFORMANCE SUMMARY:")
print(f"   Baseline GhanaSegNet: 24.37% mIoU")
print(f"   Enhanced GhanaSegNet: {base_performance*100:.2f}% mIoU")
print(f"   Enhanced + TTA: {tta_miou*100:.2f}% mIoU")
print(f"   Total improvement: +{(tta_miou - 0.2437)*100:.2f} percentage points")