In [1]:
from pathlib import Path
import os
import numpy as np
import torch
import torchvision.transforms.functional as T
from matplotlib import pyplot as plt
from torchvision.utils import draw_segmentation_masks, make_grid
import pytorch_lightning as pl
import mlflow
from tqdm import tqdm
import torch.nn.functional as F
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import RichProgressBar, ModelCheckpoint

In [2]:
from inz.data.event import Event, Hold, Test, Tier1, Tier3
from inz.data.data_module import XBDDataModule
from inz.models.unet_basic import UNet
from inz.models.unet_basic_pl import SemanticSegmentor, OrdinalCrossEntropyLoss

In [3]:
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
dm = XBDDataModule(
    path=Path("data/xBD_processed"),
    events={
        # Tier1: [
        #     Event.hurricane_florence,
        #     Event.hurricane_harvey,
        #     Event.hurricane_matthew,
        #     Event.hurricane_michael,
        # ],
        Tier3: [
            Event.joplin_tornado,
            # Event.moore_tornado,
            # Event.tuscaloosa_tornado
        ],
        # Hold: [
        #     Event.hurricane_florence,
        #     Event.hurricane_harvey,
        #     Event.hurricane_matthew,
        #     Event.hurricane_michael,
        # ],
        # Test: [
        #     Event.hurricane_florence,
        #     Event.hurricane_harvey,
        #     Event.hurricane_matthew,
        #     Event.hurricane_michael,
        # ],
    },
    val_faction=0.3,
    test_fraction=0.,
    train_batch_size=20,
    val_batch_size=20,
    test_batch_size=20,
)
dm.prepare_data()
dm.setup("fit")

print(
    f"{len(dm.train_dataloader())} train batches, {len(dm.val_dataloader())} val batches, {len(dm.test_dataloader())} test batches, "
)

84 train batches, 36 val batches, 0 test batches, 


In [5]:
aaa_loc = []
aaa_cls = []
for batch in tqdm(dm.train_dataloader()):
    pre_images, pre_masks, post_images, post_masks = batch
    counts_post = torch.bincount(post_masks.argmax(dim=1).reshape(-1), minlength=6)
    aaa_cls.append(counts_post)
    counts_pre = torch.bincount(pre_masks.argmax(dim=1).reshape(-1), minlength=6)
    aaa_loc.append(torch.tensor([counts_pre[0], counts_pre[1:].sum()]))

loc_counts = torch.stack(aaa_loc).sum(dim=0).to(torch.float)
cls_counts = torch.stack(aaa_cls).sum(dim=0).to(torch.float)

print(cls_counts)

loc_weights = loc_counts.sum() / loc_counts
loc_weights = (loc_weights / loc_weights.sum()).cuda()
cls_weights = cls_counts.sum() / cls_counts
cls_weights = (cls_weights / cls_weights.sum()).cuda() * torch.tensor([0.1, 0.1, 99999, 99999, 99999, 99999]).cuda()

print(f"Localization weights: {loc_weights}\nClassification weights: {cls_weights}")

100%|██████████| 84/84 [00:16<00:00,  5.02it/s]


tensor([97941888.,  6290833.,  1816537.,  1126032.,  1985791.,   218499.])
Localization weights: tensor([0.1048, 0.8952], device='cuda:0')
Classification weights: tensor([1.5266e-04, 2.3768e-03, 8.2310e+03, 1.3278e+04, 7.5295e+03, 6.8431e+04],
       device='cuda:0')


In [6]:
model = SemanticSegmentor(
    model=UNet(in_channels=3, out_channels=6),
    localization_loss=torch.nn.BCEWithLogitsLoss(pos_weight=loc_weights[1]),
    classification_loss=OrdinalCrossEntropyLoss(n_classes=6, weights=cls_weights.cuda()),
    n_classes=6,
)

/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'localization_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['localization_loss'])`.
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'classification_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['classification_loss'])`.


In [7]:
mlflow.pytorch.autolog()


checkpoint_callback = ModelCheckpoint(
    save_top_k=1, verbose=True, monitor="f1", mode="min", filename="{epoch}-f1-{f1:.5f}"
)

trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[RichProgressBar(), checkpoint_callback],
    logger=MLFlowLogger(),
    precision=16,
)
trainer.fit(model, datamodule=dm)

/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/lightning_fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024/06/09 18:10:37 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '05d05ce695124b49afa344b54e245cc1', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Epoch 0, global step 84: 'f1' reached 0.18614 (best 0.18614), saving model to './mlruns/971144892744864990/d9decc1e0241479c87ca644307c0e062/checkpoints/epoch=0-f1-f1=0.18614.ckpt' as top 1
