In [1]:
from pathlib import Path

import mlflow
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import MLFlowLogger
from tqdm import tqdm

In [2]:
from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3, Tier1, Hold
from inz.models.unet_siamese import UNetSiamese
from inz.models.unet_siamese_pl import OrdinalCrossEntropyLoss, SemanticSegmentorSiamese

In [3]:
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
dm = XBDDataModule(
    path=Path("data/xBD_processed"),
    # events={
    #     # Tier1: [
    #     #     Event.hurricane_florence,
    #     #     Event.hurricane_harvey,
    #     #     Event.hurricane_matthew,
    #     #     Event.hurricane_michael,
    #     # ],
    #     # Tier3: [
    #     #     Event.joplin_tornado,
    #     #     # Event.moore_tornado,
    #     #     # Event.tuscaloosa_tornado
    #     # ],
    #     # Hold: [
    #     #     Event.hurricane_florence,
    #     #     Event.hurricane_harvey,
    #     #     Event.hurricane_matthew,
    #     #     Event.hurricane_michael,
    #     # ],
    #     # Test: [
    #     #     Event.hurricane_florence,
    #     #     Event.hurricane_harvey,
    #     #     Event.hurricane_matthew,
    #     #     Event.hurricane_michael,
    #     # ],
    # },
    # val_faction=0.15,
    # test_fraction=0.0,
    train_batch_size=24,
    val_batch_size=24,
    test_batch_size=24,
    split_events={
        "train": {
            Tier1: [
                Event.hurricane_florence,
                Event.hurricane_harvey,
                Event.hurricane_matthew,
                Event.hurricane_michael,
            ],
            Tier3: [Event.joplin_tornado, Event.moore_tornado, Event.tuscaloosa_tornado],
        },
        "val": {
            Hold: [
                Event.hurricane_florence,
                Event.hurricane_harvey,
                Event.hurricane_matthew,
                Event.hurricane_michael,
            ],
        },
    },
)
dm.prepare_data()
dm.setup("fit")

print(f"{len(dm.train_dataloader())} train batches, {len(dm.val_dataloader())} val batches")

1292 train batches, 278 val batches


In [5]:
aaa_loc = []
aaa_cls = []
for batch in tqdm(dm.train_dataloader()):
    pre_images, pre_masks, post_images, post_masks = batch
    counts_post = torch.bincount(post_masks.argmax(dim=1).reshape(-1), minlength=6)
    aaa_cls.append(counts_post)
    counts_pre = torch.bincount(pre_masks.argmax(dim=1).reshape(-1), minlength=6)
    aaa_loc.append(torch.tensor([counts_pre[0], counts_pre[1:].sum()]))

loc_counts = torch.stack(aaa_loc).sum(dim=0).to(torch.float)
cls_counts = torch.stack(aaa_cls).sum(dim=0).to(torch.float)

print(cls_counts)

loc_weights = loc_counts.sum() / loc_counts
loc_weights = (loc_weights / loc_weights.sum()).cuda()
cls_weights = cls_counts.sum() / cls_counts
cls_weights = (cls_weights / cls_weights.sum()).cuda() * torch.tensor(
    [1, 1, 1, 1, 1, 0]
).cuda()  # last class is "unclassified"

print(f"Localization weights: {loc_weights}\nClassification weights: {cls_weights}")

100%|██████████| 1292/1292 [04:56<00:00,  4.36it/s]


tensor([1.8963e+09, 8.6628e+07, 2.0118e+07, 2.0930e+07, 6.4454e+06, 1.7008e+06])
Localization weights: tensor([0.0670, 0.9330], device='cuda:0')
Classification weights: tensor([0.0006, 0.0135, 0.0583, 0.0560, 0.1820, 0.0000], device='cuda:0')


In [6]:
model = SemanticSegmentorSiamese(
    model=UNetSiamese(in_channels=3, out_channels=6),
    localization_loss=torch.nn.BCEWithLogitsLoss(pos_weight=loc_weights[1]),
    classification_loss=OrdinalCrossEntropyLoss(n_classes=6, weights=cls_weights),
    n_classes=6,
)

/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'localization_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['localization_loss'])`.
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'classification_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['classification_loss'])`.


In [7]:
mlflow.pytorch.autolog()


f1_checkpoint_callback = ModelCheckpoint(
    save_top_k=1, verbose=True, monitor="f1", mode="max", filename="f1_{epoch}{f1:.5f}{iou:.5f}{loss:.5f}"
)

# iou_checkpoint_callback = ModelCheckpoint(
#     save_top_k=1, verbose=True, monitor="iou", mode="max", filename="iou_{epoch}{f1:.5f}{iou:.5f}{loss:.5f}"
# )

trainer = pl.Trainer(
    max_epochs=500,
    callbacks=[
        RichProgressBar(),
        f1_checkpoint_callback,
        # iou_checkpoint_callback,
    ],
    logger=MLFlowLogger(experiment_name="basic_siamese"),
    precision="bf16",
)
trainer.fit(model, datamodule=dm)

/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/lightning_fabric/connector.py:563: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024/06/11 01:32:08 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'c23cf976feab440bb6b77cb20276783f', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Epoch 0, global step 1292: 'f1' reached 0.21725 (best 0.21725), saving model to './mlruns/124960261813571015/459c2992a1bc4e3da497bee5c0cb3633/checkpoints/f1_epoch=0f1=0.21725iou=0.14719loss=2.27019.ckpt' as top 1


Epoch 1, global step 2584: 'f1' was not in top 1


Epoch 2, global step 3876: 'f1' was not in top 1


Epoch 3, global step 5168: 'f1' reached 0.22599 (best 0.22599), saving model to './mlruns/124960261813571015/459c2992a1bc4e3da497bee5c0cb3633/checkpoints/f1_epoch=3f1=0.22599iou=0.15062loss=2.58500.ckpt' as top 1


Epoch 4, global step 6460: 'f1' reached 0.23778 (best 0.23778), saving model to './mlruns/124960261813571015/459c2992a1bc4e3da497bee5c0cb3633/checkpoints/f1_epoch=4f1=0.23778iou=0.15932loss=2.31731.ckpt' as top 1


Epoch 5, global step 7752: 'f1' was not in top 1


Epoch 6, global step 9044: 'f1' was not in top 1


Epoch 7, global step 10336: 'f1' was not in top 1


Epoch 8, global step 11628: 'f1' was not in top 1


Epoch 9, global step 12920: 'f1' was not in top 1


Epoch 10, global step 14212: 'f1' reached 0.25349 (best 0.25349), saving model to './mlruns/124960261813571015/459c2992a1bc4e3da497bee5c0cb3633/checkpoints/f1_epoch=10f1=0.25349iou=0.16302loss=2.11377.ckpt' as top 1


Epoch 11, global step 15504: 'f1' reached 0.26285 (best 0.26285), saving model to './mlruns/124960261813571015/459c2992a1bc4e3da497bee5c0cb3633/checkpoints/f1_epoch=11f1=0.26285iou=0.16575loss=1.47692.ckpt' as top 1


Epoch 12, global step 16796: 'f1' reached 0.27523 (best 0.27523), saving model to './mlruns/124960261813571015/459c2992a1bc4e3da497bee5c0cb3633/checkpoints/f1_epoch=12f1=0.27523iou=0.17239loss=1.52741.ckpt' as top 1


Epoch 13, global step 18088: 'f1' was not in top 1


Epoch 14, global step 19380: 'f1' reached 0.27671 (best 0.27671), saving model to './mlruns/124960261813571015/459c2992a1bc4e3da497bee5c0cb3633/checkpoints/f1_epoch=14f1=0.27671iou=0.17031loss=1.71462.ckpt' as top 1
