In [1]:
from pathlib import Path

import mlflow
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import MLFlowLogger
from tqdm import tqdm

In [2]:
from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3, Tier1, Hold, Test
from inz.models.unet_siamese import UNetSiamese
from inz.models.unet_siamese_pl import (
    OrdinalCrossEntropyLoss,
    SemanticSegmentorSiamese,
    FocalLoss,
    DiceLoss,
    CrossEntropyDiceLoss,
)

In [3]:
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
dm = XBDDataModule(
    path=Path("data/xBD_processed"),
    events={
        # Tier1: [
        #     Event.hurricane_florence,
        #     Event.hurricane_harvey,
        #     Event.hurricane_matthew,
        #     Event.hurricane_michael,
        # ],
        Tier3: [
            Event.joplin_tornado,
            #     # Event.moore_tornado,
            #     # Event.tuscaloosa_tornado
        ],
        # Hold: [
        #     Event.hurricane_florence,
        #     Event.hurricane_harvey,
        #     Event.hurricane_matthew,
        #     Event.hurricane_michael,
        # ],
        # Test: [
        #     Event.hurricane_florence,
        #     Event.hurricane_harvey,
        #     Event.hurricane_matthew,
        #     Event.hurricane_michael,
        # ],
    },
    val_faction=0.15,
    test_fraction=0.0,
    train_batch_size=24,
    val_batch_size=24,
    test_batch_size=24,
    # split_events={
    #     "train": {
    #         Tier1: [
    #             # Event.hurricane_florence,
    #             # Event.hurricane_harvey,
    #             # Event.hurricane_matthew,
    #             # Event.hurricane_michael,
    #         ],
    #         Tier3: [
    #             Event.joplin_tornado,
    #             Event.moore_tornado,
    #             Event.tuscaloosa_tornado
    #         ],
    #     },
    #     "val": {
    #         Hold: [
    #             # Event.hurricane_florence,
    #             # Event.hurricane_harvey,
    #             Event.hurricane_matthew,
    #             # Event.hurricane_michael,
    #         ],
    #         Test: [
    #             # Event.hurricane_florence,
    #             # Event.hurricane_harvey,
    #             Event.hurricane_matthew,
    #             # Event.hurricane_michael,
    #         ],
    #     },
    # },
)
dm.prepare_data()
dm.setup("fit")

print(f"{len(dm.train_dataloader())} train batches, {len(dm.val_dataloader())} val batches")

85 train batches, 15 val batches


In [5]:
aaa_loc = []
aaa_cls = []
for batch in tqdm(dm.train_dataloader()):
    pre_images, pre_masks, post_images, post_masks = batch
    counts_post = torch.bincount(post_masks.argmax(dim=1).reshape(-1), minlength=6)
    aaa_cls.append(counts_post)
    counts_pre = torch.bincount(pre_masks.argmax(dim=1).reshape(-1), minlength=6)
    aaa_loc.append(torch.tensor([counts_pre[0], counts_pre[1:].sum()]))

loc_counts = torch.stack(aaa_loc).sum(dim=0).to(torch.float)
cls_counts = torch.stack(aaa_cls).sum(dim=0).to(torch.float)

print(cls_counts)

loc_weights = loc_counts.sum() / loc_counts
loc_weights = (loc_weights / loc_weights.sum()).cuda()
cls_weights = cls_counts.sum() / cls_counts
cls_weights = (cls_weights / cls_weights.sum()).cuda() * torch.tensor(
    [1, 1, 1, 1, 1, 0]
).cuda()  # last class is "unclassified"

print(f"Localization weights: {loc_weights}\nClassification weights: {cls_weights}")

100%|██████████| 85/85 [00:20<00:00,  4.08it/s]


tensor([1.1890e+08, 7.7670e+06, 2.1349e+06, 1.3483e+06, 2.4458e+06, 2.4971e+05])
Localization weights: tensor([0.1052, 0.8948], device='cuda:0')
Classification weights: tensor([0.0015, 0.0223, 0.0813, 0.1287, 0.0710, 0.0000], device='cuda:0')


In [6]:
model = SemanticSegmentorSiamese(
    model=UNetSiamese(in_channels=3, out_channels=6),
    localization_loss=FocalLoss(reduction="mean", weight=loc_weights[1]),
    # localization_loss=torch.nn.BCEWithLogitsLoss(pos_weight=loc_weights[1]),
    classification_loss=CrossEntropyDiceLoss(weights=cls_weights, reduction="mean"),
    # classification_loss=OrdinalCrossEntropyLoss(n_classes=6, weights=cls_weights),
    n_classes=6,
)

/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'localization_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['localization_loss'])`.
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'classification_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['classification_loss'])`.


In [7]:
mlflow.pytorch.autolog()


f1_checkpoint_callback = ModelCheckpoint(
    save_top_k=1, verbose=True, monitor="f1", mode="max", filename="f1_{epoch}{f1:.5f}{iou:.5f}{loss:.5f}"
)

# iou_checkpoint_callback = ModelCheckpoint(
#     save_top_k=1, verbose=True, monitor="iou", mode="max", filename="iou_{epoch}{f1:.5f}{iou:.5f}{loss:.5f}"
# )

trainer = pl.Trainer(
    max_epochs=500,
    callbacks=[
        RichProgressBar(),
        f1_checkpoint_callback,
        # iou_checkpoint_callback,
    ],
    logger=MLFlowLogger(experiment_name="basic_siamese"),
    precision="bf16",
)
trainer.fit(model, datamodule=dm)

/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/lightning_fabric/connector.py:563: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024/06/12 11:52:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'be77ae391c2c4c3485f1f18ab4c7838a', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:44: attribute 'localization_loss' removed from hparams because it cannot be pickled
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Epoch 0, global step 85: 'f1' reached 0.29326 (best 0.29326), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=0f1=0.29326iou=0.16177loss=0.99997.ckpt' as top 1


Epoch 1, global step 170: 'f1' reached 0.31373 (best 0.31373), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=1f1=0.31373iou=0.16703loss=0.99110.ckpt' as top 1


Epoch 2, global step 255: 'f1' was not in top 1


Epoch 3, global step 340: 'f1' reached 0.32523 (best 0.32523), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=3f1=0.32523iou=0.17059loss=0.98505.ckpt' as top 1


Epoch 4, global step 425: 'f1' reached 0.36036 (best 0.36036), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=4f1=0.36036iou=0.18346loss=0.98052.ckpt' as top 1


Epoch 5, global step 510: 'f1' was not in top 1


Epoch 6, global step 595: 'f1' reached 0.42306 (best 0.42306), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=6f1=0.42306iou=0.21762loss=0.97899.ckpt' as top 1


Epoch 7, global step 680: 'f1' was not in top 1


Epoch 8, global step 765: 'f1' was not in top 1


Epoch 9, global step 850: 'f1' was not in top 1


Epoch 10, global step 935: 'f1' was not in top 1


Epoch 11, global step 1020: 'f1' was not in top 1


Epoch 12, global step 1105: 'f1' was not in top 1


Epoch 13, global step 1190: 'f1' was not in top 1


Epoch 14, global step 1275: 'f1' was not in top 1


Epoch 15, global step 1360: 'f1' was not in top 1


Epoch 16, global step 1445: 'f1' was not in top 1


Epoch 17, global step 1530: 'f1' was not in top 1


Epoch 18, global step 1615: 'f1' was not in top 1


Epoch 19, global step 1700: 'f1' was not in top 1


Epoch 20, global step 1785: 'f1' was not in top 1


Epoch 21, global step 1870: 'f1' was not in top 1


Epoch 22, global step 1955: 'f1' reached 0.42945 (best 0.42945), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=22f1=0.42945iou=0.20948loss=0.97864.ckpt' as top 1


Epoch 23, global step 2040: 'f1' reached 0.44323 (best 0.44323), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=23f1=0.44323iou=0.21721loss=0.97230.ckpt' as top 1


Epoch 24, global step 2125: 'f1' was not in top 1


Epoch 25, global step 2210: 'f1' was not in top 1


Epoch 26, global step 2295: 'f1' reached 0.44530 (best 0.44530), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=26f1=0.44530iou=0.20638loss=0.98119.ckpt' as top 1


Epoch 27, global step 2380: 'f1' was not in top 1


Epoch 28, global step 2465: 'f1' was not in top 1


Epoch 29, global step 2550: 'f1' was not in top 1


Epoch 30, global step 2635: 'f1' was not in top 1


Epoch 31, global step 2720: 'f1' was not in top 1


Epoch 32, global step 2805: 'f1' was not in top 1


Epoch 33, global step 2890: 'f1' was not in top 1


Epoch 34, global step 2975: 'f1' reached 0.47219 (best 0.47219), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=34f1=0.47219iou=0.22780loss=0.97680.ckpt' as top 1


Epoch 35, global step 3060: 'f1' was not in top 1


Epoch 36, global step 3145: 'f1' was not in top 1


Epoch 37, global step 3230: 'f1' reached 0.48220 (best 0.48220), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=37f1=0.48220iou=0.24055loss=0.96727.ckpt' as top 1


Epoch 38, global step 3315: 'f1' was not in top 1


Epoch 39, global step 3400: 'f1' was not in top 1


Epoch 40, global step 3485: 'f1' was not in top 1


Epoch 41, global step 3570: 'f1' was not in top 1


Epoch 42, global step 3655: 'f1' was not in top 1


Epoch 43, global step 3740: 'f1' was not in top 1


Epoch 44, global step 3825: 'f1' was not in top 1


Epoch 45, global step 3910: 'f1' was not in top 1


Epoch 46, global step 3995: 'f1' was not in top 1


Epoch 47, global step 4080: 'f1' was not in top 1


Epoch 48, global step 4165: 'f1' reached 0.48839 (best 0.48839), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=48f1=0.48839iou=0.22835loss=0.96364.ckpt' as top 1


Epoch 49, global step 4250: 'f1' was not in top 1


Epoch 50, global step 4335: 'f1' was not in top 1


Epoch 51, global step 4420: 'f1' reached 0.48857 (best 0.48857), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=51f1=0.48857iou=0.24307loss=0.97100.ckpt' as top 1


Epoch 52, global step 4505: 'f1' was not in top 1


Epoch 53, global step 4590: 'f1' was not in top 1


Epoch 54, global step 4675: 'f1' was not in top 1


Epoch 55, global step 4760: 'f1' was not in top 1


Epoch 56, global step 4845: 'f1' was not in top 1


Epoch 57, global step 4930: 'f1' was not in top 1


Epoch 58, global step 5015: 'f1' was not in top 1


Epoch 59, global step 5100: 'f1' was not in top 1


Epoch 60, global step 5185: 'f1' was not in top 1


Epoch 61, global step 5270: 'f1' reached 0.49691 (best 0.49691), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=61f1=0.49691iou=0.24135loss=0.97519.ckpt' as top 1


Epoch 62, global step 5355: 'f1' was not in top 1


Epoch 63, global step 5440: 'f1' was not in top 1


Epoch 64, global step 5525: 'f1' reached 0.49949 (best 0.49949), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=64f1=0.49949iou=0.23804loss=0.96841.ckpt' as top 1


Epoch 65, global step 5610: 'f1' reached 0.50024 (best 0.50024), saving model to './mlruns/124960261813571015/5801385d6f0245a7b11b53c65fbc3e14/checkpoints/f1_epoch=65f1=0.50024iou=0.23871loss=0.96088.ckpt' as top 1


Epoch 66, global step 5695: 'f1' was not in top 1


Epoch 67, global step 5780: 'f1' was not in top 1


Epoch 68, global step 5865: 'f1' was not in top 1


Epoch 69, global step 5950: 'f1' was not in top 1


Epoch 70, global step 6035: 'f1' was not in top 1


Epoch 71, global step 6120: 'f1' was not in top 1


Epoch 72, global step 6205: 'f1' was not in top 1


Epoch 73, global step 6290: 'f1' was not in top 1


Epoch 74, global step 6375: 'f1' was not in top 1


Epoch 75, global step 6460: 'f1' was not in top 1


