In [None]:
import datetime
import os
from pathlib import Path
import pickle
import random
from typing import Any, List, Tuple

import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import pandas as pd
import pytorch_lightning as L
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from sklearn.model_selection import train_test_split
import timm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Precision, Recall
from torchmetrics.functional.classification import multiclass_confusion_matrix, multiclass_f1_score

from configs.config import CFG
from util.my_dataset import MyDataModule, MyDataset


pd.options.display.max_rows = 100
torch.set_float32_matmul_precision("high")

In [None]:
class LitClassifierModel(L.LightningModule):
    def __init__(
            self,
            num_classes: int,
            learning_rate: float,
        ) -> None:

        super().__init__()
        self.num_classes = num_classes
        self.model = timm.create_model("efficientnet_b2", pretrained=True, in_chans=5, num_classes=0)
        self.model.classifier = nn.Linear(self.model.num_features, num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.save_hyperparameters(ignore=['criterion'])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logit = self.model(x)
        return logit

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, _, label = batch
        logit = self.forward(x)
        loss = self.criterion(logit, label)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, _, label = batch
        logit = self.forward(x)
        loss = self.criterion(logit, label)
        f1 = multiclass_f1_score(logit, label, num_classes=self.num_classes, average="macro")
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_f1", f1, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def on_test_epoch_start(self):
        self.targets = []
        self.preds = []

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, _, label = batch
        logit = self.forward(x)
        self.targets.append(label.cpu())
        self.preds.append(logit.argmax(dim=1).cpu())
        loss = self.criterion(logit, label)
        f1 = multiclass_f1_score(logit, label, num_classes=self.num_classes, average="macro")
        self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_f1", f1, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": loss, "f1": f1}

    def on_test_epoch_end(self):
        targets = torch.cat(self.targets)
        preds = torch.cat(self.preds)
        cm = multiclass_confusion_matrix(preds, targets, num_classes=self.num_classes)
        print(cm)
        del self.targets
        del self.preds
    
    def predict_step(self, x: torch.Tensor):
        logit = self.forward(x)
        return logit

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=self.learning_rate,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.3,
            div_factor=25,
            final_div_factor=1e+04,
        )
        # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.estimated_stepping_batches, eta_min=1e-5)
        scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
            "monitor": "val_loss",
            "strict": False,
        }
        return (
            {
                "optimizer": optimizer,
                "lr_scheduler": scheduler_config,
            },
        )

### Define Configurations

In [None]:
now_time = datetime.datetime.now()
output_dir = Path(f"../output/classifier_{now_time.date()}-{now_time.hour:02}-{now_time.minute:02}")

config = CFG(
    output_dir=output_dir,
    debag=False 3,
    train_ratio=0.8,
    seed=42,
    batch_size=256,
    epochs=10,
    patience=5,
)
config.seed_everything()

### Load Data Paths

In [None]:
dir_path = Path("../data")
print([p.stem for p in dir_path.glob("*")])

In [None]:
families = {
    "CurveFault_A": 0,
    "CurveFault_B": 0,
    "CurveVel_A": 1,
    "CurveVel_B": 1,
    "FlatFault_A": 2,
    "FlatFault_B": 2,
    "FlatVel_A": 3,
    "FlatVel_B": 3,
    "Style_A": 4,
    "Style_B": 4, 
}

paths = []
for family, label in families.items():
    for i, p in enumerate(dir_path.joinpath(family).glob("*.npz")):
        paths.append((family, label, p))
        if config.debag and i == 1000:
            break
paths = pd.DataFrame(paths, columns=["family", "label", "path"])

### Split Paths into training, validation, and test.

In [None]:
train_valid_paths, test_paths = train_test_split(
    paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=config.seed,
    stratify=paths["family"]
)
train_paths, valid_paths = train_test_split(
    train_valid_paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=config.seed,
    stratify=train_valid_paths["family"]
)
display(train_paths)
display(valid_paths)
display(test_paths)

### Define The Model

In [None]:
model = LitClassifierModel(
    num_classes=5,
    learning_rate=1e-03,
)

### Start Training with Dataset A

In [None]:
with open("../output/statistics_A.pkl", "rb") as f:
    statistics_A = pickle.load(f)
print(statistics_A["All"]["mean_x"])
print(statistics_A["All"]["std_x"])

In [None]:
families_A = [
    "CurveFault_A",
    "CurveVel_A",
    "FlatFault_A",
    "FlatVel_A",
    "Style_A",
]

train_paths_A = train_paths.query("family in @families_A")
valid_paths_A = valid_paths.query("family in @families_A")
test_paths_A = test_paths.query("family in @families_A")
display(train_paths_A)
display(valid_paths_A)
display(test_paths_A)
display(pd.crosstab(train_paths_A["family"], train_paths_A["label"]))
display(pd.crosstab(valid_paths_A["family"], valid_paths_A["label"]))
display(pd.crosstab(test_paths_A["family"], test_paths_A["label"]))

In [None]:
%%time


datamodule_A = MyDataModule(
    train_paths=train_paths_A,
    valid_paths=valid_paths_A,
    test_paths=test_paths_A,
    seed=config.seed,
    batch_size=config.batch_size,
    mean_x=statistics_A["All"]["mean_x"],
    std_x=statistics_A["All"]["std_x"],
    mean_y=None,
    std_y=None,
)

callbacks=[
    EarlyStopping(monitor="val_f1", patience=config.patience, mode='max'),
    LearningRateMonitor(logging_interval="step"),
    TQDMProgressBar(),
    StochasticWeightAveraging(
        swa_lrs=1e-5,
        swa_epoch_start=int(0.8*config.epochs),
        annealing_epochs=int(0.2*config.epochs),
    ),
]

trainer = L.Trainer(
    default_root_dir=config.output_dir,
    enable_checkpointing=False,
    accelerator="cuda",
    max_epochs=config.epochs,
    precision="bf16-mixed",
    callbacks=callbacks,
    logger=CSVLogger(config.output_dir, name="classifier_A"),
    log_every_n_steps=1,
    val_check_interval=None,
    check_val_every_n_epoch=1,
)

trainer.fit(model, datamodule=datamodule_A)
trainer.test(model, datamodule=datamodule_A)

### Start Training with dataset B

In [None]:
families_B = [
    "CurveFault_B",
    "CurveVel_B",
    "FlatFault_B",
    "FlatVel_B",
    "Style_B",
]

train_paths_B = train_paths.query("family in @families_B")
display(train_paths_B)
display(pd.crosstab(train_paths_B["family"], train_paths_B["label"]))
display(pd.crosstab(valid_paths["family"], valid_paths["label"]))
display(pd.crosstab(test_paths["family"], test_paths["label"]))

In [None]:
%%time


model.learning_rate = 1e-04

datamodule_B = MyDataModule(
    train_paths=train_paths_B,
    valid_paths=valid_paths,
    test_paths=test_paths,
    seed=config.seed,
    batch_size=config.batch_size,
    mean_x=statistics_A["All"]["mean_x"],
    std_x=statistics_A["All"]["std_x"],
    mean_y=None,
    std_y=None,
)

callbacks=[
    EarlyStopping(monitor="val_f1", patience=config.patience, mode='max'),
    LearningRateMonitor(logging_interval="step"),
    TQDMProgressBar(),
    StochasticWeightAveraging(
        swa_lrs=1e-6,
        swa_epoch_start=int(0.8*config.epochs),
        annealing_epochs=int(0.2*config.epochs),
    ),
]

trainer = L.Trainer(
    default_root_dir=config.output_dir,
    enable_checkpointing=False,
    accelerator="cuda",
    max_epochs=config.epochs,
    precision="bf16-mixed",
    callbacks=callbacks,
    logger=CSVLogger(config.output_dir, name="classifier_B"),
    log_every_n_steps=1,
    val_check_interval=None,
    check_val_every_n_epoch=1,
)

trainer.fit(model, datamodule=datamodule_B)
trainer.test(model, datamodule=datamodule_B)

In [None]:
metrics0 = pd.read_csv(config.output_dir.joinpath("classifier_A/version_0/metrics.csv"))
metrics0 = metrics0.sort_values(["step", "epoch"]).reset_index(drop=True)
display(metrics0.head())
display(metrics0[["epoch", "val_f1_epoch"]].dropna())

_, axs = plt.subplots(3, 1)
metrics0[["step", "lr-AdamW"]].dropna().plot(x="step", y="lr-AdamW", kind="line", marker=".", ax=axs[0])
metrics0[["epoch", "val_f1_epoch"]].dropna().plot(x="epoch", y="val_f1_epoch", kind="line", marker=".", ax=axs[1])
metrics0[["epoch", "val_loss_epoch"]].dropna().plot(x="epoch", y="val_loss_epoch", kind="line", marker=".", ax=axs[2])
axs[0].set_xlabel("step")
axs[0].set_ylabel("learning rate")
axs[1].set_xlabel("epoch")
axs[1].set_ylabel("F1")
axs[2].set_xlabel("epoch")
axs[2].set_ylabel("Loss")
plt.tight_layout()
plt.show()

In [None]:
metrics1 = pd.read_csv(config.output_dir.joinpath("classifier_B/version_0/metrics.csv"))
metrics1 = metrics1.sort_values(["step", "epoch"]).reset_index(drop=True)
display(metrics1.head())
display(metrics1[["epoch", "val_f1_epoch"]].dropna())

_, axs = plt.subplots(3, 1)
metrics1[["step", "lr-AdamW"]].dropna().plot(x="step", y="lr-AdamW", kind="line", marker=".", ax=axs[0])
metrics1[["epoch", "val_f1_epoch"]].dropna().plot(x="epoch", y="val_f1_epoch", kind="line", marker=".", ax=axs[1])
metrics1[["epoch", "val_loss_epoch"]].dropna().plot(x="epoch", y="val_loss_epoch", kind="line", marker=".", ax=axs[2])
axs[0].set_xlabel("step")
axs[0].set_ylabel("learning rate")
axs[1].set_xlabel("epoch")
axs[1].set_ylabel("F1")
axs[2].set_xlabel("epoch")
axs[2].set_ylabel("Loss")
plt.tight_layout()
plt.show()

### Save the model

In [None]:
checkpoint_path = config.output_dir.joinpath(f"classifier_{config.seed}.ckpt")
trainer.save_checkpoint(checkpoint_path)

### Test

In [None]:
class TestDataset(Dataset):
    def __init__(
            self,
            paths: List[Path],
            mean_x: Tuple[float],
            std_x: Tuple[float],
        ) -> None:
        self.paths = paths
        self.transform_x = A.Compose([
            A.Normalize(mean=mean_x, std=std_x),
            A.Resize(512, 72),
        ])
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        path = self.paths[index]
        images = np.load(path)
        x = images["x"].transpose(1, 2, 0)
        x = self.transform_x(image=x)["image"]
        x = torch.from_numpy(x).permute(2, 0, 1)
        return x

In [None]:
model = LitClassifierModel.load_from_checkpoint(checkpoint_path)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(model.device)

test_dataset = TestDataset(
    paths=list(test_paths["path"]),
    mean_x=statistics_A["All"]["mean_x"],
    std_x=statistics_A["All"]["std_x"],
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=os.cpu_count()//2,
    pin_memory=True,
)

In [None]:
trainer = L.Trainer(
    default_root_dir=config.output_dir,
    enable_checkpointing=False,
)
test_logits = trainer.predict(model, test_dataloader)
test_logits = torch.cat(test_logits)
print(test_logits.shape)