# Training a UNet Model for Solar Radio Burst Segmentation using PyTorch

In this notebook, we demonstrate how to train a UNet model for segmenting solar radio bursts using transfer learning. 
The training is split into two phases:

1. **Phase 1:** Freeze the encoder (pre-trained on ImageNet) and train only the decoder.  
2. **Phase 2:** Unfreeze the encoder and fine-tune the entire model using a lower learning rate.

We use a combined loss function consisting of binary cross-entropy (BCE) loss and a Jaccard (IOU) loss, and we monitor the IOU and F1 metrics on a validation set.

1. Import Libraries

In [1]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from train_utils import create_dataset, build_unet, freeze_encoder_weights, unfreeze_encoder_weights, combined_loss, compute_metrics, train_one_epoch, validate_one_epoch, adjust_learning_rate, save_checkpoint, train_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


2. Data Loading and Preprocessing

We assume that image and mask CSV files are stored in a single directory.
The naming convention is as follows:

- For burst slices:
       slice_20240608_y155_x270406_SkylineHS.csv
       slice_20240608_y155_x270406_SkylineHS_mask.csv

- For non-burst slices:
       slice_20240420_y0_x3391_PeachMountain_2020_nonburst.csv

The `create_dataset` function reads the files, normalizes them to [0, 1], and splits the data into training and validation sets.

In [2]:
data_dir = '/Users/remiliascarlet/Desktop/MDP/transfer_learning/burst_data/csv/saved_slices/finished'

(train_images, train_masks), (val_images, val_masks) = create_dataset(data_dir, img_size=(256, 256), test_size=0.2, random_state=42)

print("Training images shape:", train_images.shape)
print("Validation images shape:", val_images.shape)

Training images shape: (1171, 256, 256, 1)
Validation images shape: (293, 256, 256, 1)


3. Create DataLoaders

Convert the numpy arrays into PyTorch tensors and create DataLoaders.

We also need to permute dimensions from (N, H, W, C) to (N, C, H, W).

In [3]:
import torch

# Convert numpy arrays to Torch tensors and permute to (N, channels, H, W)
train_images_tensor = torch.tensor(train_images).permute(0, 3, 1, 2)
train_masks_tensor  = torch.tensor(train_masks).permute(0, 3, 1, 2)

val_images_tensor = torch.tensor(val_images).permute(0, 3, 1, 2)
val_masks_tensor  = torch.tensor(val_masks).permute(0, 3, 1, 2)

# Create TensorDatasets
train_dataset = TensorDataset(train_images_tensor, train_masks_tensor)
val_dataset   = TensorDataset(val_images_tensor, val_masks_tensor)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

4. Model Construction

Build a UNet model using the `build_unet` function from the segmentation_models library.

Here we set `input_shape=(256,256,1)` for grayscale images, `num_classes=1` for binary segmentation,
and use pre-trained ImageNet weights with a ResNet34 backbone.

In [4]:
# Import segmentation_models_pytorch library if not already installed: pip install segmentation-models-pytorch
model = build_unet(input_shape=(256, 256, 1), num_classes=1, encoder_weights='imagenet')
model.to(device)
print(model)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /Users/remiliascarlet/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth
100%|██████████| 83.3M/83.3M [00:03<00:00, 28.7MB/s]


Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

5. Training the Model

We train the model using a two-phase process:

    - **Phase 1:** The encoder is frozen for a number of epochs, training only the decoder.

    - **Phase 2:** The encoder is unfrozen, the learning rate is reduced, and the entire model is fine-tuned.

Early stopping is applied if no improvement is observed for a given number of epochs.

In [6]:
initial_lr = 1e-3
freeze_epochs = 10    # Adjust as needed.
total_epochs = 20
patience = 3
checkpoint_dir = './checkpoints'

# Train the model using the previously defined `train_model` function.
train_model(model, train_loader, val_loader, initial_lr=initial_lr,
            freeze_epochs=freeze_epochs, total_epochs=total_epochs,
            checkpoint_dir=checkpoint_dir, patience=patience, device=device)

Phase 1: Freezing encoder and training decoder only.
Epoch 1/10 - Train Loss: 1.1909 - Val Loss: 1.0954 - IOU: 0.2082 - F1: 0.2572
Checkpoint saved at epoch 1 with best metric 1.0953744963595742 -> ./checkpoints/checkpoint_epoch_1_metric_1.0954.pth


KeyboardInterrupt: 

6. Load and Evaluate the Best Model

After training, load the best saved checkpoint and evaluate the model on the validation set.

In [None]:
checkpoint_path = 'xxxxx'
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print("Loaded best model from checkpoint.")

# Evaluate the model on the validation DataLoader
val_loss, val_metrics = validate_one_epoch(model, val_loader, device)
print(f"Validation Loss: {val_loss:.4f}, IOU: {val_metrics['iou']:.4f}, F1: {val_metrics['f1']:.4f}")