# PE-YOLO v2: YOLOv8n + PENet with Adaptive Fusion

**Changes from v1:**
- No beta residual scaling ‚Äî PENet output goes directly to YOLO (matching author)
- Learnable AdaptiveFusion for HF/LF merging (improvement over author's concat+conv)
- GPU-native Sobel (author uses CPU cv2 loop)
- 3ch‚Üí3ch output ‚Äî ALL pretrained weights load perfectly (no shape mismatch)

**Training Strategy:**
- Phase 1: Train PENet only (YOLO frozen), 30 epochs
- Phase 2: Full end-to-end fine-tuning, 80 epochs
- Phase 3: (Optional) Aggressive fine-tuning, 40 epochs

In [None]:
!pip install ultralytics --quiet

In [None]:
# Cell 2: Register PENet v2 Wrapper in Ultralytics
import sys
sys.path.insert(0, 'penet_v2')  # Add penet_v2 folder to path

from PENet_v2 import PENetWrapper
import ultralytics.nn.tasks as _tasks
import ultralytics.nn.modules as _modules

_tasks.__dict__["PENetWrapper"] = PENetWrapper
_modules.__dict__["PENetWrapper"] = PENetWrapper
print("‚úÖ PENetWrapper v2 registered in ultralytics.")

In [None]:
# Cell 3: Build model, load pretrained weights, SAVE checkpoint, then validate
#
# KEY ADVANTAGE of v2: PENet outputs 3ch ‚Üí YOLO Conv0 input is still 3ch
# ‚Üí ALL 355 pretrained weights load with matching shapes (zero skipped!)

import torch, copy
from ultralytics import YOLO

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# 1) Build model from v2 YAML
model = YOLO("penet_v2/yolov8-penet-v2.yaml")
model.model.to(device)

# Count PENet params
penet_params = sum(p.numel() for p in model.model.model[0].parameters())
total_params = sum(p.numel() for p in model.model.parameters())
print(f"PENet v2 params: {penet_params:,} ({100*penet_params/total_params:.1f}% of total {total_params:,})")

# 2) Load pretrained YOLOv8n weights with +1 index shift
PRETRAINED_PATH = "runs/detect/train12/weights/best.pt"
pretrained = torch.load(PRETRAINED_PATH, weights_only=False, map_location=device)
pretrained_sd = pretrained["model"].state_dict()
model_sd = model.model.state_dict()

shifted = {}
skipped = []
for key, val in pretrained_sd.items():
    if key.startswith("model."):
        parts = key.split(".")
        if parts[1].isdigit():
            new_key = f"model.{int(parts[1]) + 1}.{'.'.join(parts[2:])}"
            if new_key in model_sd and model_sd[new_key].shape == val.shape:
                shifted[new_key] = val
            elif new_key in model_sd:
                skipped.append(f"{new_key}: {val.shape} ‚Üí {model_sd[new_key].shape}")

model.model.load_state_dict(shifted, strict=False)
print(f"\n‚úÖ Loaded {len(shifted)} pretrained weights (shifted +1 for PENet)")
if skipped:
    print(f"‚ö†Ô∏è  Skipped {len(skipped)} shape-mismatched weights:")
    for s in skipped:
        print(f"   {s}")
else:
    print("‚úÖ Zero shape mismatches ‚Äî all YOLO weights loaded perfectly!")

# 3) SAVE BEFORE val() (val fuses Conv+BN in-place, destroys BN weights)
INIT_CKPT = "penet_v2/penet_v2_yolov8n_init.pt"
torch.save({
    "model": copy.deepcopy(model.model).half(),
    "train_args": {
        "task": "detect",
        "data": "exdark.yaml",
        "imgsz": 640,
        "model": "penet_v2/yolov8-penet-v2.yaml",
    },
}, INIT_CKPT)
print(f"‚úÖ Saved UNFUSED model to {INIT_CKPT}")

# 4) Validate to check starting mAP
# NOTE: mAP will be LOWER than baseline because PENet has random weights
# and is distorting the image. This is expected!
print("\n‚îÄ‚îÄ Validation (PENet has random weights, expect lower mAP) ‚îÄ‚îÄ")
metrics = model.val(data="exdark.yaml", imgsz=640, batch=48, workers=8)
print(f"mAP50:    {metrics.box.map50:.4f}  (baseline was 0.669)")
print(f"mAP50-95: {metrics.box.map:.4f}")

In [None]:
# Cell 4: PHASE 1 ‚Äî Train PENet only (YOLO frozen)
# PENet learns to enhance images in a way that helps YOLO detect

import sys
sys.path.insert(0, 'penet_v2')

from PENet_v2 import PENetWrapper
import ultralytics.nn.tasks as _tasks
import ultralytics.nn.modules as _modules
_tasks.__dict__["PENetWrapper"] = PENetWrapper
_modules.__dict__["PENetWrapper"] = PENetWrapper

import torch
from ultralytics import YOLO

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model = YOLO("penet_v2/penet_v2_yolov8n_init.pt")

# Freeze everything except PENet (model.0.*)
for name, param in model.model.named_parameters():
    if name.startswith("model.0"):
        param.requires_grad = True
    else:
        param.requires_grad = False

trainable = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.model.parameters())
print(f"Phase 1 ‚Äî trainable: {trainable:,} / {total:,} params")

results_p1 = model.train(
    data="exdark.yaml",
    epochs=30,
    imgsz=640,
    batch=48,
    lr0=1e-3,
    lrf=0.1,
    cos_lr=True,
    warmup_epochs=3,
    optimizer="AdamW",
    weight_decay=0.01,
    mosaic=0.5,
    mixup=0.0,
    hsv_h=0.015,
    hsv_s=0.4,
    hsv_v=0.3,
    workers=8,
    cache="ram",
    amp=True,
    pretrained=False,
    project="runs/penet_v2",
    name="phase1",
)

In [None]:
# Cell 5: Quick diagnostic ‚Äî what did PENet learn?

import sys
sys.path.insert(0, 'penet_v2')

from PENet_v2 import PENetWrapper
import ultralytics.nn.tasks as _tasks
import ultralytics.nn.modules as _modules
_tasks.__dict__["PENetWrapper"] = PENetWrapper
_modules.__dict__["PENetWrapper"] = PENetWrapper

import torch
from ultralytics import YOLO

model = YOLO("runs/penet_v2/phase1/weights/best.pt")
penet_wrapper = model.model.model[0]
device = next(model.model.parameters()).device

# Test with random input
dummy = torch.rand(1, 3, 640, 640, device=device)
with torch.no_grad():
    enhanced = penet_wrapper(dummy)

diff = (enhanced - dummy).abs()
print(f"--- PENet v2 Enhancement Analysis ---")
print(f"Input  ‚Äî mean: {dummy.mean():.4f}, std: {dummy.std():.4f}")
print(f"Output ‚Äî mean: {enhanced.mean():.4f}, std: {enhanced.std():.4f}")
print(f"Mean pixel change:  {diff.mean():.6f}")
print(f"Max pixel change:   {diff.max():.6f}")
print(f"% pixels changed > 0.01: {(diff > 0.01).float().mean() * 100:.1f}%")
print(f"% pixels changed > 0.05: {(diff > 0.05).float().mean() * 100:.1f}%")

if diff.mean() < 1e-4:
    print("\n‚ö†Ô∏è  PENet is near-identity ‚Äî Phase 2 will push it further")
elif diff.mean() < 0.01:
    print(f"\n‚úÖ Subtle enhancement (mean change: {diff.mean():.4f})")
else:
    print(f"\nüî• Strong enhancement (mean change: {diff.mean():.4f}) ‚Äî PENet is active!")

# Check adaptive fusion gate statistics
penet = penet_wrapper.model
for i, ae in enumerate(penet.AEs):
    gate_weight = ae.adaptive_fusion.gate[-2].weight  # last conv before sigmoid
    gate_bias = ae.adaptive_fusion.gate[-2].bias
    print(f"AE_{i} fusion gate ‚Äî bias mean: {gate_bias.mean():.4f}, weight std: {gate_weight.std():.4f}")

In [None]:
# Cell 6: PHASE 2 ‚Äî Full end-to-end fine-tuning

import sys
sys.path.insert(0, 'penet_v2')

from PENet_v2 import PENetWrapper
import ultralytics.nn.tasks as _tasks
import ultralytics.nn.modules as _modules
_tasks.__dict__["PENetWrapper"] = PENetWrapper
_modules.__dict__["PENetWrapper"] = PENetWrapper

import torch
from ultralytics import YOLO

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model = YOLO("runs/penet_v2/phase1/weights/best.pt")

# Unfreeze everything
for param in model.model.parameters():
    param.requires_grad = True

trainable = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
print(f"Phase 2 ‚Äî all {trainable:,} params trainable")

results_p2 = model.train(
    data="exdark.yaml",
    epochs=80,
    imgsz=640,
    batch=48,
    lr0=5e-4,
    lrf=0.01,
    cos_lr=True,
    warmup_epochs=3,
    optimizer="AdamW",
    weight_decay=0.005,
    mosaic=0.5,
    mixup=0.1,
    hsv_h=0.015,
    hsv_s=0.4,
    hsv_v=0.3,
    degrees=5.0,
    translate=0.1,
    scale=0.3,
    workers=8,
    cache="ram",
    amp=True,
    pretrained=False,
    patience=20,
    project="runs/penet_v2",
    name="phase2",
)

In [None]:
# Cell 7: Evaluate final model

import sys
sys.path.insert(0, 'penet_v2')

from PENet_v2 import PENetWrapper
import ultralytics.nn.tasks as _tasks
import ultralytics.nn.modules as _modules
_tasks.__dict__["PENetWrapper"] = PENetWrapper
_modules.__dict__["PENetWrapper"] = PENetWrapper

from ultralytics import YOLO

final = YOLO("runs/penet_v2/phase2/weights/best.pt")
metrics = final.val(data="exdark.yaml", imgsz=640, batch=48, workers=8)
print(f"\n{'='*50}")
print(f"PE-YOLO v2 Final Results")
print(f"{'='*50}")
print(f"mAP50:    {metrics.box.map50:.4f}  (baseline: 0.669)")
print(f"mAP50-95: {metrics.box.map:.4f}  (baseline: 0.420)")
print(f"{'='*50}")
if metrics.box.map50 > 0.669:
    print(f"üéâ PENet v2 IMPROVED over baseline by +{(metrics.box.map50 - 0.669)*100:.1f}%")
else:
    print(f"‚ö†Ô∏è  Below baseline by {(0.669 - metrics.box.map50)*100:.1f}%")

In [None]:
# Cell 8: Visualize PENet v2 enhancement on real dark images

import sys
sys.path.insert(0, 'penet_v2')

from PENet_v2 import PENetWrapper
import ultralytics.nn.tasks as _tasks
import ultralytics.nn.modules as _modules
_tasks.__dict__["PENetWrapper"] = PENetWrapper
_modules.__dict__["PENetWrapper"] = PENetWrapper

import torch, cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ultralytics import YOLO

# Load best model (change path to phase1/phase2 as needed)
model = YOLO("runs/penet_v2/phase2/weights/best.pt")
penet_wrapper = model.model.model[0]
device = next(model.model.parameters()).device

# Find images
img_dir = Path("images/val")
img_paths = sorted(p for p in img_dir.rglob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp"})
sample_paths = img_paths[:6]
print(f"Found {len(img_paths)} images, showing {len(sample_paths)}")

fig, axes = plt.subplots(len(sample_paths), 3, figsize=(15, 5 * len(sample_paths)))
if len(sample_paths) == 1:
    axes = axes[np.newaxis, :]

for idx, img_path in enumerate(sample_paths):
    img_bgr = cv2.imread(str(img_path))
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img_rgb, (640, 640))
    img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
    img_tensor = img_tensor.to(device)

    with torch.no_grad():
        enhanced_tensor = penet_wrapper(img_tensor)

    diff_tensor = (enhanced_tensor - img_tensor).abs()
    original = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
    enhanced = enhanced_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1)
    diff = diff_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
    diff_vis = (diff * 10).clip(0, 1)

    axes[idx, 0].imshow(original); axes[idx, 0].set_title(f"Original: {img_path.name}", fontsize=10); axes[idx, 0].axis("off")
    axes[idx, 1].imshow(enhanced); axes[idx, 1].set_title("PENet v2 Enhanced", fontsize=10); axes[idx, 1].axis("off")
    axes[idx, 2].imshow(diff_vis); axes[idx, 2].set_title(f"Difference (10x) | mean: {diff.mean():.4f}", fontsize=10); axes[idx, 2].axis("off")

plt.suptitle("PENet v2 Enhancement (Adaptive Fusion)", fontsize=16, fontweight="bold")
plt.tight_layout()
plt.savefig("penet_v2/penet_v2_visualization.png", dpi=150, bbox_inches="tight")
plt.show()
print("‚úÖ Saved visualization")

In [None]:
# Cell 9: (Optional) PHASE 3 ‚Äî Aggressive fine-tuning

import sys
sys.path.insert(0, 'penet_v2')

from PENet_v2 import PENetWrapper
import ultralytics.nn.tasks as _tasks
import ultralytics.nn.modules as _modules
_tasks.__dict__["PENetWrapper"] = PENetWrapper
_modules.__dict__["PENetWrapper"] = PENetWrapper

import torch
from ultralytics import YOLO

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model = YOLO("runs/penet_v2/phase2/weights/best.pt")
for param in model.model.parameters():
    param.requires_grad = True

results_p3 = model.train(
    data="exdark.yaml",
    epochs=40,
    imgsz=640,
    batch=48,
    lr0=1e-4,
    lrf=0.01,
    cos_lr=True,
    warmup_epochs=2,
    optimizer="AdamW",
    weight_decay=0.001,
    mosaic=1.0,
    mixup=0.2,
    copy_paste=0.1,
    hsv_h=0.02,
    hsv_s=0.5,
    hsv_v=0.4,
    degrees=10.0,
    translate=0.15,
    scale=0.5,
    workers=8,
    cache="ram",
    amp=True,
    pretrained=False,
    patience=15,
    project="runs/penet_v2",
    name="phase3",
)