In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from typing import Tuple, Union, List, Dict



In [2]:

class DebrisFlowDataset(Dataset):
    """A PyTorch Dataset for loading debris flow data, preparing it for CNN-LSTM models."""

    def __init__(self, main_dir: str, scaling_params: Dict[str, Dict[str, Union[None, float]]]):
        """
        Initialize the dataset with the main data directory and scaling parameters.

        Args:
            main_dir (str): The main directory where the data is stored.
            scaling_params (dict): A dictionary containing scaling parameters for each channel.
        """
        self.main_dir = main_dir
        self.scaling_params = scaling_params
        self.file_paths = self._gather_file_paths()

    def _gather_file_paths(self) -> List[Tuple[str, str, str]]:
        """Gather and pair file paths for elevation, thickness, and velocity channels."""
        file_paths = []
        for dirpath, _, filenames in os.walk(self.main_dir):
            if 'elevation' in dirpath:
                # Assume corresponding thickness and velocity files share the same prefix
                for elevation_file in filenames:
                    prefix = elevation_file.split('_elevation')[0]
                    thickness_file = f"{prefix}_thickness.npy"
                    velocity_file = f"{prefix}_velocity.npy"
                    if os.path.exists(os.path.join(dirpath, thickness_file)) and \
                       os.path.exists(os.path.join(dirpath, velocity_file)):
                        file_paths.append((
                            os.path.join(dirpath, elevation_file),
                            os.path.join(dirpath, thickness_file),
                            os.path.join(dirpath, velocity_file)
                        ))
        return file_paths

    def _scale_data(self, data: np.ndarray, channel_name: str) -> np.ndarray:
        """
        Scale the data for a given channel using scaling parameters.

        Args:
            data (np.ndarray): The data to scale.
            channel_name (str): The name of the channel to which the data belongs.

        Returns:
            np.ndarray: Scaled data.
        """
        median = self.scaling_params[channel_name]['median']
        mad = self.scaling_params[channel_name]['mad']
        return (data - median) / mad if median is not None and mad is not None else data

    def __len__(self) -> int:
        """Denotes the total number of samples."""
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Get the item at the given index in the form of a multi-channel image.

        Args:
            idx (int): The index of the item.

        Returns:
            tuple: A tuple containing torch.Tensors for the elevation, thickness, and velocity channels.
        """
        elevation_path, thickness_path, velocity_path = self.file_paths[idx]

        # Load data
        elevation = np.load(elevation_path)
        thickness = np.load(thickness_path)
        velocity = np.load(velocity_path)

        # Scale data
        elevation_scaled = self._scale_data(elevation, 'elevation')
        thickness_scaled = self._scale_data(thickness, 'thickness')
        velocity_scaled = self._scale_data(velocity, 'velocity')

        # Convert to torch.Tensor and stack as channels
        image = torch.tensor(np.stack((elevation_scaled, thickness_scaled, velocity_scaled), axis=0),
                             dtype=torch.float32)

        return image

In [None]:
def compute_channel_scaling_params(elevation_files, thickness_files, velocity_files):
    """Compute the non-zero median and MAD for each channel using unique filenames.

    This function processes the given files for each channel to compute the 
    non-zero median and MAD, which are useful for data normalization.

    Args:
        elevation_files (set): A set of unique elevation filenames.
        thickness_files (set): A set of unique thickness filenames.
        velocity_files (set): A set of unique velocity filenames.

    Returns:
        tuple: A tuple containing the dictionaries of median and MAD values for each channel.
    """
    print(f"Processing {len(elevation_files)} elevation files, {len(thickness_files)} thickness files, and {len(velocity_files)} velocity files.")
    
    def compute_median_and_mad(channel_data):
        """Compute the median and MAD of non-zero values in the data."""
        non_zero_data = channel_data[channel_data != 0].flatten()
        median_val = np.median(non_zero_data)
        mad_val = np.median(np.abs(non_zero_data - median_val))
        return median_val, mad_val

    median_vals = {}
    mad_vals = {}

    for channel_files, channel in zip([elevation_files, thickness_files, velocity_files], ['elevation', 'thickness', 'velocity']):
        start_time = time.time()
        channel_data = np.concatenate([np.load(file) for file in channel_files])
        median_vals[channel], mad_vals[channel] = compute_median_and_mad(channel_data)
        del channel_data  # Free up memory
        print(f"Processed {channel} files in {time.time() - start_time:.2f} seconds.")

    return median_vals, mad_vals

def set_channel_scaling_params_to_dataset(dataset, median_vals, mad_vals):
    """Apply the computed non-zero median and MAD values for each channel to the dataset.

    This function updates the dataset with scaling parameters for each channel, which are used
    for normalization during preprocessing.

    Args:
        dataset (DebrisFlowDataset): The dataset to apply the scaling parameters to.
        median_vals (dict): A dictionary containing the non-zero median values for each channel.
        mad_vals (dict): A dictionary containing the MAD values for each channel.
    """
    scaling_params = {channel: {'median': median_vals[channel], 'mad': mad_vals[channel]}
                      for channel in ['elevation', 'thickness', 'velocity']}
    dataset.set_scaling_params(scaling_params)

In [None]:
class MediumUNetPlus(nn.Module):
    """A basic U-Net architecture for semantic segmentation with dropout regularization.

    Attributes:
        enc1: First encoder block.
        enc2: Second encoder block.
        enc3: Third encoder block.
        enc4: Fourth encoder block.
        bottleneck: The bottleneck part of the network including dropout layers.
        dec1: First decoder block.
        dec2: Second decoder block.
        dec3: Third decoder block.
        dec4: Fourth decoder block.
        out_conv: Final output convolutional layer.
    """

    def __init__(self, in_channels, out_channels, dropout_rate=0.5):
        """Initializes the BasicUNet with the given number of input and output channels and dropout rate.

        Args:
            in_channels: The number of input channels.
            out_channels: The number of output channels.
            dropout_rate: The dropout rate to use in the bottleneck and decoder blocks.
        """
        super(MediumUNetPlus, self).__init__()

        # Encoder
        self.enc1 = self.encoder_block(in_channels, 32)
        self.enc2 = self.encoder_block(32, 64)
        self.enc3 = self.encoder_block(64, 128)
        self.enc4 = self.encoder_block(128, 256)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Decoder
        self.dec1 = self.decoder_block(512 + 256, 256, dropout_rate)
        self.dec2 = self.decoder_block(256 + 128, 128, dropout_rate)
        self.dec3 = self.decoder_block(128 + 64, 64, dropout_rate)
        self.dec4 = self.decoder_block(64 + 32, 32, dropout_rate)

        # Final output
        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)

    def encoder_block(self, in_channels, out_channels):
        """Creates an encoder block with Convolution, Batch Normalization, ReLU activation, and MaxPooling.

        Args:
            in_channels: The number of input channels for the block.
            out_channels: The number of output channels for the block.

        Returns:
            An nn.Sequential module comprising the encoder block layers.
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def decoder_block(self, in_channels, out_channels, dropout_rate):
        """Creates a decoder block with Convolution, Batch Normalization, ReLU activation, Dropout, and Upsampling.

        Args:
            in_channels: The number of input channels for the block.
            out_channels: The number of output channels for the block.
            dropout_rate: The dropout rate to use after convolutional layers.

        Returns:
            An nn.Sequential module comprising the decoder block layers.
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )

    def forward(self, x):
        """Defines the forward pass of the BasicUNet with skip connections.

        Args:
            x: The input tensor.

        Returns:
            The output tensor after passing through the U-Net.
        """
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)

        # Bottleneck
        b = self.bottleneck(e4)

        # Decoder with skip connections
        d1 = self.dec1(torch.cat((e4, b), dim=1))
        d2 = self.dec2(torch.cat((e3, d1), dim=1))
        d3 = self.dec3(torch.cat((e2, d2), dim=1))
        d4 = self.dec4(torch.cat((e1, d3), dim=1))

        # Final output
        out = self.out_conv(d4)

        return out