In [1]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset
from torchgeo.samplers import RandomGeoSampler, Units
# from torchgeo.transforms import stack_samples
from typing import Dict, List, Optional, Callable
from torchgeo.datasets import RasterDataset, unbind_samples, stack_samples
import wandb
from sklearn.metrics import jaccard_score
from pathlib import Path
import schedulefree
import torchseg

In [3]:
import torch
from flipnslide.tiling import FlipnSlide
from typing import Any
from torch import Tensor
import kornia as K
from torchgeo.transforms import AugmentationSequential
from torchgeo.transforms import indices


class _FlipnSlide(K.augmentation.GeometricAugmentationBase2D):
    """Flip and slide a tensor."""

    def __init__(self, tilesize: int, viz: bool = False) -> None:
        """Initialize a new _FlipnSlide instance.

        Args:
            tilesize: desired tile size
            viz: visualization flag
        """
        super().__init__(same_on_batch=True, p=1)
        self.flags = {'tilesize': tilesize, 'viz': viz}

    def compute_transformation(
        self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]
    ) -> Tensor:
        """Compute the transformation.

        Args:
            input: the input tensor
            params: generated parameters
            flags: static parameters

        Returns:
            the transformation
        """
        out: Tensor = self.identity_matrix(input)
        return out

    def apply_transform(
        self,
        input: Tensor,
        params: dict[str, Tensor],
        flags: dict[str, Any],
        transform: Tensor = None,
    ) -> Tensor:
        """Apply the transform.

        Args:
            input: the input tensor (nSamples, Channels, Height, Width)
            params: generated parameters
            flags: static parameters
            transform: the geometric transformation tensor

        Returns:
            the augmented input (with tiling applied to each sample)
        """
        n_samples, channels, height, width = input.shape
        # Prepare a list to store the transformed samples
        transformed_samples = []
        
        # Loop over each sample in the batch
        for sample in input:
            np_array = sample.cpu().numpy()  # Convert sample to NumPy format
            sample_tiled = FlipnSlide(
                tile_size=flags['tilesize'], 
                data_type='tensor',
                save=False, 
                image=np_array,
                viz=flags['viz']
            )
            # Collect the transformed sample (converted back to tensor)
            transformed_samples.append(sample_tiled.tiles.unsqueeze(0))

        transformed_samples = torch.cat(transformed_samples, dim=0)
        
        # Stack all transformed samples back into a batch
        return transformed_samples.view(-1, *transformed_samples.shape[2:])


# Usage with AugmentationSequential
flipnslide = _FlipnSlide(tilesize=64, viz=False)

tfms_img = AugmentationSequential(
    indices.AppendNDBI(index_swir=5, index_nir=3),
    indices.AppendNDWI(index_green=1, index_nir=3),
    indices.AppendNDVI(index_nir=3, index_red=2),
    data_keys = ['image']
)
tfms_fns = AugmentationSequential(
    flipnslide,
    data_keys= ['image', 'mask']
)

# dataloader = DataLoader(train_dset, sampler=sampler, batch_size=1, collate_fn=stack_samples)
# batch = next(iter(dataloader))
# print(batch.keys())
# x = batch
# print(x['image'].shape)
# x = tfms_img(x)
# print(x['image'].shape)

# # print(tfms_both(x)['image'].shape)
# x = tfms_fns(x)
# print(x['mask'].unsqueeze(1).shape)

In [8]:
# Define the LightningModule
class SegmentationModel(pl.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: Callable,
        acc_fns: Optional[List[Callable]] = None,
        learning_rate: float = 0.0025,
        save_model_path: Optional[str] = None,
        train_tfms: Optional[Callable] = None,
        val_tfms: Optional[Callable] = None,
    ):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.acc_fns = acc_fns
        self.learning_rate = learning_rate
        self.save_model_path = save_model_path
        self.best_chkpt_score = float("-inf")
        self.train_tfms = train_tfms
        self.val_tfms = val_tfms

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        # Set optimizer to train mode
        self.optimizers().train()

        if self.train_tfms is not None:
            # if multiple transforms, Apply transforms iteratively
            if isinstance(self.train_tfms, list):
                batch = [tfm(batch) for tfm in self.train_tfms][0]
            else:
                batch = self.train_tfms(batch)

        X = batch["image"].to(self.device)
        y = batch["mask"].type(torch.long).to(self.device)
        y = y.squeeze(1)
        pred = self(torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0))
        loss = self.loss_fn(pred, y)
        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Set optimizer to eval mode
        self.optimizers().eval()

        if self.val_tfms is not None:
            # if multiple transforms, Apply transforms iteratively
            if isinstance(self.val_tfms, list):
                batch = [tfm(batch) for tfm in self.val_tfms][0]
            else:
                batch = self.val_tfms(batch)      

        X = batch["image"].type(torch.float32).to(self.device)
        y = batch["mask"].type(torch.long).to(self.device)
        y = y.squeeze(1)
        pred = self(torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0))
        val_loss = self.loss_fn(pred, y)
        self.log("val_loss", val_loss, on_step=True, on_epoch=False, prog_bar=True)

        if self.acc_fns is not None:
            acc = [acc_fn(pred, y) for acc_fn in self.acc_fns]
            for i, acc_fn in enumerate(self.acc_fns):
                self.log(f"val_acc_{i}", acc[i], on_step=True, on_epoch=False, prog_bar=True)

            # Save the best model based on the first accuracy metric
            if self.save_model_path and acc[0] > self.best_chkpt_score:
                self.best_chkpt_score = acc[0]
                torch.save(self.model.state_dict(), self.save_model_path)
                print(f"Saving model with validation score: {self.best_chkpt_score:.4f}")

    def configure_optimizers(self):
        optimizer = schedulefree.AdamWScheduleFree(self.parameters(), lr=self.learning_rate)
        return optimizer


# Define accuracy functions
def oa(pred, y):
    flat_y = y.squeeze()
    flat_pred = pred.argmax(dim=1)
    acc = torch.count_nonzero(flat_y == flat_pred) / torch.numel(flat_y)
    return acc


def iou(pred, y):
    flat_y = y.cpu().numpy().squeeze()
    flat_pred = pred.argmax(dim=1).detach().cpu().numpy()
    return jaccard_score(flat_y.reshape(-1), flat_pred.reshape(-1), zero_division=1.0, average="weighted")


# Define the loss function
def loss(p, t):
    return torch.nn.functional.cross_entropy(p, t)


# Load datasets
root = Path(r"C:\Users\coach\myfiles\postdoc\Fire\data\DNN")
train_imgs = RasterDataset(paths=(root / "train/X_chips").as_posix(), crs="epsg:4326", res=0.00025)
train_msks = RasterDataset(paths=(root / "train/Y_chips").as_posix(), crs="epsg:4326", res=0.00025)
valid_imgs = RasterDataset(paths=(root / "test/X_chips").as_posix(), crs="epsg:4326", res=0.00025)
valid_msks = RasterDataset(paths=(root / "test/Y_chips").as_posix(), crs="epsg:4326", res=0.00025)

train_msks.is_image = False
valid_msks.is_image = False

train_dset = train_imgs & train_msks
valid_dset = valid_imgs & valid_msks

train_sampler = RandomGeoSampler(train_imgs, size=256, length=3, units=Units.PIXELS)
valid_sampler = RandomGeoSampler(valid_imgs, size=256, length=3, units=Units.PIXELS)

# Custom collation function to filter out non-tensor objects
def custom_collate_fn(batch):
    # Extract only tensor-like objects (images and masks)
    filtered_batch = []
    for sample in batch:
        filtered_sample = {
            "image": sample["image"],  # Assuming "image" is a tensor
            "mask": sample["mask"],    # Assuming "mask" is a tensor
        }
        filtered_batch.append(filtered_sample)

    # Stack images and masks into batches
    images = torch.stack([item["image"] for item in filtered_batch])
    masks = torch.stack([item["mask"] for item in filtered_batch])
    return {
        "image": images,
        "mask": masks,
    }

train_dataloader = DataLoader(train_dset, sampler=train_sampler, batch_size=3, collate_fn=custom_collate_fn)
valid_dataloader = DataLoader(valid_dset, sampler=valid_sampler, batch_size=3, collate_fn=custom_collate_fn)

# Initialize wandb
wandb.init(project="burn_area_mapping", name="DLconvnextv2_PL", job_type="L8_pretraining_test", mode='dryrun')

# Define the model
model = torchseg.Unet(
    encoder_name="convnextv2_tiny",
    encoder_weights="imagenet",
    in_channels=14,
    classes=2,
    encoder_depth=4,
    decoder_channels=(256, 128, 64, 32),
    head_upsampling=2,
)

# Create the LightningModule
lightning_model = SegmentationModel(
    model=model,
    loss_fn=loss,
    acc_fns=[oa, iou],
    learning_rate=0.0025,
    save_model_path=r"C:\Users\coach\myfiles\postdoc\Fire\models\DL_imgnet_convnextT_14012025.pth",
    train_tfms = [tfms_img, tfms_fns],
    val_tfms = tfms_img
)

# Define the Trainer
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,  # Use 2 GPUs
    max_epochs=150,
    logger=pl.loggers.WandbLogger(project="burn_area_mapping")
)

# Train the model
trainer.fit(lightning_model, train_dataloader, valid_dataloader)

# Finish wandb run
wandb.finish()

0,1
trainer/global_step,▁
val_acc_0,▁
val_acc_1,▁
val_loss,▁

0,1
trainer/global_step,0.0
val_acc_0,0.77883
val_acc_1,0.62109
val_loss,5.00203


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\ProgramData\Anaconda3\envs\erthy\Lib\site-packages\pytorch_lightning\loggers\wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                   | Params | Mode 
------------------------------------------------------------
0 | model    | Unet                   | 32.0 M | train
1 | val_tfms | AugmentationSequential | 0      | eval 
------------------------------------------------------------
32.0 M    Trainable params
0         Non-trainable params
32.0 M    Total params
127.942   Total estimated model params size (MB)
297       Modules in train mode
6         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\ProgramData\Anaconda3\envs\erthy\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Saving model with validation score: 0.1385


c:\ProgramData\Anaconda3\envs\erthy\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\ProgramData\Anaconda3\envs\erthy\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Saving model with validation score: 0.7929


Validation: |          | 0/? [00:00<?, ?it/s]

Saving model with validation score: 0.9357


Validation: |          | 0/? [00:00<?, ?it/s]

Saving model with validation score: 0.9896


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

: 