In [6]:
import os
import glob
import rasterio
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

In [7]:
# This is the function we built in Week 1. We'll use it again.
def load_and_stack_by_id(base_path, label_filename):
    """Loads a single 5-band image and its mask."""
    try:
        identifier = f"_{os.path.basename(label_filename).split('_')[-2]}_{os.path.basename(label_filename).split('_')[-1]}"
    except IndexError:
        return None, None

    band_folders = ['Band1', 'Band2', 'Band3', 'Band4', 'Band5']
    layer_paths = [glob.glob(os.path.join(base_path, f, f'*{identifier}'))[0] for f in band_folders]
    
    image_layers = [rasterio.open(path).read(1) for path in layer_paths]
    stacked_image = np.stack(image_layers, axis=0)
    
    mask_path = os.path.join(base_path, 'label', label_filename)
    with rasterio.open(mask_path) as src:
        mask = src.read(1)
        
    return stacked_image, mask

print("Imports and helper function are ready.")

Imports and helper function are ready.


In [8]:
# Creating the Custom GlacierDataset Class ---

class GlacierDataset(Dataset):
    def __init__(self, base_path, image_ids):
        """
        This is the constructor. It runs only once when we create the dataset.
        It sets up the list of files to use.
        """
        super().__init__()
        self.base_path = base_path
        self.image_ids = image_ids # This will be a list of label filenames
        print(f"Dataset created with {len(self.image_ids)} samples.")

    def __len__(self):
        """
        This method simply returns the total number of samples in the dataset.
        PyTorch uses this to know how big the dataset is.
        """
        return len(self.image_ids)

    def __getitem__(self, idx):
        """
        This is the most important method. PyTorch calls this to get a SINGLE sample.
        'idx' is the index of the sample to fetch (e.g., 0 for the first sample).
        """
        # Get the filename for the requested sample
        label_filename = self.image_ids[idx]
        
        # Use our helper function to load the image and mask
        image, mask = load_and_stack_by_id(self.base_path, label_filename)
        
        # Convert the NumPy arrays to PyTorch Tensors
        # We also change the data type to float, which is what neural networks expect.
        image_tensor = torch.from_numpy(image.astype(np.float32))
        mask_tensor = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)
        return image_tensor, mask_tensor

print("GlacierDataset class is defined.")

GlacierDataset class is defined.


In [9]:
# Split Data and Create DataLoaders ---

TRAIN_DATA_PATH = 'D:/GlacierHack_practice/train'

# Get all 25 label filenames
all_label_filenames = sorted(os.listdir(os.path.join(TRAIN_DATA_PATH, 'label')))

# Simple 80/20 split: 20 for training, 5 for validation
train_ids = all_label_filenames[:20]
val_ids = all_label_filenames[20:]

# Create an instance of our dataset for training
train_dataset = GlacierDataset(base_path=TRAIN_DATA_PATH, image_ids=train_ids)

# Create an instance for validation
val_dataset = GlacierDataset(base_path=TRAIN_DATA_PATH, image_ids=val_ids)

# Now, create the DataLoaders
train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=False)

print("DataLoaders are ready.")

Dataset created with 20 samples.
Dataset created with 5 samples.
DataLoaders are ready.


In [10]:
# Verify the DataLoader ---

print("\n--- Verifying the DataLoader ---")
# 'next(iter(loader))' is how you get one batch of data
images_batch, masks_batch = next(iter(train_loader))

# Let's check the shape. It should be (batch_size, channels, height, width)
print(f"Shape of one batch of images: {images_batch.shape}")
print(f"Shape of one batch of masks: {masks_batch.shape}")


--- Verifying the DataLoader ---
Shape of one batch of images: torch.Size([4, 5, 512, 512])
Shape of one batch of masks: torch.Size([4, 1, 512, 512])
