https://github.com/jvanvugt/pytorch-domain-adaptation/blob/master/adda.py

In [1]:
from pathlib import Path
from functools import partial
import dotenv
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Tier3, Tier1
from inz.models.baseline_module import BaselineModule
from inz.util import get_loc_cls_weights, get_wandb_logger, show_masks_comparison
from inz.xview2_strong_baseline.legacy.losses import ComboLoss
from inz.xview2_strong_baseline.legacy.zoo.models import Res34_Unet_Double
from inz.models.baseline_module import BaselineModule
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics.functional.classification import multiclass_f1_score
from torchmetrics.functional.classification import binary_accuracy

In [2]:
import sys  # noqa: I001

sys.path.append("inz/revgrad")

from utils import GradientReversal

sys.path.append("inz/xview2_strong_baseline")

from legacy.zoo.models import Res34_Unet_Double

In [3]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
BATCH_SIZE = 8

dm1 = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events={
        Tier1: [
            Event.hurricane_matthew,
        ],
    },
    val_fraction=0.15,
    test_fraction=0.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
)
dm1.prepare_data()
dm1.setup("fit")
dm2 = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events={
        Tier1: [
            Event.palu_tsunami,
        ],
    },
    val_fraction=0.15,
    test_fraction=0.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
)
dm2.prepare_data()
dm2.setup("fit")

print(f"{len(dm1.train_dataloader())} train batches, {len(dm1.val_dataloader())} val batches")

102 train batches, 18 val batches


In [5]:
MODEL_CKPT = "/home/tomek/inz/inz/outputs/split_wind_test_hurricane_matthew_baseline/2024-11-06_05-02-23/checkpoints/experiment_name-0-epoch-79-step-9600-f1-0.601407-last.ckpt"

In [6]:
source_model = BaselineModule.load_from_checkpoint(
    checkpoint_path=MODEL_CKPT,
    model=Res34_Unet_Double(pretrained=True),
    loss=ComboLoss(weights={"dice": 1, "focal": 1}),
    optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
    scheduler_factory=partial(
        torch.optim.lr_scheduler.MultiStepLR,
        gamma=0.5,
        milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190],
    ),
    class_weights=torch.Tensor([0.01, 1.0, 9.0478803, 8.68207691, 12.9632271]),
)

target_model = BaselineModule.load_from_checkpoint(
    checkpoint_path=MODEL_CKPT,
    model=Res34_Unet_Double(pretrained=True),
    loss=ComboLoss(weights={"dice": 1, "focal": 1}),
    optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
    scheduler_factory=partial(
        torch.optim.lr_scheduler.MultiStepLR,
        gamma=0.5,
        milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190],
    ),
    class_weights=torch.Tensor([0.01, 1.0, 9.0478803, 8.68207691, 12.9632271]),
)

using weights from ResNet34_Weights.IMAGENET1K_V1
using weights from ResNet34_Weights.IMAGENET1K_V1


In [7]:
class EncoderWrapper(torch.nn.Module):
    def __init__(self, enc_layers: list[torch.nn.Module]) -> None:
        super().__init__()
        self.conv1, self.conv2, self.conv3, self.conv4, self.conv5 = enc_layers

    # def forward(self, x):
    #     return (self._encode_once(x[:, :3, ...]), self._encode_once(x[:, 3:, ...]))
    #     # return torch.cat([torch.flatten(y, start_dim=1) for y in outputs])

    def forward(self, x):
        enc1 = self.conv1(x)
        enc2 = self.conv2(enc1)
        enc3 = self.conv3(enc2)
        enc4 = self.conv4(enc3)
        enc5 = self.conv5(enc4)
        return [enc1, enc2, enc3, enc4, enc5]

In [8]:
class DecoderWrapper(torch.nn.Module):
    def __init__(self, dec_layers: list[torch.nn.Module], outconv: torch.nn.Module) -> None:
        super().__init__()
        (
            self.conv6,
            self.conv6_2,
            self.conv7,
            self.conv7_2,
            self.conv8,
            self.conv8_2,
            self.conv9,
            self.conv9_2,
            self.conv10,
        ) = dec_layers
        self.res = outconv

    def forward(self, x1, x2):
        enc1_1, enc2_1, enc3_1, enc4_1, enc5_1 = x1
        enc1_2, enc2_2, enc3_2, enc4_2, enc5_2 = x2
        return self._forward(enc1_1, enc2_1, enc3_1, enc4_1, enc5_1, enc1_2, enc2_2, enc3_2, enc4_2, enc5_2)

    def _forward(self, enc1_1, enc2_1, enc3_1, enc4_1, enc5_1, enc1_2, enc2_2, enc3_2, enc4_2, enc5_2):
        output1 = self._decode_once(enc1_1, enc2_1, enc3_1, enc4_1, enc5_1)
        output2 = self._decode_once(enc1_2, enc2_2, enc3_2, enc4_2, enc5_2)
        return self.res(torch.cat([output1, output2], 1))

    def _decode_once(self, enc1, enc2, enc3, enc4, enc5):
        dec6 = self.conv6(F.interpolate(enc5, scale_factor=2))
        dec6 = self.conv6_2(torch.cat([dec6, enc4], 1))
        dec7 = self.conv7(F.interpolate(dec6, scale_factor=2))
        dec7 = self.conv7_2(torch.cat([dec7, enc3], 1))
        dec8 = self.conv8(F.interpolate(dec7, scale_factor=2))
        dec8 = self.conv8_2(torch.cat([dec8, enc2], 1))
        dec9 = self.conv9(F.interpolate(dec8, scale_factor=2))
        dec9 = self.conv9_2(torch.cat([dec9, enc1], 1))
        dec10 = self.conv10(F.interpolate(dec9, scale_factor=2))
        return dec10

In [9]:
def set_requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad = requires_grad

In [10]:
source_encoder = (
    EncoderWrapper(
        enc_layers=[
            source_model.model.conv1,
            source_model.model.conv2,
            source_model.model.conv3,
            source_model.model.conv4,
            source_model.model.conv5,
        ]
    )
    .eval()
    .to(device)
)
set_requires_grad(source_encoder, False)

target_encoder = EncoderWrapper(
    enc_layers=[
        target_model.model.conv1,
        target_model.model.conv2,
        target_model.model.conv3,
        target_model.model.conv4,
        target_model.model.conv5,
    ]
).to(device)

decoder = (
    DecoderWrapper(
        dec_layers=[
            source_model.model.conv6,
            source_model.model.conv6_2,
            source_model.model.conv7,
            source_model.model.conv7_2,
            source_model.model.conv8,
            source_model.model.conv8_2,
            source_model.model.conv9,
            source_model.model.conv9_2,
            source_model.model.conv10,
        ],
        outconv=source_model.model.res,
    )
    .eval()
    .to(device)
)
set_requires_grad(decoder, False)

In [11]:
from torch.nn.modules import Module

REVERSAL_LAMBDA = 0.2


class DecoderDiscriminator(DecoderWrapper):
    def __init__(self, dec_layers: list[Module], outconv: Module) -> None:
        super().__init__(dec_layers, outconv)
        self.res = nn.Sequential(outconv, nn.Flatten(), nn.Linear(5 * 512**2, 1))


discriminator_base = BaselineModule(
    model=Res34_Unet_Double(pretrained=True),
    loss=ComboLoss(weights={"dice": 1, "focal": 1}),
    optimizer_factory=partial(torch.optim.AdamW, lr=0.0002, weight_decay=1e-6),
    scheduler_factory=partial(
        torch.optim.lr_scheduler.MultiStepLR,
        gamma=0.5,
        milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190],
    ),
    class_weights=torch.Tensor([0.1, 1.0, 9.0478803, 8.68207691, 12.9632271]),
)

discriminator = DecoderDiscriminator(
    dec_layers=[
        discriminator_base.model.conv6,
        discriminator_base.model.conv6_2,
        discriminator_base.model.conv7,
        discriminator_base.model.conv7_2,
        discriminator_base.model.conv8,
        discriminator_base.model.conv8_2,
        discriminator_base.model.conv9,
        discriminator_base.model.conv9_2,
        discriminator_base.model.conv10,
    ],
    outconv=discriminator_base.model.res,
).to(device)

using weights from ResNet34_Weights.IMAGENET1K_V1


In [12]:
source_loader = dm1.train_dataloader()
target_loader = dm2.train_dataloader()


optim_target_encoder = torch.optim.AdamW(list(target_encoder.parameters()) + list(target_encoder.parameters()), lr=0.00001, weight_decay=1e-7)
# scheduler_target_encoder = torch.optim.lr_scheduler.MultiStepLR(
#     optim_target_encoder, gamma=0.5, milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190]
# )

optim_discriminator = torch.optim.AdamW(list(discriminator.parameters()) + list(discriminator.parameters()), lr=0.00001, weight_decay=1e-7)
# scheduler_discriminator = torch.optim.lr_scheduler.MultiStepLR(
#     optim_discriminator, gamma=0.5, milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190]
# )

  return torch._dynamo.disable(fn, recursive)(*args, **kwargs)


# TODO only use selected (deeper/shallower?) layers for feature regularization

https://arxiv.org/abs/1702.05464

https://github.com/jvanvugt/pytorch-domain-adaptation/blob/master/adda.py

In [13]:
from tqdm.auto import tqdm
from itertools import islice

EPOCHS = 40

STEPS_ENC = 14
STEPS_DISC = 2
c_enc = 0
c_disc = 0
current = "disc"

for epoch in range(1, EPOCHS + 1):
    batches = zip(source_loader, target_loader)
    n_batches = min(len(source_loader), len(target_loader)) - 1

    total_discriminator_loss = total_encoder_loss =  total_discriminator_accuracy = discriminator_steps = encoder_steps = 0

    target_encoder.train()
    discriminator.train()

    for step, ((source_batch), (target_batch)) in enumerate(
        tqdm(islice(batches, n_batches), total=n_batches, position=0, leave=True), start=1
    ):
        s_img_pre, s_mask_pre, s_img_post, s_mask_post = source_batch
        t_img_pre, t_mask_pre, t_img_post, t_mask_post = target_batch

        img_source = torch.cat([s_img_pre, s_img_post], dim=1).to(device)
        img_target = torch.cat([s_img_pre, s_img_post], dim=1).to(device)

        if current == "disc":
            # train discriminator

            set_requires_grad(target_encoder, False)
            set_requires_grad(discriminator, True)

            domain_y = torch.cat([torch.ones(img_source.shape[0]), torch.zeros(img_target.shape[0])]).to(device)

            features_source_pre = source_encoder(s_img_pre.to(device))
            features_source_post = source_encoder(s_img_post.to(device))
            features_target_pre = target_encoder(t_img_pre.to(device))
            features_target_post = target_encoder(t_img_post.to(device))

            features_source_flat = torch.cat(
                (
                    torch.cat([t.flatten(start_dim=1) for t in features_source_pre], dim=1),
                    torch.cat([t.flatten(start_dim=1) for t in features_source_post], dim=1),
                ),
                dim=1,
            )
            features_target_flat = torch.cat(
                (
                    torch.cat([t.flatten(start_dim=1) for t in features_target_pre], dim=1),
                    torch.cat([t.flatten(start_dim=1) for t in features_target_post], dim=1),
                ),
                dim=1,
            )

            features_cat = torch.cat((features_source_flat, features_target_flat))

            domain_preds_source = discriminator(features_source_pre, features_source_post)
            domain_preds_target = discriminator(features_target_pre, features_target_post)

            domain_preds = torch.cat((domain_preds_source, domain_preds_target)).squeeze()

            discriminator_loss = F.binary_cross_entropy_with_logits(domain_preds, domain_y)


            optim_discriminator.zero_grad()
            discriminator_loss.backward()
            optim_discriminator.step()

            total_discriminator_loss += discriminator_loss.item()
            discriminator_accuracy = binary_accuracy(domain_preds, domain_y)
            total_discriminator_accuracy += discriminator_accuracy

            discriminator_steps += 1

            c_disc += 1
            if c_disc == STEPS_DISC:
                c_disc = 0
                current = "enc"

            tqdm.write(f"STEP {(epoch-1) * BATCH_SIZE + step:03d}: discriminator_loss={discriminator_loss.item():.4f}, discriminator_accuracy={discriminator_accuracy.item():.4f}")

        else:
            # train encoder

            set_requires_grad(target_encoder, True)
            set_requires_grad(discriminator, False)

            features_target_pre = target_encoder(t_img_pre.to(device))
            features_target_post = target_encoder(t_img_post.to(device))
            features_target_flat = torch.cat(
                (
                    torch.cat([t.flatten(start_dim=1) for t in features_target_pre], dim=1),
                    torch.cat([t.flatten(start_dim=1) for t in features_target_post], dim=1),
                ),
                dim=1,
            )

            # flipped labels
            domain_y = torch.ones(img_source.shape[0]).to(device)

            domain_preds = discriminator(features_target_pre, features_target_post).squeeze()

            encoder_loss = F.binary_cross_entropy_with_logits(domain_preds, domain_y)

            optim_target_encoder.zero_grad()
            encoder_loss.backward()
            optim_target_encoder.step()

            total_encoder_loss += encoder_loss.item()

            encoder_steps += 1

            c_enc += 1
            if c_enc == STEPS_ENC:
                c_enc = 0
                current = "disc"

            tqdm.write(f"STEP {(epoch-1) * BATCH_SIZE + step:03d}: encoder_loss={encoder_loss.item():.4f}")



    # evaluate
    target_encoder.eval()
    discriminator.eval()

    batches = zip(source_loader, target_loader)
    n_batches = min(len(source_loader), len(target_loader)) - 1

    total_f1_target = 0
    total_f1_cls_target = torch.Tensor([0., 0., 0., 0., 0.]).to(device)

    with torch.no_grad():
        for step, ((source_batch), (target_batch)) in enumerate(
            tqdm(islice(batches, n_batches), total=n_batches, position=0, leave=True), start=1
        ):
            t_img_pre, t_mask_pre, t_img_post, t_mask_post = target_batch

            features_target_pre = target_encoder(t_img_pre.to(device))
            features_target_post = target_encoder(t_img_post.to(device))

            preds_target = decoder(features_target_pre, features_target_post)
            total_f1_target += multiclass_f1_score(preds_target.argmax(dim=1), t_mask_post.to(device).argmax(dim=1), num_classes=5)
            total_f1_cls_target += multiclass_f1_score(preds_target.argmax(dim=1), t_mask_post.to(device).argmax(dim=1), num_classes=5, average='none')


    mean_encoder_loss = total_encoder_loss / encoder_steps
    mean_discriminator_loss = total_discriminator_loss / discriminator_steps
    mean_f1_target = total_f1_target / n_batches
    mean_f1_target_class = total_f1_cls_target / n_batches
    mean_discriminator_accuracy = total_discriminator_accuracy / discriminator_steps
    tqdm.write(
        f"############## EPOCH {epoch:03d}: encoder_loss={mean_encoder_loss:.4f}, discriminator_loss={mean_discriminator_loss:.4f}\n"
        f"discriminator_accuracy={mean_discriminator_accuracy:.4f}\n"
        f"target_f1={mean_f1_target:.4f}\n"
        f"target_f1_class={mean_f1_target_class}"
    )

  0%|          | 0/48 [00:00<?, ?it/s]

STEP 001: discriminator_loss=0.6898, discriminator_accuracy=0.6250
STEP 002: discriminator_loss=1.0371, discriminator_accuracy=0.5000
STEP 003: encoder_loss=0.3783
STEP 004: encoder_loss=0.3850
STEP 005: encoder_loss=0.3443
STEP 006: encoder_loss=0.3265
STEP 007: encoder_loss=0.3544
STEP 008: encoder_loss=0.3290
STEP 009: encoder_loss=0.3255
STEP 010: encoder_loss=0.3317
STEP 011: encoder_loss=0.3118
STEP 012: encoder_loss=0.3190
STEP 013: encoder_loss=0.3064
STEP 014: encoder_loss=0.3410
STEP 015: encoder_loss=0.2816
STEP 016: encoder_loss=0.2927
STEP 017: discriminator_loss=0.8584, discriminator_accuracy=0.5000
STEP 018: discriminator_loss=0.7874, discriminator_accuracy=0.5000
STEP 019: encoder_loss=0.6140
STEP 020: encoder_loss=0.6223
STEP 021: encoder_loss=0.5916
STEP 022: encoder_loss=0.5940
STEP 023: encoder_loss=0.5413
STEP 024: encoder_loss=0.5475
STEP 025: encoder_loss=0.5614
STEP 026: encoder_loss=0.5390
STEP 027: encoder_loss=0.5248
STEP 028: encoder_loss=0.6017
STEP 029: en

  0%|          | 0/48 [00:00<?, ?it/s]

############## EPOCH 001: encoder_loss=0.6666, discriminator_loss=0.7985
discriminator_accuracy=0.5104
target_f1=0.3829
target_f1_class=tensor([9.5475e-01, 5.4677e-01, 2.4741e-05, 9.1625e-02, 3.2111e-01],
       device='cuda:0')


  0%|          | 0/48 [00:00<?, ?it/s]

STEP 009: discriminator_loss=0.7383, discriminator_accuracy=0.5000
STEP 010: discriminator_loss=0.7469, discriminator_accuracy=0.5000
STEP 011: encoder_loss=0.9859
STEP 012: encoder_loss=0.9348
STEP 013: encoder_loss=0.9578
STEP 014: encoder_loss=0.8599
STEP 015: encoder_loss=0.8443
STEP 016: encoder_loss=0.9742
STEP 017: encoder_loss=0.9150
STEP 018: encoder_loss=0.8987
STEP 019: encoder_loss=0.8642
STEP 020: encoder_loss=0.8364
STEP 021: encoder_loss=0.8477
STEP 022: encoder_loss=0.8175
STEP 023: encoder_loss=0.8033
STEP 024: encoder_loss=0.7637
STEP 025: discriminator_loss=0.7164, discriminator_accuracy=0.5000
STEP 026: discriminator_loss=0.7207, discriminator_accuracy=0.4375
STEP 027: encoder_loss=0.8518
STEP 028: encoder_loss=0.7543
STEP 029: encoder_loss=0.7619
STEP 030: encoder_loss=0.6764
STEP 031: encoder_loss=0.6982
STEP 032: encoder_loss=0.7141
STEP 033: encoder_loss=0.7161
STEP 034: encoder_loss=0.7096
STEP 035: encoder_loss=0.6665
STEP 036: encoder_loss=0.7049
STEP 037: en

  0%|          | 0/48 [00:00<?, ?it/s]

KeyboardInterrupt: 