# Dendritic YOLOv8 - Proper PAI Graph Generation

This notebook generates the **required PAI graph** using PerforatedAI's built-in graphing system.

## IMPORTANT: The PAI graph is generated automatically by the library!
- Do NOT create custom matplotlib graphs
- Use `add_extra_score()` for training scores (green line)
- Use `add_validation_score()` for validation scores (orange line)
- The graph will be saved to `PAI/PAI.png`

## Setup
1. **Runtime -> Change runtime type -> GPU (T4 or L4)**
2. **Run all cells in order**

In [None]:
# Cell 1: Install dependencies
!pip install -q ultralytics
!pip install -q perforatedai==3.0.7

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Cell 2: Imports
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from ultralytics import YOLO

# PerforatedAI imports - REQUIRED for proper graph generation
from perforatedai import globals_perforatedai as GPA
from perforatedai import utils_perforatedai as UPA

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Create PAI output directory
os.makedirs('PAI', exist_ok=True)

In [None]:
# Cell 3: Load model and initialize PerforatedAI
print("Loading YOLOv8n...")
yolo = YOLO('yolov8n.pt')
model = yolo.model

def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

baseline_params = count_params(model)
print(f"Baseline parameters: {baseline_params:,}")

# Configure PerforatedAI
GPA.pc.set_testing_dendrite_capacity(False)
GPA.pc.set_verbose(True)
GPA.pc.set_dendrite_update_mode(True)

# Initialize PAI - this enables dendrite optimization and graph generation
# save_name='PAI' means output goes to PAI/PAI.png
model = UPA.initialize_pai(
    model,
    save_name='PAI',
    maximizing_score=True,  # We maximize mAP
    making_graphs=True  # Enable graph generation
)

model = model.to(device)
yolo.model = model
print("PerforatedAI initialized!")

In [None]:
# Cell 4: Setup optimizer through PAI tracker
GPA.pai_tracker.set_optimizer(optim.Adam)
GPA.pai_tracker.set_scheduler(ReduceLROnPlateau)

lr = 0.001
optimArgs = {'params': model.parameters(), 'lr': lr}
schedArgs = {'mode': 'max', 'patience': 5}
optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)

print("Optimizer configured through PAI tracker")

In [None]:
# Cell 5: Training loop with PROPER PAI API calls
#
# CRITICAL for correct PAI graph:
# 1. add_extra_score(train_score, 'train') -> GREEN line
# 2. add_validation_score(val_score, model) -> ORANGE line + triggers dendrites
# 3. Loop until training_complete is True
#
# The PAI library automatically generates the graph at PAI/PAI.png

MAX_EPOCHS = 100  # Will stop early when PAI signals completion
DATA = 'coco128.yaml'
BATCH = 16
IMGSZ = 640

print(f"Starting training...")
print(f"Baseline parameters: {baseline_params:,}")
print("="*60)

epoch = 0
training_complete = False

while not training_complete and epoch < MAX_EPOCHS:
    print(f"\nEpoch {epoch + 1}/{MAX_EPOCHS}")
    
    # Train one epoch
    yolo.model = model
    results = yolo.train(
        data=DATA,
        epochs=1,
        imgsz=IMGSZ,
        batch=BATCH,
        device=device,
        exist_ok=True,
        verbose=False,
        project='runs/train',
        name='dendritic'
    )
    model = yolo.model
    
    # Get training score
    train_score = float(results.results_dict.get('metrics/mAP50(B)', 0))
    
    # CRITICAL: Add TRAINING score (creates GREEN line in PAI graph)
    GPA.pai_tracker.add_extra_score(train_score * 100, 'train')
    
    # Validate
    yolo.model = model
    val_results = yolo.val(verbose=False)
    val_score = float(val_results.box.map50)
    
    print(f"  Train mAP50: {train_score:.4f}")
    print(f"  Val mAP50:   {val_score:.4f}")
    
    # CRITICAL: Add VALIDATION score (creates ORANGE line, may trigger dendrites)
    model, restructured, training_complete = GPA.pai_tracker.add_validation_score(
        val_score * 100, model
    )
    model = model.to(device)
    yolo.model = model
    
    # If restructured, reset optimizer
    if restructured:
        print(f"\n>>> DENDRITES ADDED! <<<")
        print(f"    Parameters: {count_params(model):,}")
        optimArgs['params'] = model.parameters()
        optimizer, scheduler = GPA.pai_tracker.setup_optimizer(model, optimArgs, schedArgs)
    
    if training_complete:
        print("\n" + "="*60)
        print("TRAINING COMPLETE!")
        print("="*60)
    
    epoch += 1

print(f"\nFinished after {epoch} epochs")

In [None]:
# Cell 6: Save the PAI graphs
# This generates the REQUIRED graph at PAI/PAI.png
print("Saving PAI graphs...")
try:
    GPA.pai_tracker.save_graphs()
    print("SUCCESS! Graphs saved to PAI/PAI.png")
except Exception as e:
    print(f"Error: {e}")

# Final results
final_params = count_params(model)
yolo.model = model
final_val = yolo.val(verbose=False)

print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"Baseline parameters:  {baseline_params:,}")
print(f"Final parameters:     {final_params:,}")
print(f"Parameter change:     {((final_params - baseline_params) / baseline_params) * 100:+.1f}%")
print(f"Final mAP@0.5:        {final_val.box.map50:.4f}")
print("="*60)

In [None]:
# Cell 7: Display the PAI graph (this is the REQUIRED output)
from IPython.display import Image, display

if os.path.exists('PAI/PAI.png'):
    print("PAI Output Graph (REQUIRED for submission):")
    print()
    display(Image('PAI/PAI.png'))
    print("\nThis graph shows:")
    print("- Green line: Training scores")
    print("- Orange line: Validation scores")
    print("- Vertical bars: Dendrite addition epochs")
    print("- Blue/Red lines: What would have happened without dendrites")
else:
    print("ERROR: PAI/PAI.png not found!")
    print("Make sure training completed and save_graphs() was called.")

In [None]:
# Cell 8: Download the PAI graph
try:
    from google.colab import files
    
    if os.path.exists('PAI/PAI.png'):
        files.download('PAI/PAI.png')
        print("Downloaded PAI/PAI.png")
    
    # Also download any CSVs
    for f in os.listdir('PAI'):
        if f.endswith('.csv'):
            files.download(f'PAI/{f}')
            print(f"Downloaded PAI/{f}")
except:
    print("Files are in PAI/ directory")