In [1]:
from pathlib import Path
from functools import partial
import dotenv
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3, Tier1, Hold, Test
from inz.models.baseline_module import BaselineModule
from inz.util import get_loc_cls_weights, get_wandb_logger, show_masks_comparison
from inz.xview2_strong_baseline.legacy.losses import ComboLoss
from inz.xview2_strong_baseline.legacy.zoo.models import Res34_Unet_Double
from inz.models.baseline_module import BaselineModule
from inz.data.zipped_data_module import ZippedDataModule
from inz.models.farseg_singlebranch_module import SingleBranchFarSeg, FarSegSingleBranchModule
import torch.nn as nn
from torchvision.utils import draw_segmentation_masks, make_grid
import numpy as np
from torchmetrics.functional.classification import multiclass_f1_score, binary_f1_score
from torchmetrics.functional.classification import binary_accuracy
import simplecv

In [2]:
import sys  # noqa: I001

sys.path.append("inz/revgrad")

from utils import GradientReversal

sys.path.append("inz/xview2_strong_baseline")

from legacy.zoo.models import Res34_Unet_Double

In [3]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
from inz.data.data_module_floodnet import FloodNetDataset, FloodNetModule
import torchvision.transforms as T


BATCH_SIZE = 16

source_events = {
    Tier1: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Test: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Hold: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Tier3: [
        Event.nepal_flooding
    ],
}

_dm_source = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events=source_events,
    val_fraction=0.0,
    test_fraction=0.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
)
_dm_target = FloodNetModule(
    path=Path("data/floodnet_processed_512/FloodNet-Supervised_v1.0"),
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
    transform=T.Compose(
        transforms=[
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply(
                p=0.6, transforms=[T.RandomAffine(degrees=(-10, 10), scale=(0.9, 1.1), translate=(0.1, 0.1))]
            ),
        ]
    ),
)

dm = ZippedDataModule(
    dm1=_dm_source,
    dm2=_dm_target,
    match_type="max",
    num_workers=2,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE
)
dm.prepare_data()
dm.setup("fit")


In [5]:
dataloader = dm.train_dataloader()
source_batch, target_batch = next(iter(dataloader))

s_img_pre, s_mask_pre, s_img_post, s_mask_post = source_batch
t_img_pre, t_mask_pre, t_img_post, t_mask_post = target_batch

In [6]:
# plt.rcParams["savefig.bbox"] = "tight"


# def show(imgs: list[torch.Tensor]):
#     if not isinstance(imgs, list):
#         imgs = [imgs]
#     fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
#     for i, img in enumerate(imgs):
#         img = img.detach()
#         img = T.functional.to_pil_image(img)
#         axs[0, i].imshow(np.asarray(img))
#         axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

# plt.rcParams["figure.figsize"] = [12, 24]
# colors_source = [
#     (128, 128, 128),
#     (0, 255, 0),
#     (244, 255, 0),
#     (255, 174, 0),
#     (255, 0, 0),
#     (255, 255, 255),
# ]
# show(
#     [
#         make_grid((s_img_post + 1) / 2, nrow=2),
#         make_grid(
#             [draw_segmentation_masks(((i + 1) * 127.5).to(torch.uint8), m, colors=colors_source, alpha=0.5) for i, m in zip(s_img_post, s_mask_post.to(torch.bool))], nrow=2
#         ),
#     ]
# )
# colors_target = [
#     (128, 128, 128),
#     (0, 255, 0),
#     (255, 0, 0),
# ]
# show(
#     [
#         make_grid((t_img_post + 1) / 2, nrow=2),
#         make_grid(
#             [draw_segmentation_masks(((i + 1) * 127.5).to(torch.uint8), m, colors=colors_target, alpha=0.5) for i, m in zip(t_img_post, t_mask_post.to(torch.bool))], nrow=2
#         ),
#     ]
# )

In [7]:
from inz.models.baseline_singlebranch import SingleBranchBaselinePLModule, BaselineSingleBranchModule
from inz.models.msl.msl_module_wrapper import FloodnetMslModuleWrapper
from inz.models.msl.msl_loss import IW_MaxSquareloss
from functools import partial

CLASS_WEIGHTS = torch.Tensor([0.01, 1.0, 9.0478803, 8.68207691, 12.9632271]).to(device)

# optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
# lr=0.0002, weight_decay=1e-6 works
# lr=0.00005, weight_decay=1e-6 works better
# consider early stopping? 0.0002 initially improved the score, but degraded before the batch ended
# ^ degradation slower / not present with 0.00005
OPTIM_FACTORY = partial(torch.optim.AdamW, lr=0.00005, weight_decay=1e-6)
SCHED_FACTORY = partial(
        torch.optim.lr_scheduler.MultiStepLR,
        gamma=0.5,
        milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190],
    )
MODEL_CKPT = Path('/home/tomek/inz/inz/outputs/baseline_singlebranch_flood/latest_run/checkpoints/experiment_name-0-epoch-31-step-7616-challenge_score_safe-0.5146-best-challenge-score.ckpt')
_model = SingleBranchBaselinePLModule.load_from_checkpoint(
    checkpoint_path=MODEL_CKPT,
    model=BaselineSingleBranchModule(pretrained=True),
    loss=ComboLoss(weights={"dice": 1, "focal": 1}),
    optimizer_factory=OPTIM_FACTORY,
    scheduler_factory=SCHED_FACTORY,
    class_weights=CLASS_WEIGHTS,
).to(device)

# MODEL_CKPT = Path('/home/tomek/inz/inz/saved_checkpoints/runs/farseg_single/2024-10-25_00-48-01/checkpoints/experiment_name-0-epoch-28-step-28275-challenge_score_safe-0.6489-best-challenge-score.ckpt')
# _model = FarSegSingleBranchModule.load_from_checkpoint(
#     checkpoint_path=MODEL_CKPT,
#     model=SingleBranchFarSeg(
#         n_classes=5,
#         farseg_config={
#             "resnet_encoder": {
#                 "resnet_type": "resnet50",
#                 "include_conv5": True,
#                 "batchnorm_trainable": True,
#                 "pretrained": True,
#                 "freeze_at": 0,
#                 # 8, 16 or 32
#                 "output_stride": 32,
#                 "with_cp": [False, False, False, False],
#                 "stem3_3x3": False
#             },
#             "fpn": {
#                 "in_channels_list": [256, 512, 1024, 2048],
#                 "out_channels": 256,
#                 "conv_block": simplecv.module.fpn.default_conv_block,
#                 "top_blocks": None,
#             },
#             "scene_relation": {
#                 "in_channels": 2048,
#                 "channel_list": [256, 256, 256, 256],
#                 "out_channels": 256
#             },
#             "decoder": {
#                 "in_channels": 256,
#                 "out_channels": 128,
#                 "in_feat_output_strides": [4, 8, 16, 32],
#                 "out_feat_output_stride": 4,
#                 "norm_fn": torch.nn.BatchNorm2d,
#                 "num_groups_gn": None
#             },
#             "num_classes": 5,
#             "loss": {
#                 "cls_weight": 1.0,
#                 "ignore_index": 255
#             },
#             "annealing_softmax_focalloss": {
#                 "gamma": 2.0,
#                 "max_step": 10000,
#                 "annealing_type": "cosine",
#             }
#         }
#     ),
#     optimizer_factory=OPTIM_FACTORY,
#     scheduler_factory=SCHED_FACTORY,
#     class_weights=CLASS_WEIGHTS,
# ).to(device)

model = FloodnetMslModuleWrapper(
    pl_module=_model,
    n_classes_target=3,
    msl_loss_module=IW_MaxSquareloss(ignore_index=-1, num_class=3, ratio=0.2).to(device),
    msl_lambda=0.2,
    optimizer_factory=OPTIM_FACTORY,
    scheduler_factory=SCHED_FACTORY,
    target_conf_matrix_labels=("Background", "Non-flooded", "Flooded")
).to(device)

# model.forward_target(torch.cat([torch.zeros_like(t_img_post), t_img_post], dim=1).to(device))

INFO:simplecv.util.logger:ResNetEncoder: pretrained = True


scene_relation: on
loss type: cosine


/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'msl_loss_module' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['msl_loss_module'])`.


In [8]:
import datetime

In [9]:
# from copy import deepcopy

# wandb_logger_test_initial = get_wandb_logger(
#     run_name=f"delete-me-test-initial-{datetime.datetime.now().replace(microsecond=0).isoformat()}",
#     project="inz",
#     watch_model=True,
#     watch_model_log_frequency=500,
#     watch_model_model=model,
# )

In [10]:
# dm_test_initial = deepcopy(dm)
# dm_test_initial.prepare_data()
# dm_test_initial.setup("test")
# dm_test_initial.test_dataloader = dm_test_initial.train_dataloader

# trainer = pl.Trainer(
#     accelerator="gpu",
#     max_epochs=1,
#     precision="bf16-mixed",
#     deterministic=True,
#     sync_batchnorm=True,
#     callbacks=[
#         pl.callbacks.RichProgressBar()
#     ],
#     logger=wandb_logger_test_initial
# )

# trainer.test(model=model, datamodule=dm_test_initial)

In [11]:
wandb_logger = get_wandb_logger(
    run_name=f"delete-me-{datetime.datetime.now().replace(microsecond=0).isoformat()}",
    project="inz",
    watch_model=True,
    watch_model_log_frequency=500,
    watch_model_model=model,
)

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=5,
    precision="bf16-mixed",
    deterministic=True,
    sync_batchnorm=True,
    callbacks=[
        pl.callbacks.RichProgressBar()
    ],
    logger=wandb_logger
)

trainer.fit(model=model, datamodule=dm)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtomasz-owienko-stud[0m ([33mtomasz-owienko-stud-warsaw-university-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/tomek/.netrc


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:44: attribute 'msl_loss_module' removed from hparams because it cannot be pickled
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()