In [1]:
import sys
from simplecv.module import fpn
import torch
import torch.nn as nn
import dotenv
import pytorch_lightning as pl
from pathlib import Path
from typing import Any, Callable

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import torchmetrics.classification
import torchmetrics.segmentation
from torch import Tensor

import dotenv
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3
from inz.models.baseline_module import BaselineModule
from inz.util import get_loc_cls_weights, get_wandb_logger, show_masks_comparison
from inz.xview2_strong_baseline.legacy.losses import ComboLoss
from inz.xview2_strong_baseline.legacy.zoo.models import Res34_Unet_Double

sys.path.append("inz/farseg")

from inz.farseg.module.farseg import FarSeg

In [2]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

INFO:lightning_fabric.utilities.seed:Seed set to 123


In [3]:
dm = XBDDataModule(
    path=Path("data/xBD_processed_noresize"),
    drop_unclassified_channel=True,
    events={
        Tier3: [
            Event.joplin_tornado,
        ],
    },
    val_fraction=0.15,
    test_fraction=0.0,
    train_batch_size=8,
    val_batch_size=8,
    test_batch_size=8,
)
dm.prepare_data()
dm.setup("fit")

print(f"{len(dm.train_dataloader())} train batches, {len(dm.val_dataloader())} val batches")

16 train batches, 3 val batches


In [4]:
config = dict(
    resnet_encoder=dict(
        resnet_type="resnet50",
        include_conv5=True,
        batchnorm_trainable=True,
        pretrained=True,
        freeze_at=0,
        # 8, 16 or 32
        output_stride=32,
        with_cp=(False, False, False, False),
        stem3_3x3=False,
    ),
    fpn=dict(
        in_channels_list=(256, 512, 1024, 2048),
        out_channels=256,
        conv_block=fpn.default_conv_block,
        top_blocks=None,
    ),
    scene_relation=dict(
        in_channels=2048,
        channel_list=(256, 256, 256, 256),
        out_channels=256,
        scale_aware_proj=True,
    ),
    decoder=dict(
        in_channels=256,
        out_channels=128,
        in_feat_output_strides=(4, 8, 16, 32),
        out_feat_output_stride=4,
        norm_fn=nn.BatchNorm2d,
        num_groups_gn=None,
    ),
    num_classes=5,
    loss=dict(
        cls_weight=1.0,
        ignore_index=255,
    ),
    annealing_softmax_focalloss=dict(gamma=2.0, max_step=10000, annealing_type="cosine"),
)

In [5]:
class DoubleBranchFarSeg(nn.Module):
    def __init__(
        self,
        farseg_config: dict,
        n_classes: int
    ):
        super(DoubleBranchFarSeg, self).__init__()
        self.farseg_config = farseg_config
        self.n_classes = n_classes
        self.module = FarSeg(config=farseg_config)
        self.outconv = nn.Conv2d(farseg_config["decoder"]["out_channels"] * 2, n_classes, 1)

    def _module_forward(self, x, module):
        feat_list = module.en(x)
        fpn_feat_list = module.fpn(feat_list)
        if 'scene_relation' in module.config:
            c5 = feat_list[-1]
            c6 = module.gap(c5)
            refined_fpn_feat_list = module.sr(c6, fpn_feat_list)
        else:
            refined_fpn_feat_list = fpn_feat_list

        return module.decoder(refined_fpn_feat_list)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self._module_forward(x1, self.module)
        x2 = self._module_forward(x2, self.module)
        return self.outconv.forward(torch.cat([x1, x2], dim=1))

In [9]:
module = DoubleBranchFarSeg(farseg_config=config, n_classes=5).to(device)
module

INFO:simplecv.util.logger:ResNetEncoder: pretrained = True


scene_relation: on
loss type: cosine


DoubleBranchFarSeg(
  (module): FarSeg(
    (en): ResNetEncoder(
      (resnet): ResNet(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True

In [7]:
dummy_batch = torch.rand((64, 3, 256, 256), dtype=torch.float).to(device)
out = module.forward(dummy_batch, dummy_batch)
out

tensor([[[[ 2.3378e-01,  1.3187e-01, -8.7459e-03,  ..., -3.3284e-02,
           -1.6368e-02,  1.4929e-01],
          [ 1.6156e-01,  1.8875e-01,  1.4566e-01,  ...,  3.7550e-02,
            3.4384e-02,  1.2225e-01],
          [ 9.2675e-02,  1.4514e-01,  1.9306e-01,  ..., -4.4589e-02,
            3.8915e-03,  7.1288e-02],
          ...,
          [ 1.0555e-01,  1.2524e-04, -5.6513e-02,  ...,  2.1224e-01,
            1.1614e-01,  1.7065e-01],
          [ 9.3180e-02,  1.0335e-01,  3.2186e-02,  ...,  1.3751e-01,
            5.7803e-02,  2.2310e-01],
          [ 6.4131e-02,  8.1607e-02,  3.8296e-02,  ...,  7.8346e-03,
            9.1136e-02,  1.1579e-01]],

         [[ 1.3928e-01,  1.4977e-01,  1.9038e-01,  ..., -1.1426e-01,
           -2.0719e-01, -7.7461e-02],
          [ 1.0370e-01,  6.4732e-02,  1.1475e-01,  ..., -6.8705e-02,
           -1.5121e-01, -2.4426e-02],
          [ 5.3508e-02,  5.5013e-02,  4.7841e-02,  ..., -9.4921e-02,
           -8.4437e-02, -2.2836e-01],
          ...,
     

In [8]:
class DoubleBranchFarSegModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        loss: nn.Module,
        optimizer_factory: Callable[[Any], torch.optim.Optimizer],
        scheduler_factory: Callable[[Any], torch.optim.lr_scheduler.LRScheduler] | None = None,
        class_weights: Tensor | None = None,
    ):
        super(DoubleBranchFarSegModule, self).__init__()
        # n classes
        n_classes = 5
        self.n_classes = n_classes

        self.class_weights = class_weights

        self.save_hyperparameters(ignore=["model", "loss"])

        self.model = model

        self.optimizer_factory = optimizer_factory
        self.scheduler_factory = scheduler_factory

        # loss function
        self.loss_fn = loss
        # metrics
        self.accuracy_loc = torchmetrics.classification.BinaryAccuracy()
        self.iou_loc = torchmetrics.segmentation.MeanIoU(num_classes=2)

        self.f1 = torchmetrics.classification.MulticlassF1Score(num_classes=n_classes)
        self.precision = torchmetrics.classification.MulticlassPrecision(num_classes=n_classes)
        self.recall = torchmetrics.classification.MulticlassRecall(num_classes=n_classes)
        self.iou = torchmetrics.segmentation.MeanIoU(num_classes=n_classes)

        self.f1_per_class = torchmetrics.classification.MulticlassF1Score(num_classes=n_classes, average="none")
        self.precision_per_class = torchmetrics.classification.MulticlassPrecision(
            num_classes=n_classes, average="none"
        )
        self.recall_per_class = torchmetrics.classification.MulticlassRecall(num_classes=n_classes, average="none")
        self.iou_per_class = torchmetrics.segmentation.MeanIoU(num_classes=n_classes, per_class=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)  # type: ignore[no-any-return]

    def loss(
        self, images_pre: Tensor, masks_pre: Tensor, images_post: Tensor, masks_post: Tensor
    ) -> tuple[Tensor, Tensor]:
        preds = self.forward(torch.cat([images_pre, images_post], dim=1))
        if self.class_weights is not None:
            # per-class loss in unweighted!
            class_loss = torch.stack(
                [self.loss_fn(preds[:, i, ...], masks_post.to(torch.float)[:, i, ...]) for i in range(preds.shape[1])]
            )
            loss = class_loss.dot(self.class_weights).sum()
        else:
            loss = self.loss_fn(preds, masks_post.to(torch.float))
            class_loss = Tensor([0, 0, 0, 0, 0])

        return loss, class_loss

    def training_step(self, batch: list[Tensor], batch_idx: int) -> Tensor:
        loss, class_loss = self.loss(*batch)

        class_loss_dict = {f"train_loss_{i}": loss_val for i, loss_val in enumerate(class_loss)}
        self.log_dict(class_loss_dict | {"train_loss": loss}, prog_bar=True, batch_size=batch.shape[0])
        return loss  # type: ignore[no-any-return]

    def validation_step(self, batch: list[Tensor], batch_idx: int):  # type: ignore[no-untyped-def]
        with torch.no_grad():
            images_pre, _, images_post, masks_post = batch

            cls_preds = self.forward(torch.cat([images_pre, images_post], dim=1))
            cls_preds_masks = F.one_hot(cls_preds.argmax(dim=1), num_classes=self.n_classes).moveaxis(-1, 1)

            loss, class_loss = self.loss(*batch)

            log_dict = (
                {
                    "acc_loc": self.accuracy_loc(
                        cls_preds.argmax(dim=1).gt(0).to(torch.float), masks_post.argmax(dim=1).gt(0).to(torch.float)
                    ),
                    "iou_loc": self.iou_loc(
                        F.one_hot(cls_preds.argmax(dim=1).gt(0).to(torch.long), num_classes=2).moveaxis(-1, 1),
                        F.one_hot(masks_post.argmax(dim=1).gt(0).to(torch.long), num_classes=2).moveaxis(-1, 1),
                    ),
                }
                | {
                    name: getattr(self, name)(cls_preds.argmax(dim=1), masks_post.argmax(dim=1))
                    for name in ["f1", "precision", "recall"]
                }
                | {"iou": self.iou(cls_preds_masks, masks_post.to(torch.uint8))}
                | {
                    f"{name}_{i}": val
                    for name, vec in {
                        name: getattr(self, f"{name}_per_class")(cls_preds.argmax(dim=1), masks_post.argmax(dim=1))
                        for name in ["f1", "precision", "recall"]
                    }.items()
                    for i, val in enumerate(vec)
                }
                | {
                    f"iou_{i}": val
                    for i, val in enumerate(self.iou_per_class(cls_preds_masks, masks_post.to(torch.uint8)))
                }
                | {f"val_loss_{i}": loss_val for i, loss_val in enumerate(class_loss)}
                | {"val_loss": loss}
            )
            self.log_dict(log_dict, prog_bar=True, batch_size=batch.shape[0])

            return log_dict

    def configure_optimizers(self):  # type: ignore[no-untyped-def]
        optimizer = self.optimizer_factory(self.model.parameters())
        if self.scheduler_factory:
            scheduler = self.scheduler_factory(optimizer)
            return [optimizer], [scheduler]
        else:
            return optimizer