In [1]:
from pathlib import Path

import dotenv
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
import torchvision.transforms as T


In [2]:
import os 
if str(Path.cwd()).endswith("notebooks"):
    os.chdir("..")

In [3]:
from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3, Tier1
from inz.models.baseline_module import BaselineModule
from inz.util import get_loc_cls_weights, get_wandb_logger, show_masks_comparison
from inz.xview2_strong_baseline.legacy.losses import ComboLoss
from inz.xview2_strong_baseline.legacy.zoo.models import Res34_Unet_Double

In [4]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [5]:
BATCH_SIZE = 64

dm = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events={
        Tier1: [
            Event.guatemala_volcano,
            Event.hurricane_florence,
            Event.hurricane_harvey,
            Event.hurricane_matthew,
            Event.hurricane_michael,
            Event.mexico_earthquake,
            Event.midwest_flooding,
            Event.palu_tsunami,
            Event.santa_rosa_wildfire,
            Event.socal_fire,
        ],
        Tier3: [
            Event.joplin_tornado,
            Event.lower_puna_volcano,
            Event.moore_tornado,
            Event.nepal_flooding,
            Event.pinery_bushfire,
            Event.portugal_wildfire,
            Event.sunda_tsunami,
            Event.tuscaloosa_tornado,
            Event.woolsey_fire,
        ],
    },
    val_fraction=0.15,
    test_fraction=0.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
    transform=T.Compose(
        transforms=[
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply(
                p=0.6, transforms=[T.RandomAffine(degrees=(-10, 10), scale=(0.9, 1.1), translate=(0.1, 0.1))]
            ),
        ]
    ),
)
dm.prepare_data()
dm.setup("fit")

print(f"{len(dm.train_dataloader())} train batches, {len(dm.val_dataloader())} val batches")

488 train batches, 86 val batches


In [6]:
# loc_weights, cls_weights = get_loc_cls_weights(
#     dataloader=dm.train_dataloader(), device=device, drop_unclassified_class=True
# )

cls_weights = torch.Tensor([0.01, 1, 9.04788032, 8.68207691, 12.9632271]).to(device)

print(f"Classification weights: {cls_weights}")

Classification weights: tensor([1.0000e-02, 1.0000e+00, 9.0479e+00, 8.6821e+00, 1.2963e+01],
       device='cuda:0')


In [7]:
class BaselineSingleBranchModule(Res34_Unet_Double):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # reduce the filter number by half
        self.res = torch.nn.Conv2d(48, 5, 1, stride=1, padding=0)

    # Since this is no longer a siamese network, we forward the tensor only once
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self.forward_once(x)
        return self.res(output)

In [8]:
class SingleBranchBaselinePLModule(BaselineModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # pretend you didn't see anything
        return super().forward(x[:, 3:, ...])

In [9]:
model = SingleBranchBaselinePLModule(
    model=BaselineSingleBranchModule(pretrained=True).to(device),
    loss=ComboLoss(weights={"dice": 1, "focal": 1}).to(device),
    class_weights=cls_weights.to(device),
    optimizer_factory=lambda params: torch.optim.AdamW(params, lr=0.0002, weight_decay=1e-6),
    scheduler_factory=lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer,
        gamma=0.5,
        milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190],
    ),
).to(device)

x_post = torch.rand([16, 6, 256, 256]).to(device)

with torch.no_grad():
    out = model(x_post)
    assert out.shape == (16, 5, 256, 256), f"Shape is {out.size()}"

using weights from ResNet34_Weights.IMAGENET1K_V1


In [10]:
CKPT_DIR = "baseline_singlebranch_ckpt"

wandb_logger = get_wandb_logger(run_name="baseline_singlebranch")

trainer = pl.Trainer(
    max_epochs=40,
    callbacks=[
        RichProgressBar(),
        ModelCheckpoint(
            dirpath=CKPT_DIR,
            save_last="link",
            save_top_k=1,
            monitor="epoch",
            mode="max",
            filename="baseline_singlebranch-{epoch:02d}-{step:03d}-{f1:.6f}-last"
        ),
        ModelCheckpoint(
            dirpath=CKPT_DIR,
            save_last="link",
            save_top_k=2,
            monitor="f1_safe",
            mode="max",
            filename="baseline_singlebranch-{epoch:02d}-{step:03d}-{f1_safe:.6f}-best-f1"
        ),
        ModelCheckpoint(
            dirpath=CKPT_DIR,
            save_last="link",
            save_top_k=2,
            monitor="challenge_score_safe",
            mode="max",
            filename="baseline_singlebranch-{epoch:02d}-{step:03d}-{challenge_score_safe:.6f}-best-challenge-score"
        )
    ],
    precision="bf16",
    logger=wandb_logger,
)
trainer.fit(model, datamodule=dm)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtomasz-owienko-stud[0m ([33mtomasz-owienko-stud-warsaw-university-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/tomek/.netrc
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/lightning_fabric/connector.py:563: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:44: attribute 'optimizer_factory' removed from hparams because it cannot be pickle

Output()

In [11]:
model.eval()

SingleBranchBaselinePLModule(
  (model): BaselineSingleBranchModule(
    (conv6): ConvRelu(
      (layer): Sequential(
        (0): Conv2d(512, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (conv6_2): ConvRelu(
      (layer): Sequential(
        (0): Conv2d(576, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (conv7): ConvRelu(
      (layer): Sequential(
        (0): Conv2d(320, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (conv7_2): ConvRelu(
      (layer): Sequential(
        (0): Conv2d(288, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (conv8): ConvRelu(
      (layer): Sequential(
        (0): Conv2d(160, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (conv8_2): ConvRelu(
      (layer): Seq

In [None]:
it = iter(dm.val_dataloader())
batch = next(it)
images_pre, masks_pre, images_post, masks_post = batch
with torch.no_grad():
    m = model.to(device)
    preds = m(torch.cat([images_pre, images_post], dim=1).to(device))

In [None]:
plt.rcParams["savefig.bbox"] = "tight"
plt.rcParams["figure.figsize"] = [20, 120]

show_masks_comparison(
    images_pre=images_pre, images_post=images_post, masks_pre=masks_pre, masks_post=masks_post, preds=preds, opacity=0.3
)