https://arxiv.org/pdf/1909.13589
https://github.com/ZJULearning/MaxSquareLoss


In [1]:
from pathlib import Path
from functools import partial
import dotenv
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3, Tier1, Hold, Test
from inz.models.baseline_module import BaselineModule
from inz.util import get_loc_cls_weights, get_wandb_logger, show_masks_comparison
from inz.xview2_strong_baseline.legacy.losses import ComboLoss
from inz.xview2_strong_baseline.legacy.zoo.models import Res34_Unet_Double
from inz.models.baseline_module import BaselineModule
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics.functional.classification import multiclass_f1_score, binary_f1_score
from torchmetrics.functional.classification import binary_accuracy

In [2]:
import sys  # noqa: I001

sys.path.append("inz/revgrad")

from utils import GradientReversal

sys.path.append("inz/xview2_strong_baseline")

from legacy.zoo.models import Res34_Unet_Double

In [3]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
from inz.data.data_module_floodnet import FloodNetDataset, FloodNetModule
import torchvision.transforms as T


# BATCH_SIZE = 16
BATCH_SIZE = 16

source_events = {
    Tier1: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Test: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Hold: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Tier3: [
        Event.nepal_flooding
    ],
}

dm1 = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events=source_events,
    val_fraction=0.15,
    test_fraction=0.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
)
dm1.prepare_data()
dm1.setup("fit")
dm2 = FloodNetModule(
    path=Path("data/floodnet_processed_512/FloodNet-Supervised_v1.0"),
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
    transform=T.Compose(
        transforms=[
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply(
                p=0.6, transforms=[T.RandomAffine(degrees=(-10, 10), scale=(0.9, 1.1), translate=(0.1, 0.1))]
            ),
        ]
    ),
)
dm2.prepare_data()
dm2.setup("fit")


In [5]:
CLASS_WEIGHTS = torch.Tensor([0.01, 1.0, 9.0478803, 8.68207691, 12.9632271])

In [6]:
# Baseline single branch

from inz.models.baseline_singlebranch import SingleBranchBaselinePLModule, BaselineSingleBranchModule

# MODEL_CKPT = Path('/home/tomek/inz/inz/saved_checkpoints/runs/baseline_singlebranch/baseline_singlebranch_ckpt/baseline_singlebranch-epoch=33-step=16592-challenge_score_safe=0.639932-best-challenge-score.ckpt')
MODEL_CKPT = Path('/home/tomek/inz/inz/outputs/baseline_singlebranch_flood/latest_run/checkpoints/experiment_name-0-epoch-31-step-7616-challenge_score_safe-0.5146-best-challenge-score.ckpt')
model = SingleBranchBaselinePLModule.load_from_checkpoint(
    checkpoint_path=MODEL_CKPT,
    model=BaselineSingleBranchModule(pretrained=True),
    loss=ComboLoss(weights={"dice": 1, "focal": 1}),
    # optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
    # lr=0.0002, weight_decay=1e-6 works
    # lr=0.00005, weight_decay=1e-6 works better
    # consider early stopping? 0.0002 initially improved the score, but degraded before the batch ended
    # ^ degradation slower / not present with 0.00005
    optimizer_factory=partial(torch.optim.AdamW, lr=0.00005, weight_decay=1e-6),
    scheduler_factory=partial(
        torch.optim.lr_scheduler.MultiStepLR,
        gamma=0.5,
        milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190],
    ),
    class_weights=CLASS_WEIGHTS,
).to(device)
model.class_weights = model.class_weights.to(device)



# Farseg single branch

# from inz.models.farseg_singlebranch_module import SingleBranchFarSeg, FarSegSingleBranchModule
# import simplecv

# MODEL_CKPT = Path('/home/tomek/inz/inz/saved_checkpoints/runs/farseg_single/2024-10-25_00-48-01/checkpoints/experiment_name-0-epoch-28-step-28275-challenge_score_safe-0.6489-best-challenge-score.ckpt')
# model = FarSegSingleBranchModule.load_from_checkpoint(
#     checkpoint_path=MODEL_CKPT,
#     model=SingleBranchFarSeg(
#         n_classes=5,
#         farseg_config={
#             "resnet_encoder": {
#                 "resnet_type": "resnet50",
#                 "include_conv5": True,
#                 "batchnorm_trainable": True,
#                 "pretrained": True,
#                 "freeze_at": 0,
#                 # 8, 16 or 32
#                 "output_stride": 32,
#                 "with_cp": [False, False, False, False],
#                 "stem3_3x3": False
#             },
#             "fpn": {
#                 "in_channels_list": [256, 512, 1024, 2048],
#                 "out_channels": 256,
#                 "conv_block": simplecv.module.fpn.default_conv_block,
#                 "top_blocks": None,
#             },
#             "scene_relation": {
#                 "in_channels": 2048,
#                 "channel_list": [256, 256, 256, 256],
#                 "out_channels": 256
#             },
#             "decoder": {
#                 "in_channels": 256,
#                 "out_channels": 128,
#                 "in_feat_output_strides": [4, 8, 16, 32],
#                 "out_feat_output_stride": 4,
#                 "norm_fn": torch.nn.BatchNorm2d,
#                 "num_groups_gn": None
#             },
#             "num_classes": 5,
#             "loss": {
#                 "cls_weight": 1.0,
#                 "ignore_index": 255
#             },
#             "annealing_softmax_focalloss": {
#                 "gamma": 2.0,
#                 "max_step": 10000,
#                 "annealing_type": "cosine",
#             }
#         }
#     ),
#     optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
#     scheduler_factory=partial(
#         torch.optim.lr_scheduler.PolynomialLR,
#         power=0.9,
#         total_iters=5000,
#     ),
#     class_weights=CLASS_WEIGHTS,
# ).to(device)
# model.class_weights = model.class_weights.to(device)

using weights from ResNet34_Weights.IMAGENET1K_V1


In [7]:
optim = model.optimizer_factory(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optim,
    gamma=0.5,
    milestones=[
        5,
        11,
        17,
        23,
        29,
        33,
        47,
        50,
        60,
        70,
        90,
        110,
        130,
        150,
        170,
        180,
        190,
    ],
)

https://arxiv.org/pdf/1909.13589
https://github.com/ZJULearning/MaxSquareLoss


In [8]:
class MaxSquareloss(nn.Module):
    def __init__(self, ignore_index=-1, num_class=19):
        super().__init__()
        self.ignore_index = ignore_index
        self.num_class = num_class

    def forward(self, pred, prob):
        """
        :param pred: predictions (N, C, H, W)
        :param prob: probability of pred (N, C, H, W)
        :return: maximum squares loss
        """
        # prob -= 0.5
        mask = prob != self.ignore_index
        loss = -torch.mean(torch.pow(prob, 2)[mask]) / 2
        return loss


class IW_MaxSquareloss(nn.Module):
    """

    #############################
    NO GUARANTEE THIS WORKS
    #############################

    """

    def __init__(self, ignore_index=-1, num_class=19, ratio=0.2):
        super().__init__()
        self.ignore_index = ignore_index
        self.num_class = num_class
        self.ratio = ratio

    def forward(self, pred, prob, label=None):
        """
        :param pred: predictions (N, C, H, W)
        :param prob: probability of pred (N, C, H, W)
        :param label(optional): the map for counting label numbers (N, C, H, W)
        :return: maximum squares loss with image-wise weighting factor
        """
        # prob -= 0.5
        N, C, H, W = prob.size()
        mask = prob != self.ignore_index
        maxpred, argpred = torch.max(prob, 1)
        mask_arg = maxpred != self.ignore_index
        argpred = torch.where(mask_arg, argpred, torch.ones(1).to(prob.device, dtype=torch.long) * self.ignore_index)
        if label is None:
            label = argpred
        weights = []
        batch_size = prob.size(0)
        for i in range(batch_size):
            hist = torch.histc(
                label[i].cpu().data.float(), bins=self.num_class + 1, min=-1, max=self.num_class - 1
            ).float()
            hist = hist[1:]
            weight = (
                (1 / torch.max(torch.pow(hist, self.ratio) * torch.pow(hist.sum(), 1 - self.ratio), torch.ones(1)))
                .to(argpred.device)[argpred[i]]
                .detach()
            )
            weights.append(weight)
        weights = torch.stack(weights, dim=0)
        mask = mask_arg.unsqueeze(1).expand_as(prob)
        prior = torch.mean(prob, (2, 3), True).detach()
        # loss = -torch.sum((torch.pow(prob, 2)*weights)[mask]) / (batch_size*self.num_class)
        # fix for tensor shape issue
        loss = -torch.sum((torch.pow(prob, 2) * weights.unsqueeze(dim=1))[mask]) / (batch_size * self.num_class)
        return loss

In [9]:

import gc
from tqdm.auto import tqdm
from itertools import islice, cycle

source_loader = dm1.train_dataloader()
target_loader = dm2.train_dataloader()

EPOCHS = 40
LAMBDA_TARGET = 0.2

# target_loss = MaxSquareloss(ignore_index=-1, num_class=3).to(device)
target_loss = IW_MaxSquareloss(ignore_index=-1, num_class=3, ratio=0.2).to(device)

for epoch in range(1, EPOCHS + 1):
    # batches = zip(source_loader, target_loader)
    # n_batches = min(len(source_loader), len(target_loader)) - 1
    batches = zip(source_loader, cycle(target_loader))  # assuming target_loader is smaller
    n_batches = len(source_loader) - 1

    total_source_loss = total_target_loss = total_source_f1 = total_target_f1 = total_source_challenge_score = (
        total_target_challenge_score
    ) = 0
    total_source_f1_class = torch.Tensor([0.0, 0.0, 0.0, 0.0, 0.0]).to(device)
    total_target_f1_class = torch.Tensor([0.0, 0.0, 0.0]).to(device)
    for step, ((source_batch), (target_batch)) in enumerate(
        tqdm(islice(batches, n_batches), total=n_batches, position=0, leave=True), start=1
    ):

        gc.collect()
        if epoch == 0:
            model.eval()
            torch.set_grad_enabled(False)
        else:
            model.train()
            torch.set_grad_enabled(True)

        s_img_pre, s_mask_pre, s_img_post, s_mask_post = source_batch

        s_img_pre = s_img_pre.to(device)
        s_mask_pre = s_mask_pre.to(device)
        s_img_post = s_img_post.to(device)
        s_mask_post = s_mask_post.to(device)

        source_preds = model(torch.cat([s_img_pre, s_img_post], dim=1))

        if epoch != 0:
            # Baseline
            loss_source = torch.stack(
                [
                    model.loss_fn(source_preds[:, i, ...], s_mask_post.to(torch.float)[:, i, ...])
                    for i in range(source_preds.shape[1])
                ]
            )
            loss_source = loss_source.dot(model.class_weights).sum()

            # FarSeg
            # loss_source = model.model.module.config.loss.cls_weight * model.model.module.cls_loss(
            #     source_preds, torch.argmax(s_mask_post, dim=1)
            # )

            loss_source.backward()
            total_source_loss += loss_source.cpu().item()

        t_img_pre, t_mask_pre, t_img_post, t_mask_post = target_batch

        t_img_pre = t_img_pre.to(device)
        t_mask_pre = t_mask_pre.to(device)
        t_img_post = t_img_post.to(device)
        t_mask_post = t_mask_post.to(device)

        _target_preds = model(torch.cat([t_img_pre, t_img_post], dim=1))
        target_preds = torch.cat([_target_preds[:, :2, ...], _target_preds[:, 2:, ...].max(dim=1, keepdim=True).values], dim=1)

        if epoch != 0:
            target_preds_P = F.softmax(target_preds, dim=1)
            target_preds_labels = target_preds_P
            loss_target = LAMBDA_TARGET * target_loss(target_preds, target_preds_labels)

            loss_target.backward()
            total_target_loss += loss_target.cpu().item()

        if epoch != 0:
            optim.step()
            optim.zero_grad()

        f1_source = multiclass_f1_score(source_preds.argmax(dim=1), s_mask_post.argmax(dim=1), num_classes=5)
        total_source_f1 += f1_source

        f1_source_class = multiclass_f1_score(
            source_preds.argmax(dim=1), s_mask_post.argmax(dim=1), num_classes=5, average="none"
        )
        total_source_f1_class += f1_source_class

        source_challenge_score = 0.3 * binary_f1_score(
            (source_preds.argmax(dim=1) > 0).to(torch.int), (s_mask_post.argmax(dim=1) > 0).to(torch.int)
        ) + 0.7 * 5 / sum(1 / v for v in f1_source_class)
        total_source_challenge_score += source_challenge_score

        with torch.no_grad():
            _label_preds_target = model(torch.cat([t_img_pre, t_img_post], dim=1))
            label_preds_target = torch.cat([_label_preds_target[:, :2, ...], _label_preds_target[:, 2:, ...].max(dim=1, keepdim=True).values], dim=1)
            f1_target = multiclass_f1_score(
                label_preds_target.argmax(dim=1), t_mask_post.to(device).argmax(dim=1), num_classes=3
            )
            total_target_f1 += f1_target

            f1_target_class = multiclass_f1_score(
                label_preds_target.argmax(dim=1), t_mask_post.to(device).argmax(dim=1), num_classes=3, average="none"
            )
            total_target_f1_class += f1_target_class

            target_challenge_score = 0.3 * binary_f1_score(
                (target_preds.argmax(dim=1) > 0).to(torch.int), (t_mask_post.argmax(dim=1) > 0).to(torch.int)
            ) + 0.7 * 3 / sum(1 / v for v in f1_target_class)
            total_target_challenge_score += target_challenge_score

        if epoch != 0:
            tqdm.write(
                f"STEP {(epoch-1) * BATCH_SIZE + step:03d}: source_loss={loss_source.item():.4f}, target_loss={loss_target.item():.4f}, "
                f"source_f1={f1_source:.4f}, "
                f"target_f1={f1_target:.4f}, "
                f"\nsource_f1_class={f1_source_class}"
                f"\ntarget_f1_class={f1_target_class}"
                f"\nsource_challenge_score={source_challenge_score:.4f}"
                f"\ntarget_challenge_score={target_challenge_score:.4f}"
            )


    if epoch != 0:
        scheduler.step()

    mean_source_loss = total_source_loss / n_batches
    mean_target_loss = total_target_loss / n_batches
    mean_f1_source = total_source_f1 / n_batches
    mean_f1_target = total_target_f1 / n_batches
    mean_f1_source_class = total_source_f1_class / n_batches
    mean_f1_target_class = total_target_f1_class / n_batches
    mean_source_challenge_score = total_source_challenge_score / n_batches
    mean_target_challenge_score = total_target_challenge_score / n_batches
    tqdm.write(
        f"############## EPOCH {epoch:03d}: source_loss={mean_source_loss:.4f}, target_loss={mean_target_loss:.4f}, "
        f"source_f1={mean_f1_source:.4f}, "
        f"target_f1={mean_f1_target:.4f}, "
        f"\nsource_f1_class={mean_f1_source_class}"
        f"\ntarget_f1_class={mean_f1_target_class}"
        f"\nsource_challenge_score={mean_source_challenge_score:.4f}"
        f"\ntarget_challenge_score={mean_target_challenge_score:.4f}"
    )

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 001: source_loss=26.3826, target_loss=-0.0697, source_f1=0.4592, target_f1=0.5438, 
source_f1_class=tensor([0.9829, 0.6566, 0.4046, 0.2521, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9645, 0.2985, 0.3684], device='cuda:0')
source_challenge_score=0.2199
target_challenge_score=0.4732
STEP 002: source_loss=21.1042, target_loss=-0.0689, source_f1=0.5768, target_f1=0.6828, 
source_f1_class=tensor([0.9689, 0.6908, 0.4066, 0.6593, 0.1584], device='cuda:0')
target_f1_class=tensor([0.9797, 0.5755, 0.4933], device='cuda:0')
source_challenge_score=0.4919
target_challenge_score=0.6630
STEP 003: source_loss=24.2987, target_loss=-0.0695, source_f1=0.4766, target_f1=0.5814, 
source_f1_class=tensor([0.9944, 0.5211, 0.2727, 0.5950, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9765, 0.2614, 0.5063], device='cuda:0')
source_challenge_score=0.2071
target_challenge_score=0.5267
STEP 004: source_loss=16.8321, target_loss=-0.0701, source_f1=0.6400, target_f1=0.6288, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 017: source_loss=12.4074, target_loss=-0.0699, source_f1=0.7429, target_f1=0.5648, 
source_f1_class=tensor([0.9813, 0.7417, 0.7149, 0.7296, 0.5469], device='cuda:0')
target_f1_class=tensor([0.9693, 0.5149, 0.2101], device='cuda:0')
source_challenge_score=0.7339
target_challenge_score=0.4636
STEP 018: source_loss=20.8705, target_loss=-0.0700, source_f1=0.5855, target_f1=0.6164, 
source_f1_class=tensor([0.9890, 0.6695, 0.5700, 0.6988, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9421, 0.3064, 0.6006], device='cuda:0')
source_challenge_score=0.2315
target_challenge_score=0.5067
STEP 019: source_loss=23.6847, target_loss=-0.0696, source_f1=0.4986, target_f1=0.6862, 
source_f1_class=tensor([0.9850, 0.6532, 0.3487, 0.4499, 0.0559], device='cuda:0')
target_f1_class=tensor([0.9550, 0.4745, 0.6291], device='cuda:0')
source_challenge_score=0.3506
target_challenge_score=0.6226
STEP 020: source_loss=24.2402, target_loss=-0.0703, source_f1=0.4659, target_f1=0.6588, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 033: source_loss=23.0568, target_loss=-0.0697, source_f1=0.5365, target_f1=0.6532, 
source_f1_class=tensor([0.9770, 0.6348, 0.4641, 0.6066, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9677, 0.3319, 0.6599], device='cuda:0')
source_challenge_score=0.2113
target_challenge_score=0.5859
STEP 034: source_loss=16.5081, target_loss=-0.0693, source_f1=0.6306, target_f1=0.7133, 
source_f1_class=tensor([0.9915, 0.7064, 0.5239, 0.3444, 0.5867], device='cuda:0')
target_f1_class=tensor([0.9580, 0.4032, 0.7788], device='cuda:0')
source_challenge_score=0.6278
target_challenge_score=0.6131
STEP 035: source_loss=18.9986, target_loss=-0.0695, source_f1=0.5988, target_f1=0.5554, 
source_f1_class=tensor([0.9824, 0.5731, 0.4556, 0.6838, 0.2989], device='cuda:0')
target_f1_class=tensor([0.9700, 0.4737, 0.2225], device='cuda:0')
source_challenge_score=0.5738
target_challenge_score=0.4674
STEP 036: source_loss=22.3347, target_loss=-0.0698, source_f1=0.6952, target_f1=0.6126, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 049: source_loss=22.4280, target_loss=-0.0704, source_f1=0.5448, target_f1=0.5171, 
source_f1_class=tensor([0.9764, 0.7012, 0.4650, 0.5816, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9638, 0.2333, 0.3542], device='cuda:0')
source_challenge_score=0.2137
target_challenge_score=0.4591
STEP 050: source_loss=24.4238, target_loss=-0.0710, source_f1=0.4662, target_f1=0.6840, 
source_f1_class=tensor([0.9944, 0.5371, 0.3398, 0.4598, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9591, 0.5114, 0.5816], device='cuda:0')
source_challenge_score=0.1809
target_challenge_score=0.6595
STEP 051: source_loss=20.8516, target_loss=-0.0697, source_f1=0.5654, target_f1=0.7201, 
source_f1_class=tensor([0.9864, 0.6376, 0.4687, 0.7343, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9653, 0.4944, 0.7004], device='cuda:0')
source_challenge_score=0.2004
target_challenge_score=0.6615
STEP 052: source_loss=19.7077, target_loss=-0.0707, source_f1=0.6019, target_f1=0.5725, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 065: source_loss=21.7844, target_loss=-0.0702, source_f1=0.5682, target_f1=0.7344, 
source_f1_class=tensor([0.9835, 0.6705, 0.6275, 0.5597, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9651, 0.5672, 0.6710], device='cuda:0')
source_challenge_score=0.2373
target_challenge_score=0.6996
STEP 066: source_loss=22.3046, target_loss=-0.0701, source_f1=0.5239, target_f1=0.5969, 
source_f1_class=tensor([0.9891, 0.6617, 0.3171, 0.4691, 0.1828], device='cuda:0')
target_f1_class=tensor([0.9574, 0.4263, 0.4069], device='cuda:0')
source_challenge_score=0.4603
target_challenge_score=0.5418
STEP 067: source_loss=25.9635, target_loss=-0.0696, source_f1=0.4524, target_f1=0.6396, 
source_f1_class=tensor([0.9849, 0.4970, 0.2113, 0.5688, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9642, 0.3738, 0.5807], device='cuda:0')
source_challenge_score=0.2145
target_challenge_score=0.5603
STEP 068: source_loss=24.0631, target_loss=-0.0699, source_f1=0.5470, target_f1=0.7137, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 081: source_loss=18.5072, target_loss=-0.0704, source_f1=0.8102, target_f1=0.6927, 
source_f1_class=tensor([0.9842, 0.7570, 0.7649, 0.7346, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9678, 0.4950, 0.6152], device='cuda:0')
source_challenge_score=0.2325
target_challenge_score=0.6641
STEP 082: source_loss=20.6316, target_loss=-0.0686, source_f1=0.5729, target_f1=0.7233, 
source_f1_class=tensor([0.9913, 0.6439, 0.7433, 0.4857, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9740, 0.4466, 0.7494], device='cuda:0')
source_challenge_score=0.2160
target_challenge_score=0.6319
STEP 083: source_loss=9.4897, target_loss=-0.0701, source_f1=0.7571, target_f1=0.7006, 
source_f1_class=tensor([0.9859, 0.7145, 0.5545, 0.6416, 0.8890], device='cuda:0')
target_f1_class=tensor([0.9585, 0.4932, 0.6501], device='cuda:0')
source_challenge_score=0.7356
target_challenge_score=0.6633
STEP 084: source_loss=20.3407, target_loss=-0.0698, source_f1=0.5732, target_f1=0.7613, 
source_f1_class=tenso

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 097: source_loss=10.8993, target_loss=-0.0705, source_f1=0.7388, target_f1=0.5728, 
source_f1_class=tensor([0.9847, 0.6413, 0.6431, 0.6904, 0.7346], device='cuda:0')
target_f1_class=tensor([0.9620, 0.4888, 0.2675], device='cuda:0')
source_challenge_score=0.7280
target_challenge_score=0.5053
STEP 098: source_loss=23.8150, target_loss=-0.0712, source_f1=0.5197, target_f1=0.5733, 
source_f1_class=tensor([0.9697, 0.6581, 0.3379, 0.6327, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9446, 0.2960, 0.4792], device='cuda:0')
source_challenge_score=0.2028
target_challenge_score=0.5161
STEP 099: source_loss=21.2273, target_loss=-0.0700, source_f1=0.5730, target_f1=0.6769, 
source_f1_class=tensor([0.9824, 0.6963, 0.6543, 0.5320, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9697, 0.5548, 0.5061], device='cuda:0')
source_challenge_score=0.2097
target_challenge_score=0.6242
STEP 100: source_loss=20.5900, target_loss=-0.0703, source_f1=0.5899, target_f1=0.6718, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 113: source_loss=17.2431, target_loss=-0.0698, source_f1=0.6297, target_f1=0.5616, 
source_f1_class=tensor([0.9853, 0.6727, 0.6168, 0.5421, 0.3318], device='cuda:0')
target_f1_class=tensor([0.9501, 0.2657, 0.4689], device='cuda:0')
source_challenge_score=0.6101
target_challenge_score=0.4639
STEP 114: source_loss=21.7539, target_loss=-0.0700, source_f1=0.5537, target_f1=0.6182, 
source_f1_class=tensor([0.9873, 0.6551, 0.4894, 0.6368, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9618, 0.2737, 0.6190], device='cuda:0')
source_challenge_score=0.2176
target_challenge_score=0.5171
STEP 115: source_loss=21.0009, target_loss=-0.0705, source_f1=0.5722, target_f1=0.7295, 
source_f1_class=tensor([0.9821, 0.6925, 0.6737, 0.5130, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9711, 0.5271, 0.6903], device='cuda:0')
source_challenge_score=0.2187
target_challenge_score=0.7116
STEP 116: source_loss=23.9988, target_loss=-0.0703, source_f1=0.6444, target_f1=0.5796, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 129: source_loss=19.4138, target_loss=-0.0703, source_f1=0.7767, target_f1=0.5792, 
source_f1_class=tensor([0.9953, 0.7480, 0.4931, 0.8702, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9525, 0.3893, 0.3958], device='cuda:0')
source_challenge_score=0.2377
target_challenge_score=0.5131
STEP 130: source_loss=14.2896, target_loss=-0.0703, source_f1=0.6944, target_f1=0.6150, 
source_f1_class=tensor([0.9830, 0.7082, 0.6315, 0.6472, 0.5020], device='cuda:0')
target_f1_class=tensor([0.9531, 0.2824, 0.6094], device='cuda:0')
source_challenge_score=0.6976
target_challenge_score=0.5166


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fbc986a53d0>>
Traceback (most recent call last):
  File "/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


STEP 131: source_loss=21.0563, target_loss=-0.0705, source_f1=0.5928, target_f1=0.6081, 
source_f1_class=tensor([0.9808, 0.7496, 0.4939, 0.7398, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9285, 0.2553, 0.6405], device='cuda:0')
source_challenge_score=0.2290
target_challenge_score=0.4958
STEP 132: source_loss=22.1393, target_loss=-0.0706, source_f1=0.5310, target_f1=0.6281, 
source_f1_class=tensor([0.9779, 0.5979, 0.4724, 0.4657, 0.1412], device='cuda:0')
target_f1_class=tensor([0.9595, 0.3929, 0.5318], device='cuda:0')
source_challenge_score=0.4524
target_challenge_score=0.5890
STEP 133: source_loss=22.4236, target_loss=-0.0702, source_f1=0.5554, target_f1=0.6112, 
source_f1_class=tensor([0.9876, 0.7506, 0.7149, 0.3240, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9695, 0.3887, 0.4755], device='cuda:0')
source_challenge_score=0.2361
target_challenge_score=0.5643
STEP 134: source_loss=18.5187, target_loss=-0.0713, source_f1=0.6274, target_f1=0.7027, 
source_f1_class=tens

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

from torchvision.utils import draw_segmentation_masks, make_grid  # type: ignore[import-untyped]
import numpy as np
from PIL import Image

In [None]:
class SemanticSegmentationTarget:
    def __init__(self, category, mask):
        self.category = category
        self.mask = mask
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()

    def __call__(self, model_output):
        return (model_output[self.category, :, :] * self.mask).sum()

In [None]:
dm_test = dm2
dm_test.prepare_data()
dm_test.setup("test")

In [None]:
# IMG_NUM = 7340
# IMG_NUM = 6710
IMG_NUM = 6716
IMG_PATH = Path(f"data/floodnet_processed_512/FloodNet-Supervised_v1.0/train/train-org-img/{IMG_NUM}.png")

# FarSeg
# target_layers = model.model.module.decoder.blocks

# Baseline
target_layers = [
    model.model.conv6_2,
    model.model.conv7_2,
    model.model.conv8_2,
    model.model.conv9_2,
    model.model.conv10,
    model.model.res,
]

# region boring
item = dm_test._train_dataset[dm_test._train_dataset.image_paths.index(IMG_PATH)]
image_pre, mask_pre, image_post, mask_post = item

with torch.no_grad():
    m = model.to(device)
    _preds = m(torch.cat([torch.unsqueeze(image_pre, 0), torch.unsqueeze(image_post, 0)], dim=1).to(device))
    preds = torch.cat([_preds[:, :2, ], _preds[:, 2:, ].max(dim=1, keepdim=True).values], dim=1)

for c in range(0, 5):
    targets = [SemanticSegmentationTarget(c, mask_post.argmax(axis=0))]
    with GradCAM(
        model=model,
        target_layers=target_layers,
    ) as cam:
        grayscale_cam = cam(
            input_tensor=torch.unsqueeze(torch.cat([image_pre, image_post], dim=0), 0), targets=targets
        )[0, :]
        cam_image = show_cam_on_image(
            ((image_post + 1) / 2).moveaxis(0, -1).detach().cpu().numpy(), grayscale_cam, use_rgb=True
        )

    colors = [
        (128, 128, 128),
        (0, 255, 0),
        (255, 0, 0),
    ]
    preds_masks = draw_segmentation_masks(
        ((image_post + 1) * 127.5).to(torch.uint8),
        torch.nn.functional.one_hot(preds[0].argmax(dim=0), num_classes=3).moveaxis(-1, 0).to(bool),
        colors=colors,
        alpha=0.3,
    )
    ground_truth_masks = draw_segmentation_masks(
        ((image_post + 1) * 127.5).to(torch.uint8), mask_post.to(bool), colors=colors, alpha=0.3
    )
    comp = np.hstack(
        (
            ground_truth_masks.moveaxis(0, -1).detach().cpu().numpy(),
            preds_masks.moveaxis(0, -1).detach().cpu().numpy(),
            cam_image,
        )
    )
    display(Image.fromarray(comp))

# acts = []
# for c in range(0, 3):
#     targets = [SemanticSegmentationTarget(c, mask_post.argmax(axis=0))]
#     with GradCAM(
#         model=model,
#         target_layers=target_layers,
#     ) as cam:
#         grayscale_cam = cam(
#             input_tensor=torch.unsqueeze(torch.cat([image_pre, image_post], dim=0), 0), targets=targets
#         )[0, :]
#         cam_image = show_cam_on_image(
#             ((image_post + 1) / 2).moveaxis(0, -1).detach().cpu().numpy(), grayscale_cam, use_rgb=True
#         )

#     acts.append(torch.    from_numpy(cam_image).moveaxis(-1, 0))
# comp_acts = make_grid(torch.stack(acts), padding=5, pad_value=255)
# display(Image.fromarray(comp_acts.moveaxis(0, -1).numpy()))

# endregion