In [1]:
import datetime
import os
from pathlib import Path
import pickle
import random
from typing import Any, Dict, List, Tuple, Union

import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import pandas as pd
import pytorch_lightning as L
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from sklearn.model_selection import train_test_split
import timm
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Precision, Recall
from torchmetrics.functional.classification import multiclass_confusion_matrix, multiclass_f1_score

from configs.config import CFG
from model.multitask_unet import MultiTaskUNet
from ewc.ewc import EWC
from util.get_logger import get_logger
from util.my_dataset import MyDataModule, MyDataset


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.set_float32_matmul_precision("high")

  check_for_updates()


In [2]:
from timm import list_models

list_models()

['aimv2_1b_patch14_224',
 'aimv2_1b_patch14_336',
 'aimv2_1b_patch14_448',
 'aimv2_3b_patch14_224',
 'aimv2_3b_patch14_336',
 'aimv2_3b_patch14_448',
 'aimv2_huge_patch14_224',
 'aimv2_huge_patch14_336',
 'aimv2_huge_patch14_448',
 'aimv2_large_patch14_224',
 'aimv2_large_patch14_336',
 'aimv2_large_patch14_448',
 'bat_resnext26ts',
 'beit_base_patch16_224',
 'beit_base_patch16_384',
 'beit_large_patch16_224',
 'beit_large_patch16_384',
 'beit_large_patch16_512',
 'beitv2_base_patch16_224',
 'beitv2_large_patch16_224',
 'botnet26t_256',
 'botnet50ts_256',
 'caformer_b36',
 'caformer_m36',
 'caformer_s18',
 'caformer_s36',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_medium',
 'coat_lite_medium_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_small',
 'coat_tiny',
 'coatnet_0_224',
 'coatnet_0_rw_224',
 'coa

In [3]:
class LitUNetModel(L.LightningModule):
    def __init__(
            self,
            model_name: str,
            pretrained: bool,
            num_classes: int,
            height: int,
            width: int,
            learning_rate: float,
            mean_y: torch.Tensor,
            std_y: torch.Tensor,
            ewc: Union[nn.Module, None] = None,
            lambda_ewc: float = 1000.0,
        ) -> None:

        super().__init__()
        self.model = MultiTaskUNet(model_name, num_classes, pretrained, height, width)
        self.num_classes = num_classes
        self.criterion1 = nn.L1Loss()
        self.criterion2 = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.mean_y = mean_y
        self.std_y = std_y
        self.ewc = ewc
        self.lambda_ewc = lambda_ewc
        self.save_hyperparameters(ignore=["criterion1", "criterion2"])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        reg_logit, clf_logit = self.model(x)
        pred_class = clf_logit.argmax(dim=1)
        mean_by_class = self.mean_y[pred_class]
        std_by_class = self.std_y[pred_class]
        reg_logit = reg_logit * std_by_class + mean_by_class
        return reg_logit, clf_logit

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, y, label, _ = batch
        batch_size = len(x)
        reg_logit, clf_logit = self.forward(x)
        loss1 = self.criterion1(reg_logit, y)
        loss2 = self.criterion2(clf_logit, label)
        loss = loss1 + loss2

        if self.ewc is not None:
            ewc_loss = self.ewc.penalty()
            loss += self.lambda_ewc * ewc_loss
            self.log("train_ewc_loss", ewc_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)

        self.log("train_loss1", loss1, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("train_loss2", loss2, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        return loss
    
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, y, label, _ = batch
        batch_size = len(x)
        reg_logit, clf_logit = self.forward(x)
        loss1 = self.criterion1(reg_logit, y)
        loss2 = self.criterion2(clf_logit, label)
        loss = loss1 + loss2
        f1 = multiclass_f1_score(clf_logit, label, num_classes=self.num_classes, average="macro")
        self.log("val_loss1", loss1, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("val_loss2", loss2, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("val_f1", f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        return loss
    
    def on_test_epoch_start(self):
        self.clf_targets = []
        self.clf_preds = []
        self.mae_all = 0
        self.num_data = 0
    
    def test_step(
            self,
            batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            batch_idx: int
        ) -> Dict[str, float]:
        
        x, y, label, path = batch
        batch_size = len(x)
        reg_logit, clf_logit = self.forward(x)
        reg_logit = torch.clip(reg_logit, min=1500, max=4500)
        loss1 = F.l1_loss(reg_logit, y)
        loss2 = F.cross_entropy(clf_logit, label)
        loss = loss1 + loss2
        f1 = multiclass_f1_score(clf_logit, label, num_classes=self.num_classes, average="macro")
        self.mae_all += F.l1_loss(reg_logit, y, reduction="sum")
        self.num_data += len(x)

        self.clf_targets.append(label.cpu())
        self.clf_preds.append(clf_logit.argmax(dim=1).cpu())
        
        self.log("test_loss1", loss1, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("test_loss2", loss2, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("test_f1", f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)

        if batch_idx == 0:
            print(path[0])
            _, axs = plt.subplots(1, 8, figsize=(32, 8))
            for i in range(5):
                im0 = axs[i].imshow(x.float()[0, i].cpu(), aspect="auto")
                plt.colorbar(im0, ax=axs[i])
            im1 = axs[5].imshow(reg_logit.float()[0, 0].cpu(), aspect="auto")
            im2 = axs[6].imshow(y.float()[0, 0].cpu(), aspect="auto")
            im3 = axs[7].imshow(y.float()[0, 0].cpu()-reg_logit.float()[0, 0].cpu(), aspect="auto")
            plt.colorbar(im1, ax=axs[5])
            plt.colorbar(im2, ax=axs[6])
            plt.colorbar(im3, ax=axs[7])
            plt.suptitle(path[0])
            plt.tight_layout()
            plt.show()

        return {"loss": loss}
    
    def on_test_epoch_end(self):
        clf_targets = torch.cat(self.clf_targets)
        clf_preds = torch.cat(self.clf_preds)
        cm = multiclass_confusion_matrix(clf_preds, clf_targets, num_classes=self.num_classes)
        print(cm)
        print(f"Test MAE: {self.mae_all / self.num_data / 70 / 70:.4f}")
        print(f"# of test data: {self.num_data}")
        del self.clf_targets
        del self.clf_preds
    
    def predict_step(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        reg_logit, clf_logit = self.forward(x)
        return reg_logit, clf_logit

    def configure_optimizers(self) -> Dict[str, object]:
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=self.learning_rate,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.3,
            div_factor=25,
            final_div_factor=1e+04,
        )
        scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
            "monitor": "val_loss",
            "strict": False,
        }
        return (
            {
                "optimizer": optimizer,
                "lr_scheduler": scheduler_config,
            },
        )

### Define Configurations

In [4]:
now_time = datetime.datetime.now()
output_dir = Path(f"../output/multitask_{now_time.date()}-{now_time.hour:02}-{now_time.minute:02}")

config = CFG(
    output_dir=output_dir,
    model_name="convnextv2_tiny",
    pretrained=True,
    debag=True,
    train_ratio=0.8,
    seed=42,
    height=288,
    width=288,
    batch_size=32,
    epochs=10,
    patience=5,
    accumulation_steps=1,
)
config.seed_everything()

logger = get_logger(output_dir.joinpath('output.log'))
config_log = [
    f'{k} = {config.__dict__[k]}'
    for k, _ in config.__dict__.items()
    if not k.startswith('__')
]
logger.info('\n'.join(config_log))
logger.info('\n')

2025-05-24 18:47:47,945 util.get_logger:26 <module> [INFO]:
output_dir = ../output/multitask_2025-05-24-18-47
model_name = convnextv2_tiny
pretrained = True
debag = True
train_ratio = 0.8
seed = 42
height = 288
width = 288
batch_size = 32
epochs = 10
patience = 5
accumulation_steps = 1
device = cuda
2025-05-24 18:47:47,946 util.get_logger:27 <module> [INFO]:




### Load Data Paths

In [5]:
dir_path = Path("../data")
print([p.stem for p in dir_path.glob("*")])

['FlatFault_A', 'Style_A', 'CurveVel_B', 'FlatVel_A', 'CurveVel_A', 'CurveFault_A', 'CurveFault_B', 'FlatFault_B', 'FlatVel_B', 'Style_B']


In [6]:
families = {
    "CurveFault_A": 0,
    "CurveFault_B": 0,
    "CurveVel_A": 1,
    "CurveVel_B": 1,
    "FlatFault_A": 2,
    "FlatFault_B": 2,
    "FlatVel_A": 3,
    "FlatVel_B": 3,
    "Style_A": 4,
    "Style_B": 4,
}

paths = []
for family, label in families.items():
    for i, p in enumerate(dir_path.joinpath(family).glob("*.npz")):
        paths.append((family, label, p))
paths = pd.DataFrame(paths, columns=["family", "label", "path"])
if config.debag:
    paths = paths.sample(n=10_000, replace=False)
display(paths)

Unnamed: 0,family,label,path
207110,FlatFault_A,2,../data/FlatFault_A/seis3_1_1_vel3_1_1_141.npz
40471,CurveFault_A,0,../data/CurveFault_A/seis2_1_27_vel2_1_27_254.npz
220536,FlatFault_A,2,../data/FlatFault_A/seis2_1_3_vel2_1_3_289.npz
239834,FlatFault_B,2,../data/FlatFault_B/seis6_1_29_vel6_1_29_204.npz
406539,Style_B,4,../data/Style_B/data18_model18_287.npz
...,...,...,...
47422,CurveFault_A,0,../data/CurveFault_A/seis4_1_28_vel4_1_28_253.npz
420060,Style_B,4,../data/Style_B/data50_model50_40.npz
186880,FlatFault_A,2,../data/FlatFault_A/seis3_1_30_vel3_1_30_276.npz
388244,Style_A,4,../data/Style_A/data40_model40_85.npz


### Split Paths into training, validation, and test.

In [7]:
train_valid_paths, test_paths = train_test_split(
    paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=config.seed,
    stratify=paths["family"]
)
train_paths, valid_paths = train_test_split(
    train_valid_paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=config.seed,
    stratify=train_valid_paths["family"]
)
display(train_paths)
display(valid_paths)
display(test_paths)

Unnamed: 0,family,label,path
21657,CurveFault_A,0,../data/CurveFault_A/seis2_1_14_vel2_1_14_389.npz
342986,Style_A,4,../data/Style_A/data11_model11_435.npz
223236,FlatFault_B,2,../data/FlatFault_B/seis6_1_28_vel6_1_28_489.npz
155504,CurveVel_B,1,../data/CurveVel_B/data56_model56_139.npz
426763,Style_B,4,../data/Style_B/data98_model98_358.npz
...,...,...,...
75779,CurveFault_B,0,../data/CurveFault_B/seis8_1_28_vel8_1_28_249.npz
109163,CurveVel_A,1,../data/CurveVel_A/data55_model55_266.npz
352024,Style_A,4,../data/Style_A/data17_model17_194.npz
96038,CurveFault_B,0,../data/CurveFault_B/seis8_1_24_vel8_1_24_154.npz


Unnamed: 0,family,label,path
432502,Style_B,4,../data/Style_B/data129_model129_81.npz
411348,Style_B,4,../data/Style_B/data18_model18_123.npz
458440,Style_B,4,../data/Style_B/data18_model18_200.npz
37184,CurveFault_A,0,../data/CurveFault_A/seis3_1_7_vel3_1_7_467.npz
104937,CurveFault_B,0,../data/CurveFault_B/seis8_1_10_vel8_1_10_443.npz
...,...,...,...
404179,Style_B,4,../data/Style_B/data112_model112_199.npz
131687,CurveVel_A,1,../data/CurveVel_A/data24_model24_469.npz
158756,CurveVel_B,1,../data/CurveVel_B/data49_model49_446.npz
417408,Style_B,4,../data/Style_B/data87_model87_493.npz


Unnamed: 0,family,label,path
200790,FlatFault_A,2,../data/FlatFault_A/seis2_1_4_vel2_1_4_364.npz
190750,FlatFault_A,2,../data/FlatFault_A/seis2_1_31_vel2_1_31_249.npz
39752,CurveFault_A,0,../data/CurveFault_A/seis2_1_8_vel2_1_8_336.npz
85525,CurveFault_B,0,../data/CurveFault_B/seis7_1_12_vel7_1_12_346.npz
173283,FlatFault_A,2,../data/FlatFault_A/seis4_1_28_vel4_1_28_358.npz
...,...,...,...
62377,CurveFault_B,0,../data/CurveFault_B/seis7_1_7_vel7_1_7_112.npz
301306,FlatVel_A,3,../data/FlatVel_A/data5_model5_162.npz
200958,FlatFault_A,2,../data/FlatFault_A/seis4_1_31_vel4_1_31_129.npz
91224,CurveFault_B,0,../data/CurveFault_B/seis6_1_7_vel6_1_7_1.npz


In [8]:
mean_x = []
std_x = []
with open("../output/statistics_A.pkl", "rb") as f:
    statistics_A = pickle.load(f)
    mean_x.append(statistics_A["All"]["mean_log_x"])
    std_x.append(statistics_A["All"]["std_log_x"])
mean_x = torch.tensor(mean_x).reshape(-1, 1, 1)
std_x = torch.tensor(std_x).reshape(-1, 1, 1)
display(mean_x.shape)
display(std_x.shape)

mean_y = []
std_y = []
for f in ["CurveFault", "CurveVel", "FlatFault", "FlatVel", "Style"]:
    with open(f"../output/statistics_{f}.pkl", "rb") as f:
        statistics = pickle.load(f)
        mean_y.append(statistics["All"]["mean_y"])
        std_y.append(statistics["All"]["std_y"])
mean_y = torch.tensor(mean_y)
std_y = torch.tensor(std_y)
mean_y = mean_y.reshape(-1, 1, 1, 1)
std_y = std_y.reshape(-1, 1, 1, 1)
print(mean_y.shape)
display(std_y.shape)

  mean_x = torch.tensor(mean_x).reshape(-1, 1, 1)


torch.Size([5, 1, 1])

torch.Size([5, 1, 1])

torch.Size([5, 1, 1, 1])


torch.Size([5, 1, 1, 1])

In [9]:
family_pairs = {
    "All": (
        [
            "CurveFault_A",
            "CurveVel_A",
            "FlatFault_A",
            "FlatVel_A",
            "Style_A",
        ],
        [
            "CurveFault_B",
            "CurveVel_B",
            "FlatFault_B",
            "FlatVel_B",
            "Style_B",
        ],
    )
}

In [10]:
display(pd.crosstab(train_paths["family"], train_paths["label"]))
display(pd.crosstab(valid_paths["family"], valid_paths["label"]))
display(pd.crosstab(test_paths["family"], test_paths["label"]))

label,0,1,2,3,4
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CurveFault_A,741,0,0,0,0
CurveFault_B,730,0,0,0,0
CurveVel_A,0,415,0,0,0
CurveVel_B,0,404,0,0,0
FlatFault_A,0,0,731,0,0
FlatFault_B,0,0,769,0,0
FlatVel_A,0,0,0,398,0
FlatVel_B,0,0,0,399,0
Style_A,0,0,0,0,882
Style_B,0,0,0,0,931


label,0,1,2,3,4
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CurveFault_A,185,0,0,0,0
CurveFault_B,182,0,0,0,0
CurveVel_A,0,103,0,0,0
CurveVel_B,0,101,0,0,0
FlatFault_A,0,0,183,0,0
FlatFault_B,0,0,192,0,0
FlatVel_A,0,0,0,100,0
FlatVel_B,0,0,0,100,0
Style_A,0,0,0,0,221
Style_B,0,0,0,0,233


label,0,1,2,3,4
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CurveFault_A,231,0,0,0,0
CurveFault_B,228,0,0,0,0
CurveVel_A,0,130,0,0,0
CurveVel_B,0,126,0,0,0
FlatFault_A,0,0,229,0,0
FlatFault_B,0,0,240,0,0
FlatVel_A,0,0,0,125,0
FlatVel_B,0,0,0,124,0
Style_A,0,0,0,0,276
Style_B,0,0,0,0,291


In [None]:
%%time


for family, (family_A, family_B) in family_pairs.items():

    model = LitUNetModel(
        model_name=config.model_name,
        pretrained=config.pretrained,
        num_classes=5,
        height=config.height,
        width=config.width,
        learning_rate=1e-04,
        mean_y=mean_y.to(config.device),
        std_y=std_y.to(config.device),
    )
    
    """
    train with dataset A
    """
    train_paths_A = train_paths.query("family in @family_A")
    valid_paths_A = valid_paths.query("family in @family_A")
    test_paths_A = test_paths.query("family in @family_A")
    display(pd.crosstab(train_paths_A["family"], train_paths_A["label"]))
    display(pd.crosstab(valid_paths_A["family"], valid_paths_A["label"]))
    display(pd.crosstab(test_paths_A["family"], test_paths_A["label"]))

    datamodule = MyDataModule(
        train_paths=train_paths_A,
        valid_paths=valid_paths_A,
        test_paths=test_paths_A,
        seed=config.seed,
        batch_size=config.batch_size,
        height=config.height,
        width=config.width,
        mean_x=mean_x,
        std_x=std_x,
        mean_y=None,
        std_y=None,
    )

    callbacks=[
        EarlyStopping(monitor="val_loss", patience=config.patience, mode='min'),
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(),
        # StochasticWeightAveraging(
        #     swa_lrs=1e-5,
        #     swa_epoch_start=int(0.8*config.epochs),
        #     annealing_epochs=int(0.2*config.epochs),
        # ),
    ]

    trainer = L.Trainer(
        default_root_dir=config.output_dir,
        enable_checkpointing=False,
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_epochs=config.epochs,
        precision="bf16-mixed",
        callbacks=callbacks,
        logger=CSVLogger(config.output_dir, name="A"),
        log_every_n_steps=150,
        val_check_interval=None,
        check_val_every_n_epoch=1,
        accumulate_grad_batches=config.accumulation_steps,
        gradient_clip_val=0,
    )

    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    checkpoint_path = config.output_dir.joinpath(f"A/multitask_A_{config.seed}.ckpt")
    trainer.save_checkpoint(checkpoint_path)

    """
    calculate the ewc
    """
    ewc = EWC(model=model, datamodule=datamodule, loss_fn1=F.l1_loss(), loss_fn2=F.cross_entropy())
    del datamodule
    
    """
    train with resampled dataset A and dataset B
    """
    train_paths_B = train_paths.query("family in @family_B")
    train_paths_A = train_paths_A.sample(n=len(train_paths_B)//3, replace=False, random_state=config.seed)
    train_paths_B = pd.concat([train_paths_A, train_paths_B], ignore_index=True).reset_index(drop=True)
    valid_paths_B = valid_paths.query("family in @family_B")
    display(pd.crosstab(train_paths_B["family"], train_paths_B["label"]))
    display(pd.crosstab(valid_paths_B["family"], valid_paths_B["label"]))
    display(pd.crosstab(test_paths["family"], test_paths["label"]))
    
    datamodule = MyDataModule(
        train_paths=train_paths_B,
        valid_paths=valid_paths_B,
        test_paths=test_paths,
        seed=config.seed,
        batch_size=config.batch_size,
        height=config.height,
        width=config.width,
        mean_x=mean_x,
        std_x=std_x,
        mean_y=None,
        std_y=None,
    )

    callbacks=[
        EarlyStopping(monitor="val_loss", patience=config.patience, mode='min'),
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(),
        # StochasticWeightAveraging(
        #     swa_lrs=1e-5,
        #     swa_epoch_start=int(0.8*config.epochs),
        #     annealing_epochs=int(0.2*config.epochs),
        # ),
    ]

    trainer = L.Trainer(
        default_root_dir=config.output_dir,
        enable_checkpointing=False,
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_epochs=config.epochs,
        precision="bf16-mixed",
        callbacks=callbacks,
        logger=CSVLogger(config.output_dir, name="B"),
        log_every_n_steps=150,
        val_check_interval=None,
        check_val_every_n_epoch=1,
        accumulate_grad_batches=config.accumulation_steps,
        gradient_clip_val=0,
    )

    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    checkpoint_path = config.output_dir.joinpath(f"B/multitask_B_{config.seed}.ckpt")
    trainer.save_checkpoint(checkpoint_path)

    del datamodule
    del callbacks
    del trainer

0 torch.Size([1, 96, 72, 72])
1 torch.Size([1, 192, 36, 36])
2 torch.Size([1, 384, 18, 18])
3 torch.Size([1, 768, 9, 9])


label,0,1,2,3,4
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CurveFault_A,741,0,0,0,0
CurveVel_A,0,415,0,0,0
FlatFault_A,0,0,731,0,0
FlatVel_A,0,0,0,398,0
Style_A,0,0,0,0,882


label,0,1,2,3,4
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CurveFault_A,185,0,0,0,0
CurveVel_A,0,103,0,0,0
FlatFault_A,0,0,183,0,0
FlatVel_A,0,0,0,100,0
Style_A,0,0,0,0,221


label,0,1,2,3,4
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CurveFault_A,231,0,0,0,0
CurveVel_A,0,130,0,0,0
FlatFault_A,0,0,229,0,0
FlatVel_A,0,0,0,125,0
Style_A,0,0,0,0,276


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/ss/kaggle_work/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (99) is smaller than the logging interval Trainer(log_every_n_steps=150). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name       | Type             | Params | Mode 
--------------------------------------------------------
0 | model      | MultiTaskUNet    | 38.1 M | train
1 | criterion1 | L1Loss           | 0      | train
2 | criterion2 | CrossEntropyLoss | 0      | train
--------------------------------------------------------
38.1 M    Trainable params
0         Non-trainable params
38.1 M    Total params
152.526   Total estimated model params

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
for family, (family_A, family_B) in family_pairs.items():
    metrics0 = pd.read_csv(config.output_dir.joinpath(f"A/version_0/metrics.csv"))
    metrics0 = metrics0.sort_values(["step", "epoch"]).reset_index(drop=True)
    display(metrics0.head())
    display(metrics0[["epoch", "val_loss_epoch"]].dropna())
    
    _, axs = plt.subplots(5, 1, figsize=(8, 8))
    metrics0[["step", "lr-AdamW"]].dropna().plot(x="step", y="lr-AdamW", kind="line", marker=".", ax=axs[0])
    metrics0[["epoch", "val_loss_epoch"]].dropna().plot(x="epoch", y="val_loss_epoch", kind="line", marker=".", ax=axs[1])
    metrics0[["epoch", "val_loss1_epoch"]].dropna().plot(x="epoch", y="val_loss1_epoch", kind="line", marker=".", ax=axs[2])
    metrics0[["epoch", "val_loss2_epoch"]].dropna().plot(x="epoch", y="val_loss2_epoch", kind="line", marker=".", ax=axs[3])
    metrics0[["epoch", "val_f1_epoch"]].dropna().plot(x="epoch", y="val_f1_epoch", kind="line", marker=".", ax=axs[4])
    axs[0].set_xlabel("step")
    axs[0].set_ylabel("learning rate")
    axs[1].set_xlabel("epoch")
    axs[1].set_ylabel("Loss")
    axs[2].set_xlabel("epoch")
    axs[2].set_ylabel("Loss1")
    axs[3].set_xlabel("epoch")
    axs[3].set_ylabel("Loss2")
    axs[4].set_xlabel("epoch")
    axs[4].set_ylabel("F1")
    plt.tight_layout()
    plt.show()

    metrics1 = pd.read_csv(config.output_dir.joinpath(f"B/version_0/metrics.csv"))
    metrics1 = metrics1.sort_values(["step", "epoch"]).reset_index(drop=True)
    display(metrics1.head())
    display(metrics1[["epoch", "val_loss_epoch"]].dropna())

    _, axs = plt.subplots(5, 1, figsize=(8, 8))
    metrics1[["step", "lr-AdamW"]].dropna().plot(x="step", y="lr-AdamW", kind="line", marker=".", ax=axs[0])
    metrics1[["epoch", "val_loss_epoch"]].dropna().plot(x="epoch", y="val_loss_epoch", kind="line", marker=".", ax=axs[1])
    metrics1[["epoch", "val_loss1_epoch"]].dropna().plot(x="epoch", y="val_loss1_epoch", kind="line", marker=".", ax=axs[2])
    metrics1[["epoch", "val_loss2_epoch"]].dropna().plot(x="epoch", y="val_loss2_epoch", kind="line", marker=".", ax=axs[3])
    metrics1[["epoch", "val_f1_epoch"]].dropna().plot(x="epoch", y="val_f1_epoch", kind="line", marker=".", ax=axs[4])
    axs[0].set_xlabel("step")
    axs[0].set_ylabel("learning rate")
    axs[1].set_xlabel("epoch")
    axs[1].set_ylabel("Loss")
    axs[2].set_xlabel("epoch")
    axs[2].set_ylabel("Loss1")
    axs[3].set_xlabel("epoch")
    axs[3].set_ylabel("Loss2")
    axs[4].set_xlabel("epoch")
    axs[4].set_ylabel("F1")
    plt.tight_layout()
    plt.show()

### Test

In [None]:
from util.log_transform import log_transform_torch
    

class InferenceDataset(Dataset):
    def __init__(
            self,
            paths: List[Path],
            height: int,
            width: int,
            mean_x: torch.Tensor,
            std_x: torch.Tensor,
        ) -> None:
        
        self.paths = paths
        self.height = height
        self.width = width
        self.mean_x = mean_x
        self.std_x = std_x
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index: int) -> torch.Tensor:
        path = self.paths[index]
        images = np.load(path)
        x = images["x"] # (5, 1000, 70)
        x = torch.from_numpy(x)
        x = log_transform_torch(x)
        x = (x - self.mean_x) / self.std_x
        x = x.unsqueeze(dim=0) # (1, 5, 1000, 70)
        x = F.interpolate(x, size=(self.height, self.width), mode="nearest")
        x = x.squeeze(dim=0)
        x = x.float()
        return x

In [None]:
display(pd.crosstab(test_paths["family"], test_paths["label"]))

In [None]:
checkpoint_path = config.output_dir.joinpath(f"B/multitask_B_{config.seed}.ckpt")

model = LitUNetModel.load_from_checkpoint(checkpoint_path)
model.eval()

model.to("cuda")
print(model.device)

test_dataset = InferenceDataset(
    paths=list(test_paths["path"]),
    height=config.height,
    width=config.width,
    mean_x=mean_x,
    std_x=std_x,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=os.cpu_count()//8,
    pin_memory=False,
)

trainer = L.Trainer(
    default_root_dir=config.output_dir,
    enable_checkpointing=False,
    logger=False,
)
predictions = trainer.predict(model, test_dataloader)
test_reg_logits, test_clf_logits = zip(*predictions)
test_reg_logits = torch.cat(test_reg_logits)
test_clf_logits = torch.cat(test_clf_logits)
print(test_reg_logits.shape, test_reg_logits.min(), test_reg_logits.max())

random_index = np.random.choice(range(len(test_dataset)), size=5, replace=False)
print(random_index)

_, axs = plt.subplots(1, 5, figsize=(12, 4))
for e, i in enumerate(random_index):
    img = axs[e].imshow(test_reg_logits[i, 0], aspect="auto")
    plt.colorbar(img, ax=axs[e])
plt.tight_layout()
plt.show()

### Compute the confusion matrix and F1 Score

In [None]:
test_paths["pred_label"] = torch.argmax(test_clf_logits, dim=1)
display(test_paths)

cm = multiclass_confusion_matrix(
    torch.Tensor(test_paths["pred_label"].values),
    torch.Tensor(test_paths["label"].values),
    num_classes=5,
)
display(cm)

f1 = multiclass_f1_score(
    torch.Tensor(test_paths["pred_label"].values),
    torch.Tensor(test_paths["label"].values),
    num_classes=5,
    average="macro",
)
display(f1)

### Compute the MAE for each class

In [None]:
test_paths.reset_index(drop=True)

In [None]:
count = 0
mae_all = 0
for g_name, g in test_paths.reset_index(drop=True).groupby("family"):
    print(g_name)
    idx = g.index.to_list()
    g_reg_images = test_reg_logits[idx]
    g_true_images = [np.load(p)["y"] for p in g["path"]]
    g_true_images = np.stack(g_true_images, axis=0)
    g_true_images = torch.Tensor(g_true_images)
    # g_true_images = (g_true_images - statistics["All"]["mean_y"]) / statistics["All"]["std_y"]
    # g_true_images = g_true_images * statistics["All"]["std_y"] + statistics["All"]["mean_y"]
    
    _, axs = plt.subplots(1, 3, figsize=(12, 4))
    img0 = axs[0].imshow(g_reg_images[0, 0], aspect="auto")
    img1 = axs[1].imshow(g_true_images[0, 0], aspect="auto")
    img2 = axs[2].imshow(g_true_images[0, 0]-g_reg_images[0, 0], aspect="auto")
    plt.colorbar(img0, ax=axs[0])
    plt.colorbar(img1, ax=axs[1])
    plt.colorbar(img2, ax=axs[2])
    plt.tight_layout()
    plt.show()

    mae = F.l1_loss(
        g_reg_images,
        g_true_images,
        reduction="sum",
    )
    mae_all += mae.item()
    print(f"MAE: {mae.item() / len(g) / 70 / 70:.4f}")
    print("="*50)

mae_all = mae_all / len(test_paths) / 70 / 70
print(f"All MAE: {mae_all:.4f}")

### Export logs as HTML file

In [None]:
os.system(f"jupyter nbconvert --to html --output-dir {config.output_dir} 3_train_multitask.ipynb")

In [None]:
test_paths.drop(columns=["pred_label"], inplace=True)

In [None]:
checkpoint_path = config.output_dir.joinpath(f"B/multitask_B_{config.seed}.ckpt")

model = LitUNetModel.load_from_checkpoint(checkpoint_path)
model.eval()

model.to(config.device)
print(model.device)

datamodule = MyDataModule(
    train_paths=train_paths,
    valid_paths=valid_paths,
    test_paths=test_paths,
    seed=config.seed,
    batch_size=config.batch_size,
    height=config.height,
    width=config.width,
    mean_x=mean_x,
    std_x=std_x,
    mean_y=None,
    std_y=None,
)

trainer = L.Trainer(
    default_root_dir=config.output_dir,
    enable_checkpointing=False,
    logger=False,
)

trainer.test(model, datamodule)

In [None]:
image = np.load("../data/FlatFault_A/seis2_1_4_vel2_1_4_364.npz")
x = torch.from_numpy(image["x"]).float()
x = log_transform_torch(x)
print(x.shape, mean_x.shape, std_x.shape)

x = (x - mean_x) / std_x
x = x.unsqueeze(dim=0)
x = F.interpolate(x, (config.height, config.width), mode="nearest")
print(x.shape)
x = x.float()

y = image["y"].astype(np.float32)

x = x.to(config.device)

model = model.to(config.device)
model.eval()
reg, clf = model(x)

# reg = reg * statistics["All"]["std_y"] + statistics["All"]["mean_y"]
reg = reg.cpu().detach().numpy()
clf = clf.cpu().detach().numpy()
print(reg.shape, clf.shape)

In [None]:
_, axs = plt.subplots(1, 3, figsize=(12, 4))
img0 = axs[0].imshow(reg[0, 0], aspect="auto")
img1 = axs[1].imshow(y[0], aspect="auto")
img2 = axs[2].imshow(y[0]-reg[0, 0], aspect="auto")
plt.colorbar(img0, ax=axs[0])
plt.colorbar(img1, ax=axs[1])
plt.colorbar(img2, ax=axs[2])
plt.tight_layout()
plt.show()