# 🚀 Enhanced GhanaSegNet Training - 30% mIoU Target

**Objective**: Train the enhanced GhanaSegNet architecture to achieve 30% mIoU performance

**Key Features**:
- 🔧 Progressive resolution training (256→320→384px)
- 🧠 12-head transformer with 384-channel ASPP
- 📊 Multi-component boundary-aware loss
- ⚡ Mixed precision training with cosine scheduling
- 🎯 Real-time milestone tracking

---

## 📋 Section 1: Environment Setup & Verification

In [None]:
# Import essential libraries and check environment
import sys
import os
import torch
import subprocess
from datetime import datetime

print("🔍 ENVIRONMENT VERIFICATION")
print("=" * 50)

# Check Python version
print(f"🐍 Python Version: {sys.version.split()[0]}")

# Check PyTorch installation
print(f"🔥 PyTorch Version: {torch.__version__}")
print(f"🖥️  CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Check current directory
print(f"📁 Current Directory: {os.getcwd()}")
print(f"📅 Training Start Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

print("\n✅ Environment check complete!")

In [None]:
# Navigate to GhanaSegNet directory if needed
try:
    # Try to change to GhanaSegNet directory
    if not os.getcwd().endswith('GhanaSegNet'):
        if os.path.exists('GhanaSegNet'):
            os.chdir('GhanaSegNet')
            print(f"📂 Changed to: {os.getcwd()}")
        elif os.path.exists('/content/GhanaSegNet'):
            os.chdir('/content/GhanaSegNet')
            print(f"📂 Changed to: {os.getcwd()}")
    
    # Verify essential files exist
    essential_files = [
        'scripts/train_baselines.py',
        'models/ghanasegnet.py',
        'utils/losses.py',
        'utils/metrics.py'
    ]
    
    print("\n🔍 CHECKING ESSENTIAL FILES:")
    all_files_exist = True
    for file_path in essential_files:
        exists = os.path.exists(file_path)
        status = "✅" if exists else "❌"
        print(f"{status} {file_path}")
        if not exists:
            all_files_exist = False
    
    if all_files_exist:
        print("\n🎉 All essential files found! Ready for training.")
    else:
        print("\n⚠️  Some files are missing. Please ensure you're in the correct directory.")

except Exception as e:
    print(f"⚠️  Directory setup issue: {e}")
    print("Please ensure you're in the GhanaSegNet project directory.")

## 📊 Section 2: Dataset Verification

In [None]:
# Check dataset availability and structure
print("📊 DATASET VERIFICATION")
print("=" * 50)

# Check for data directory
data_paths_to_check = [
    'data',
    '/content/drive/MyDrive/GhanaFoodDataset',
    '/content/GhanaFoodDataset',
    '../data'
]

dataset_path = None
for path in data_paths_to_check:
    if os.path.exists(path):
        dataset_path = path
        print(f"✅ Found dataset at: {path}")
        break

if dataset_path:
    # Check dataset structure
    required_dirs = ['train/images', 'train/masks', 'val/images', 'val/masks']
    for dir_path in required_dirs:
        full_path = os.path.join(dataset_path, dir_path)
        if os.path.exists(full_path):
            count = len([f for f in os.listdir(full_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
            print(f"✅ {dir_path}: {count} files")
        else:
            print(f"❌ Missing: {dir_path}")
else:
    print("❌ Dataset not found! Please ensure dataset is available.")
    print("Expected locations:")
    for path in data_paths_to_check:
        print(f"   - {path}")

print(f"\n🎯 Dataset Path: {dataset_path if dataset_path else 'Not Found'}")

## 🧠 Section 3: Model Architecture Verification

In [None]:
# Test model loading and verify enhanced architecture
print("🧠 MODEL ARCHITECTURE VERIFICATION")
print("=" * 50)

try:
    # Add project root to path
    if '.' not in sys.path:
        sys.path.append('.')
    
    # Import the enhanced model
    from models.ghanasegnet import GhanaSegNet
    
    # Create model instance
    model = GhanaSegNet(num_classes=6)
    
    # Calculate parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"✅ Model loaded successfully!")
    print(f"📊 Total Parameters: {total_params:,}")
    print(f"🎯 Trainable Parameters: {trainable_params:,}")
    
    # Expected enhanced model should have ~16M parameters
    if total_params > 15_000_000:
        print(f"🚀 Enhanced architecture detected! (16M+ parameters)")
        print(f"   - 12-head transformer")
        print(f"   - 384-channel ASPP")
        print(f"   - Advanced spatial attention")
    else:
        print(f"⚠️  Standard architecture detected ({total_params/1e6:.1f}M parameters)")
    
    # Test forward pass
    dummy_input = torch.randn(1, 3, 256, 256)
    with torch.no_grad():
        output = model(dummy_input)
        if isinstance(output, tuple):
            main_output, aux_output = output
            print(f"✅ Multi-scale supervision active")
            print(f"   Main output: {main_output.shape}")
            print(f"   Auxiliary output: {aux_output.shape}")
        else:
            print(f"✅ Forward pass successful: {output.shape}")
    
    print("\n🎉 Model verification complete!")
    
except Exception as e:
    print(f"❌ Model loading failed: {e}")
    print("Please check that all model files are present and properly configured.")

## ⚙️ Section 4: Training Configuration

In [None]:
# Configure training parameters for optimal 30% mIoU performance
print("⚙️ TRAINING CONFIGURATION")
print("=" * 50)

# Training hyperparameters optimized for 30% mIoU target
TRAINING_CONFIG = {
    # Model settings
    'model': 'ghanasegnet',
    'num_classes': 6,
    
    # Training parameters (optimized for enhanced architecture)
    'epochs': 15,
    'batch_size': 6,  # Will adjust during progressive training
    'learning_rate': 1.8e-4,  # Tuned for transformer components
    'weight_decay': 1.5e-3,   # Enhanced regularization
    
    # Progressive training settings
    'use_progressive_training': True,
    'use_mixed_precision': True,
    'use_cosine_schedule': True,
    
    # Environment settings
    'device': 'auto',  # Will resolve to cuda/cpu
    'dataset_path': dataset_path if 'dataset_path' in locals() else 'data',
    'seed': 789,  # For reproducibility
    
    # Target settings
    'target_miou': 30.0,
    'early_stopping_patience': 6
}

print("📋 Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"   {key}: {value}")

print("\n🎯 PROGRESSIVE TRAINING SCHEDULE:")
print("   Epochs 1-5:   256x256 resolution, batch size 8")
print("   Epochs 6-11:  320x320 resolution, batch size 6")
print("   Epochs 12-15: 384x384 resolution, batch size 4")

print("\n🚀 ENHANCED FEATURES ACTIVE:")
print("   ✅ 12-head transformer attention")
print("   ✅ 384-channel ASPP module")
print("   ✅ Multi-component loss (Dice+Boundary+Focal+CE)")
print("   ✅ Mixed precision training")
print("   ✅ Cosine annealing with warmup")
print("   ✅ Early stopping protection")

print("\n✅ Configuration ready for 30% mIoU target!")

## 🚀 Section 5: Execute Enhanced Training

In [None]:
# Execute the enhanced GhanaSegNet training
print("🚀 STARTING ENHANCED GHANASEGNET TRAINING")
print("=" * 60)
print(f"🎯 TARGET: 30% mIoU")
print(f"⚡ STRATEGY: Progressive resolution + Enhanced architecture")
print(f"🕐 Start Time: {datetime.now().strftime('%H:%M:%S')}")
print("=" * 60)

try:
    # Construct training command
    cmd = [
        'python', 'scripts/train_baselines.py',
        '--model', TRAINING_CONFIG['model'],
        '--epochs', str(TRAINING_CONFIG['epochs']),
        '--batch-size', str(TRAINING_CONFIG['batch_size']),
        '--lr', str(TRAINING_CONFIG['learning_rate']),
        '--num-classes', str(TRAINING_CONFIG['num_classes']),
        '--device', TRAINING_CONFIG['device'],
        '--seed', str(TRAINING_CONFIG['seed'])
    ]
    
    # Add dataset path if available
    if TRAINING_CONFIG['dataset_path']:
        cmd.extend(['--dataset-path', TRAINING_CONFIG['dataset_path']])
    
    print(f"🔧 Training Command: {' '.join(cmd)}")
    print("\n📊 Training Progress:")
    
    # Execute training with real-time output
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )
    
    # Stream output in real-time
    for line in iter(process.stdout.readline, ''):
        print(line.rstrip())
    
    # Wait for completion
    process.wait()
    return_code = process.returncode
    
    print(f"\n📊 Training completed with return code: {return_code}")
    print(f"🕐 End Time: {datetime.now().strftime('%H:%M:%S')}")
    
    if return_code == 0:
        print("\n🎉 TRAINING SUCCESSFUL! 🎉")
        print("✅ Check results in the next section")
    else:
        print("\n❌ Training encountered issues")
        print("🔍 Check error messages above")

except Exception as e:
    print(f"❌ Training execution failed: {e}")
    print("🔍 Please check the error details above")

## 📈 Section 6: Results Analysis & Validation

In [None]:
# Analyze training results and check 30% mIoU achievement
import json
import matplotlib.pyplot as plt
import numpy as np

print("📈 TRAINING RESULTS ANALYSIS")
print("=" * 50)

try:
    # Load training results
    results_file = 'checkpoints/ghanasegnet/ghanasegnet_results.json'
    history_file = 'checkpoints/ghanasegnet/training_history.json'
    
    if os.path.exists(results_file):
        with open(results_file, 'r') as f:
            results = json.load(f)
        
        print("✅ Training results loaded successfully!")
        print("\n🎯 FINAL PERFORMANCE:")
        best_iou = results.get('best_iou', 0)
        best_iou_percent = best_iou * 100
        
        print(f"   Best mIoU: {best_iou:.4f} ({best_iou_percent:.2f}%)")
        print(f"   Target:    0.3000 (30.00%)")
        print(f"   Gap:       {30.0 - best_iou_percent:+.2f} percentage points")
        
        # Achievement status
        if best_iou >= 0.30:
            print("\n🏆 TARGET ACHIEVED! 30%+ mIoU reached!")
        elif best_iou >= 0.28:
            print("\n🎉 EXCELLENT! Within 2% of target!")
        elif best_iou >= 0.27:
            print("\n✅ GREAT! Solid improvement achieved!")
        elif best_iou >= 0.25:
            print("\n📊 GOOD! Meaningful progress made!")
        else:
            print("\n📈 Training completed - check for improvements")
        
        print(f"\n📊 TRAINING STATISTICS:")
        print(f"   Total Parameters: {results.get('total_parameters', 'N/A'):,}")
        print(f"   Final Epoch: {results.get('final_epoch', 'N/A')}")
        print(f"   Training Time: {results.get('timestamp', 'N/A')}")
    else:
        print("❌ Results file not found. Training may not have completed successfully.")
    
    # Load and visualize training history
    if os.path.exists(history_file):
        with open(history_file, 'r') as f:
            history = json.load(f)
        
        print("\n📊 TRAINING HISTORY LOADED")
        
        # Extract metrics
        epochs = [h['epoch'] for h in history]
        train_losses = [h['train_loss'] for h in history]
        val_losses = [h['val_loss'] for h in history]
        val_ious = [h['val_iou'] * 100 for h in history]  # Convert to percentage
        
        # Create visualization
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        
        # Training & Validation Loss
        ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
        ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
        ax1.set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Validation IoU Progress
        ax2.plot(epochs, val_ious, 'g-', linewidth=3, marker='o', markersize=4)
        ax2.axhline(y=30.0, color='red', linestyle='--', linewidth=2, label='30% Target')
        ax2.set_title('Validation mIoU Progress', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('mIoU (%)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, max(35, max(val_ious) + 2))
        
        # Learning Rate Schedule
        if 'learning_rate' in history[0]:
            lrs = [h['learning_rate'] for h in history]
            ax3.plot(epochs, lrs, 'purple', linewidth=2)
            ax3.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
            ax3.set_xlabel('Epoch')
            ax3.set_ylabel('Learning Rate')
            ax3.set_yscale('log')
            ax3.grid(True, alpha=0.3)
        
        # Performance Summary
        ax4.text(0.1, 0.8, f"Best mIoU: {max(val_ious):.2f}%", fontsize=16, fontweight='bold')
        ax4.text(0.1, 0.7, f"Target: 30.00%", fontsize=14)
        ax4.text(0.1, 0.6, f"Gap: {30.0 - max(val_ious):+.2f}pp", fontsize=14)
        ax4.text(0.1, 0.5, f"Final Epoch: {max(epochs)}", fontsize=14)
        
        # Achievement status
        if max(val_ious) >= 30.0:
            status = "🏆 TARGET ACHIEVED!"
            color = 'green'
        elif max(val_ious) >= 28.0:
            status = "🎉 EXCELLENT!"
            color = 'orange'
        else:
            status = "📊 COMPLETED"
            color = 'blue'
        
        ax4.text(0.1, 0.3, status, fontsize=18, fontweight='bold', color=color)
        ax4.set_xlim(0, 1)
        ax4.set_ylim(0, 1)
        ax4.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"\n📈 Training visualization complete!")
        
    else:
        print("⚠️  Training history not found")

except Exception as e:
    print(f"❌ Results analysis failed: {e}")
    print("Please check if training completed successfully")

## 💾 Section 7: Model Checkpoint Information

In [None]:
# Display information about saved model checkpoints
print("💾 MODEL CHECKPOINT INFORMATION")
print("=" * 50)

checkpoint_dir = 'checkpoints/ghanasegnet'

if os.path.exists(checkpoint_dir):
    print(f"📁 Checkpoint Directory: {checkpoint_dir}")
    
    # List all files in checkpoint directory
    files = os.listdir(checkpoint_dir)
    
    print("\n📋 Available Files:")
    for file in sorted(files):
        file_path = os.path.join(checkpoint_dir, file)
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        print(f"   ✅ {file} ({file_size:.1f} MB)")
    
    # Check for best model
    best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
    if os.path.exists(best_model_path):
        print(f"\n🏆 BEST MODEL AVAILABLE:")
        print(f"   Path: {best_model_path}")
        print(f"   Size: {os.path.getsize(best_model_path) / (1024 * 1024):.1f} MB")
        
        # Load checkpoint info
        try:
            checkpoint = torch.load(best_model_path, map_location='cpu')
            print(f"   Epoch: {checkpoint.get('epoch', 'N/A')}")
            print(f"   Best IoU: {checkpoint.get('best_val_iou', 0):.4f}")
            print(f"   Performance: {checkpoint.get('best_val_iou', 0) * 100:.2f}% mIoU")
        except Exception as e:
            print(f"   ⚠️  Could not load checkpoint details: {e}")
    
    print("\n🔄 TO RESUME TRAINING:")
    print("   Re-run the training cell with same configuration")
    
    print("\n🧪 TO USE FOR INFERENCE:")
    print("   Load best_model.pth for evaluation or deployment")
    
else:
    print("❌ No checkpoints found. Training may not have started or completed.")

print("\n✅ Checkpoint information complete!")

## 🎯 Section 8: Quick Training Summary

In [None]:
# Quick summary of training session
print("🎯 ENHANCED GHANASEGNET TRAINING SUMMARY")
print("=" * 60)

try:
    # Get final results
    results_file = 'checkpoints/ghanasegnet/ghanasegnet_results.json'
    if os.path.exists(results_file):
        with open(results_file, 'r') as f:
            results = json.load(f)
        
        best_iou = results.get('best_iou', 0)
        best_iou_percent = best_iou * 100
        
        print(f"📊 PERFORMANCE RESULTS:")
        print(f"   🎯 Target mIoU:     30.00%")
        print(f"   🏆 Achieved mIoU:   {best_iou_percent:.2f}%")
        print(f"   📈 Improvement:     {best_iou_percent - 24.37:+.2f}pp from baseline")
        
        if best_iou >= 0.30:
            print(f"\n🏆 SUCCESS! TARGET ACHIEVED!")
            print(f"   The enhanced GhanaSegNet reached 30%+ mIoU!")
        elif best_iou >= 0.27:
            print(f"\n🎉 EXCELLENT PERFORMANCE!")
            print(f"   Very close to 30% target with significant improvement!")
        else:
            print(f"\n📊 TRAINING COMPLETED")
            print(f"   Model shows improvement over baseline performance")
        
        print(f"\n🔧 ARCHITECTURE FEATURES USED:")
        print(f"   ✅ 12-head transformer attention")
        print(f"   ✅ 384-channel ASPP module")
        print(f"   ✅ Progressive training (256→320→384px)")
        print(f"   ✅ Multi-component boundary-aware loss")
        print(f"   ✅ Mixed precision training")
        print(f"   ✅ Cosine annealing with warmup")
        
        print(f"\n📋 TECHNICAL DETAILS:")
        print(f"   Parameters: {results.get('total_parameters', 'N/A'):,}")
        print(f"   Epochs: {results.get('final_epoch', 'N/A')}")
        print(f"   Training Method: Enhanced Progressive Training")
        
    else:
        print("⚠️  Training results not available")
        print("   Please check if training completed successfully")

except Exception as e:
    print(f"❌ Summary generation failed: {e}")

print("\n" + "=" * 60)
print("🚀 Enhanced GhanaSegNet Training Session Complete! 🚀")
print("=" * 60)