In [1]:
import sys
import multiprocessing as mp
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchgeo.datasets import NonGeoDataset, stack_samples, unbind_samples
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.trainers import PixelwiseRegressionTask, SemanticSegmentationTask
from torchvision.transforms.functional import pad
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import geopandas as gpd
import rasterio
import numpy as np
import wandb

sys.path.append("/n/home07/kayan/asm/scripts/")
from asm_datamodules import *

## Set parameters

In [2]:
# device configuration
device, num_devices = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
workers = mp.cpu_count()
print(f"Running on {num_devices} {device}(s) with {workers} cpus")

Running on 1 cuda(s) with 32 cpus


In [3]:
# model parameters
lr = 1e-3
n_epoch = 5
batch_size = 8
loss = "ce"

In [4]:
# file names and paths
root = "/n/home07/kayan/asm/data/" # root for data files
project = "ASM_seg_test" # project name in WandB
run_name = "0_losses+images"

## Create datamodule

In [5]:
datamodule = ASMDataModule(batch_size=batch_size, num_workers=1, split=True, split_n=100, root=root, transforms=min_max_transform)

Split with 64 train images, 16 validation images, and 20 test images
Mine proportions
 Train: 1.0
 Validation: 1.0
 Test: 1.0


## Create prediction task

In [6]:
class MySemanticSegmentationTask(SemanticSegmentationTask):
    def validation_step(
        self, batch, batch_idx, dataloader_idx=0
    ) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        y_hat = self(x)
        y_hat_hard = y_hat.argmax(dim=1)
        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss)
        self.val_metrics(y_hat_hard, y)
        self.log_dict(self.val_metrics)

        if (
            batch_idx < 10
            and hasattr(self.trainer, "datamodule")
            and hasattr(self.trainer.datamodule, "plot")
            and self.logger
            and hasattr(self.logger, "experiment")
            and hasattr(self.logger.experiment, "add_figure")
        ):
            try:
                datamodule = self.trainer.datamodule
                batch["prediction"] = y_hat_hard
                for key in ["image", "mask", "prediction"]:
                    batch[key] = batch[key].cpu()
                sample = unbind_samples(batch)[0]
                fig = datamodule.plot(sample)
                if fig:
                    summary_writer = self.logger.experiment
                    summary_writer.add_figure(
                        f"image/{batch_idx}", fig, global_step=self.global_step
                    )
                    plt.close()
            except ValueError:
                pass
        return y_hat_hard

In [7]:
task = MySemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    weights=True,
    loss=loss,
    in_channels=4,
    num_classes=2,
    lr=lr,
    patience=5,
    freeze_backbone=True,
    freeze_decoder=False
)

## Set up WandB Logging

In [59]:
wandb_logger = WandbLogger(project=project, name=run_name, log_model="all")

In [60]:
class WandBCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # log train loss to WandB
        train_loss = trainer.callback_metrics.get("train_loss_epoch")
        if train_loss is not None:
            wandb.log({"train_loss": train_loss.item()}, step=trainer.global_step)
            
    def on_validation_epoch_end(self, trainer, pl_module):
        # Log validation loss to WandB
        val_loss = trainer.callback_metrics.get("val_loss_epoch")
        if val_loss is not None:
            wandb.log({"val_loss": val_loss.item()}, step=trainer.global_step)
            
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
 
        # outputs corresponds to our model predictions
        # log n sample image predictions from first batch
        if batch_idx == 0:
            n = 5
            imgs = batch["image"]
            masks = batch["mask"].to(torch.float64)
            outputs = outputs.to(torch.float64)
            captions = ["Image", "Ground truth", "Prediction"]
            for i in range(n):
                wandb_logger.log_image(key=str(i), images=[imgs[i], masks[i], outputs[i]], caption=captions)
            
            '''wandb_logger.log_image(key='Images', images = [img for img in imgs[:n]])
            wandb_logger.log_image(key='Ground truth', images = [mask for mask in masks[:n]])
            wandb_logger.log_image(key='Prediction', images = [pred for pred in outputs[:n]])'''

# Set up trainer

In [61]:
trainer = Trainer(
        accelerator=device,
        devices=num_devices,
        max_epochs=n_epoch,
        callbacks=[WandBCallback()],
        logger=wandb_logger
    )

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.1 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model=task, datamodule=datamodule)

/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/n/home07/kayan/miniconda3/envs/geo-ml/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory ./ASM_seg_test/2qk3vnv6/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | Unet             | 14.3 M
---------------------------------------------------
3.2 M     Trainable params
11.2 M    Non-trainable params
14.3 M    Total par

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [51]:
wandb.finish()



VBox(children=(Label(value='397.746 MB of 397.754 MB uploaded\r'), FloatProgress(value=0.9999804567533667, max…

0,1
epoch,▁▃▅▆█
trainer/global_step,▁▁▁▁▃▃▃▃▅▅▅▅▆▆▆▆█
val_MulticlassAccuracy,▁▁▁▁▁
val_MulticlassJaccardIndex,▁▁▁▁▁
val_loss,▄▁▂▆█

0,1
epoch,4.0
trainer/global_step,39.0
val_MulticlassAccuracy,0.99325
val_MulticlassJaccardIndex,0.9866
val_loss,0.04573
