# Training a UNet Model for Solar Radio Burst Segmentation using PyTorch

In this notebook, we demonstrate how to train a UNet model for segmenting solar radio bursts using transfer learning. 
The training is split into two phases:

1. **Phase 1:** Freeze the encoder (pre-trained on ImageNet) and train only the decoder.  
2. **Phase 2:** Unfreeze the encoder and fine-tune the entire model using a lower learning rate.

We use a combined loss function consisting of binary cross-entropy (BCE) loss and a Jaccard (IOU) loss, and we monitor the IOU and F1 metrics on a validation set.

1. Import Libraries

In [1]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from train_utils import create_dataset, build_unet, freeze_encoder_weights, unfreeze_encoder_weights, combined_loss, compute_metrics, train_one_epoch, validate_one_epoch, adjust_learning_rate, save_checkpoint, train_model

ModuleNotFoundError: No module named 'train_utils'

2. Data Loading and Preprocessing

We assume that image and mask CSV files are stored in a single directory.
The naming convention is as follows:

- For burst slices:
       slice_20240608_y155_x270406_SkylineHS.csv
       slice_20240608_y155_x270406_SkylineHS_mask.csv

- For non-burst slices:
       slice_20240420_y0_x3391_PeachMountain_2020_nonburst.csv

The `create_dataset` function reads the files, normalizes them to [0, 1], and splits the data into training and validation sets.

In [None]:
data_dir = '/Users/remiliascarlet/Desktop/MDP/transfer_learning/burst_data/csv/saved_slices/finished'

(train_images, train_masks), (val_images, val_masks) = create_dataset(data_dir, img_size=(256, 256), test_size=0.2, random_state=42)