 PV Segmentation with PyTorch Lightning & segmentation-models-pytorch

This notebook provides our initial approach for training and evaluating deep learning models baselines for solar photovoltaic (PV) panel segmentation from satellite imagery. It leverages the power and flexibility of `PyTorch Lightning` for streamlined training and `segmentation-models-pytorch` (SMP) for easy access to a wide variety of cutting-edge model architectures.

**Key Goals:**
1.  *Data Preparation:* Set up a PyTorch `Lightning DataModule` to efficiently load and preprocess image patches and their corresponding masks.
2.  *Model Definition:* Define a reusable `PyTorch LightningModule` that can accommodate various segmentation architectures from SMP. 
3.  *Training:* Execute training loops, leveraging PyTorch Lightning's features for hardware acceleration (including Apple Silicon *MPS*), logging, and checkpointing with wandb.
4.  *Evaluation:* Assess model performance using relevant segmentation metrics and visualize predictions, ground truth, and sample from lower-res STAC imagery using `cubo`.

**Assumptions:**
*   Image patches and their corresponding binary masks are assumed to be pre-prepared and stored in specified directories.
*   The notebook is designed to be adaptable for different datasets, such as those derived from Maxar imagery with YOLO labels converted to masks, or other datasets providing pixel-coordinate labels.

This notebook will demonstrate the advantages of using PyTorch Lightning for organizing code and simplifying complex training workflows, and how `segmentation-models-pytorch` allows for rapid experimentation with different model backbones and architectures.

## 2. Research Problem & Motivation

### The Challenge of PV Segmentation

As the world transitions to renewable energy sources, solar photovoltaic (PV) installations are growing exponentially worldwide. Accurate mapping and monitoring of these installations is crucial for energy planning, grid management, carbon accounting, and sustainable development. However, traditional methods of tracking PV installations rely on incomplete permit data, manual surveys, or voluntary reporting—all of which present significant gaps in coverage and accuracy.

Satellite imagery offers a promising solution for automated detection of PV installations at regional and global scales. Yet, several challenges make this a non-trivial computer vision problem:

1. **Multi-scale challenge**: PV installations vary dramatically in size, from small residential rooftop panels (a few m²) to utility-scale solar farms (several km²)
2. **Visual variability**: PV panels appear differently depending on panel type, orientation, age, viewing angle, and illumination conditions
3. **Resolution trade-offs**: As demonstrated in Clark et al.'s study (2023), detection performance is strongly affected by image resolution, creating a compromise between coverage area and detection accuracy
4. **Class imbalance**: PV installations typically occupy a small fraction of any given geographic area, creating extreme class imbalance in training data
5. **Data scarcity**: High-quality labeled datasets for training are limited and geographically biased toward certain regions



### Research Questions

This project addresses the following key questions:

1. **How can we leverage state-of-the-art deep learning architectures to improve PV segmentation accuracy across diverse geographic regions including those with sparse data?**

2. **Can PyTorch Lightning's framework enable more efficient experimentation across multiple model architectures to identify optimal approaches for this domain-specific problem?**

3. **What combination of data augmentation strategies, model architectures, loss functions, and training approaches best addresses the unique challenges of PV segmentation?**

4. **How can we optimize models to work effectively across different spatial resolutions while maximizing the area that can be covered in operational settings?**

### Global Significance

Accurate mapping of solar PV installations has far-reaching implications:

- **Energy transition monitoring**: Tracking actual deployment rates of solar PV against climate targets
- **Grid integration**: Supporting power system planning by precisely locating distributed energy resources
- **Environmental impact assessment**: Understanding land use changes and habitat effects of renewable energy development
- **Socioeconomic analysis**: Studying adoption patterns across different communities to inform equitable energy transition policies
- **Sustainable Development Goals**: Contributing directly to SDG 7 (Affordable and Clean Energy) and SDG 13 (Climate Action)

In their seminal work with very high resolution (VHR) satellite imagery *(< 1 meter/pixel*), Cecilia Clark and Fabio Pacifici (2023) demonstrated that resolution significantly impacts detection performance, with our employer, Maxar Intelligence's, ["HD Technology" product](https://blog.maxar.com/tech-and-tradecraft/2022/maxars-hd-technology-provides-measurable-improvements-in-machine-learning-applications) (proprietary upscaling algorithm capable of simulating 15.5cm GSD) delivering substantially better results than native resolution (31cm) imagery. This **resolution-performance trade-off** informs our approach to developing models that can work effectively across varying image resolutions. They summarize the challenges succinctly below: 

**"Residential solar panels are considered small, weak targets even in VHR satellite imagery due to the average number of pixels per object, variation among  
objects, and complex context**. *Existing satellite imagery datasets often include large-scale, or non-residential, solar panel annotations* due to resolution  
of the imagery and therefore ability to detect small objects. There are available datasets of VHR imagery to support accurate detection
of small-scale and residential installations, but **the imagery is generally sourced from aerial platforms.**" 

By developing improved segmentation techniques, this research contributes to the broader goal of creating comprehensive, accurate, and timely inventories of global PV installations—a critical capability for managing the ongoing energy transition.

<figure style="text-align: center">
<img src="report/assets/figures/Munich_2021-06-18_WV03_HD_16x9.jpg" style="width:70%; height:auto;">
<figcaption align = "center"> 31cm native resolution vs simulated "15.5"cm spatial resolution </figcaption>
</figure>

## The Power of PyTorch Lightning

PyTorch Lightning is a lightweight PyTorch wrapper that significantly simplifies the process of training deep learning models. It provides a structured framework that abstracts away much of the boilerplate code typically associated with PyTorch training loops, allowing researchers and developers to focus more on the model architecture and data.

**Key Advantages of PyTorch Lightning:**

*   **Reduced Boilerplate:** Lightning handles the engineering aspects of training, such as the training loop, validation loop, and test loop. This means you write less code for common tasks.
*   **Organized Code:** It promotes a clean and organized code structure through its core components: `LightningModule` and `LightningDataModule`.
    *   The `LightningModule` encapsulates all model-related code (architecture, optimizers, training steps, validation steps, etc.).
    *   The `LightningDataModule` handles all data-related operations (data loading, transformations, splitting, batching).
*   **Simplified Training & Iteration:** With the boilerplate handled, iterating on different model architectures or hyperparameters becomes much faster and more straightforward.
*   **Hardware Agnostic:** Lightning makes it easy to train models on CPUs, GPUs (single or multiple), and TPUs with minimal code changes. You can specify the `accelerator` (e.g., *"gpu"*, **"mps"**, *"tpu"*, "cpu") and `devices` (e.g., number of GPUs) directly in the `Trainer`.
*   **Scalability:** It seamlessly supports distributed training (multi-GPU, multi-node) and mixed-precision training (`precision='16-mixed'`), which are crucial for training large models or large datasets.
*   **Reproducibility:** By organizing code and managing training details, Lightning helps in creating more reproducible experiments.
*   **Callbacks & Loggers:** It has a rich ecosystem of callbacks (for checkpointing, early stopping, learning rate monitoring, etc.) and loggers (TensorBoard, CSVLogger, etc.) that integrate easily into the training process.

As highlighted in the [PyTorch Lightning tutorial by DataCamp](https://www.datacamp.com/tutorial/pytorch-lightning-tutorial), its broad yet useful abstractions allow for quick training and iteration on multiple model architectures and facilitate scaling to multi-GPU or cloud environments. This notebook will leverage these features to efficiently train our PV segmentation models.

<figure style="text-align: center">
<img src="report/assets/figures/xkcd_python.png" style="width:60%; height:auto;">
<figcaption align = "center"> Illustration of what modern Python DL workflows can feel like</figcaption>
</figure>

In [None]:
# %% --- 2. Configuration ---

# --- Data Parameters ---
# These paths point to pre-prepared image patches and their corresponding masks
# For initial testing with datasets like Maxar's (Clark et al.) or Jiang et al.,
# ensure you have converted their pixel-coordinate labels into raster mask images.
MASK_DIR = Path('data/maxar_sample_masks_native/') # <<< --- UPDATE (e.g., where you save generated masks)
IMAGE_PATCH_DIR = Path('data/maxar_sample_chips_native/') # <<< --- UPDATE (e.g., where Maxar image chips are)
# LABEL_FILE_PATH is not directly used by DataModule if image/mask paths are directly globbed,
# but can be useful for cross-referencing or generating file lists.
# LABEL_FILE_PATH = 'path/to/your/pv_labels.gpkg'

PATCH_SIZE_PIXELS = 256 # Should match your prepared image/mask chip size
NUM_WORKERS = os.cpu_count() // 2 if os.cpu_count() else 0 # For DataLoaders

# --- Model & Training Hyperparameters (Common) ---
IN_CHANNELS = 3 # RGB. Change to 4 if using RGB+NIR
NUM_CLASSES = 1 # Binary segmentation (PV vs background)
TARGET_ACTIVATION = 'sigmoid' # Output activation for the model

LEARNING_RATE = 1e-4
BATCH_SIZE = 8 # Adjust based on GPU/MPS memory
NUM_EPOCHS = 5 # Start with very few epochs for initial weekend testing
VAL_SPLIT_RATIO = 0.2

# --- MPS/GPU Configuration ---
# Check for MPS availability (Apple Silicon GPU)
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    ACCELERATOR = "mps"
    DEVICES = 1
    PRECISION_TRAINER = "32" # MPS generally prefers 32-bit for stability, though 16-mixed might work for some ops
    print("MPS (Apple Silicon GPU) backend is available and will be used.")
elif torch.cuda.is_available():
    ACCELERATOR = "gpu"
    DEVICES = 1 # Or specify number of GPUs [0, 1] or "auto"
    PRECISION_TRAINER = '16-mixed'
    print("CUDA GPU is available and will be used.")
else:
    ACCELERATOR = "cpu"
    DEVICES = 1
    PRECISION_TRAINER = "32"
    print("No GPU or MPS found. CPU will be used (training will be slow).")


# --- Define Architectures and Encoders to Test ---
# Prioritizing lighter versions for initial testing
MODEL_CONFIGURATIONS = {
    "UnetPlusPlus_ResNet18": {"arch": "UnetPlusPlus", "encoder": "resnet18"},
    "DeepLabV3Plus_MobileNetV2": {"arch": "DeepLabV3Plus", "encoder": "mobilenet_v2"},
    "FPN_EfficientNetB0": {"arch": "FPN", "encoder": "efficientnet-b0"},
    "MAnet_ResNet18": {"arch": "MAnet", "encoder": "resnet18"},
    # DPT and Segformer are generally heavier, can be uncommented later
    # "DPT_Hybrid_Tiny": {"arch": "DPT", "encoder": "vit_tiny_patch16_224"}, # Check SMP for exact ViT encoder names for DPT
    # "Segformer_MiTB0": {"arch": "Segformer", "encoder": "mit_b0"},
}
ENCODER_WEIGHTS = 'imagenet'

## The Composability of the PyTorch Ecosystem for Data Handling

The PyTorch ecosystem offers a highly composable and flexible set of tools for data loading and preprocessing, which are essential for any deep learning pipeline. Key components include `Dataset`, `DataLoader`, and various transformation libraries.

**Core Components:**

*   **`torch.utils.data.Dataset`:** This is an abstract class representing a dataset. To create a custom dataset, you typically inherit from `Dataset` and override two methods:
    *   `__len__(self)`: Should return the size of the dataset.
    *   `__getitem__(self, idx)`: Should return the sample (e.g., an image and its corresponding mask) at the given index `idx`. This is where you load and preprocess individual data points.

*   **`torch.utils.data.DataLoader`:** This utility wraps an iterable around the `Dataset` to enable easy access to the samples. It handles many crucial aspects of data loading efficiently:
    *   **Batching:** Groups multiple samples into batches.
    *   **Shuffling:** Randomly shuffles the data at every epoch to prevent model bias.
    *   **Parallel Loading:** Uses multiple worker processes (`num_workers`) to load data in parallel, which can significantly speed up training by preventing the GPU from waiting for data.
    *   **Memory Pinning (`pin_memory`):** When using GPUs, setting `pin_memory=True` can speed up data transfer from CPU to GPU memory.

*   **Transformation Libraries (e.g., `torchvision.transforms`, `albumentations`):**
    *   **`torchvision.transforms`:** Provides common image transformations (resizing, cropping, normalization, conversion to tensor, etc.). These are often composed together using `transforms.Compose`.
    *   **`albumentations`:** A powerful library specifically designed for image augmentation. It offers a wide variety of augmentations (flips, rotations, color adjustments, noise, blurs, etc.) and is highly optimized for performance. It integrates well with PyTorch and other frameworks.

**Composability in Action:**

These components are designed to work together seamlessly. A typical workflow involves:
1.  Creating a custom `Dataset` class to load and apply initial transformations to individual image-mask pairs.
2.  Wrapping this `Dataset` instance with a `DataLoader` to manage batching, shuffling, and parallel loading.
3.  The `DataLoader` then provides an iterator that yields batches of data (images and masks) ready to be fed into the model during training or evaluation.

This modular approach, as generally seen in the PyTorch world and highlighted in guides like the [PyTorch Segmentation Models practical guide](https://medium.com/@heyamit10/pytorch-segmentation-models-a-practical-guide-5bf973a32e30), makes the data pipeline flexible, maintainable, and efficient. In this notebook, we use `PVSegmentationDataset` (a custom `Dataset`) and `PVSegmentationDataModule` (which internally uses `DataLoader` and `transforms`) to manage our data.

# PV Segmentation with Torch Lightning & SMP

This notebook provides a framework for training and evaluating deep learning models for solar photovoltaic (PV) panel segmentation from satellite imagery. It leverages the power and flexibility of PyTorch Lightning for streamlined training and `segmentation-models-pytorch` (SMP) for easy access to a wide variety of cutting-edge model architectures.

This notebook will demonstrate the benefits of using PyTorch Lightning for organizing code and simplifying complex training workflows, and leveraging `segmentation-models-pytorch` (SMP) allows for rapid experimentation with different model backbones and architectures.

## Datasets Overview

The training and evaluation of PV segmentation models rely on diverse, publicly available datasets. Many of these datasets are located in [Zenodo](https://zenodo.org/), a general-purpose open-access repository developed under the European OpenAIRE program and operated by CERN. Others are hosted in figshare, a web-based platform for sharing research data and other types of content. The rest are hosted in GitHub repositories or other open-access data platforms.

The dataset labels are available in a variety of formats, including CSV, GeoJSON, GeoPackage, ESRI shapefiles, raw raster masks, and GeoParquet. For this notebook, we assume that these datasets have been preprocessed into image patches and corresponding *raster* segmentation masks.

Here is a list of some prominent Solar Panel dataset publications, their first authors, DOI links, and approximate number of labels, which can be sources for preparing data for this notebook:

-   **"Distributed solar photovoltaic array location and extent dataset for remote sensing object identification"** - K. Bradbury, 2016 | [paper DOI](https://doi.org/10.1038/sdata.2016.106) | [dataset DOI](https://doi.org/10.6084/m9.figshare.3385780.v4) | polygon annotations for 19,433 PV modules in 4 cities in California, USA
-   **"A solar panel dataset of very high resolution satellite imagery to support the Sustainable Development Goals"** - C. Clark et al, 2023 | [paper DOI](https://doi.org/10.1038/s41597-023-02539-8) | [dataset DOI](https://doi.org/10.6084/m9.figshare.22081091.v3) | 2,542 object labels (per spatial resolution)
-   "A harmonised, high-coverage, open dataset of solar photovoltaic installations in the UK" - D. Stowell et al, 2020** | [paper DOI](https://doi.org/10.1038/s41597-020-00739-0) | [dataset DOI](https://zenodo.org/records/4059881) | 265,418 data points (over 255,000 are stand-alone installations, 1067 solar farms, and rest are subcomponents within solar farms)
-   "Georectified polygon database of ground-mounted large-scale solar photovoltaic sites in the United States" - K. Sydny, 2023 | [paper DOI](https://doi.org/10.1038/s41597-023-02644-8) | [dataset DOI](https://www.sciencebase.gov/catalog/item/6671c479d34e84915adb7536) | 4186 data points
-   "Vectorized solar photovoltaic installation dataset across China in 2015 and 2020" - J. Liu et al, 2024 | [paper DOI](https://doi.org/10.1038/s41597-024-04356-z) | [dataset link](https://github.com/qingfengxitu/ChinaPV) | 3,356 PV labels (inspect quality!)
-   *"Multi-resolution dataset for photovoltaic panel segmentation from satellite and aerial imagery"* - H. Jiang, 2021 | [paper DOI](https://doi.org/10.5194/essd-13-5389-2021) | [dataset DOI](https://doi.org/10.5281/zenodo.5171712) | 3,716 samples of PV data points
- **"A crowdsourced dataset of aerial images with annotated solar photovoltaic arrays and installation metadata"** - G. Kasmi, 2023 | [paper DOI](https://doi.org/10.1038/s41597-023-01951-4) | [dataset DOI](https://doi.org/10.5281/zenodo.6865878) | > 28K points of PV installations; 13K+ segmentation masks for PV arrays; metadata for 8K+ installations
-   "An Artificial Intelligence Dataset for Solar Energy Locations in India" - A. Ortiz, 2022 | [paper DOI](https://doi.org/10.1038/s41597-022-01499-9) | [dataset link 1](https://researchlabwuopendata.blob.core.windows.net/solar-farms/solar_farms_india_2021.geojson) or [dataset link 2](https://raw.githubusercontent.com/microsoft/solar-farms-mapping/refs/heads/main/data/solar_farms_india_2021_merged_simplified.geojson) | 117 geo-referenced points of solar installations across India
- **"GloSoFarID: Global multispectral dataset for Solar Farm IDentification in satellite imagery"** - Z. Yang, 2024** | [paper DOI](https://doi.org/10.48550/arXiv.2404.05180) | [dataset DOI](https://github.com/yzyly1992/GloSoFarID/tree/main/data_coordinates) | 6,793 PV samples across 3 years (double counting of samples)
-   "A global inventory of photovoltaic solar energy generating units" - L. Kruitwagen et al, 2021 | [paper DOI](https://doi.org/10.1038/s41586-021-03957-7) | [dataset DOI](https://doi.org/10.5281/zenodo.5005867) | 50,426 for training, cross-validation, and testing; 68,661 predicted polygon labels
-   "Harmonised global datasets of wind and solar farm locations and power" - S. Dunnett et al, 2020 | [paper DOI](https://doi.org/10.1038/s41597-020-0469-8) | [dataset DOI](https://doi.org/10.6084/m9.figshare.11310269.v6) | 35272 PV installations

**Key Goals for this Notebook:**
1.  **Data Preparation:** Set up a PyTorch Lightning DataModule to efficiently load and preprocess image patches and their corresponding masks derived from the datasets listed above (or similar).
2.  **Model Definition:** Define a reusable PyTorch LightningModule that can accommodate various segmentation architectures from `segmentation-models-pytorch`.
3.  **Training:** Execute training loops, leveraging PyTorch Lightning's features for hardware acceleration (including Apple Silicon MPS), logging, and checkpointing.
4.  **Evaluation:** Assess model performance using relevant metrics and visualize predictions.

**Assumptions for this Notebook:**
*   Image patches and their corresponding binary masks are assumed to be pre-prepared (e.g., using the `fetch-pv-datasets-ESDA.ipynb` notebook or similar methods) and stored in specified directories.
*   The notebook is designed to be adaptable for different datasets, such as those derived from Maxar imagery with YOLO labels converted to masks, or other datasets providing pixel-coordinate labels that have been rasterized to masks.


In [None]:
# -*- coding: utf-8 -*-
# %% --- 1. Setup & Imports ---
# Standard libraries
import os
import time
from pathlib import Path
import warnings

# Geospatial libraries (primarily for data prep, less so for core training loop if data is pre-processed)
# import geopandas as gpd # Keep if your label_gpkg_path is used to derive file lists
import xarray as xr
import rasterio # For reading image patches
# import pystac_client
# from shapely.geometry import Point, box

# ML/DL libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryJaccardIndex, BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# --- Suppress specific warnings ---
warnings.filterwarnings("ignore", category=UserWarning, module="rasterio")
warnings.filterwarnings("ignore", message=".* Shapely GEOS version .*")

In [None]:
# %% --- 3. PyTorch Lightning DataModule ---

class PVSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None, mask_transform=None,
                 target_size=(PATCH_SIZE_PIXELS, PATCH_SIZE_PIXELS), in_channels=IN_CHANNELS):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.mask_transform = mask_transform
        self.target_size = target_size
        self.in_channels = in_channels
        assert len(self.image_paths) == len(self.mask_paths), "Mismatch between number of images and masks"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            # Load image using PIL, then convert to numpy for consistency if needed by rasterio-like processing
            # Or directly use PIL if images are standard formats like PNG/JPG
            # If images are GeoTIFFs:
            if str(img_path).lower().endswith(('.tif', '.tiff')):
                with rasterio.open(img_path) as src:
                    # Read specified channels
                    num_bands_to_read = self.in_channels
                    # Ensure we don't try to read more bands than available
                    if src.count < num_bands_to_read:
                        print(f"Warning: Image {img_path.name} has {src.count} bands, but {num_bands_to_read} were requested. Reading available bands.")
                        num_bands_to_read = src.count

                    image_data = src.read(list(range(1, num_bands_to_read + 1))) # Bands are 1-indexed

                    # If fewer channels read than expected, pad with zeros (or handle differently)
                    if image_data.shape[0] < self.in_channels:
                        padding = np.zeros((self.in_channels - image_data.shape[0], src.height, src.width), dtype=image_data.dtype)
                        image_data = np.concatenate((image_data, padding), axis=0)

                    # Convert to HWC for PIL
                    image = np.moveaxis(image_data, 0, -1).astype(np.uint8) # Assuming 8-bit after scaling
                    image_pil = Image.fromarray(image)


            else: # Assume PNG, JPG etc.
                image_pil = Image.open(img_path)
                if self.in_channels == 3 and image_pil.mode != 'RGB':
                    image_pil = image_pil.convert('RGB')
                elif self.in_channels == 4 and image_pil.mode != 'RGBA': # Example for RGBA
                    image_pil = image_pil.convert('RGBA')
                elif self.in_channels == 1 and image_pil.mode != 'L':
                    image_pil = image_pil.convert('L')


            if image_pil.size != self.target_size:
                image_pil = image_pil.resize(self.target_size, Image.BILINEAR)

        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Returning zeros.")
            # Create zero image with correct number of channels
            zero_data = np.zeros((self.target_size[1], self.target_size[0], self.in_channels), dtype=np.uint8)
            image_pil = Image.fromarray(zero_data)


        try: # Load Mask
            mask = Image.open(mask_path).convert('L') # Grayscale
            if mask.size != self.target_size:
                mask = mask.resize(self.target_size, Image.NEAREST) # Use NEAREST for masks
            mask_np = np.array(mask)
            mask_np = (mask_np > 0).astype(np.float32) # Ensure binary 0 or 1
            # mask_np = np.expand_dims(mask_np, axis=-1) # H, W, C (C=1) # Not needed if ToTensor adds channel
            mask_pil = Image.fromarray(mask_np, mode='F') # Mode 'F' for float32
        except Exception as e:
            print(f"Error loading mask {mask_path}: {e}. Returning zeros.")
            mask_pil = Image.fromarray(np.zeros(self.target_size, dtype=np.float32), mode='F')


        if self.transform: image_tensor = self.transform(image_pil)
        else: image_tensor = transforms.ToTensor()(image_pil)

        if self.mask_transform: mask_tensor = self.mask_transform(mask_pil)
        else: mask_tensor = transforms.ToTensor()(mask_pil) # ToTensor on (H,W) PIL gives (1,H,W)

        return image_tensor, mask_tensor


class PVSegmentationDataModule(pl.LightningDataModule):
    def __init__(self, image_dir: str, mask_dir: str,
                 batch_size: int = 32, num_workers: int = 0,
                 val_split_ratio: float = 0.2, seed: int = 42,
                 patch_size: int = PATCH_SIZE_PIXELS, in_channels: int = IN_CHANNELS): # Use global defaults
        super().__init__()
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split_ratio = val_split_ratio
        self.seed = seed
        self.patch_size = (patch_size, patch_size) # Target size as tuple
        self.in_channels = in_channels

        self.imagenet_mean = [0.485, 0.456, 0.406]
        self.imagenet_std = [0.229, 0.224, 0.225]
        if self.in_channels == 4:
            self.imagenet_mean.append(0.406) # Placeholder for NIR mean
            self.imagenet_std.append(0.225)  # Placeholder for NIR std
        elif self.in_channels == 1: # For grayscale
            self.imagenet_mean = [0.449] # Approx ImageNet grayscale mean
            self.imagenet_std = [0.226]  # Approx ImageNet grayscale std


        self.train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])
        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])
        self.mask_transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def prepare_data(self):
        if not self.image_dir.exists() or not self.mask_dir.exists():
            raise FileNotFoundError(f"Image dir ({self.image_dir}) or Mask dir ({self.mask_dir}) not found.")
        print(f"Data located in {self.image_dir} and {self.mask_dir}")

    def setup(self, stage: str = None):
        img_extensions = ['*.tif', '*.png', '*.jpg', '*.jpeg']
        all_image_paths = []
        for ext in img_extensions:
            all_image_paths.extend(list(self.image_dir.glob(ext)))
        all_image_paths = sorted([p for p in all_image_paths if p.is_file() and p.stat().st_size > 100]) # Basic check

        all_mask_paths_final = []
        valid_image_paths_final = []
        missing_masks_count = 0

        for img_path in all_image_paths:
            mask_found = False
            for mask_ext in ['.png', '.tif', '.jpg', '.jpeg']: # Check common mask extensions
                potential_mask_path = self.mask_dir / (img_path.stem + mask_ext)
                if potential_mask_path.exists() and potential_mask_path.is_file() and potential_mask_path.stat().st_size > 0:
                    all_mask_paths_final.append(potential_mask_path)
                    valid_image_paths_final.append(img_path)
                    mask_found = True
                    break
            if not mask_found:
                missing_masks_count += 1

        if missing_masks_count > 0:
             print(f"Warning: Skipped {missing_masks_count} images due to missing or invalid masks.")
        if not valid_image_paths_final:
            raise ValueError("No valid image/mask pairs found. Check data directories and file names/extensions.")

        dataset_size = len(valid_image_paths_final)
        val_size = int(dataset_size * self.val_split_ratio)
        train_size = dataset_size - val_size

        indices = list(range(dataset_size))
        if train_size > 0 and val_size > 0 :
            train_indices, val_indices = random_split(indices, [train_size, val_size],
                                                  generator=torch.Generator().manual_seed(self.seed))
        elif train_size > 0 : # Use all for training if val_size is 0
            print("Warning: Validation split resulted in 0 validation samples. Using all data for training.")
            train_indices = indices
            val_indices = []
        else:
            raise ValueError(f"Train size is {train_size} and val size is {val_size}. Cannot create datasets.")


        train_img_p = [valid_image_paths_final[i] for i in train_indices]
        train_msk_p = [all_mask_paths_final[i] for i in train_indices]
        val_img_p = [valid_image_paths_final[i] for i in val_indices] if val_indices else []
        val_msk_p = [all_mask_paths_final[i] for i in val_indices] if val_indices else []


        if stage == 'fit' or stage is None:
            self.train_dataset = PVSegmentationDataset(train_img_p, train_msk_p,
                                                       transform=self.train_transform,
                                                       mask_transform=self.mask_transform,
                                                       target_size=self.patch_size, in_channels=self.in_channels)
            if val_img_p: # Only create val_dataset if there are validation samples
                self.val_dataset = PVSegmentationDataset(val_img_p, val_msk_p,
                                                        transform=self.val_transform,
                                                        mask_transform=self.mask_transform,
                                                        target_size=self.patch_size, in_channels=self.in_channels)
                print(f"Setup complete. Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}")
            else:
                self.val_dataset = None # Explicitly set to None
                print(f"Setup complete. Train: {len(self.train_dataset)}, Val: No validation set.")


    def train_dataloader(self):
        if hasattr(self, 'train_dataset') and self.train_dataset:
            return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
                              num_workers=self.num_workers, pin_memory=(ACCELERATOR != "cpu"), persistent_workers=self.num_workers > 0)
        return None

    def val_dataloader(self):
        if hasattr(self, 'val_dataset') and self.val_dataset:
            return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
                              num_workers=self.num_workers, pin_memory=(ACCELERATOR != "cpu"), persistent_workers=self.num_workers > 0)
        return None

## `segmentation-models-pytorch` (SMP): A Rich Toolkit for Segmentation

`segmentation-models-pytorch` (SMP) is a Python library built on PyTorch that provides a high-level API for image segmentation tasks. It simplifies the implementation of various state-of-the-art segmentation architectures and allows for easy integration of pre-trained encoders.

**Key Features of SMP:**

*   **Variety of Architectures:** SMP offers a collection of popular and effective segmentation architectures, including:
    *   Unet
    *   Unet++
    *   MAnet
    *   Linknet
    *   FPN (Feature Pyramid Network)
    *   PSPNet (Pyramid Scene Parsing Network)
    *   DeepLabV3 / DeepLabV3+
    *   PAN (Pyramid Attention Network)
    *   DPT (Dense Prediction Transformer)
    *   SegFormer
*   **Pre-trained Encoders:** One of the most powerful features of SMP is its seamless integration with a vast number of pre-trained encoders (backbones). This is largely facilitated by its use of libraries like `timm` (PyTorch Image Models by Ross Wightman).
    *   This allows you to use encoders like ResNets (resnet18, resnet34, resnet50, etc.), EfficientNets (efficientnet-b0 to b7), MobileNets, ViTs (Vision Transformers), and many others, often with weights pre-trained on ImageNet.
    *   Using pre-trained encoders can significantly speed up convergence and improve performance, especially when working with limited datasets.
*   **Ease of Use:** Creating a segmentation model is typically a one-liner: `smp.Unet(encoder_name='resnet34', encoder_weights='imagenet', in_channels=3, classes=1)`.
*   **Flexibility:** You can easily switch between different architectures and encoders to experiment and find the best combination for your specific task.
*   **Loss Functions and Metrics:** SMP also includes common loss functions (e.g., DiceLoss, JaccardLoss, FocalLoss) and metrics relevant to segmentation.

As mentioned in the [PyTorch Segmentation Models practical guide](https://medium.com/@heyamit10/pytorch-segmentation-models-a-practical-guide-5bf973a32e30), SMP's strength lies in providing ready-to-use segmentation models with a wide choice of decoders and a huge variety of pre-trained encoders from `timm`. This composability allows for rapid prototyping and benchmarking of different approaches.

In this notebook, we define a `PVSegmentationTask` (a `LightningModule`) that utilizes `smp.create_model` to dynamically build segmentation models based on the configurations specified, allowing us to easily test different architectures and encoders.

In [None]:
# %% --- 4. PyTorch Lightning Module ---
# PVSegmentationTask class definition (largely the same as previous version)
class PVSegmentationTask(pl.LightningModule):
    def __init__(self, model_arch: str, encoder_name: str, encoder_weights: str,
                 in_channels: int, num_classes: int, activation: str,
                 learning_rate: float = 1e-4, loss_weights: tuple = (0.5, 0.5)):
        super().__init__()
        self.save_hyperparameters()

        self.model = smp.create_model(
            arch=self.hparams.model_arch,
            encoder_name=self.hparams.encoder_name,
            encoder_weights=self.hparams.encoder_weights,
            in_channels=self.hparams.in_channels,
            classes=self.hparams.num_classes,
            activation=self.hparams.activation
        )
        self.dice_loss = DiceLoss(mode='binary', from_logits=(self.hparams.activation is None))
        self.bce_loss = SoftBCEWithLogitsLoss()

        metrics_args = {"task": "binary", "threshold": 0.5} # num_classes not needed for binary task with single output
        metrics_collection = MetricCollection({
            'iou': BinaryJaccardIndex(**metrics_args), 'f1': BinaryF1Score(**metrics_args),
            'accuracy': BinaryAccuracy(**metrics_args), 'precision': BinaryPrecision(**metrics_args),
            'recall': BinaryRecall(**metrics_args),
        })
        self.train_metrics = metrics_collection.clone(prefix='train_')
        self.val_metrics = metrics_collection.clone(prefix='val_')

    def forward(self, x): return self.model(x)

    def _calculate_loss(self, y_pred, y_true):
        y_true = y_true.float()
        if self.hparams.activation is None:
             bce = self.bce_loss(y_pred, y_true); dice = self.dice_loss(y_pred, y_true)
        else:
             epsilon = 1e-7; y_pred_clamped = torch.clamp(y_pred, epsilon, 1.0 - epsilon)
             logits = torch.log(y_pred_clamped / (1.0 - y_pred_clamped)); bce = self.bce_loss(logits, y_true)
             dice = self.dice_loss(y_pred, y_true)
        return self.hparams.loss_weights[0] * bce + self.hparams.loss_weights[1] * dice

    def training_step(self, batch, batch_idx):
        x, y_true = batch; y_pred = self(x); loss = self._calculate_loss(y_pred, y_true)
        self.train_metrics.update(y_pred, y_true.int())
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_train_epoch_end(self):
        self.log_dict(self.train_metrics.compute(), logger=True); self.train_metrics.reset()

    def validation_step(self, batch, batch_idx):
        x, y_true = batch; y_pred = self(x); loss = self._calculate_loss(y_pred, y_true)
        self.val_metrics.update(y_pred, y_true.int())
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking: # Skip logging during sanity check
             metrics = self.val_metrics.compute()
             self.log_dict(metrics, logger=True)
             self.val_metrics.reset()


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_iou"}}

## Streamlined Training with PyTorch Lightning Trainer

Once the `LightningDataModule` (for data handling) and the `LightningModule` (for the model, optimizers, and training/validation logic) are defined, PyTorch Lightning makes the actual training process remarkably straightforward using its `Trainer` class.

**The `Trainer` Class:**

The `Trainer` automates most of the training loop, including:
*   Iterating over epochs and batches.
*   Calling the appropriate methods in your `LightningModule` (e.g., `training_step`, `validation_step`).
*   Performing optimizer steps and learning rate scheduler adjustments.
*   Moving data to the correct device (CPU/GPU/TPU).
*   Handling distributed training and mixed-precision if configured.

**Key `Trainer` Arguments Used in this Notebook:**

*   `accelerator`: Specifies the hardware to use (e.g., "mps", "gpu", "cpu").
*   `devices`: Specifies the number of devices or specific device IDs.
*   `max_epochs`: The maximum number of epochs to train for.
*   `logger`: Accepts one or more logger instances (e.g., `TensorBoardLogger`, `CSVLogger`) to record metrics and hyperparameters.
*   `callbacks`: A list of callback objects that can customize the training behavior at various points. Common callbacks include:
    *   `ModelCheckpoint`: Saves the model periodically, often based on a monitored metric (e.g., best validation IoU).
    *   `LearningRateMonitor`: Logs the learning rate at each epoch or step.
    *   `EarlyStopping`: Stops training if a monitored metric stops improving for a certain number of epochs (patience).
*   `precision`: Configures training precision (e.g., "32" for full precision, "16-mixed" for mixed-precision training).

**Initiating Training:**

Training is typically started with a single line: `trainer.fit(model, datamodule=data_module)`.

PyTorch Lightning's `Trainer` abstracts away the complexities of the training loop, allowing you to focus on the core components of your deep learning model and experiment more rapidly.

In [None]:
import traceback

# %% --- 5. Training Execution ---
data_module = PVSegmentationDataModule(
    image_dir=str(IMAGE_PATCH_DIR), mask_dir=str(MASK_DIR),
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
    val_split_ratio=VAL_SPLIT_RATIO, patch_size=PATCH_SIZE_PIXELS, in_channels=IN_CHANNELS
)
try:
    data_module.prepare_data()
    data_module.setup(stage='fit')

    # Visualize a sample batch from DataModule
    train_dl = data_module.train_dataloader()
    if train_dl and len(train_dl) > 0 :
        print("Visualizing a sample batch from DataModule's train_dataloader...")
        images, masks = next(iter(train_dl))
        def show_dm_batch(image_tensor, mask_tensor, num_samples=4, mean_val=None, std_val=None, in_channels=IN_CHANNELS):
            images_np = image_tensor[:num_samples].cpu().numpy()
            masks_np = mask_tensor[:num_samples].cpu().numpy()
            fig, axes = plt.subplots(num_samples, 2, figsize=(8, num_samples * 4))
            if num_samples == 1: axes = np.array([axes])
            for i in range(num_samples):
                img = images_np[i].transpose(1, 2, 0)
                if mean_val and std_val:
                    m, s = np.array(mean_val), np.array(std_val)
                    img = s * img + m
                img = np.clip(img, 0, 1)
                if img.shape[-1] == 1: img = img.squeeze(-1) # For grayscale display

                mask = masks_np[i].squeeze()
                axes[i, 0].imshow(img, cmap='gray' if img.ndim==2 else None); axes[i, 0].set_title("Image"); axes[i, 0].axis('off')
                axes[i, 1].imshow(mask, cmap='gray'); axes[i, 1].set_title("Mask"); axes[i, 1].axis('off')
            plt.tight_layout(); plt.show()
        show_dm_batch(images, masks, mean_val=data_module.imagenet_mean, std_val=data_module.imagenet_std)
    else:
        print("Train dataloader from DataModule is empty or not available.")
        data_module = None # Prevent training if data is not loaded

except Exception as e:
    print(f"Error setting up DataModule or visualizing batch: {e}")
    data_module = None

if data_module:
    for config_name, params in MODEL_CONFIGURATIONS.items():
        print(f"\n--- Training Model: {config_name} ---")
        lightning_model = PVSegmentationTask(
            model_arch=params['arch'], encoder_name=params['encoder'], encoder_weights=ENCODER_WEIGHTS,
            in_channels=IN_CHANNELS, num_classes=NUM_CLASSES, activation=TARGET_ACTIVATION,
            learning_rate=LEARNING_RATE
        )
        tb_logger = TensorBoardLogger("tb_logs", name=config_name)
        csv_logger = CSVLogger("csv_logs", name=config_name)
        checkpoint_cb = pl.callbacks.ModelCheckpoint(
            dirpath=f"checkpoints/{config_name}", filename="{epoch}-{val_iou:.4f}",
            monitor="val_iou", mode="max", save_top_k=1,
        )
        lr_monitor_cb = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
        early_stop_cb = pl.callbacks.EarlyStopping(monitor="val_iou", patience=5, verbose=True, mode="max")

        trainer = pl.Trainer(
            accelerator=ACCELERATOR, devices=DEVICES, max_epochs=NUM_EPOCHS,
            logger=[tb_logger, csv_logger], callbacks=[checkpoint_cb, lr_monitor_cb, early_stop_cb],
            precision=PRECISION_TRAINER,
            # strategy="ddp_find_unused_parameters_true" if ACCELERATOR=="gpu" and DEVICES > 1 else "auto"
        )
        try:
            print(f"Starting training for {config_name} with {ACCELERATOR}...")
            trainer.fit(lightning_model, datamodule=data_module)
            print(f"Training finished for {config_name}.")
            if checkpoint_cb.best_model_path:
                print(f"Best model for {config_name} saved at: {checkpoint_cb.best_model_path}")
        except Exception as e:
            print(f"Error during training of {config_name}: {e}")
            traceback.print_exc()
            continue
else:
    print("DataModule not initialized. Skipping training.")

## Model Evaluation and Prediction Visualization

After training, it's crucial to evaluate the model's performance on unseen data (typically the validation or a separate test set) and visualize its predictions to gain qualitative insights. This section demonstrates a basic approach to:

1.  **Loading the Best Model:** PyTorch Lightning's `ModelCheckpoint` callback saves the best performing model based on a monitored metric. We load this checkpoint for evaluation.
2.  **Making Predictions:** The loaded model is used to make predictions on a batch of data from the validation set.
3.  **Visualizing Results:** The original images, ground truth masks, and the model's predicted masks are displayed side-by-side for comparison.

This allows for a visual assessment of how well the model is segmenting the PV panels.

In [None]:
# %% --- 6. Evaluation & Visualization (Example) ---
if data_module and MODEL_CONFIGURATIONS:
    first_config_name = list(MODEL_CONFIGURATIONS.keys())[0]
    # Determine device for evaluation (mps, cuda, or cpu)
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        DEVICE_EVAL = torch.device("mps")
    elif torch.cuda.is_available():
        DEVICE_EVAL = torch.device("cuda")
    else:
        DEVICE_EVAL = torch.device("cpu")
    print(f"Using device for evaluation: {DEVICE_EVAL}")

    checkpoint_dir = Path(f"checkpoints/{first_config_name}/")
    if checkpoint_dir.exists():
        ckpt_files = sorted(list(checkpoint_dir.glob("*.ckpt")), key=os.path.getmtime, reverse=True)
        if ckpt_files:
            best_model_path_eval = ckpt_files[0] # Load the most recently saved best model
            print(f"\n--- Evaluating and Visualizing: {first_config_name} from {best_model_path_eval} ---")
            try:
                eval_model = PVSegmentationTask.load_from_checkpoint(best_model_path_eval, map_location=DEVICE_EVAL)
                eval_model.to(DEVICE_EVAL) # Ensure model is on the correct device
                eval_model.eval()
                val_loader_eval = data_module.val_dataloader()
                if val_loader_eval and len(val_loader_eval) > 0:
                    images_eval, masks_gt_eval = next(iter(val_loader_eval))
                    images_eval = images_eval.to(DEVICE_EVAL) # Move images to device

                    with torch.no_grad():
                        masks_pred_eval = eval_model(images_eval).cpu() # Move predictions to CPU for numpy/PIL

                    unnorm_transform = transforms.Compose([
                        transforms.Normalize(mean=[0.]*IN_CHANNELS, std=[1/s for s in data_module.imagenet_std]),
                        transforms.Normalize(mean=[-m for m in data_module.imagenet_mean], std=[1.]*IN_CHANNELS),
                        transforms.ToPILImage()
                    ])
                    masks_gt_pil = [transforms.ToPILImage()(m.cpu()) for m in masks_gt_eval]
                    if eval_model.hparams.activation is None: masks_pred_eval = torch.sigmoid(masks_pred_eval)
                    masks_pred_binary_pil = [transforms.ToPILImage()((m > 0.5).float().cpu()) for m in masks_pred_eval]

                    num_to_show = min(4, images_eval.size(0))
                    fig, axes = plt.subplots(num_to_show, 3, figsize=(12, num_to_show * 4))
                    if num_to_show == 1: axes = np.array([axes]) # Ensure axes is always 2D for consistent indexing
                    for i in range(num_to_show):
                        img_pil = unnorm_transform(images_eval[i].cpu()) # Move image to CPU before unnorm
                        axes[i, 0].imshow(img_pil); axes[i, 0].set_title("Image"); axes[i, 0].axis('off')
                        axes[i, 1].imshow(masks_gt_pil[i], cmap='gray'); axes[i, 1].set_title("Ground Truth"); axes[i, 1].axis('off')
                        axes[i, 2].imshow(masks_pred_binary_pil[i], cmap='gray'); axes[i, 2].set_title("Prediction"); axes[i, 2].axis('off')
                    plt.tight_layout(); plt.show()
                else: print("Validation dataloader not available or empty for visualization.")
            except Exception as e: print(f"Error during eval/viz of {first_config_name}: {e}"); import traceback; traceback.print_exc()
        else: print(f"No checkpoint file found in {checkpoint_dir} for {first_config_name}.")
    else: print(f"Checkpoint directory not found for {first_config_name}. Train first.")

## Next Steps and Further Experimentation

This notebook provides a foundational framework for PV panel segmentation. Here are some potential next steps and areas for further experimentation:

*   **Full Training Runs:** Increase `NUM_EPOCHS` for more comprehensive training.
*   **Hyperparameter Tuning:** Experiment with different learning rates, batch sizes, optimizer settings, and loss function weights.
*   **Explore More Architectures/Encoders:** Leverage the flexibility of `segmentation-models-pytorch` to try other model configurations available in `MODEL_CONFIGURATIONS` or add new ones.
*   **Data Augmentation:** Implement more sophisticated data augmentation techniques using `albumentations` within the `PVSegmentationDataset` or `PVSegmentationDataModule` to improve model generalization.
*   **Advanced Loss Functions:** Explore other loss functions or combinations suitable for imbalanced segmentation tasks.
*   **Test Set Evaluation:** Create a dedicated test set (if not already done) and evaluate the final model on it for an unbiased performance measure.
*   **Post-processing:** Implement post-processing steps (e.g., removing small predicted regions, morphological operations) to potentially improve segmentation quality.
*   **Larger Datasets:** Train on larger and more diverse datasets if available.
*   **Cross-Validation:** Implement k-fold cross-validation for more robust performance estimation.

Review the logs generated in the `tb_logs/` (TensorBoard) and `csv_logs/` directories, and inspect the saved model checkpoints in the `checkpoints/` directory to monitor training progress and select the best models.

## Interactive Visualization Slideshow

The cell below creates an interactive slideshow to display multiple screenshots without cluttering the notebook. This is particularly useful for showing a series of visualizations from the NYT article about clean energy or for comparing model results.

The slideshow has navigation buttons and automatically loads all image files from a specified directory.

In [None]:
# %% --- 7. Next Steps ---
print("\nNotebook execution finished.")
print("Next steps: Review logs in 'tb_logs/' and 'csv_logs/'. Check 'checkpoints/' for saved models.")
print("Consider increasing NUM_EPOCHS for full training runs.")

In [None]:
# %% --- 9. Interactive Visualization Slideshow ---

from IPython.display import HTML, Image, display
import os
from pathlib import Path
import ipywidgets as widgets
import matplotlib.pyplot as plt
import glob

def create_slideshow(image_dir="report/assets/visualizations", height=500):
    """Create an interactive slideshow from images in the specified directory.
    
    Args:
        image_dir: Path to directory containing the screenshots/images
        height: Height in pixels for the display area
    
    Returns:
        Interactive widget displaying the slideshow
    """
    # Create directory if it doesn't exist
    os.makedirs(image_dir, exist_ok=True)
    
    # Find all image files
    image_extensions = ["jpg", "jpeg", "png", "gif"]
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob.glob(os.path.join(image_dir, f"*.{ext}")))
        image_files.extend(glob.glob(os.path.join(image_dir, f"*.{ext.upper()}")))  # Include uppercase extensions
    
    image_files = sorted(image_files)  # Sort alphabetically
    
    if not image_files:
        print(f"No images found in {image_dir}")
        print(f"Please add your screenshots to the {image_dir} directory")
        print(f"Supported formats: {', '.join(image_extensions)}")
        return None
    
    # Create widgets
    slider = widgets.IntSlider(
        value=0,
        min=0, 
        max=len(image_files)-1, 
        step=1,
        description='Image:',
        continuous_update=False,
        layout=widgets.Layout(width='50%')
    )
    
    prev_button = widgets.Button(
        description='Previous',
        disabled=False,
        button_style='', 
        tooltip='Previous image',
        icon='arrow-left'
    )
    
    next_button = widgets.Button(
        description='Next',
        disabled=False,
        button_style='', 
        tooltip='Next image',
        icon='arrow-right'
    )
    
    image_widget = widgets.Image(
        layout=widgets.Layout(height=f"{height}px"),
    )
    
    title_widget = widgets.HTML(
        layout=widgets.Layout(height='auto')
    )

    # Define callback functions
    def on_slider_change(change):
        index = change['new']
        display_image(index)
    
    def on_prev_button_click(b):
        if slider.value > 0:
            slider.value -= 1
    
    def on_next_button_click(b):
        if slider.value < len(image_files) - 1:
            slider.value += 1
    
    def display_image(index):
        filename = image_files[index]
        with open(filename, 'rb') as f:
            image_widget.value = f.read()
        
        base_filename = os.path.basename(filename)
        title_widget.value = f"<div style='text-align: center; font-weight: bold;'>{base_filename} ({index + 1}/{len(image_files)})</div>"
    
    # Attach callbacks to widgets
    slider.observe(on_slider_change, names='value')
    prev_button.on_click(on_prev_button_click)
    next_button.on_click(on_next_button_click)
    
    # Display initial image
    display_image(0)
    
    # Arrange widgets
    button_box = widgets.HBox([prev_button, next_button])
    main_box = widgets.VBox([title_widget, image_widget, slider, button_box])
    
    return main_box

# Create and display the slideshow widget
slideshow = create_slideshow()

if slideshow:
    display(slideshow)
else:
    # Create directory structure if it doesn't exist
    viz_dir = "report/assets/visualizations"
    os.makedirs(viz_dir, exist_ok=True)
    
    # Display placeholder image with instructions
    plt.figure(figsize=(10, 6))
    plt.text(0.5, 0.5, f"Add your screenshots to:\n{os.path.abspath(viz_dir)}", 
             ha='center', va='center', fontsize=16, wrap=True)
    plt.axis('off')
    plt.title("Interactive Slideshow - Setup Instructions", fontsize=18)
    plt.tight_layout()
    plt.show()
    
    print("\nTo use the slideshow:")
    print(f"1. Add your screenshot images to: {os.path.abspath(viz_dir)}")
    print("2. Make sure they're in jpg, jpeg, png, or gif format")
    print("3. Re-run this cell to see the interactive slideshow")
    print("\nTip: You can change the images directory by modifying the argument to create_slideshow()")


In [None]:
# %% --- 8. Export to CoreML and Quantization ---

# Only run this cell after successful training and evaluation

if 'eval_model' not in locals() or eval_model is None:
    print("No model available for export. Please run the evaluation cell first.")
else:
    print("Preparing model for export to CoreML...")
    
    try:
        import coremltools as ct
        from coremltools.models.neural_network import quantization_utils
        
        # Move model to CPU for export
        eval_model.to('cpu')
        eval_model.eval()
        
        # Define input shape
        example_input = torch.rand(1, IN_CHANNELS, PATCH_SIZE_PIXELS, PATCH_SIZE_PIXELS)
        
        # First, export to TorchScript format
        scripted_model = torch.jit.trace(eval_model, example_input)
        
        # Convert to CoreML
        print("Converting model to CoreML format...")
        mlmodel = ct.convert(
            scripted_model,
            inputs=[ct.TensorType(name="input", shape=example_input.shape)],
            convert_to="mlprogram",  # Use the newer ML Program format for better performance
            # compute_units="ALL"  # Can be "ALL", "CPU_ONLY", "CPU_AND_GPU", "CPU_AND_NE" (Neural Engine)
        )
        
        # Set model metadata
        mlmodel.short_description = "PV Segmentation using PyTorch Lightning and SMP"
        mlmodel.input_description['input'] = "Input image (RGB or multi-channel)"
        mlmodel.output_description['output'] = "Segmentation mask for PV panels"
        
        # Save the model
        output_path = Path("exported_models") / f"{first_config_name}_coreml.mlpackage"
        os.makedirs(output_path.parent, exist_ok=True)
        mlmodel.save(str(output_path))
        print(f"Model successfully exported to: {output_path}")
        
        # ---------- QUANTIZATION OPTIONS (COMMENTED) ----------
        
        # # FP16 Quantization
        # print("\nCreating FP16 quantized model...")
        # mlmodel_fp16 = quantization_utils.quantize_weights(mlmodel, nbits=16)
        # mlmodel_fp16.save(str(output_path).replace(".mlpackage", "_fp16.mlpackage"))
        # print("FP16 model saved.")
        
        # # INT8 Quantization (more aggressive compression, may affect accuracy)
        # # print("\nCreating INT8 quantized model...")
        # # mlmodel_int8 = quantization_utils.quantize_weights(mlmodel, nbits=8)
        # # mlmodel_int8.save(str(output_path).replace(".mlpackage", "_int8.mlpackage"))
        # # print("INT8 model saved.")
        
        # # For even more advanced quantization with calibration:
        # # from coremltools.optimize.coreml import create_quantized_model
        # # mlmodel_quantized = create_quantized_model(
        # #    mlmodel,
        # #    calibration_data=create_image_generator(calibration_images),
        # #    multiarray_dtype='int8'
        # # )
        
    except ImportError as e:
        print(f"CoreML export failed: {e}")
        print("Please install coremltools using: pip install coremltools")
    except Exception as e:
        print(f"Error during CoreML export: {e}")
        import traceback
        traceback.print_exc()
