# Training a UNet Model for Solar Radio Burst Segmentation using PyTorch

In this notebook, we demonstrate how to train a UNet model for segmenting solar radio bursts using transfer learning. 
The training is split into two phases:

1. **Phase 1:** Freeze the encoder (pre-trained on ImageNet) and train only the decoder.  
2. **Phase 2:** Unfreeze the encoder and fine-tune the entire model using a lower learning rate.

We use a combined loss function consisting of binary cross-entropy (BCE) loss and a Jaccard (IOU) loss, and we monitor the IOU and F1 metrics on a validation set.

## 🎯 Enhanced Loss Function Setup

This notebook now supports the new enhanced loss function system with:
- **Focal Loss**: For handling class imbalance (radio bursts are rare)
- **Boundary Loss**: For improving edge detection accuracy
- **Adaptive IoU Loss**: For better overlap quality assessment

The enhanced loss functions are automatically optimized for solar radio burst detection!


In [1]:

import os
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

# 🎯 Import enhanced loss function utilities  
# from train_utils import (
#     create_dataset, build_unet, freeze_encoder_weights, unfreeze_encoder_weights,
#     combined_loss, simple_combined_loss, focal_loss, boundary_loss, adaptive_iou_loss,
#     compute_metrics, train_one_epoch, validate_one_epoch, adjust_learning_rate, 
#     save_checkpoint, train_model
# )
# from loss_config_example import get_loss_config_for_scenario
# from loss_tuner import LossTuner, quick_tune

# print("🎯 Enhanced loss functions loaded successfully!")

# # Show available loss configurations
# configs = ["balanced", "imbalanced", "noisy", "boundary_critical"]
# print("📋 Available loss configurations:")
# for config in configs:
#     desc = get_loss_config_for_scenario(config)["description"]
#     print(f"  • {config}: {desc}")

from train_utils_old import (
    create_dataset,
    build_unet,
    build_deeplabv3,      
    build_deeplabv3_rgb_to_mono,
    build_model_by_name,
    train_model,          
    simple_combined_loss
)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


2. Data Loading and Preprocessing

We assume that image and mask CSV files are stored in a single directory.
The naming convention is as follows:

- For burst slices:
       slice_20240608_y155_x270406_SkylineHS.csv
       slice_20240608_y155_x270406_SkylineHS_mask.csv

- For non-burst slices:
       slice_20240420_y0_x3391_PeachMountain_2020_nonburst.csv

The `create_dataset` function reads the files, normalizes them to [0, 1], and splits the data into training and validation sets.

In [2]:
data_dir = '/Users/remiliascarlet/Desktop/MDP/transfer_learning/burst_data/csv/saved_slices/25spring'

(train_images, train_masks), (val_images, val_masks) = create_dataset(data_dir, img_size=(256, 256), test_size=0.2, random_state=42)

print("Training images shape:", train_images.shape)
print("Validation images shape:", val_images.shape)

Training images shape: (1171, 256, 256, 1)
Validation images shape: (293, 256, 256, 1)


3. Create DataLoaders

Convert the numpy arrays into PyTorch tensors and create DataLoaders.

We also need to permute dimensions from (N, H, W, C) to (N, C, H, W).

In [3]:
import torch

# Convert numpy arrays to Torch tensors and permute to (N, channels, H, W)
train_images_tensor = torch.tensor(train_images).permute(0, 3, 1, 2)
train_masks_tensor  = torch.tensor(train_masks).permute(0, 3, 1, 2)

val_images_tensor = torch.tensor(val_images).permute(0, 3, 1, 2)
val_masks_tensor  = torch.tensor(val_masks).permute(0, 3, 1, 2)

# Create TensorDatasets
train_dataset = TensorDataset(train_images_tensor, train_masks_tensor)
val_dataset   = TensorDataset(val_images_tensor, val_masks_tensor)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

4. Model Construction

Build a UNet model using the `build_unet` function from the segmentation_models library.

Here we set `input_shape=(256,256,1)` for grayscale images, `num_classes=1` for binary segmentation,
and use pre-trained ImageNet weights with a ResNet34 backbone.

In [4]:
# Import segmentation_models_pytorch library if not already installed: pip install segmentation-models-pytorch
# model = build_unet(input_shape=(256, 256, 1), num_classes=1, encoder_weights='imagenet')
# model.to(device)

# Build UNet from scratch (no ImageNet)
# model = build_unet(
#     input_shape=(256, 256, 1), 
#     num_classes=1, 
#     encoder_weights=None        # 🔥 Key: No ImageNet pretraining
# )

# model.to(device)

# model = build_deeplabv3(
#     input_shape=(256, 256, 1), 
#     num_classes=1, 
#     encoder_weights=None,        # 🔥 Key: No ImageNet pretraining
#     encoder_name='resnet34'      # Can change to 'resnet18' for faster training
# )

# model.to(device)

model = build_deeplabv3_rgb_to_mono(
    input_shape=(256, 256, 1),
    num_classes=1,
    conversion_method='luminance',  # Use human perception weights
    encoder_name='resnet34'         # Can try 'resnet18' for faster training
)

model.to(device)

print(model)

🚀 Building DeepLabV3+ with RGB-to-mono conversion
   Conversion method: luminance
   Encoder: resnet34
📥 Loading ImageNet pretrained RGB model...
🎯 Creating target single channel model...
🔄 Converting RGB first layer weights using 'luminance' method...
   RGB weights shape: torch.Size([64, 3, 7, 7])
   Applied luminance-weighted conversion
   Converted weights shape: torch.Size([64, 1, 7, 7])
🔗 Transferring weights...
   ✅ Converted: encoder.conv1.weight
📊 Weight transfer summary:
   ✅ Converted layers: 1
   ✅ Transferred layers: 276
   ⚠️ Skipped layers: 0
🔍 Validating converted model...
✅ Validation successful:
   Input shape: torch.Size([1, 1, 256, 256])
   Output shape: torch.Size([1, 1, 256, 256])
   Output range: [0.4927, 0.5008]
   Total parameters: 22,431,185
🎉 RGB-to-mono conversion completed successfully!
   Model ready for training on radio burst data
DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias

# Simple loss function

In [None]:
# 🚀 DeepLabV3+ Training with Simple Loss Function
print("🚀 Starting DeepLabV3+ training with simple loss...")
print("   Using BCE + IoU loss (no complex focal/boundary loss)")
print("   This focuses on pure architecture comparison")

# Simple training configuration for DeepLabV3+
training_config_simple = {
    'initial_lr': 1e-3,       # Standard learning rate
    'freeze_epochs': 50,       # No freezing needed (no ImageNet pretraining)
    'total_epochs': 100,       # Can be increased for better results
    'patience': 5,           
    'checkpoint_dir': './checkpoints_deeplabv3_rgb_convert'
}

print("✅ Simple configuration ready:")
print(f"   Learning rate: {training_config_simple['initial_lr']}")
print(f"   Training epochs: {training_config_simple['total_epochs']}")
print(f"   No encoder freezing (training from scratch)")
print(f"   Using simple BCE + IoU loss")

# Train the model with simple loss function
trained_model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    initial_lr=training_config_simple['initial_lr'],
    freeze_epochs=training_config_simple['freeze_epochs'],  # 0 for DeepLabV3+
    total_epochs=training_config_simple['total_epochs'],
    checkpoint_dir=training_config_simple['checkpoint_dir'],
    patience=training_config_simple['patience'],
    device=device
    # 注意：没有 loss_weights, focal_params 等复杂参数
    # 自动使用 simple_combined_loss (BCE + IoU)
)

print("✅ Training completed!")

# 📊 Quick Results Analysis
if history:
    final_metrics = history[-1]
    final_f1 = final_metrics.get('val_f1', 0)
    final_iou = final_metrics.get('val_iou', 0)
    
    print(f"\n🎯 DeepLabV3+ + Simple Loss Results:")
    print(f"   Final F1 Score: {final_f1:.4f}")
    print(f"   Final IoU Score: {final_iou:.4f}")
    
    # Compare with your baseline
    baseline_f1 = 0.3  # Your UNet result
    improvement = final_f1 - baseline_f1
    
    print(f"\n📈 vs UNet+ImageNet baseline:")
    print(f"   Baseline F1: {baseline_f1:.4f}")
    print(f"   DeepLabV3+ F1: {final_f1:.4f}")
    print(f"   Improvement: +{improvement:.4f} ({(improvement/baseline_f1)*100:.1f}%)")
    
    if final_f1 > 0.4:
        print("   🎉 Excellent! Clear architecture improvement")
    elif final_f1 > 0.35:
        print("   👍 Good improvement! DeepLabV3+ is working")
    else:
        print("   🤔 Need investigation - try different backbone or more epochs")
else:
    print("❌ No training history available")

🚀 Starting DeepLabV3+ training with simple loss...
   Using BCE + IoU loss (no complex focal/boundary loss)
   This focuses on pure architecture comparison
✅ Simple configuration ready:
   Learning rate: 0.001
   Training epochs: 100
   No encoder freezing (training from scratch)
   Using simple BCE + IoU loss
🔧 Using SIMPLE loss function (BCE + IoU only)
   No focal loss, boundary loss, or other enhancements
   Good for clean model architecture comparison
Phase 1: Freezing encoder and training decoder for 50 epochs.
Epoch 1/50 - Train Loss: 1.1220 - Val Loss: 1.0423 - IOU: 0.2383 - F1: 0.2862
Checkpoint saved at epoch 1 with best metric 1.0423162893245095 -> ./checkpoints_deeplabv3_rgb_convert/checkpoint_epoch_1_metric_1.0423.pth
Epoch 2/50 - Train Loss: 1.0268 - Val Loss: 1.0102 - IOU: 0.2437 - F1: 0.2895
Checkpoint saved at epoch 2 with best metric 1.0102182313015586 -> ./checkpoints_deeplabv3_rgb_convert/checkpoint_epoch_2_metric_1.0102.pth
Epoch 3/50 - Train Loss: 0.9725 - Val Los

## 🔧 Enhanced Training with Loss Function Tuning

The training functions now automatically use enhanced loss functions optimized for radio burst detection. You can also customize the loss parameters based on your specific needs.


In [None]:
# 🎯 Loss Function Configuration Options
print("🎯 Enhanced Loss Function Configuration")
print("="*50)

# Option 1: Use predefined scenarios (RECOMMENDED)
print("📋 Option 1: Predefined Scenarios")
scenarios = {
    "imbalanced": "Sparse positive samples (recommended for radio bursts)",  
    "balanced": "Roughly balanced positive/negative samples",
    "noisy": "Significant noise and artifacts in data",
    "boundary_critical": "High precision required for edge detection",
    "original": "Use simple BCE+IoU loss (backward compatibility)"
}

for scenario, description in scenarios.items():
    config = get_loss_config_for_scenario(scenario)
    print(f"  • {scenario}: {description}")
    if not config.get('use_simple_loss', False):
        print(f"    Weights: {config['loss_weights']}")
        print(f"    Focal: {config['focal_params']}")

print()

# Option 2: Custom parameters
print("🔧 Option 2: Custom Parameters")
print("  You can specify custom loss_weights and focal_params in train_model()")
print("  Example:")
print("    loss_weights = {'focal': 1.5, 'iou': 1.2, 'boundary': 0.3}")
print("    focal_params = {'alpha': 0.85, 'gamma': 3.0}")

print()

# Option 3: Auto-tuning based on validation metrics
print("📊 Option 3: Auto-tuning (Advanced)")
print("  Use LossTuner to get suggestions based on validation metrics")
example_metrics = {'precision': 0.85, 'recall': 0.45, 'f1': 0.59, 'iou': 0.42}
print(f"  Example: {example_metrics}")

# Uncomment to use auto-tuning:
# tuner = LossTuner()
# suggested = tuner.suggest_parameters(example_metrics)
# print(f"  Suggested: {suggested}")

In [None]:
# 🎯 DeepLabV3+ Training Configuration
print("⚙️ Configuring training parameters for DeepLabV3+...")

# Training parameters optimized for DeepLabV3+
training_config_deeplabv3 = {
    'initial_lr': 1e-3,       # Standard learning rate
    'freeze_epochs': 0,       # No freezing needed (no ImageNet)
    'total_epochs': 50,       # Can be increased for better results
    'patience': 15,           # More patience due to no pretraining
    'loss_config': 'imbalanced'  # Use predefined imbalanced configuration
}

print("✅ Configuration ready:")
print(f"   Loss weights: {loss_config_deeplabv3}")
print(f"   Training epochs: {training_config_deeplabv3['total_epochs']}")
print(f"   No encoder freezing (training from scratch)")

In [None]:
# Enhanced Training with Configurable Loss Parameters
# initial_lr = 1e-3
# freeze_epochs = 100
# total_epochs = 150
# patience = 10
# checkpoint_dir = './checkpoints_enhanced'

print("🚀 Starting Enhanced Training with Configurable Loss...")
print("="*60)

# Choose your training method:

# # Method A: Use predefined scenario (RECOMMENDED)
# loss_config = "imbalanced"  # Choose: "balanced", "imbalanced", "noisy", "boundary_critical", "original"

# train_model(model, train_loader, val_loader, 
#            initial_lr=initial_lr,
#            freeze_epochs=freeze_epochs, 
#            total_epochs=total_epochs,
#            checkpoint_dir=checkpoint_dir, 
#            patience=patience, 
#            device=device,
#            loss_config=loss_config)

# Method B: Use custom parameters (ADVANCED)
# custom_loss_weights = {'focal': 1.5, 'iou': 1.2, 'boundary': 0.3}
# custom_focal_params = {'alpha': 0.85, 'gamma': 3.0}
# 
# train_model(model, train_loader, val_loader, 
#            initial_lr=initial_lr,
#            freeze_epochs=freeze_epochs, 
#            total_epochs=total_epochs,
#            checkpoint_dir=checkpoint_dir, 
#            patience=patience, 
#            device=device,
#            loss_weights=custom_loss_weights,
#            focal_params=custom_focal_params)

# Train with enhanced loss configuration
trained_model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    initial_lr=training_config_deeplabv3['initial_lr'],
    freeze_epochs=training_config_deeplabv3['freeze_epochs'],  # 0 for DeepLabV3+
    total_epochs=training_config_deeplabv3['total_epochs'],
    checkpoint_dir='./checkpoints_deeplabv3',  # Separate directory
    patience=training_config_deeplabv3['patience'],
    device=device,
    loss_config=training_config_deeplabv3['loss_config']  # Use imbalanced config
)

# Method B: Custom DeepLabV3+ parameters (ADVANCED)
# custom_loss_weights = {'focal': 0.8, 'iou': 1.0, 'boundary': 0.2}  # Adjust weights
# custom_focal_params = {'alpha': 0.85, 'gamma': 2.5}                # Fine-tune focal loss
# custom_backbone = 'resnet18'  # 'resnet18', 'resnet34', 'efficientnet-b0'
# 
# # Rebuild model with custom backbone if needed
# # model = build_deeplabv3(encoder_name=custom_backbone, encoder_weights=None)
# # model.to(device)
# 
# trained_model, history = train_model(
#     model=model,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     initial_lr=training_config_deeplabv3['initial_lr'],
#     freeze_epochs=training_config_deeplabv3['freeze_epochs'],  # 0 for DeepLabV3+
#     total_epochs=training_config_deeplabv3['total_epochs'],
#     checkpoint_dir='./checkpoints_deeplabv3',  # Separate directory
#     patience=training_config_deeplabv3['patience'],
#     device=device,
#     loss_weights=custom_loss_weights,
#     focal_params=custom_focal_params
# )

## 📊 Enhanced vs Original Loss Comparison

Load and Evaluate the Best Model

After training, load the best saved checkpoint and evaluate the model on the validation set.

You can compare the enhanced loss function with the original BCE+IoU loss to see the improvement:


In [None]:
# Load best enhanced model and compare with original loss
print("🔍 Comparing Enhanced vs Original Loss Functions")
print("="*50)

# Load the enhanced model
best_enhanced_path = './checkpoints_enhanced/[your_best_checkpoint].pth'  # Update this path
# checkpoint = torch.load(best_enhanced_path, map_location=device)
# model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate with enhanced loss (already loaded)
print("📊 Enhanced Loss Results:")
val_loss_enhanced, val_metrics_enhanced = validate_one_epoch(model, val_loader, device)
print(f"  Validation Loss: {val_loss_enhanced:.4f}")
print(f"  IoU: {val_metrics_enhanced['iou']:.4f}")
print(f"  F1:  {val_metrics_enhanced['f1']:.4f}")

# For comparison, evaluate with original loss
print("\\n📊 Original Loss (BCE+IoU) for comparison:")
model.eval()
total_original_loss = 0.0
num_batches = 0

with torch.no_grad():
    for x, y in val_loader:
        x, y = x.to(device), y.to(device)
        preds = model(x)
        # Use original simple loss
        original_loss = simple_combined_loss(y, preds)
        total_original_loss += original_loss.item()
        num_batches += 1

avg_original_loss = total_original_loss / num_batches
print(f"  Original Loss: {avg_original_loss:.4f}")
print(f"  IoU: {val_metrics_enhanced['iou']:.4f} (same model)")
print(f"  F1:  {val_metrics_enhanced['f1']:.4f} (same model)")

# Show improvement
improvement = avg_original_loss - val_loss_enhanced
improvement_pct = improvement / avg_original_loss * 100
print(f"\\n🎯 Loss Improvement: {improvement:.4f} ({improvement_pct:+.1f}%)")

# Individual loss component analysis
print(f"\\n🧩 Enhanced Loss Component Breakdown:")
with torch.no_grad():
    sample_x, sample_y = next(iter(val_loader))
    sample_x, sample_y = sample_x.to(device), sample_y.to(device)
    sample_preds = model(sample_x)
    
    focal_val = focal_loss(sample_preds, sample_y, alpha=0.8, gamma=2.5)
    iou_val = adaptive_iou_loss(sample_preds, sample_y, power=1.5)
    boundary_val = boundary_loss(sample_preds, sample_y)
    
    print(f"  Focal Loss:    {focal_val.item():.4f}")
    print(f"  IoU Loss:      {iou_val.item():.4f}")
    print(f"  Boundary Loss: {boundary_val.item():.4f}")
