In [1]:
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms

import segmentation_models_pytorch as smp
from segmentation_models_pytorch import Unet
from segmentation_models_pytorch.encoders import get_preprocessing_fn
from segmentation_models_pytorch.utils import train as smp_train
from segmentation_models_pytorch import utils

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.profilers import AdvancedProfiler, SimpleProfiler

from torchmetrics.segmentation import MeanIoU, GeneralizedDiceScore
from torchmetrics import Accuracy, F1Score, Precision, Recall, ConfusionMatrix, AUROC

import os
import cv2
from sklearn.model_selection import train_test_split

import numpy as np
import random
from tqdm.notebook import tqdm
from datetime import datetime

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# SET UP INPUT/OUTPUT PATHS

In [3]:
evr_dir = '/media/ubuntu/E/EVR_region_files'

img_dir = '/media/ubuntu/E/ML_data/imgs'
mask_dir = '/media/ubuntu/E/ML_data/masks'
binary_mask_dir = '/media/ubuntu/E/ML_data/binary_masks'

# COLLECT ZAR PATHS

In [4]:
from pathlib import Path
import xarray as xr
import pandas as pd
import warnings

import echoregions as er

# Find EVR for each ZARR 
zarr_dir = '/media/ubuntu/E/processed'
zarr_evr_df = pd.DataFrame(columns = ['date', 'zarr', 'evr_files'])
counter = 0
for zarr_subdir in os.listdir(zarr_dir):
    for zarr_file in os.listdir(os.path.join(zarr_dir, zarr_subdir)):
        _, date, _ = zarr_file.split('-')
        date = date[5:]
        month = date[:2]
        day = date[2:4]
        if month == '06': month = 'June'
        elif month == '07': month = 'July'
        else: 
            print(date, 'Different month', month)
        evr_date = day + month
        evr_files = []
        for evr_file in os.listdir(evr_dir):
            if evr_file.startswith(evr_date):
                evr_files.append(evr_file)
    
        zarr_file = os.path.join(zarr_dir, zarr_subdir, zarr_file, zarr_file + '_Sv.zarr')
        zarr_evr_df.loc[counter] = [date, zarr_file, evr_files]
        counter += 1

zarr_evr_df

NameError: name 'pd' is not defined

# CUT ALL DATASETS

In [None]:
import math

warnings.filterwarnings("ignore", category=UserWarning, message="Returning No Mask. Empty 3D Mask cannot be converted to 2D Mask.")
warnings.filterwarnings("ignore", category=UserWarning, message="No gridpoint belongs to any region.")

chunk_sizes = {
    'channel': -1,           # Load all channels in one chunk
    'ping_time': 100,        # Chunk by 100
    'range_sample': 650      # Split into 2 chunks
}

def find_nan_depths(channel_data):
    nan_mask = channel_data.isnull()
    all_nans = nan_mask.all(dim='ping_time')
    all_nan_depths = channel_data.range_sample.where(all_nans, drop=True).values
    return all_nan_depths

# Use Dask
def load_zarr_lazy(zarr_path, chunk_sizes=chunk_sizes, ignore_vars = []):
    return xr.open_zarr(zarr_path, chunks=chunk_sizes, drop_variables = ignore_vars)   

def correct_echo_range(ds):
    # Replace channel and ping_time with their first elements
    first_channel = ds["channel"].values[0]
    first_ping_time = ds["ping_time"].values[0]
    
    # Slice the echo_range to get the desired range of values
    selected_echo_range = ds["echo_range"].sel(channel=first_channel, ping_time=first_ping_time)
    selected_echo_range = selected_echo_range.values.tolist()
    selected_echo_range = [value + 8.6 for value in selected_echo_range]

    # Find min and max ignoring NaNs
    min_val = np.nanmin(selected_echo_range)
    max_val = np.nanmax(selected_echo_range)
    
    # Assign the values to the depth coordinate, transducer offset 8.6m
    ds = ds.assign_coords(range_sample=selected_echo_range)

    # Remove nan values
    ds = ds.sel(range_sample=slice(min_val, max_val))
    
    return ds

def normalize_each_channel(data):
    # Replace NaNs with the minimum value - 10 for each channel
    min_vals = np.nanmin(data, axis=(0, 1), keepdims=True)
    data = np.where(np.isnan(data), min_vals - 10, data)

    # Calculate the minimum value for each channel
    min_vals = np.nanmin(data, axis=(0, 1), keepdims=True)
    max_vals = np.nanmax(data, axis=(0, 1), keepdims=True)
    
    # Calculate normalization parameters
    ranges = max_vals - min_vals
    ranges[ranges == 0] = 1
    
    # Normalize the data
    data_normalized = (data - min_vals) / ranges
    return data_normalized

def fill_na_with_interpolation(chunk):
    for channel in chunk.channel:
        chunk.loc[dict(channel=channel)] = chunk.sel(channel=channel).interpolate_na(dim='ping_time', method='linear')
    return chunk

def load_combine_process_zarrs(zarr_paths, ignore_vars = []):
    datasets = [load_zarr_lazy(path, chunk_sizes = {}, ignore_vars = ignore_vars) for path in zarr_paths]
    combined_dataset = xr.concat(datasets, dim='ping_time')
    combined_dataset = combined_dataset.sortby('ping_time')
    # select first 3 channels
    combined_dataset = combined_dataset.isel(channel=slice(0, 3))
    # remove empty pings
    combined_dataset = combined_dataset.dropna(dim='ping_time', how='all', subset=['Sv'])
    combined_dataset = correct_echo_range(combined_dataset)
    #combined_dataset = apply_remove_background_noise(combined_dataset)
    combined_dataset = combined_dataset.rename({'range_sample': 'depth'})
    return combined_dataset

def chunk_mask(combined_dataset, regions2d_list, date, chunk_ratio = 1.5, min_nonzero = 10):
    # Cut image and mask into chunks equal to the height * chunk_ratio
    ds_lengh = combined_dataset.sizes['ping_time']
    chunk_size = int(combined_dataset.sizes['depth'] * chunk_ratio)
    num_chunks = math.ceil(ds_lengh / chunk_size)
    # Iterate over chunks, normalize each, overlay mask
    for i in range(0, num_chunks):
        start = i*chunk_size
        end = min((i+1)*chunk_size, ds_lengh)
        chunk = combined_dataset.isel(ping_time=slice(start,end))["Sv"]
        #print(chunk['ping_time'].min().values, chunk['ping_time'].max().values)
        #print(chunk.values.T.shape)

        # Overlay mask for this chunk
        mask = None
        for regions2d in regions2d_list:
            region_mask_ds, region_points = regions2d.mask(
                        chunk.isel(channel=1).drop_vars("channel"),
                        region_class=region_classes,
                        collapse_to_2d = True
                    )
            if region_mask_ds:
                loc_mask = region_mask_ds['mask_2d'].fillna(0).values.astype(int)
                
                # Replace region_id with class_id in the mask
                region_class_mapping = regions2d.data.merge(region_classes_df, how='left', left_on = 'region_class', right_on = 'class')
                region_class_mapping = region_class_mapping[region_class_mapping.region_id.isin(np.unique(loc_mask))][['region_id', 'class_ind']].astype(int).sort_values(by='region_id')
                for _, row in region_class_mapping.iterrows():
                    loc_mask[loc_mask == row.region_id] = row.class_ind
                
                if mask is not None and mask.size > 0:
                    mask = mask + loc_mask
                else:
                    mask = loc_mask

        # Only save chunks for which there is non-empty mask
        if mask is not None and mask.size > 0:
            print(i, chunk['ping_time'].min().values, chunk['ping_time'].max().values, np.count_nonzero(mask), np.unique(mask))
            # check that there are enough annotated pixels
            if np.count_nonzero(mask) < min_nonzero: continue

            fname = '%s_%d.npy' % (date, i)
    
            chunk_filepath = os.path.join(img_dir, fname)
            
            chunk = normalize_each_channel(chunk.values.T)
            np.save(chunk_filepath, chunk)
            chunk = None

            # binarize mask
            binary_mask = (mask > 0).astype(int)
            binary_mask_filepath = os.path.join(binary_mask_dir, fname)
            np.save(binary_mask_filepath, binary_mask)
            binary_mask = None     

            # mask with class types
            mask_filepath = os.path.join(mask_dir, fname)
            np.save(mask_filepath, mask)
            mask = None

def process_one_day(zarr_paths, evr_paths, date, chunk_ratio = 1.5, min_nonzero = 10, ignore_vars = []):
    combined_dataset = load_combine_process_zarrs(zarr_paths)

    # there can be several evr files
    regions2d_list = [er.read_evr(evr_file) for evr_file in evr_paths]

    chunk_mask(combined_dataset, regions2d_list, date, chunk_ratio, min_nonzero)

In [11]:
region_classes=["Def herring", "Prob herring", "Poss herring", "Surface herring", "Mackerel", 
                                                "Gadoids", "Norway pout", "unidentified fish"]
region_classes_df = pd.DataFrame({'class': region_classes}, index=range(1, len(region_classes)+1))
region_classes_df.to_csv(os.path.join(mask_dir, 'classes.csv'))
region_classes_df['class_ind'] = region_classes_df.index.astype(int)

ignore_vars = ['source_filenames', 'filenames', 'angle_offset_alongship', 'angle_offset_athwartship',
                'beamwidth_alongship', 'beamwidth_athwartship', 'water_level', 'angle_sensitivity_alongship',
               'angle_sensitivity_athwartship', 'equivalent_beam_angle', #frequency_nominal,
               'gain_correction', 'sa_correction', 'sound_absorption', 'sound_speed'
              ]

# Process all datasets
for date, rows in tqdm(zarr_evr_df.groupby('date'), total = len(zarr_evr_df.date.unique())):
    if date == '0629' or date == '0630' or date == '0702': continue # Exclude broken datasets
    zarr_paths = rows.zarr.values
    evr_paths = [os.path.join(evr_dir, e) for e in np.unique(np.concatenate(rows.evr_files.values))]
    process_one_day(zarr_paths, evr_paths, date, ignore_vars = ignore_vars)
    #break

  0%|          | 0/18 [00:00<?, ?it/s]

  result = blockwise(
  result = blockwise(


10 2007-07-01T04:09:42.938826000 2007-07-01T04:34:29.032574000 798 [0 3]
11 2007-07-01T04:34:29.798199000 2007-07-01T04:59:17.516951000 2214 [0 3]
12 2007-07-01T04:59:18.282574000 2007-07-01T05:24:04.001324000 903 [0 3]
13 2007-07-01T05:24:04.766951000 2007-07-01T05:48:45.595074000 256 [0 2]
14 2007-07-01T05:48:46.360701000 2007-07-01T06:13:26.063826000 241 [0 2 8]
18 2007-07-01T07:27:13.532574000 2007-07-01T07:51:43.079449000 977 [0 2]
19 2007-07-01T07:51:43.829449000 2007-07-01T08:16:13.954449000 135 [0 2]


# LOAD DATASET

In [3]:
from torchvision.transforms.functional import hflip
def random_horizontal_flip(image, mask, p=0.5):
    """
    Horizontally flip the given image and mask with a given probability.

    Parameters:
    - image: torch.Tensor, the input image tensor.
    - mask: torch.Tensor, the input mask tensor.
    - p: float, probability of the image and mask being flipped. Default is 0.5.

    Returns:
    - image: torch.Tensor, the potentially flipped image.
    - mask: torch.Tensor, the potentially flipped mask.
    """
    if torch.rand(1).item() < p:
        image = hflip(image)
        mask = hflip(mask)
    return image, mask

class SonarDataset(Dataset):
    def __init__(self, data_paths, mask_paths, resize_size=512, num_patches_per_image=3):
        """
        Initialize the SonarDataset.

        Parameters:
        - data_paths: list of str, paths to the image data files.
        - mask_paths: list of str, paths to the mask data files.
        - resize_size: int, the size to which the patches will be resized. Default is 512.
        - num_patches_per_image: int, number of patches to extract from each image. Default is 3.
        """
        self.data_paths = data_paths
        self.mask_paths = mask_paths
        self.resize_size = resize_size
        self.num_patches_per_image = num_patches_per_image
        
        # Transformations applied on the patches, NO IMAGENET
        #self.transforms = transforms.Compose([
        #    transforms.ToTensor(),
        #])

        # IMAGENET NORMALIZAZION
        self.transforms = transforms.Compose([
            transforms.ToTensor(),  # Applies only if your data is not already a tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        """
        Return the total number of patches in the dataset.
        """
        return len(self.data_paths) * self.num_patches_per_image

    def __getitem__(self, idx):
        """
        Retrieve a patch and its corresponding mask by index.

        Parameters:
        - idx: int, the index of the patch to retrieve.

        Returns:
        - image_patch: torch.Tensor, the transformed image patch.
        - mask_patch: torch.Tensor, the transformed mask patch.
        """
        file_idx = idx // self.num_patches_per_image
        data_path = self.data_paths[file_idx]
        mask_path = self.mask_paths[file_idx]

        image = np.load(data_path, mmap_mode='r')
        # take only first 3 channels
        image = image[..., :3]
        
        mask = np.load(mask_path, mmap_mode='r')
        patch_size = image.shape[0]

        max_x = image.shape[1] - patch_size
        if max_x <= 0:
            raise ValueError("Patch size is larger than the image width.")

        x = random.randint(0, max_x)
        image_patch = image[:, x:x + patch_size]
        mask_patch = mask[:, x:x + patch_size]

        if self.resize_size != patch_size:
            image_patch = cv2.resize(image_patch, (self.resize_size, self.resize_size), interpolation=cv2.INTER_NEAREST)
            mask_patch = cv2.resize(mask_patch, (self.resize_size, self.resize_size), interpolation=cv2.INTER_NEAREST)

        image_patch = self.transforms(image_patch.astype(np.float32))
        mask_patch = torch.tensor(mask_patch, dtype=torch.float32).unsqueeze(0)

        image_patch, mask_patch = random_horizontal_flip(image_patch, mask_patch)

        return image_patch, mask_patch

## Split dataset intro Train/Test/Val

In [4]:
img_dir = '/media/ubuntu/E/ML_data/imgs/'
mask_dir = '/media/ubuntu/E/ML_data/binary_masks/'

files = [f for f in os.listdir(img_dir) if f.endswith('.npy')]
days = list(set([f.split('_')[0] for f in files]))
temp_train_days, test_days = train_test_split(days, test_size=0.1, random_state=1)
train_days, val_days = train_test_split(temp_train_days, test_size=0.2, random_state=1)

train_images = [os.path.join(img_dir, f) for f in files if f.split('_')[0] in train_days]
train_masks = [os.path.join(mask_dir, f) for f in files if f.split('_')[0] in train_days]

val_images = [os.path.join(img_dir, f) for f in files if f.split('_')[0] in val_days]
val_masks = [os.path.join(mask_dir, f) for f in files if f.split('_')[0] in val_days]

test_images = [os.path.join(img_dir, f) for f in files if f.split('_')[0] in test_days]
test_masks = [os.path.join(mask_dir, f) for f in files if f.split('_')[0] in test_days]

# Create datasets
train_dataset = SonarDataset(train_images, train_masks)
val_dataset = SonarDataset(val_images, val_masks)
test_dataset = SonarDataset(test_images, test_masks)

batch_size = 4

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# TRAIN MODEL

In [5]:
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2, logits=True, reduction='mean'):
        """
        Parameters:
        alpha (tensor, optional): Weights for each class. Default is equal weight.
        gamma (int, optional): Focusing parameter. Default is 2.
        logits (bool, optional): If True, expects inputs as raw logits. If False, expects probabilities. Default is True.
        reduction (str, optional): Specifies the reduction to apply to the output: 'none', 'mean', 'sum'. Default is 'mean'.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha if alpha is not None else torch.tensor([1.0, 1.0])
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction

    def forward(self, inputs, targets):
        if self.logits:
            # Compute the binary cross-entropy loss with logits
            BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            # Compute the binary cross-entropy loss
            BCE_loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
        
        # Ensure targets are on the same device as inputs
        targets = targets.to(inputs.device).long()
        # Ensure alpha is on the same device as targets
        self.alpha = self.alpha.to(inputs.device)
        # Dynamic alpha based on target class
        alpha = self.alpha[targets]
        
        # Compute the modulating factor (1 - pt)^gamma
        pt = torch.exp(-BCE_loss)
        focal_loss = alpha * ((1 - pt) ** self.gamma) * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss

class SegModel(pl.LightningModule):
    def __init__(self, model, criterion, optimizer, threshold = 0.5):
        """
        Initialize the SegModel.

        Parameters:
        - model: PyTorch model, the segmentation model to be used.
        - criterion: loss function.
        - optimizer: optimizer function.
        - threshold: float, threshold for converting probabilities to binary predictions. Default is 0.5.
        """
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.threshold = threshold

        # Initialize metrics
        self.iou = MeanIoU(num_classes=2, per_class=True)  
        self.precision = Precision(task="multiclass", num_classes=2, average='none')
        self.recall = Recall(task="multiclass", num_classes=2, average='none')
        self.f1 = F1Score(task="multiclass", num_classes=2, average='none')
        self.dice = GeneralizedDiceScore(num_classes=2, include_background=True, per_class=True)
        self.confusion_matrix = ConfusionMatrix(task="multiclass", num_classes=2)
        self.auroc = AUROC(task="binary")

        self.test_outputs = []
        
    def forward(self, x):
        return self.model(x)
    
    def shared_step(self, batch, stage):
        """
        Shared step for training, validation.

        Parameters:
        - batch: the input batch containing images and masks.
        - stage: str, the stage of the training process (e.g., "train", "valid").

        Returns:
        - dict: contains loss and IoU for the current batch.
        """
        image, mask = batch
        out = self.forward(image)
        
        # Ensure mask is float for focal loss compatibility
        loss = self.criterion(out, mask.float())  
        
        # Convert logits to binary predictions
        preds = (out.sigmoid() > 0.5).long()
        
        tp, fp, fn, tn = smp.metrics.get_stats(preds, mask.long(), mode='binary')
        iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        self.log(f"{stage}_IoU", iou, prog_bar=True, on_step=False, on_epoch=True)
        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True) 
        return {"loss": loss, "iou": iou}
        
    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")     

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def test_step(self, batch, batch_idx):
        """
        Test step to compute various metrics.

        Parameters:
        - batch: the input batch containing images and masks.
        - batch_idx: int, the index of the current batch.

        Returns:
        - dict: contains various metrics for the current batch.
        """
        x, y = batch
        logits = self(x)
        preds = torch.sigmoid(logits) > self.threshold
        preds = preds.int()  # Convert boolean to integers
        y = y.int()  # Ensure targets are also integers

        self.precision.reset()
        self.recall.reset()
        self.f1.reset()
        self.iou.reset()
        
        # Update metrics
        iou_score = self.iou(preds, y)
        dice_score = self.dice(preds, y)
        precision = self.precision(preds, y)
        recall = self.recall(preds, y)
        f1 = self.f1(preds, y)
        cm = self.confusion_matrix(preds, y).float()
        auroc = self.auroc(preds, y)

        outputs = {
            "iou": iou_score,
            "dice": dice_score,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "confusion_matrix": cm,
            "auroc": auroc
        }

        self.test_outputs.append(outputs)
        
        return outputs

    def on_test_epoch_end(self):
        """
        Aggregates metrics at the end of the test epoch and logs them.
        """
        # Aggregate metrics
        iou_scores = torch.stack([x['iou'] for x in self.test_outputs])
        dice_scores = torch.stack([x['dice'] for x in self.test_outputs])
        precisions = torch.stack([x['precision'] for x in self.test_outputs])
        recalls = torch.stack([x['recall'] for x in self.test_outputs])
        f1_scores = torch.stack([x['f1'] for x in self.test_outputs])
        cm_scores = torch.stack([x['confusion_matrix'] for x in self.test_outputs])
        auroc_scores = torch.stack([x['auroc'] for x in self.test_outputs])
            
        # Sum confusion matrices
        sum_cm = cm_scores.sum(dim=0)
        sum_cm_np = sum_cm.cpu().numpy()  # Convert to numpy array
        cm_normalized = sum_cm_np.astype('float') / sum_cm_np.sum(axis=1)[:, np.newaxis]

        # Average metrics
        avg_auroc = auroc_scores.mean()

        avg_iou_background = iou_scores[:, 0].mean()
        avg_iou_class_of_interest = iou_scores[:, 1].mean()

        avg_dice_background = dice_scores[:, 0].mean()
        avg_dice_class_of_interest = dice_scores[:, 1].mean()
        
        avg_precision_background = precisions[:, 0].mean()
        avg_precision_class_of_interest = precisions[:, 1].mean()
        avg_recall_background = recalls[:, 0].mean()
        avg_recall_class_of_interest = recalls[:, 1].mean()
        avg_f1_background = f1_scores[:, 0].mean()
        avg_f1_class_of_interest = f1_scores[:, 1].mean()
    
        # Log aggregated metrics
        self.log('avg_auroc', avg_auroc)
        
        self.log('avg_iou_background', avg_iou_background)
        self.log('avg_iou_class_of_interest', avg_iou_class_of_interest)
        
        self.log('avg_dice_background', avg_dice_background)
        self.log('avg_dice_class_of_interest', avg_dice_class_of_interest)
        
        self.log('avg_precision_background', avg_precision_background)
        self.log('avg_precision_class_of_interest', avg_precision_class_of_interest)
        self.log('avg_recall_background', avg_recall_background)
        self.log('avg_recall_class_of_interest', avg_recall_class_of_interest)
        self.log('avg_f1_background', avg_f1_background)
        self.log('avg_f1_class_of_interest', avg_f1_class_of_interest)
    
        # Plot the normalized confusion matrix
        plt.figure(figsize=(4, 3))
        sns.heatmap(cm_normalized, annot=True, fmt='.4f', cmap='Blues', xticklabels=['BG', 'Fish'], yticklabels=['BG', 'Fish'])
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title('Normalized Confusion Matrix')
        plt.show()
    
        # Clear the outputs for the next epoch
        self.test_outputs = []

    def select_threshold(self, dataloader, device='cuda'):
        """
        Compute metrics for different probability thresholds to select the best threshold.

        Parameters:
        - dataloader: DataLoader, the DataLoader providing the images and masks.
        - device: str, the device to use for computation. Default is 'cuda'.

        Returns:
        - pd.DataFrame: a DataFrame containing metrics for different thresholds.
        """
        self.model.eval()
        results = []

        # Define probability thresholds
        thresholds = torch.arange(0.1, 1, 0.1)
    
        with torch.no_grad():
            for threshold in tqdm(thresholds):
                self.precision.reset()
                self.recall.reset()
                self.f1.reset()
                self.iou.reset()
                iou_list = []
                
                for batch_idx, batch in tqdm(enumerate(dataloader), leave=False):
                    images, masks = batch
                    images = images.to(device)
                    masks = masks.to(device)
                    # Make predictions
                    logits = self.model(images)  
                    # Binarize predictions
                    preds = torch.sigmoid(logits) > threshold
                    preds = preds.int()  # Convert boolean to integers
                    masks = masks.int()  # Ensure targets are also integers
    
                    self.f1.update(preds, masks)
                    self.precision.update(preds, masks)
                    self.recall.update(preds, masks)
                    # Update for IoU does not work correctly, so list is used instead
                    #self.iou.update(preds, masks)
                    iou_list.append(self.iou(preds, masks)[1].item())
                    
                average_precision = self.precision.compute()[1].item()
                average_recall = self.recall.compute()[1].item()
                average_f1 = self.f1.compute()[1].item()
                #average_iou = self.iou.compute()[1].item()
                average_iou = np.sum(iou_list) / len(iou_list)
    
                results.append({
                    'Threshold': threshold.item(),
                    'IoU': average_iou,
                    'F1': average_f1,
                    'Precision': average_precision,
                    'Recall': average_recall
                })
    
        df_results = pd.DataFrame(results)
        
        return df_results

    def save_preds(self, logits, batch_idx, output_dir, file_prefix):
        """
        Save predictions to the specified directory.

        Parameters:
        - logits: torch.Tensor, the logits output from the model.
        - batch_idx: int, the index of the current batch.
        - output_dir: str, directory where predictions will be saved.
        - file_prefix: str, prefix for the output file names (optional).
        """
        filename = str(batch_idx) + '.npy'
        if file_prefix:
            filename = file_prefix + '_' + filename
        pred_file_path = os.path.join(output_dir, filename)
        np.save(pred_file_path, logits.cpu())

    def predict(self, dataloader, device='cuda', output_dir=None, file_prefix=None):
        """
        Make predictions on a dataset and optionally save the results.

        Parameters:
        - dataloader: DataLoader, the DataLoader providing the images.
        - device: str, the device to use for computation. Default is 'cuda'.
        - output_dir: str, directory where predictions will be saved.
        - file_prefix: str, prefix for the output file names.
        """
        self.model.eval()

        if output_dir is None: 
            print('No output dir for predictions specified.')
            return
        
        with torch.no_grad():
            for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
                images, _ = batch
                images = images.to(device)
                logits = self.model(images)
                self.save_preds(logits, batch_idx, output_dir, file_prefix)


    def configure_optimizers(self):
        """
        Configure the optimizer and learning rate scheduler.

        Returns:
        - dict: containing the optimizer and learning rate scheduler.
        """
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.8, patience=10, min_lr=1e-05)
        return {'optimizer': self.optimizer, 'lr_scheduler': scheduler, 'monitor': 'valid_loss'}

In [6]:
# Set matmul precision
torch.set_float32_matmul_precision('medium')

# Select model
model_name = 'unet'
encoder = 'resnet34'  
pretrained = 'imagenet'

# Name of the model
run_name = model_name + '_' + encoder
if pretrained: run_name += '_' + pretrained
version_name = f"{datetime.now().strftime('%d%m-%H%M')}"
run_name += '_' + version_name

classes = 1

# Create the model
model = smp.create_model(model_name,
                         encoder_name = encoder,
                         in_channels = 3,
                         encoder_weights=pretrained,
                         classes = classes).to(device)
   
# Assuming class 0 is the background and class 1 is the class of interest
# alpha shows relative importance of background vs class
criterion = FocalLoss(alpha=torch.tensor([0.05, 0.95]), gamma=2.0, logits=True)
# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-03)#, weight_decay=1e-05)

# Set up how to save trained model weights
# checkpoint_dir can be changed
checkpoint_dir = f'./models/checkpoints_{run_name}'
checkpoint = ModelCheckpoint(dirpath = checkpoint_dir,
                                   #filename='{epoch:02d}-{valid_IoU:.2f}',
                                   filename='best_model', 
                                   save_top_k=1,
                                   verbose = True, 
                                   monitor = 'valid_loss', 
                                   mode = 'min')
# Set up early training stopping
early_stopping = EarlyStopping(monitor='valid_loss', patience=10, mode='min')
# How often update learning rate
lr_monitor = LearningRateMonitor(logging_interval='epoch')

In [7]:
pl_model = SegModel(model, criterion, optimizer)

# Select profiler
profiler = SimpleProfiler()

# Initialize the logger
logger = TensorBoardLogger(
    save_dir="tb_logs",
    name=f'{run_name}'
)

# Initialize the trainer
trainer = pl.Trainer(
    profiler=profiler,
    num_sanity_val_steps=5,
    logger=logger,
    gradient_clip_val=0.5,
    precision='16-mixed',
    accelerator='gpu',
    max_epochs=100,
    callbacks=[checkpoint, early_stopping, lr_monitor],
    #val_check_interval=0.1  # Validate more frequently
)

# Train the model
trainer.fit(pl_model, train_loader, val_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: tb_logs/unet_resnet34_imagenet_1807-1054
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | model            | Unet                      | 24.4 M
1 | criterion        | FocalLoss                 | 0     
2 | iou              | MeanIoU                   | 0     
3 | precision        | MulticlassPrecision       | 0     
4 | recall           | MulticlassRecall          | 0     
5 | f1               | MulticlassF1Score         | 0     
6 | dice             | GeneralizedDiceScore      | 0     
7 | confusion_matrix | MulticlassConfusionMatrix | 0     
8 | auroc            | BinaryAUROC               | 0     
-----------------------------------------------------

Sanity Checking: |                                                                                            …

  return F.conv2d(input, weight, bias, self.stride,


Training: |                                                                                                   …

Validation: |                                                                                                 …

Epoch 0, global step 189: 'valid_loss' reached 0.00013 (best 0.00013), saving model to '/home/ubuntu/SSP/echosounder-segmentation/models/checkpoints_unet_resnet34_imagenet_1807-1054/best_model.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 1, global step 378: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 2, global step 567: 'valid_loss' reached 0.00009 (best 0.00009), saving model to '/home/ubuntu/SSP/echosounder-segmentation/models/checkpoints_unet_resnet34_imagenet_1807-1054/best_model.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 3, global step 756: 'valid_loss' reached 0.00009 (best 0.00009), saving model to '/home/ubuntu/SSP/echosounder-segmentation/models/checkpoints_unet_resnet34_imagenet_1807-1054/best_model.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 4, global step 945: 'valid_loss' reached 0.00008 (best 0.00008), saving model to '/home/ubuntu/SSP/echosounder-segmentation/models/checkpoints_unet_resnet34_imagenet_1807-1054/best_model.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 5, global step 1134: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 6, global step 1323: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 7, global step 1512: 'valid_loss' reached 0.00008 (best 0.00008), saving model to '/home/ubuntu/SSP/echosounder-segmentation/models/checkpoints_unet_resnet34_imagenet_1807-1054/best_model.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 8, global step 1701: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 9, global step 1890: 'valid_loss' reached 0.00006 (best 0.00006), saving model to '/home/ubuntu/SSP/echosounder-segmentation/models/checkpoints_unet_resnet34_imagenet_1807-1054/best_model.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 10, global step 2079: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 11, global step 2268: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 12, global step 2457: 'valid_loss' reached 0.00006 (best 0.00006), saving model to '/home/ubuntu/SSP/echosounder-segmentation/models/checkpoints_unet_resnet34_imagenet_1807-1054/best_model.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 13, global step 2646: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 14, global step 2835: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 15, global step 3024: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 16, global step 3213: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 17, global step 3402: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 18, global step 3591: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 19, global step 3780: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 20, global step 3969: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 21, global step 4158: 'valid_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 22, global step 4347: 'valid_loss' was not in top 1
FIT Profiler Report

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                 	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                