# Dendritic YOLOv8: Proper PAI Graph Output

This notebook produces the **correct graph format** as shown in Dendrite Recommendations.pdf:
- Green line: Training scores
- Orange line: Validation scores  
- Blue/Red lines: What would have happened without dendrites
- Vertical blue bars: Where dendrites were added

## Setup
1. **Runtime → Change runtime type → GPU**
2. **Run cells in order**

In [None]:
# Cell 1: Install dependencies
import sys
print(f"Python: {sys.version}")

IN_COLAB = 'google.colab' in sys.modules
print(f"In Colab: {IN_COLAB}")

if IN_COLAB:
    print("Installing dependencies...")
    !pip install -q ultralytics matplotlib pandas
    !pip install -q perforatedai==3.0.7
    print("Done!")

In [None]:
# Cell 2: Imports and GPU check
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import os
import json
from ultralytics import YOLO
from ultralytics.data import build_dataloader, build_yolo_dataset

# PerforatedAI imports
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')
    print("Using CPU")

print(f"\nDevice: {device}")

In [None]:
# Cell 3: Helper functions
def count_parameters(model):
    """Count total trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_yolo_dataloader(yolo_model, data_yaml='coco128.yaml', batch_size=16, mode='train'):
    """
    Get a PyTorch dataloader from YOLO's data loading system.
    """
    # Use YOLO's built-in data loading
    from ultralytics.data.utils import check_det_dataset
    data_dict = check_det_dataset(data_yaml)
    
    if mode == 'train':
        dataset = build_yolo_dataset(
            yolo_model.model,
            data_dict['train'],
            batch=batch_size,
            data=data_dict,
            mode='train',
            rect=False
        )
    else:
        dataset = build_yolo_dataset(
            yolo_model.model,
            data_dict['val'],
            batch=batch_size,
            data=data_dict,
            mode='val',
            rect=True
        )
    
    dataloader = build_dataloader(
        dataset,
        batch=batch_size,
        workers=4,
        shuffle=(mode == 'train')
    )
    return dataloader, data_dict

print("Helper functions ready")

In [None]:
# Cell 4: Load YOLOv8n and initialize PerforatedAI

# Load base model
print("Loading YOLOv8n...")
yolo = YOLO('yolov8n.pt')
model = yolo.model

baseline_params = count_parameters(model)
print(f"Baseline parameters: {baseline_params / 1e6:.2f}M")

# Configure PerforatedAI
print("\nConfiguring PerforatedAI...")
GPA.pc.set_testing_dendrite_capacity(False)
GPA.pc.set_verbose(True)
GPA.pc.set_dendrite_update_mode(True)

# Create PAI output directory
os.makedirs('PAI', exist_ok=True)

# Initialize PAI - this sets up tracking and enables dendrite addition
# IMPORTANT: save_name determines where the graph is saved (PAI/save_name.png)
model = UPA.initialize_pai(
    model,
    doing_pai=True,
    save_name='PAI',  # Output will be PAI/PAI.png
    maximizing_score=True,  # We're maximizing mAP, not minimizing loss
    making_graphs=True
)

model = model.to(device)
yolo.model = model

dendritic_params = count_parameters(model)
print(f"\nDendritic parameters (initial): {dendritic_params / 1e6:.2f}M")

In [None]:
# Cell 5: Setup optimizer through PAI tracker

# PAI needs to manage the optimizer to reset it after restructuring
GPA.pai_tracker.set_optimizer(optim.Adam)
GPA.pai_tracker.set_scheduler(ReduceLROnPlateau)

lr = 0.001
optimArgs = {'params': model.parameters(), 'lr': lr}
schedArgs = {'mode': 'max', 'patience': 3, 'factor': 0.5}

optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

print("Optimizer and scheduler configured through PAI tracker")

In [None]:
# Cell 6: Training functions that track scores properly

def train_one_epoch(model, yolo, device):
    """
    Train for one epoch and return training accuracy/score.
    Uses YOLO's built-in training for one epoch.
    """
    # Use YOLO's training mechanism for one epoch
    yolo.model = model
    
    # Train for 1 epoch
    results = yolo.train(
        data='coco128.yaml',
        epochs=1,
        imgsz=640,
        batch=16,
        device=device,
        exist_ok=True,
        verbose=False,
        project='runs/train',
        name='dendritic'
    )
    
    model = yolo.model
    
    # Get training metrics (mAP50 from training)
    # Note: YOLO computes validation metrics during training
    train_map50 = float(results.results_dict.get('metrics/mAP50(B)', 0))
    
    return model, train_map50


def validate_model(model, yolo, device):
    """
    Validate model and return mAP50 score.
    """
    model.eval()
    yolo.model = model
    
    results = yolo.val(
        data='coco128.yaml',
        device=device,
        verbose=False
    )
    
    val_map50 = float(results.box.map50) if results.box.map50 else 0.0
    return val_map50

print("Training functions defined")

In [None]:
# Cell 7: Main training loop with proper PAI integration
#
# KEY POINTS for correct graph output:
# 1. Call add_extra_score(train_score, 'train') after each training epoch
# 2. Call add_validation_score(val_score, model) after validation
# 3. Continue training until training_complete is True
# 4. Reset optimizer when model is restructured
#
# The PAI tracker automatically:
# - Tracks training and validation scores
# - Detects plateaus and adds dendrites
# - Saves graphs to PAI/PAI.png showing the characteristic pattern

MAX_EPOCHS = 100  # Will stop early when PAI says training is complete

print("Starting training with PerforatedAI integration...")
print("="*60)

epoch = 0
training_complete = False
history = {'train_scores': [], 'val_scores': [], 'params': []}

while not training_complete and epoch < MAX_EPOCHS:
    print(f"\nEpoch {epoch + 1}/{MAX_EPOCHS}")
    
    # STEP 1: Train one epoch
    model, train_score = train_one_epoch(model, yolo, device)
    
    # STEP 2: Add TRAINING score to PAI tracker
    # This creates the green line in the graph
    GPA.pai_tracker.add_extra_score(train_score * 100, 'train')  # Convert to percentage
    
    # STEP 3: Validate
    val_score = validate_model(model, yolo, device)
    
    # STEP 4: Add VALIDATION score to PAI tracker
    # This creates the orange line and may trigger dendrite addition
    model, restructured, training_complete = GPA.pai_tracker.add_validation_score(
        val_score * 100,  # Convert to percentage
        model
    )
    
    # Move model back to device after potential restructuring
    model = model.to(device)
    yolo.model = model
    
    # STEP 5: If model was restructured (dendrites added), reset optimizer
    if restructured:
        print(">>> MODEL RESTRUCTURED - Dendrites added! <<<")
        optimArgs['params'] = model.parameters()
        optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)
        
        current_params = count_parameters(model)
        print(f"    New parameter count: {current_params / 1e6:.2f}M")
    
    # Log progress
    current_params = count_parameters(model)
    history['train_scores'].append(train_score)
    history['val_scores'].append(val_score)
    history['params'].append(current_params)
    
    print(f"  Train mAP50: {train_score:.4f}")
    print(f"  Val mAP50:   {val_score:.4f}")
    print(f"  Parameters:  {current_params / 1e6:.2f}M")
    
    if training_complete:
        print("\n" + "="*60)
        print("TRAINING COMPLETE - PerforatedAI optimization finished!")
        print("="*60)
    
    epoch += 1

print(f"\nFinished after {epoch} epochs")

In [None]:
# Cell 8: Save final graphs and results
#
# The PAI tracker automatically saves graphs, but we can also call it manually
# The output will be at PAI/PAI.png (based on save_name)

print("Saving final graphs...")

# PAI's save_graphs creates the multi-panel output matching Dendrite Recommendations
try:
    GPA.pai_tracker.save_graphs()
    print("Graphs saved to PAI/PAI.png")
except Exception as e:
    print(f"Note: Graph saving encountered: {e}")

# Final results
final_params = count_parameters(model)
final_val_score = validate_model(model, yolo, device)

print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"Baseline parameters:  {baseline_params / 1e6:.2f}M")
print(f"Final parameters:     {final_params / 1e6:.2f}M")
print(f"Parameter change:     {((final_params - baseline_params) / baseline_params) * 100:+.1f}%")
print(f"Final validation mAP50: {final_val_score:.4f}")
print("="*60)

# Save results JSON
results = {
    'baseline': {
        'params_M': baseline_params / 1e6
    },
    'dendritic': {
        'params_M': final_params / 1e6,
        'final_val_mAP50': final_val_score
    },
    'epochs_trained': epoch,
    'history': {
        'train_scores': history['train_scores'],
        'val_scores': history['val_scores'],
        'params': [p / 1e6 for p in history['params']]
    }
}

with open('PAI/results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("\nResults saved to PAI/results.json")
print("Graph output at PAI/PAI.png")

In [None]:
# Cell 9: Display the generated graph
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

graph_path = 'PAI/PAI.png'

if os.path.exists(graph_path):
    print("Displaying PAI output graph:")
    print("(This should match the format in Dendrite Recommendations.pdf)")
    print()
    
    img = mpimg.imread(graph_path)
    plt.figure(figsize=(20, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    print("\nGraph interpretation:")
    print("- Green line: Training scores")
    print("- Orange line: Validation scores")
    print("- Vertical blue/red bars: Epochs where dendrites were added")
    print("- Blue/Red continuation lines: What would have happened without dendrites")
else:
    print(f"Graph not found at {graph_path}")
    print("This may mean training didn't complete or PAI wasn't configured correctly.")

In [None]:
# Cell 10: Download files
if IN_COLAB:
    from google.colab import files
    
    print("Downloading output files...")
    
    if os.path.exists('PAI/PAI.png'):
        files.download('PAI/PAI.png')
    
    if os.path.exists('PAI/results.json'):
        files.download('PAI/results.json')
    
    # Also download any CSV files PAI generates
    for f in os.listdir('PAI'):
        if f.endswith('.csv'):
            files.download(f'PAI/{f}')
    
    print("Done! Check your downloads.")
else:
    print("Files are available in the PAI/ directory")