## Import libraries and set dataset paths


In [2]:
from pathlib import Path

import datetime
from lib.dataloader import *

uav_set = Path("/home/emilia/WaterSegNet/datasets/uav_dataset/")
satelite_set = Path("/home/emilia/WaterSegNet/datasets/satelite_dataset/")
full_set = Path("/home/emilia/WaterSegNet/datasets/complete_dataset/")
# Directory paths for semantic segmentation dataset
root_dir = uav_set

dir_test_img = root_dir / "test/images"
dir_test_mask = root_dir / "test/labels"
dir_train_img = root_dir / "train/images"
dir_train_mask = root_dir / "train/labels"
dir_valid_img = root_dir / "valid/images"
dir_valid_mask = root_dir / "valid/labels"

# Directory paths for checkpoints and best models
dir_checkpoint = root_dir / "checkpoints/"
dir_best_model = root_dir / "best_models/"
dir_best_model /= datetime.datetime.now().strftime("%d-%m-%Y-%H-%M")


train_set = SegDataset(dir_train_img, dir_train_mask)
valid_set = SegDataset(dir_valid_img, dir_valid_mask)
test_set = SegDataset(dir_test_img, dir_test_mask)

In [None]:
import matplotlib.pyplot as plt

for i in range(6, 15):
    sample = test_set[i]
    img, mask = sample["image"], sample["mask"]
    # for visualization we have to transpose back to HWC
    plt.imshow(np.transpose(img, (1, 2, 0)))
    # plt.imshow(mask, alpha=0.5, cmap="Accent")
    plt.title("Padded and resized image")
    plt.axis("off")
    plt.show()

## Train


In [3]:
from lib.train import *
import segmentation_models_pytorch as smp
import torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from segnet.SegNet_model import SegNet

model = smp.UnetPlusPlus(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None,
)


model = model.to(memory_format=torch.channels_last)

metrics = torchmetrics.MetricCollection(
    torchmetrics.Accuracy(task="binary", num_classes=1, multiclass=False),
    torchmetrics.Recall(task="binary", num_classes=1, multiclass=False),
    torchmetrics.Precision(task="binary", num_classes=1, multiclass=False),
    torchmetrics.F1Score(
        task="binary", num_classes=1, multiclass=False
    ),  # Dice Coefficient
)

train_metrics = torchmetrics.MetricTracker(metrics)
val_metrics = torchmetrics.MetricTracker(metrics)
test_metrics = torchmetrics.MetricTracker(metrics)


seg_model = SegModel(
    model,
    lr=1e-3,
    optimizer_type="adamw",
    train_metrics=train_metrics,
    val_metrics=val_metrics,
    test_metrics=test_metrics,
    freeze_encoder=False,
)
data_module = SegDataModule(train_set, valid_set, test_set, batch_size=8)
model_name = "UNetPlusPlus_adamw_b8"

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/satelite/",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
    filename="model-{}".format(model_name),
)
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, verbose=True, mode="min"
)

tb_logger = TensorBoardLogger("lightning_logs/satelite/", name=model_name)

trainer = Trainer(
    max_epochs=200,
    gpus=1 if torch.cuda.is_available() else 0,
    callbacks=[checkpoint_callback, early_stopping],
    logger=tb_logger,
    precision=16,  # Mixed precision training
)

trainer.fit(seg_model, data_module)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type              | Params
-----------------------------------------------------------
0 | model                | UnetPlusPlus      | 26.1 M
1 | criterion            | BCEWithLogitsLoss | 0     
2 | dice_loss            | DiceLoss          | 0     
3 | train_metric_tracker | MetricTracker     | 0     
4 | val_metric_tracker   | MetricTracker     | 0     
5 | test_metric_tracker  | MetricTracker     | 0     
-----------------------------------------------------------
26.1 M    Trainable params
0         Non-trainable params
26.1 M    Total params
52.157    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Epoch 0: 100%|██████████| 84/84 [00:21<00:00,  3.96it/s, loss=0.198, v_num=0]

Metric val_loss improved. New best score: 0.294


Epoch 1: 100%|██████████| 84/84 [00:24<00:00,  3.50it/s, loss=0.212, v_num=0]

Metric val_loss improved by 0.110 >= min_delta = 0.0. New best score: 0.183


Epoch 2: 100%|██████████| 84/84 [00:24<00:00,  3.36it/s, loss=0.226, v_num=0]

Metric val_loss improved by 0.032 >= min_delta = 0.0. New best score: 0.152


Epoch 6: 100%|██████████| 84/84 [00:25<00:00,  3.35it/s, loss=0.0568, v_num=0]

Metric val_loss improved by 0.032 >= min_delta = 0.0. New best score: 0.120


Epoch 7: 100%|██████████| 84/84 [00:24<00:00,  3.40it/s, loss=0.0468, v_num=0]

Metric val_loss improved by 0.059 >= min_delta = 0.0. New best score: 0.060


Epoch 8: 100%|██████████| 84/84 [00:25<00:00,  3.34it/s, loss=0.0363, v_num=0]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.056


Epoch 13: 100%|██████████| 84/84 [00:25<00:00,  3.32it/s, loss=0.0271, v_num=0]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.053


Epoch 23: 100%|██████████| 84/84 [00:25<00:00,  3.34it/s, loss=0.0227, v_num=0]

Monitored metric val_loss did not improve in the last 10 records. Best score: 0.053. Signaling Trainer to stop.


Epoch 23: 100%|██████████| 84/84 [00:25<00:00,  3.34it/s, loss=0.0227, v_num=0]


In [4]:
%load_ext tensorboard

In [4]:
torch.cuda.empty_cache()

In [6]:
torch.cuda.memory_summary(device=None, abbreviated=False)

