<a href="https://colab.research.google.com/github/EricBaidoo/GhanaSegNet/blob/main/GhanaSegNet_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GhanaSegNet Colab Training Notebook

This notebook sets up your environment, installs dependencies, and runs your baseline training script for UNet, DeepLabV3+, and SegFormer-B0 on Colab GPU.

## Setup Instructions:
1. Run each cell in order
2. Make sure GPU is enabled: Runtime > Change runtime type > Hardware accelerator > GPU
3. Your data should be uploaded to Google Drive or included in your GitHub repo

In [1]:
# Mount Google Drive (if your data is stored there)
from google.colab import drive
drive.mount('/content/drive')

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU detected - switch to GPU runtime!")

Mounted at /content/drive
CUDA available: True
GPU: NVIDIA A100-SXM4-40GB


In [2]:
# Clone your GitHub repo
!git clone https://github.com/EricBaidoo/GhanaSegNet.git
%cd GhanaSegNet

# Check if we have the expected files
!ls -la

Cloning into 'GhanaSegNet'...
remote: Enumerating objects: 5567, done.[K
remote: Counting objects: 100% (136/136), done.[K
remote: Compressing objects: 100% (92/92), done.[K
remote: Total 5567 (delta 69), reused 98 (delta 43), pack-reused 5431 (from 2)[K
Receiving objects: 100% (5567/5567), 701.56 MiB | 17.01 MiB/s, done.
Resolving deltas: 100% (90/90), done.
/content/GhanaSegNet
total 204
drwxr-xr-x 7 root root  4096 Sep 14 17:10 .
drwxr-xr-x 1 root root  4096 Sep 14 17:10 ..
-rw-r--r-- 1 root root  3692 Sep 14 17:10 cleanup_repo.py
-rw-r--r-- 1 root root  6024 Sep 14 17:10 colab_training_cells.py
-rw-r--r-- 1 root root 83713 Sep 14 17:10 GhanaSegNet_Colab.ipynb
drwxr-xr-x 8 root root  4096 Sep 14 17:10 .git
-rw-r--r-- 1 root root   627 Sep 14 17:10 .gitignore
-rw-r--r-- 1 root root 57714 Sep 14 17:10 image.png
-rw-r--r-- 1 root root    30 Sep 14 17:10 LICENSE
drwxr-xr-x 2 root root  4096 Sep 14 17:10 models
drwxr-xr-x 2 root root  4096 Sep 14 17:10 notebooks
-rw-r--r-- 1 root roo

## 📁 Dataset Connection Instructions

**Before running the next cell:**

1. **Locate your data folder in Google Drive** - Find where you uploaded your `data` folder
2. **Check the path** - Note the exact path (e.g., `MyDrive/data` or `MyDrive/GhanaSegNet/data`)
3. **Update the copy command** - Modify the path in the next cell to match your Drive structure
4. **Run the cell** - The dataset will be copied to your Colab workspace

**Expected folder structure after copying:**
```
data/
  train/
    images/
    masks/
  val/
    images/
    masks/
  test/ (optional)
    images/
    masks/
```

In [3]:
# Download and extract data from Google Drive
# First, upload your data.tar.gz to Google Drive, then update the path below

# Option 1: If you uploaded data.tar.gz to Drive
# !cp "/content/drive/MyDrive/data.tar.gz" .
# !tar -xzf data.tar.gz

# Option 2: If you uploaded the data folder directly to Drive
# Copy your dataset from Google Drive to Colab workspace
# IMPORTANT: Update the path below to match where you uploaded your data folder in Google Drive

# Option 1: If your data folder is in the root of MyDrive
!cp -r "/content/drive/MyDrive/data" .

# Option 2: If your data folder is in a subfolder (update path as needed)
# !cp -r "/content/drive/MyDrive/YourFolder/data" .

# Option 3: If you uploaded a compressed file
# !cp "/content/drive/MyDrive/data.tar.gz" .
# !tar -xzf data.tar.gz



In [4]:
# Verify dataset is copied successfully
print("Checking dataset structure...")
!ls -la data/
print("Dataset statistics:")
!echo "Train images:" && ls data/train/images/ | wc -l
!echo "Train masks:" && ls data/train/masks/ | wc -l
!echo "Val images:" && ls data/val/images/ | wc -l 2>/dev/null || echo "No val images found"
!echo "Val masks:" && ls data/val/masks/ | wc -l 2>/dev/null || echo "No val masks found"

Checking dataset structure...
total 24
drwx------ 5 root root 4096 Sep 14 17:13 .
drwxr-xr-x 8 root root 4096 Sep 14 17:11 ..
-rw------- 1 root root 2277 Sep 14 17:11 dataset_loader.py
drwx------ 6 root root 4096 Sep 14 17:12 test
drwx------ 6 root root 4096 Sep 14 17:18 train
drwx------ 6 root root 4096 Sep 14 17:13 val
Dataset statistics:
Train images:
3451
Train masks:
3436
Val images:
741
Val masks:
738


In [6]:
# ===== CELL 1: Setup and Dependencies =====
# Install all required dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install opencv-python pillow tqdm matplotlib seaborn
!pip install efficientnet-pytorch  # Required for GhanaSegNet backbone
!pip install segmentation-models-pytorch  # For DeepLabV3+ and other models

import torch
import os
print(f"🖥️  CUDA available: {torch.cuda.is_available()}")
print(f"🎮 GPU device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")
print(f"📊 PyTorch version: {torch.__version__}")

# Verify EfficientNet installation
try:
    from efficientnet_pytorch import EfficientNet
    print("✅ EfficientNet-PyTorch installed successfully")
except ImportError:
    print("❌ EfficientNet-PyTorch not found - installing...")
    !pip install efficientnet-pytorch
    from efficientnet_pytorch import EfficientNet
    print("✅ EfficientNet-PyTorch installed successfully")

Looking in indexes: https://download.pytorch.org/whl/cu118
🖥️  CUDA available: True
🎮 GPU device: NVIDIA A100-SXM4-40GB
📊 PyTorch version: 2.8.0+cu126
✅ EfficientNet-PyTorch installed successfully


In [7]:
# ===== CELL 2: Navigate to Project (after uploading) =====
# Upload your GhanaSegNet folder to Colab first!
os.chdir('/content/GhanaSegNet')
print(f"📁 Current directory: {os.getcwd()}")
print(f"📄 Files: {os.listdir('.')[:10]}")  # Show first 10 files

# Verify all required modules can be imported
print("\n🔍 Verifying project imports...")
try:
    from models.ghanasegnet import GhanaSegNet
    print("✅ GhanaSegNet model imported successfully")
except ImportError as e:
    print(f"❌ GhanaSegNet import failed: {e}")

try:
    from models.unet import UNet
    from models.deeplabv3plus import DeepLabV3Plus
    from models.segformer import SegFormerB0
    print("✅ All baseline models imported successfully")
except ImportError as e:
    print(f"❌ Baseline model import failed: {e}")

📁 Current directory: /content/GhanaSegNet
📄 Files: ['LICENSE', '.gitignore', 'notebooks', 'scripts', '.git', 'README.md', 'Pipfile', 'cleanup_repo.py', 'requirements.txt', 'GhanaSegNet_Colab.ipynb']

🔍 Verifying project imports...
✅ GhanaSegNet model imported successfully
✅ All baseline models imported successfully


In [None]:
Train GhanaSegNet
import subprocess

print("🚀 Starting GhanaSegNet training...")
print("🧠 Your novel hybrid CNN-Transformer architecture")
print("📊 Expected performance: >24% mIoU (based on previous results)")
subprocess.run(['python', 'scripts/train_baselines.py', '--model', 'ghanasegnet', '--epochs', '15'])


In [None]:
# ===== CELL 4: Train SegFormer (Best Transformer Baseline) =====
print("🚀 Starting SegFormer training...")
print("🤖 Pure Transformer architecture baseline")
print("📊 Expected performance: 18-23% mIoU")
subprocess.run(['python', 'scripts/train_baselines.py', '--model', 'segformer', '--epochs', '15'])


In [None]:
# ===== CELL 5: Train DeepLabV3+ (CNN State-of-the-art) =====
print("🚀 Starting DeepLabV3+ training...")
print("🔬 ResNet-50 backbone with atrous convolutions")
print("📊 Expected performance: 15-21% mIoU")
subprocess.run(['python', 'scripts/train_baselines.py', '--model', 'deeplabv3plus', '--epochs', '15'])

In [None]:
# ===== CELL 6: Train UNet (Medical Baseline) =====
print("🚀 Starting UNet training...")
print("🏥 Medical imaging architecture baseline")
print("📊 Expected performance: 12-18% mIoU")
subprocess.run(['python', 'scripts/train_baselines.py', '--model', 'unet', '--epochs', '15'])


 # Results Analysis

In [None]:

import json
import matplotlib.pyplot as plt
import numpy as np

# Load all results
models = ['ghanasegnet', 'segformer', 'deeplabv3plus', 'unet']
results = {}
print("📊 Loading results...")

for model in models:
    try:
        with open(f'checkpoints/{model}/{model}_results.json', 'r') as f:
            results[model] = json.load(f)
        print(f"✅ {model.upper()}: {results[model]['best_iou']*100:.2f}% mIoU")
    except FileNotFoundError:
        print(f"❌ {model.upper()}: Results not found")

# Create comprehensive comparison plot
if len(results) > 0:
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

    # Model names and colors
    model_names = [m.upper() for m in results.keys()]
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'][:len(results)]

    # 1. Performance comparison
    ious = [results[m]['best_iou'] * 100 for m in results.keys()]
    bars1 = ax1.bar(model_names, ious, color=colors)
    ax1.set_ylabel('mIoU (%)', fontsize=12)
    ax1.set_title('🏆 Model Performance Comparison', fontsize=14, fontweight='bold')
    ax1.set_ylim(0, max(ious) * 1.2)
    for i, v in enumerate(ious):
        ax1.text(i, v + 1, f'{v:.1f}%', ha='center', va='bottom', fontweight='bold')

    # 2. Parameter efficiency
    params = [results[m]['total_parameters'] / 1e6 for m in results.keys()]
    bars2 = ax2.bar(model_names, params, color=colors)
    ax2.set_ylabel('Parameters (Millions)', fontsize=12)
    ax2.set_title('⚡ Model Efficiency Comparison', fontsize=14, fontweight='bold')
    for i, v in enumerate(params):
        ax2.text(i, v + 0.5, f'{v:.1f}M', ha='center', va='bottom', fontweight='bold')

    # 3. Training epochs (final)
    epochs = [results[m]['final_epoch'] for m in results.keys()]
    bars3 = ax3.bar(model_names, epochs, color=colors)
    ax3.set_ylabel('Training Epochs', fontsize=12)
    ax3.set_title('🛑 Early Stopping Analysis', fontsize=14, fontweight='bold')
    for i, v in enumerate(epochs):
        ax3.text(i, v + 0.2, f'{v}', ha='center', va='bottom', fontweight='bold')

    # 4. Efficiency scatter plot
    ax4.scatter(params, ious, c=colors, s=200, alpha=0.7)
    for i, model in enumerate(model_names):
        ax4.annotate(model, (params[i], ious[i]), xytext=(5, 5),
                    textcoords='offset points', fontweight='bold')
    ax4.set_xlabel('Parameters (Millions)', fontsize=12)
    ax4.set_ylabel('mIoU (%)', fontsize=12)
    ax4.set_title('📈 Efficiency vs Performance', fontsize=14, fontweight='bold')
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Summary table
    print("\n" + "="*80)
    print("🏆 GHANASEGNET FAIR COMPARISON RESULTS")
    print("="*80)
    print(f"{'Model':<15} {'mIoU (%)':<10} {'Parameters':<12} {'Epochs':<8} {'Status'}")
    print("-"*80)

    best_iou = max(ious)
    for i, model in enumerate(results.keys()):
        status = "🥇 WINNER" if ious[i] == best_iou else ""
        print(f"{model.upper():<15} {ious[i]:<10.1f} {params[i]:<12.1f}M {epochs[i]:<8} {status}")

    print("-"*80)
    if len(ious) > 1:
        sorted_ious = sorted(ious, reverse=True)
        improvement = ((sorted_ious[0] - sorted_ious[1]) / sorted_ious[1]) * 100
        print(f"🚀 Best model improvement over 2nd best: {improvement:.1f}%")
    print(f"⚡ Most efficient model: {model_names[params.index(min(params))]} ({min(params):.1f}M params)")
    print("="*80)

# Download Results

In [None]:

# Create downloadable results package
import zipfile

def create_results_package():
    with zipfile.ZipFile('ghanasegnet_results.zip', 'w') as zipf:
        # Add model checkpoints
        for model in ['ghanasegnet', 'segformer', 'deeplabv3plus', 'unet']:
            try:
                zipf.write(f'checkpoints/{model}/best_{model}.pth')
                zipf.write(f'checkpoints/{model}/{model}_results.json')
                print(f"✅ Added {model} results to package")
            except FileNotFoundError:
                print(f"⚠️ {model} results not found")

        # Add training summary
        try:
            zipf.write('checkpoints/training_summary.json')
            print("✅ Added training summary")
        except FileNotFoundError:
            print("⚠️ Training summary not found")

        # Add updated training log
        try:
            zipf.write('Training_Results_Log.md')
            print("✅ Added training log")
        except FileNotFoundError:
            print("⚠️ Training log not found")

    print("📦 Results package created: ghanasegnet_results.zip")
    print("💾 Download this file to save your training results!")

create_results_package()

# Download command for Colab
from google.colab import files
files.download('ghanasegnet_results.zip')

print("🎉 Training completed! Check your downloads for the results package.")
print(f"📊 Total models trained: {len(results)}")
if results:
    best_model = max(results.keys(), key=lambda k: results[k]['best_iou'])
    print(f"🏆 Best performing model: {best_model.upper()} ({results[best_model]['best_iou']*100:.2f}% mIoU)")