In [None]:
import datetime
import os
from pathlib import Path
import pickle
import random
from typing import Any, Dict, List, Tuple, Union

import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import pandas as pd
import pytorch_lightning as L
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from sklearn.model_selection import train_test_split
import timm
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Precision, Recall
from torchmetrics.functional.classification import multiclass_confusion_matrix, multiclass_f1_score

from configs.config import CFG
from model.image2image_unet import Image2ImageUNet
from util.my_dataset import MyDataModule, MyDataset


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.set_float32_matmul_precision("high")

In [None]:
from timm import list_models

list_models()

In [None]:
class LitUNetModel(L.LightningModule):
    def __init__(
            self,
            model_name: str,
            learning_rate: float,
            mean_y: float,
            std_y: float,
        ) -> None:

        super().__init__()
        self.model = Image2ImageUNet(model_name)
        self.criterion = nn.L1Loss()
        self.learning_rate = learning_rate
        self.mean_y = mean_y
        self.std_y = std_y
        self.save_hyperparameters(ignore=['criterion'])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logit = self.model(x)
        return logit

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, y, _ = batch
        logit = self.forward(x)
        loss = self.criterion(logit, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, y, _ = batch
        logit = self.forward(x)
        loss = self.criterion(logit, y)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def test_step(
            self,
            batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            batch_idx: int
        ) -> Dict[str, float]:
        
        x, y, _ = batch
        logit = self.forward(x)
        logit = logit * self.std_y + self.mean_y
        logit = torch.clip(logit, min=1500, max=4500)
        y = y * self.std_y + self.mean_y
        loss = F.l1_loss(logit, y)
        self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        if batch_idx == 0:
            _, axs = plt.subplots(1, 8, figsize=(32, 8))
            for i in range(5):
                im0 = axs[i].imshow(x.float()[0, i].cpu(), aspect="auto")
                plt.colorbar(im0, ax=axs[i])
            im1 = axs[5].imshow(logit.float()[0, 0].cpu(), aspect="auto")
            im2 = axs[6].imshow(y.float()[0, 0].cpu(), aspect="auto")
            im3 = axs[7].imshow(y.float()[0, 0].cpu()-logit.float()[0, 0].cpu(), aspect="auto")
            plt.colorbar(im1, ax=axs[5])
            plt.colorbar(im2, ax=axs[6])
            plt.colorbar(im3, ax=axs[7])
            plt.tight_layout()
            plt.show()

        return {"loss": loss}
    
    def predict_step(self, x: torch.Tensor) -> torch.Tensor:
        logit = self.forward(x)
        logit = logit * self.std_y + self.mean_y
        return logit

    def configure_optimizers(self) -> Dict[str, object]:
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=self.learning_rate,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.3,
            div_factor=25,
            final_div_factor=1e+04,
        )
        scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
            "monitor": "val_loss",
            "strict": False,
        }
        return (
            {
                "optimizer": optimizer,
                "lr_scheduler": scheduler_config,
            },
        )

### Define Configurations

In [None]:
now_time = datetime.datetime.now()
output_dir = Path(f"../output/image2image_{now_time.date()}-{now_time.hour:02}-{now_time.minute:02}")

config = CFG(
    output_dir=output_dir,
    model_name="convnextv2_tiny",
    debag=False,
    train_ratio=0.8,
    seed=42,
    batch_size=64,
    epochs=10,
    patience=5,
)
config.seed_everything()

### Load Data Paths

In [None]:
dir_path = Path("../data")
print([p.stem for p in dir_path.glob("*")])

In [None]:
families = {
    "CurveFault_A": 0,
    "CurveFault_B": 0,
    "CurveVel_A": 1,
    "CurveVel_B": 1,
    "FlatFault_A": 2,
    "FlatFault_B": 2,
    "FlatVel_A": 3,
    "FlatVel_B": 3,
    "Style_A": 4,
    "Style_B": 4, 
}

paths = []
for family, label in families.items():
    for i, p in enumerate(dir_path.joinpath(family).glob("*.npz")):
        paths.append((family, label, p))
paths = pd.DataFrame(paths, columns=["family", "label", "path"])
if config.debag:
    paths = paths.sample(n=10_000, replace=False)
display(paths)

In [None]:
# def get_parent(x: Path):
#     parent = "_".join(x.stem.split("_")[:-1])
#     return parent


# paths["path"] = "../data/" + paths["family"] + "/" + paths["path"].apply(get_parent)
# paths = paths[["family", "label", "path"]].drop_duplicates()
# paths = paths.sort_values(["family", "path"]).reset_index(drop=True)
# if config.debag:
#     paths = paths.sample(n=100, replace=False).reset_index(drop=True)
# display(paths)

### Split Paths into training, validation, and test.

In [None]:
train_valid_paths, test_paths = train_test_split(
    paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=config.seed,
    stratify=paths["family"]
)
train_paths, valid_paths = train_test_split(
    train_valid_paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=config.seed,
    stratify=train_valid_paths["family"]
)
display(train_paths)
display(valid_paths)
display(test_paths)

In [None]:
with open("../output/statistics.pkl", "rb") as f:
    statistics = pickle.load(f)

In [None]:
family_pairs = {
    "CurveFault": ("CurveFault_A", "CurveFault_B"),
    # "CurveVel": ("CurveVel_A", "CurveVel_B"),
    # "FlatFault": ("FlatFault_A", "FlatFault_B"),
    # "FlatVel": ("FlatVel_A", "FlatVel_B"),
    # "Style": ("Style_A", "Style_B"),
}

In [None]:
%%time


for family, (family_A, family_B) in family_pairs.items():
    print(family_A, family_B)
    print(statistics[family_A]["mean_y"], statistics[family_A]["std_y"])

    model = LitUNetModel(
        model_name=config.model_name,
        learning_rate=1e-03,
        mean_y=statistics[family_A]["mean_y"],
        std_y=statistics[family_A]["std_y"],
    )
    
    """
    train with dataset A
    """
    train_paths_A = train_paths.query("family == @family_A")
    valid_paths_A = valid_paths.query("family == @family_A")
    test_paths_A = test_paths.query("family == @family_A")
    display(pd.crosstab(train_paths_A["family"], train_paths_A["label"]))
    display(pd.crosstab(valid_paths_A["family"], valid_paths_A["label"]))
    display(pd.crosstab(test_paths_A["family"], test_paths_A["label"]))

    datamodule_A = MyDataModule(
        train_paths=train_paths_A,
        valid_paths=valid_paths_A,
        test_paths=test_paths_A,
        seed=config.seed,
        batch_size=config.batch_size,
        mean_x=statistics[family_A]["mean_log_x"],
        std_x=statistics[family_A]["std_log_x"],
        mean_y=statistics[family_A]["mean_y"],
        std_y=statistics[family_A]["mean_y"],
    )
    del train_paths_A, valid_paths_A, test_paths_A

    callbacks=[
        EarlyStopping(monitor="val_loss", patience=config.patience, mode='min'),
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(),
        # StochasticWeightAveraging(
        #     swa_lrs=1e-5,
        #     swa_epoch_start=int(0.8*config.epochs),
        #     annealing_epochs=int(0.2*config.epochs),
        # ),
    ]

    trainer = L.Trainer(
        default_root_dir=config.output_dir,
        enable_checkpointing=False,
        accelerator="cuda",
        max_epochs=config.epochs,
        precision="bf16-mixed",
        callbacks=callbacks,
        logger=CSVLogger(config.output_dir, name=family),
        log_every_n_steps=150,
        val_check_interval=None,
        check_val_every_n_epoch=1,
        accumulate_grad_batches=1,
        gradient_clip_val=0,
    )

    trainer.fit(model, datamodule=datamodule_A)
    trainer.test(model, datamodule=datamodule_A)
    del datamodule_A
    del callbacks
    del trainer

    """
    train with dataset B
    """
    model.learning_rate = 1e-03

    train_paths_B = train_paths.query("family == @family_B")
    valid_paths_B = valid_paths.query("family == @family_B")
    test_paths_B = test_paths.query("family == @family_B")
    display(pd.crosstab(train_paths_B["family"], train_paths_B["label"]))
    display(pd.crosstab(valid_paths_B["family"], valid_paths_B["label"]))
    display(pd.crosstab(test_paths_B["family"], test_paths_B["label"]))

    datamodule_B = MyDataModule(
        train_paths=train_paths_B,
        valid_paths=valid_paths_B,
        test_paths=test_paths_B,
        seed=config.seed,
        batch_size=config.batch_size,
        mean_x=statistics[family_A]["mean_log_x"],
        std_x=statistics[family_A]["std_log_x"],
        mean_y=statistics[family_A]["mean_y"],
        std_y=statistics[family_A]["std_y"],
    )
    del train_paths_B, valid_paths_B, test_paths_B

    callbacks=[
        EarlyStopping(monitor="val_loss", patience=config.patience, mode='min'),
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(),
        # StochasticWeightAveraging(
        #     swa_lrs=1e-6,
        #     swa_epoch_start=int(0.8*config.epochs),
        #     annealing_epochs=int(0.2*config.epochs),
        # ),
    ]

    trainer = L.Trainer(
        default_root_dir=config.output_dir,
        enable_checkpointing=False,
        accelerator="cuda",
        max_epochs=config.epochs,
        precision="bf16-mixed",
        callbacks=callbacks,
        logger=CSVLogger(config.output_dir, name=family),
        log_every_n_steps=150,
        val_check_interval=None,
        check_val_every_n_epoch=1,
        accumulate_grad_batches=1,
        gradient_clip_val=0,
    )

    trainer.fit(model, datamodule=datamodule_B)
    trainer.test(model, datamodule=datamodule_B)

    checkpoint_path = config.output_dir.joinpath(f"{family}/image2image_{family}_{config.seed}.ckpt")
    trainer.save_checkpoint(checkpoint_path)
    del datamodule_B
    del callbacks
    del trainer

In [None]:
%%time


for family, (family_A, family_B) in family_pairs.items():
    print(family_A, family_B)
    print(statistics[family_A]["mean_y"], statistics[family_A]["std_y"])

    model = LitUNetModel(
        model_name=config.model_name,
        learning_rate=1e-03,
        mean_y=statistics[family_A]["mean_y"],
        std_y=statistics[family_A]["std_y"],
    )
    
    """
    train with dataset A
    """
    train_paths_A = train_paths.query("family == @family_A")
    valid_paths_A = valid_paths.query("family == @family_A")
    test_paths_A = test_paths.query("family == @family_A")
    display(pd.crosstab(train_paths_A["family"], train_paths_A["label"]))
    display(pd.crosstab(valid_paths_A["family"], valid_paths_A["label"]))
    display(pd.crosstab(test_paths_A["family"], test_paths_A["label"]))

    datamodule_A = MyDataModule(
        train_paths=train_paths_A,
        valid_paths=valid_paths_A,
        test_paths=test_paths_A,
        seed=config.seed,
        batch_size=config.batch_size,
        mean_x=statistics[family_A]["mean_log_x"],
        std_x=statistics[family_A]["std_log_x"],
        mean_y=statistics[family_A]["mean_y"],
        std_y=statistics[family_A]["mean_y"],
    )
    del train_paths_A, valid_paths_A, test_paths_A

    callbacks=[
        EarlyStopping(monitor="val_loss", patience=config.patience, mode='min'),
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(),
        # StochasticWeightAveraging(
        #     swa_lrs=1e-5,
        #     swa_epoch_start=int(0.8*config.epochs),
        #     annealing_epochs=int(0.2*config.epochs),
        # ),
    ]

    trainer = L.Trainer(
        default_root_dir=config.output_dir,
        enable_checkpointing=False,
        accelerator="cuda",
        max_epochs=config.epochs,
        precision="bf16-mixed",
        callbacks=callbacks,
        logger=CSVLogger(config.output_dir, name=family),
        log_every_n_steps=150,
        val_check_interval=None,
        check_val_every_n_epoch=1,
        accumulate_grad_batches=1,
        gradient_clip_val=0,
    )

    trainer.fit(model, datamodule=datamodule_A)
    trainer.test(model, datamodule=datamodule_A)
    del datamodule_A
    del callbacks
    del trainer

    """
    train with dataset B
    """
    model.learning_rate = 1e-04

    train_paths_B = train_paths.query("family == @family_B")
    valid_paths_B = valid_paths.query("family == @family_B")
    test_paths_B = test_paths.query("family == @family_B")
    display(pd.crosstab(train_paths_B["family"], train_paths_B["label"]))
    display(pd.crosstab(valid_paths_B["family"], valid_paths_B["label"]))
    display(pd.crosstab(test_paths_B["family"], test_paths_B["label"]))

    datamodule_B = MyDataModule(
        train_paths=train_paths_B,
        valid_paths=valid_paths_B,
        test_paths=test_paths_B,
        seed=config.seed,
        batch_size=config.batch_size,
        mean_x=statistics[family_A]["mean_log_x"],
        std_x=statistics[family_A]["std_log_x"],
        mean_y=statistics[family_A]["mean_y"],
        std_y=statistics[family_A]["std_y"],
    )
    del train_paths_B, valid_paths_B, test_paths_B

    callbacks=[
        EarlyStopping(monitor="val_loss", patience=config.patience, mode='min'),
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(),
        # StochasticWeightAveraging(
        #     swa_lrs=1e-6,
        #     swa_epoch_start=int(0.8*config.epochs),
        #     annealing_epochs=int(0.2*config.epochs),
        # ),
    ]

    trainer = L.Trainer(
        default_root_dir=config.output_dir,
        enable_checkpointing=False,
        accelerator="cuda",
        max_epochs=config.epochs,
        precision="bf16-mixed",
        callbacks=callbacks,
        logger=CSVLogger(config.output_dir, name=family),
        log_every_n_steps=150,
        val_check_interval=None,
        check_val_every_n_epoch=1,
        accumulate_grad_batches=1,
        gradient_clip_val=0,
    )

    trainer.fit(model, datamodule=datamodule_B)
    trainer.test(model, datamodule=datamodule_B)

    checkpoint_path = config.output_dir.joinpath(f"{family}/image2image_{family}_{config.seed}.ckpt")
    trainer.save_checkpoint(checkpoint_path)
    del datamodule_B
    del callbacks
    del trainer

In [None]:
for family, (family_A, family_B) in family_pairs.items():
    metrics0 = pd.read_csv(config.output_dir.joinpath(f"{family}/version_0/metrics.csv"))
    metrics0 = metrics0.sort_values(["step", "epoch"]).reset_index(drop=True)
    display(metrics0.head())
    display(metrics0[["epoch", "val_loss_epoch"]].dropna())

    _, axs = plt.subplots(2, 1)
    metrics0[["step", "lr-AdamW"]].dropna().plot(x="step", y="lr-AdamW", kind="line", marker=".", ax=axs[0])
    metrics0[["epoch", "val_loss_epoch"]].dropna().plot(x="epoch", y="val_loss_epoch", kind="line", marker=".", ax=axs[1])
    axs[0].set_xlabel("step")
    axs[0].set_ylabel("learning rate")
    axs[1].set_xlabel("epoch")
    axs[1].set_ylabel("Loss")
    plt.tight_layout()
    plt.show()

    metrics1 = pd.read_csv(config.output_dir.joinpath(f"{family}/version_1/metrics.csv"))
    metrics1 = metrics1.sort_values(["step", "epoch"]).reset_index(drop=True)
    display(metrics1.head())
    display(metrics1[["epoch", "val_loss_epoch"]].dropna())

    _, axs = plt.subplots(2, 1)
    metrics1[["step", "lr-AdamW"]].dropna().plot(x="step", y="lr-AdamW", kind="line", marker=".", ax=axs[0])
    metrics1[["epoch", "val_loss_epoch"]].dropna().plot(x="epoch", y="val_loss_epoch", kind="line", marker=".", ax=axs[1])
    axs[0].set_xlabel("step")
    axs[0].set_ylabel("learning rate")
    axs[1].set_xlabel("epoch")
    axs[1].set_ylabel("Loss")
    plt.tight_layout()
    plt.show()

### Test

In [None]:
from util.log_transform import log_transform_torch


class TestDataset(Dataset):
    def __init__(
            self,
            paths: List[Path],
            mean_x: Union[np.ndarray, Tuple[float]],
            std_x: Union[np.ndarray, Tuple[float]],
        ) -> None:
        
        self.paths = paths
        self.mean_x = torch.tensor(mean_x).unsqueeze(dim=1).unsqueeze(dim=2)
        self.std_x = torch.tensor(std_x).unsqueeze(dim=1).unsqueeze(dim=2)
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        path = self.paths[index]
        images = np.load(path)
        x = images["x"] # (5, 1000, 70)
        x = torch.from_numpy(x)
        x = log_transform_torch(x)
        x = (x - self.mean_x) / self.std_x
        x = x.unsqueeze(dim=0) # (1, 5, 1000, 70)
        x = F.interpolate(x, size=(288, 288), mode="nearest")
        x = x.squeeze(dim=0)
        x = x.float()
        return x

In [None]:
display(pd.crosstab(test_paths["family"], test_paths["label"]))

In [None]:
for family, (family_A, family_B) in family_pairs.items():
    print(family)
    checkpoint_path = config.output_dir.joinpath(f"{family}/image2image_{family}_{config.seed}.ckpt")

    model = LitUNetModel.load_from_checkpoint(checkpoint_path)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(model.device)

    test_dataset = TestDataset(
        paths=list(test_paths.query("family == @family_A or family == @family_B")["path"]),
        mean_x=statistics[family_A]["mean_log_x"],
        std_x=statistics[family_A]["std_log_x"],
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=os.cpu_count()//8,
        pin_memory=False,
    )

    trainer = L.Trainer(
        default_root_dir=config.output_dir,
        enable_checkpointing=False,
    )
    test_logits = trainer.predict(model, test_dataloader)
    test_logits = torch.cat(test_logits)
    print(test_logits.shape, test_logits.min(), test_logits.max())

    random_index = np.random.choice(range(len(test_dataset)), size=5, replace=False)
    print(random_index)

    _, axs = plt.subplots(1, 5, figsize=(12, 4))
    for e, i in enumerate(random_index):
        img = axs[e].imshow(test_logits[i, 0], aspect="auto")
        plt.colorbar(img, ax=axs[e])
    plt.tight_layout()
    plt.show()