https://arxiv.org/pdf/1909.13589
https://github.com/ZJULearning/MaxSquareLoss


In [1]:
from pathlib import Path
from functools import partial
import dotenv
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3, Tier1, Hold, Test
from inz.models.baseline_module import BaselineModule
from inz.util import get_loc_cls_weights, get_wandb_logger, show_masks_comparison
from inz.xview2_strong_baseline.legacy.losses import ComboLoss
from inz.xview2_strong_baseline.legacy.zoo.models import Res34_Unet_Double
from inz.models.baseline_module import BaselineModule
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics.functional.classification import multiclass_f1_score, binary_f1_score
from torchmetrics.functional.classification import binary_accuracy

In [2]:
import sys  # noqa: I001

sys.path.append("inz/revgrad")

from utils import GradientReversal

sys.path.append("inz/xview2_strong_baseline")

from legacy.zoo.models import Res34_Unet_Double

In [3]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
from inz.data.data_module_floodnet import FloodNetDataset, FloodNetModule
import torchvision.transforms as T


# BATCH_SIZE = 16
BATCH_SIZE = 16

source_events = {
    Tier1: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Test: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Hold: [
        Event.hurricane_florence,
        Event.hurricane_harvey,
        Event.hurricane_michael
    ],
    Tier3: [
        Event.nepal_flooding
    ],
}

dm1 = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events=source_events,
    val_fraction=0.15,
    test_fraction=0.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
)
dm1.prepare_data()
dm1.setup("fit")
dm2 = FloodNetModule(
    path=Path("data/floodnet_processed_512/FloodNet-Supervised_v1.0"),
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
    transform=T.Compose(
        transforms=[
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply(
                p=0.6, transforms=[T.RandomAffine(degrees=(-10, 10), scale=(0.9, 1.1), translate=(0.1, 0.1))]
            ),
        ]
    ),
)
dm2.prepare_data()
dm2.setup("fit")


In [5]:
CLASS_WEIGHTS = torch.Tensor([0.01, 1.0, 9.0478803, 8.68207691, 12.9632271])

In [6]:
# Baseline single branch

from inz.models.baseline_singlebranch import SingleBranchBaselinePLModule, BaselineSingleBranchModule

# MODEL_CKPT = Path('/home/tomek/inz/inz/saved_checkpoints/runs/baseline_singlebranch/baseline_singlebranch_ckpt/baseline_singlebranch-epoch=33-step=16592-challenge_score_safe=0.639932-best-challenge-score.ckpt')
MODEL_CKPT = Path('/home/tomek/inz/inz/outputs/baseline_singlebranch_flood/latest_run/checkpoints/experiment_name-0-epoch-31-step-7616-challenge_score_safe-0.5146-best-challenge-score.ckpt')
model = SingleBranchBaselinePLModule.load_from_checkpoint(
    checkpoint_path=MODEL_CKPT,
    model=BaselineSingleBranchModule(pretrained=True),
    loss=ComboLoss(weights={"dice": 1, "focal": 1}),
    # optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
    # lr=0.0002, weight_decay=1e-6 SOMETIMES gives dmg f1>0 (wow)
    optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
    scheduler_factory=partial(
        torch.optim.lr_scheduler.MultiStepLR,
        gamma=0.5,
        milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190],
    ),
    class_weights=CLASS_WEIGHTS,
).to(device)
model.class_weights = model.class_weights.to(device)



# Farseg single branch

# from inz.models.farseg_singlebranch_module import SingleBranchFarSeg, SingleBranchFarSegModule
# import simplecv

# MODEL_CKPT = Path('/home/tomek/inz/inz/saved_checkpoints/runs/farseg_single/2024-10-25_00-48-01/checkpoints/experiment_name-0-epoch-28-step-28275-challenge_score_safe-0.6489-best-challenge-score.ckpt')
# model = SingleBranchFarSegModule.load_from_checkpoint(
#     checkpoint_path=MODEL_CKPT,
#     model=SingleBranchFarSeg(
#         n_classes=5,
#         farseg_config={
#             "resnet_encoder": {
#                 "resnet_type": "resnet50",
#                 "include_conv5": True,
#                 "batchnorm_trainable": True,
#                 "pretrained": True,
#                 "freeze_at": 0,
#                 # 8, 16 or 32
#                 "output_stride": 32,
#                 "with_cp": [False, False, False, False],
#                 "stem3_3x3": False
#             },
#             "fpn": {
#                 "in_channels_list": [256, 512, 1024, 2048],
#                 "out_channels": 256,
#                 "conv_block": simplecv.module.fpn.default_conv_block,
#                 "top_blocks": None,
#             },
#             "scene_relation": {
#                 "in_channels": 2048,
#                 "channel_list": [256, 256, 256, 256],
#                 "out_channels": 256
#             },
#             "decoder": {
#                 "in_channels": 256,
#                 "out_channels": 128,
#                 "in_feat_output_strides": [4, 8, 16, 32],
#                 "out_feat_output_stride": 4,
#                 "norm_fn": torch.nn.BatchNorm2d,
#                 "num_groups_gn": None
#             },
#             "num_classes": 5,
#             "loss": {
#                 "cls_weight": 1.0,
#                 "ignore_index": 255
#             },
#             "annealing_softmax_focalloss": {
#                 "gamma": 2.0,
#                 "max_step": 10000,
#                 "annealing_type": "cosine",
#             }
#         }
#     ),
#     optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
#     scheduler_factory=partial(
#         torch.optim.lr_scheduler.PolynomialLR,
#         power=0.9,
#         total_iters=5000,
#     ),
#     class_weights=CLASS_WEIGHTS,
# ).to(device)
# model.class_weights = model.class_weights.to(device)

using weights from ResNet34_Weights.IMAGENET1K_V1


In [7]:
optim = model.optimizer_factory(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optim,
    gamma=0.5,
    milestones=[
        5,
        11,
        17,
        23,
        29,
        33,
        47,
        50,
        60,
        70,
        90,
        110,
        130,
        150,
        170,
        180,
        190,
    ],
)

https://arxiv.org/pdf/1909.13589
https://github.com/ZJULearning/MaxSquareLoss


In [8]:
class MaxSquareloss(nn.Module):
    def __init__(self, ignore_index=-1, num_class=19):
        super().__init__()
        self.ignore_index = ignore_index
        self.num_class = num_class

    def forward(self, pred, prob):
        """
        :param pred: predictions (N, C, H, W)
        :param prob: probability of pred (N, C, H, W)
        :return: maximum squares loss
        """
        # prob -= 0.5
        mask = prob != self.ignore_index
        loss = -torch.mean(torch.pow(prob, 2)[mask]) / 2
        return loss


class IW_MaxSquareloss(nn.Module):
    """

    #############################
    NO GUARANTEE THIS WORKS
    #############################

    """

    def __init__(self, ignore_index=-1, num_class=19, ratio=0.2):
        super().__init__()
        self.ignore_index = ignore_index
        self.num_class = num_class
        self.ratio = ratio

    def forward(self, pred, prob, label=None):
        """
        :param pred: predictions (N, C, H, W)
        :param prob: probability of pred (N, C, H, W)
        :param label(optional): the map for counting label numbers (N, C, H, W)
        :return: maximum squares loss with image-wise weighting factor
        """
        # prob -= 0.5
        N, C, H, W = prob.size()
        mask = prob != self.ignore_index
        maxpred, argpred = torch.max(prob, 1)
        mask_arg = maxpred != self.ignore_index
        argpred = torch.where(mask_arg, argpred, torch.ones(1).to(prob.device, dtype=torch.long) * self.ignore_index)
        if label is None:
            label = argpred
        weights = []
        batch_size = prob.size(0)
        for i in range(batch_size):
            hist = torch.histc(
                label[i].cpu().data.float(), bins=self.num_class + 1, min=-1, max=self.num_class - 1
            ).float()
            hist = hist[1:]
            weight = (
                (1 / torch.max(torch.pow(hist, self.ratio) * torch.pow(hist.sum(), 1 - self.ratio), torch.ones(1)))
                .to(argpred.device)[argpred[i]]
                .detach()
            )
            weights.append(weight)
        weights = torch.stack(weights, dim=0)
        mask = mask_arg.unsqueeze(1).expand_as(prob)
        prior = torch.mean(prob, (2, 3), True).detach()
        # loss = -torch.sum((torch.pow(prob, 2)*weights)[mask]) / (batch_size*self.num_class)
        # fix for tensor shape issue
        loss = -torch.sum((torch.pow(prob, 2) * weights.unsqueeze(dim=1))[mask]) / (batch_size * self.num_class)
        return loss

In [9]:

import gc
from tqdm.auto import tqdm
from itertools import islice, cycle

source_loader = dm1.train_dataloader()
target_loader = dm2.train_dataloader()

EPOCHS = 40
LAMBDA_TARGET = 0.2

# target_loss = MaxSquareloss(ignore_index=-1, num_class=3).to(device)
target_loss = IW_MaxSquareloss(ignore_index=-1, num_class=3, ratio=0.2).to(device)

for epoch in range(1, EPOCHS + 1):
    # batches = zip(source_loader, target_loader)
    # n_batches = min(len(source_loader), len(target_loader)) - 1
    batches = zip(source_loader, cycle(target_loader))  # assuming target_loader is smaller
    n_batches = len(source_loader) - 1

    total_source_loss = total_target_loss = total_source_f1 = total_target_f1 = total_source_challenge_score = (
        total_target_challenge_score
    ) = 0
    total_source_f1_class = torch.Tensor([0.0, 0.0, 0.0, 0.0, 0.0]).to(device)
    total_target_f1_class = torch.Tensor([0.0, 0.0, 0.0]).to(device)
    for step, ((source_batch), (target_batch)) in enumerate(
        tqdm(islice(batches, n_batches), total=n_batches, position=0, leave=True), start=1
    ):

        gc.collect()
        if epoch == 0:
            model.eval()
            torch.set_grad_enabled(False)
        else:
            model.train()
            torch.set_grad_enabled(True)

        s_img_pre, s_mask_pre, s_img_post, s_mask_post = source_batch

        s_img_pre = s_img_pre.to(device)
        s_mask_pre = s_mask_pre.to(device)
        s_img_post = s_img_post.to(device)
        s_mask_post = s_mask_post.to(device)

        source_preds = model(torch.cat([s_img_pre, s_img_post], dim=1))

        if epoch != 0:
            # Baseline
            loss_source = torch.stack(
                [
                    model.loss_fn(source_preds[:, i, ...], s_mask_post.to(torch.float)[:, i, ...])
                    for i in range(source_preds.shape[1])
                ]
            )
            loss_source = loss_source.dot(model.class_weights).sum()

            # FarSeg
            # loss_source = model.model.module.config.loss.cls_weight * model.model.module.cls_loss(
            #     source_preds, torch.argmax(s_mask_post, dim=1)
            # )

            loss_source.backward()
            total_source_loss += loss_source.cpu().item()

        t_img_pre, t_mask_pre, t_img_post, t_mask_post = target_batch

        t_img_pre = t_img_pre.to(device)
        t_mask_pre = t_mask_pre.to(device)
        t_img_post = t_img_post.to(device)
        t_mask_post = t_mask_post.to(device)

        _target_preds = model(torch.cat([t_img_pre, t_img_post], dim=1))
        target_preds = torch.cat([_target_preds[:, :2, ...], _target_preds[:, 2:, ...].max(dim=1, keepdim=True).values], dim=1)

        if epoch != 0:
            target_preds_P = F.softmax(target_preds, dim=1)
            target_preds_labels = target_preds_P
            loss_target = LAMBDA_TARGET * target_loss(target_preds, target_preds_labels)

            loss_target.backward()
            total_target_loss += loss_target.cpu().item()

        if epoch != 0:
            optim.step()
            optim.zero_grad()

        f1_source = multiclass_f1_score(source_preds.argmax(dim=1), s_mask_post.argmax(dim=1), num_classes=5)
        total_source_f1 += f1_source

        f1_source_class = multiclass_f1_score(
            source_preds.argmax(dim=1), s_mask_post.argmax(dim=1), num_classes=5, average="none"
        )
        total_source_f1_class += f1_source_class

        source_challenge_score = 0.3 * binary_f1_score(
            (source_preds.argmax(dim=1) > 0).to(torch.int), (s_mask_post.argmax(dim=1) > 0).to(torch.int)
        ) + 0.7 * 5 / sum(1 / v for v in f1_source_class)
        total_source_challenge_score += source_challenge_score

        with torch.no_grad():
            _label_preds_target = model(torch.cat([t_img_pre, t_img_post], dim=1))
            label_preds_target = torch.cat([_label_preds_target[:, :2, ...], _label_preds_target[:, 2:, ...].max(dim=1, keepdim=True).values], dim=1)
            f1_target = multiclass_f1_score(
                label_preds_target.argmax(dim=1), t_mask_post.to(device).argmax(dim=1), num_classes=3
            )
            total_target_f1 += f1_target

            f1_target_class = multiclass_f1_score(
                label_preds_target.argmax(dim=1), t_mask_post.to(device).argmax(dim=1), num_classes=3, average="none"
            )
            total_target_f1_class += f1_target_class

            target_challenge_score = 0.3 * binary_f1_score(
                (target_preds.argmax(dim=1) > 0).to(torch.int), (t_mask_post.argmax(dim=1) > 0).to(torch.int)
            ) + 0.7 * 3 / sum(1 / v for v in f1_target_class)
            total_target_challenge_score += target_challenge_score

        if epoch != 0:
            tqdm.write(
                f"STEP {(epoch-1) * BATCH_SIZE + step:03d}: source_loss={loss_source.item():.4f}, target_loss={loss_target.item():.4f}, "
                f"source_f1={f1_source:.4f}, "
                f"target_f1={f1_target:.4f}, "
                f"\nsource_f1_class={f1_source_class}"
                f"\ntarget_f1_class={f1_target_class}"
                f"\nsource_challenge_score={source_challenge_score:.4f}"
                f"\ntarget_challenge_score={target_challenge_score:.4f}"
            )


    if epoch != 0:
        scheduler.step()

    mean_source_loss = total_source_loss / n_batches
    mean_target_loss = total_target_loss / n_batches
    mean_f1_source = total_source_f1 / n_batches
    mean_f1_target = total_target_f1 / n_batches
    mean_f1_source_class = total_source_f1_class / n_batches
    mean_f1_target_class = total_target_f1_class / n_batches
    mean_source_challenge_score = total_source_challenge_score / n_batches
    mean_target_challenge_score = total_target_challenge_score / n_batches
    tqdm.write(
        f"############## EPOCH {epoch:03d}: source_loss={mean_source_loss:.4f}, target_loss={mean_target_loss:.4f}, "
        f"source_f1={mean_f1_source:.4f}, "
        f"target_f1={mean_f1_target:.4f}, "
        f"\nsource_f1_class={mean_f1_source_class}"
        f"\ntarget_f1_class={mean_f1_target_class}"
        f"\nsource_challenge_score={mean_source_challenge_score:.4f}"
        f"\ntarget_challenge_score={mean_target_challenge_score:.4f}"
    )

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 001: source_loss=26.3826, target_loss=-0.0687, source_f1=0.4592, target_f1=0.5028, 
source_f1_class=tensor([0.9829, 0.6566, 0.4046, 0.2521, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9598, 0.2926, 0.2561], device='cuda:0')
source_challenge_score=0.2199
target_challenge_score=0.3887
STEP 002: source_loss=21.1806, target_loss=-0.0679, source_f1=0.5621, target_f1=0.6326, 
source_f1_class=tensor([0.9651, 0.6083, 0.4216, 0.6593, 0.1563], device='cuda:0')
target_f1_class=tensor([0.9758, 0.5133, 0.4086], device='cuda:0')
source_challenge_score=0.4723
target_challenge_score=0.5644
STEP 003: source_loss=25.7667, target_loss=-0.0686, source_f1=0.4228, target_f1=0.6407, 
source_f1_class=tensor([0.9934, 0.4247, 0.2144, 0.4815, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9683, 0.3183, 0.6353], device='cuda:0')
source_challenge_score=0.2007
target_challenge_score=0.5403
STEP 004: source_loss=21.4511, target_loss=-0.0687, source_f1=0.5207, target_f1=0.5079, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 017: source_loss=16.0533, target_loss=-0.0695, source_f1=0.6888, target_f1=0.5517, 
source_f1_class=tensor([0.9783, 0.7000, 0.6504, 0.6008, 0.5143], device='cuda:0')
target_f1_class=tensor([0.9628, 0.5417, 0.1506], device='cuda:0')
source_challenge_score=0.6841
target_challenge_score=0.3941
STEP 018: source_loss=23.3407, target_loss=-0.0696, source_f1=0.5073, target_f1=0.5707, 
source_f1_class=tensor([0.9835, 0.5285, 0.3838, 0.6406, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9291, 0.3375, 0.4456], device='cuda:0')
source_challenge_score=0.2041
target_challenge_score=0.4652
STEP 019: source_loss=25.3537, target_loss=-0.0699, source_f1=0.4400, target_f1=0.6513, 
source_f1_class=tensor([0.9708, 0.4903, 0.2875, 0.3861, 0.0653], device='cuda:0')
target_f1_class=tensor([0.9569, 0.5317, 0.4652], device='cuda:0')
source_challenge_score=0.3161
target_challenge_score=0.6089
STEP 020: source_loss=24.3239, target_loss=-0.0700, source_f1=0.4517, target_f1=0.6362, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 033: source_loss=22.4007, target_loss=-0.0689, source_f1=0.5389, target_f1=0.5772, 
source_f1_class=tensor([0.9775, 0.6031, 0.4821, 0.6318, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9653, 0.2250, 0.5413], device='cuda:0')
source_challenge_score=0.2124
target_challenge_score=0.4757
STEP 034: source_loss=18.2115, target_loss=-0.0687, source_f1=0.5711, target_f1=0.5446, 
source_f1_class=tensor([0.9913, 0.6631, 0.3878, 0.2890, 0.5242], device='cuda:0')
target_f1_class=tensor([0.9512, 0.2104, 0.4722], device='cuda:0')
source_challenge_score=0.5690
target_challenge_score=0.4054
STEP 035: source_loss=22.3598, target_loss=-0.0686, source_f1=0.5205, target_f1=0.5061, 
source_f1_class=tensor([0.9817, 0.5367, 0.3456, 0.5861, 0.1523], device='cuda:0')
target_f1_class=tensor([0.9669, 0.2622, 0.2893], device='cuda:0')
source_challenge_score=0.4631
target_challenge_score=0.4045
STEP 036: source_loss=22.5102, target_loss=-0.0693, source_f1=0.5493, target_f1=0.5498, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 049: source_loss=25.9697, target_loss=-0.0693, source_f1=0.4604, target_f1=0.5761, 
source_f1_class=tensor([0.9776, 0.5970, 0.2097, 0.5178, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9638, 0.3759, 0.3886], device='cuda:0')
source_challenge_score=0.2201
target_challenge_score=0.5248
STEP 050: source_loss=26.7345, target_loss=-0.0693, source_f1=0.4061, target_f1=0.6021, 
source_f1_class=tensor([0.9931, 0.4407, 0.2131, 0.3836, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9514, 0.4759, 0.3790], device='cuda:0')
source_challenge_score=0.1667
target_challenge_score=0.5405
STEP 051: source_loss=24.1389, target_loss=-0.0687, source_f1=0.5054, target_f1=0.5120, 
source_f1_class=tensor([0.9856, 0.5994, 0.3202, 0.6220, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9553, 0.2870, 0.2938], device='cuda:0')
source_challenge_score=0.1954
target_challenge_score=0.3997
STEP 052: source_loss=24.3262, target_loss=-0.0686, source_f1=0.4918, target_f1=0.4885, 
source_f1_class=tens

  0%|          | 0/475 [00:00<?, ?it/s]

STEP 065: source_loss=24.1649, target_loss=-0.0696, source_f1=0.5277, target_f1=0.6294, 
source_f1_class=tensor([0.9833, 0.6174, 0.5448, 0.4929, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9578, 0.4654, 0.4650], device='cuda:0')
source_challenge_score=0.2364
target_challenge_score=0.5762
STEP 066: source_loss=26.1579, target_loss=-0.0690, source_f1=0.4421, target_f1=0.4586, 
source_f1_class=tensor([0.9881, 0.6352, 0.1403, 0.4468, 0.0000], device='cuda:0')
target_f1_class=tensor([0.9452, 0.2419, 0.1888], device='cuda:0')
source_challenge_score=0.1950
target_challenge_score=0.3238
STEP 067: source_loss=22.6002, target_loss=-0.0692, source_f1=0.4867, target_f1=0.5022, 
source_f1_class=tensor([0.9846, 0.3599, 0.2888, 0.6514, 0.1487], device='cuda:0')
target_f1_class=tensor([0.9565, 0.2706, 0.2794], device='cuda:0')
source_challenge_score=0.4418
target_challenge_score=0.3871
STEP 068: source_loss=24.6952, target_loss=-0.0694, source_f1=0.5016, target_f1=0.5651, 
source_f1_class=tens

KeyboardInterrupt: 

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

from torchvision.utils import draw_segmentation_masks, make_grid  # type: ignore[import-untyped]
import numpy as np
from PIL import Image

In [None]:
class SemanticSegmentationTarget:
    def __init__(self, category, mask):
        self.category = category
        self.mask = mask
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()

    def __call__(self, model_output):
        return (model_output[self.category, :, :] * self.mask).sum()

In [None]:
dm_test = dm2
dm_test.prepare_data()
dm_test.setup("test")

In [None]:
# IMG_NUM = 7340
# IMG_NUM = 6710
IMG_NUM = 6716
IMG_PATH = Path(f"data/floodnet_processed_512/FloodNet-Supervised_v1.0/train/train-org-img/{IMG_NUM}.png")

# FarSeg
# target_layers = model.model.module.decoder.blocks

# Baseline
target_layers = [
    model.model.conv6_2,
    model.model.conv7_2,
    model.model.conv8_2,
    model.model.conv9_2,
    model.model.conv10,
    model.model.res,
]

# region boring
item = dm_test._train_dataset[dm_test._train_dataset.image_paths.index(IMG_PATH)]
image_pre, mask_pre, image_post, mask_post = item

with torch.no_grad():
    m = model.to(device)
    _preds = m(torch.cat([torch.unsqueeze(image_pre, 0), torch.unsqueeze(image_post, 0)], dim=1).to(device))
    preds = torch.cat([_preds[:, :2, ], _preds[:, 2:, ].max(dim=1, keepdim=True).values], dim=1)

for c in range(0, 5):
    targets = [SemanticSegmentationTarget(c, mask_post.argmax(axis=0))]
    with GradCAM(
        model=model,
        target_layers=target_layers,
    ) as cam:
        grayscale_cam = cam(
            input_tensor=torch.unsqueeze(torch.cat([image_pre, image_post], dim=0), 0), targets=targets
        )[0, :]
        cam_image = show_cam_on_image(
            ((image_post + 1) / 2).moveaxis(0, -1).detach().cpu().numpy(), grayscale_cam, use_rgb=True
        )

    colors = [
        (128, 128, 128),
        (0, 255, 0),
        (255, 0, 0),
    ]
    preds_masks = draw_segmentation_masks(
        ((image_post + 1) * 127.5).to(torch.uint8),
        torch.nn.functional.one_hot(preds[0].argmax(dim=0), num_classes=3).moveaxis(-1, 0).to(bool),
        colors=colors,
        alpha=0.3,
    )
    ground_truth_masks = draw_segmentation_masks(
        ((image_post + 1) * 127.5).to(torch.uint8), mask_post.to(bool), colors=colors, alpha=0.3
    )
    comp = np.hstack(
        (
            ground_truth_masks.moveaxis(0, -1).detach().cpu().numpy(),
            preds_masks.moveaxis(0, -1).detach().cpu().numpy(),
            cam_image,
        )
    )
    display(Image.fromarray(comp))

# acts = []
# for c in range(0, 3):
#     targets = [SemanticSegmentationTarget(c, mask_post.argmax(axis=0))]
#     with GradCAM(
#         model=model,
#         target_layers=target_layers,
#     ) as cam:
#         grayscale_cam = cam(
#             input_tensor=torch.unsqueeze(torch.cat([image_pre, image_post], dim=0), 0), targets=targets
#         )[0, :]
#         cam_image = show_cam_on_image(
#             ((image_post + 1) / 2).moveaxis(0, -1).detach().cpu().numpy(), grayscale_cam, use_rgb=True
#         )

#     acts.append(torch.    from_numpy(cam_image).moveaxis(-1, 0))
# comp_acts = make_grid(torch.stack(acts), padding=5, pad_value=255)
# display(Image.fromarray(comp_acts.moveaxis(0, -1).numpy()))

# endregion