In [None]:
import datetime
import os
from pathlib import Path
import pickle
import random
from typing import Any, Dict, List, Tuple, Union

import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import pandas as pd
import pytorch_lightning as L
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, StochasticWeightAveraging, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from sklearn.model_selection import train_test_split
import timm
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Precision, Recall
from torchmetrics.functional.classification import multiclass_confusion_matrix, multiclass_f1_score

from configs.config import CFG
from ema.ema import EMACallback
from model.multitask_unet import MultiTaskUNet
from util.get_logger import get_logger
from util.my_dataset import MyDataModule, MyDataset


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.set_float32_matmul_precision("high")

In [None]:
from timm import list_models

list_models()

In [None]:
import timm.optim


class LitUNetModel(L.LightningModule):
    def __init__(
            self,
            model_name: str,
            pretrained: bool,
            num_classes: int,
            height: int,
            width: int,
            learning_rate: float,
            mean_y: torch.Tensor,
            std_y: torch.Tensor,
        ) -> None:

        super().__init__()
        self.model = MultiTaskUNet(model_name, num_classes, pretrained, height, width)
        self.num_classes = num_classes
        self.criterion1 = nn.L1Loss()
        self.criterion2 = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.mean_y = mean_y
        self.std_y = std_y
        self.alpha = 1e-02
        self.save_hyperparameters(ignore=["criterion1", "criterion2"])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        reg_logit, clf_logit = self.model(x)
        pred_class = clf_logit.argmax(dim=1)
        mean_by_class = self.mean_y[pred_class]
        std_by_class = self.std_y[pred_class]
        reg_logit = reg_logit * std_by_class + mean_by_class
        reg_logit = torch.clip(reg_logit, min=1500, max=4500)
        return reg_logit, clf_logit

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, y, label, _ = batch
        batch_size = len(x)
        reg_logit, clf_logit = self.forward(x)
        loss1 = self.criterion1(reg_logit, y)
        loss2 = self.criterion2(clf_logit, label)
        loss3 = self._edge_loss(reg_logit, y)
        loss = loss1 + loss2 + self.alpha * loss3
        lr = self.optimizer.param_groups[0]["lr"]
        self.log("train_loss1", loss1, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("train_loss2", loss2, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("train_loss3", loss3, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("lr", lr, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        return loss
    
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
        x, y, label, _ = batch
        batch_size = len(x)
        reg_logit, clf_logit = self.forward(x)
        loss1 = F.l1_loss(reg_logit, y)
        loss2 = F.cross_entropy(clf_logit, label)
        loss3 = self._edge_loss(reg_logit, y)
        loss = loss1 + loss2 + self.alpha * loss3
        f1 = multiclass_f1_score(clf_logit, label, num_classes=self.num_classes, average="macro")
        self.log("val_loss1", loss1, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("val_loss2", loss2, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("val_loss3", loss3, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("val_f1", f1, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        return loss
    
    def on_test_epoch_start(self):
        self.clf_targets = []
        self.clf_preds = []
        self.mae_all = 0
        self.num_data = 0
    
    def test_step(
            self,
            batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            batch_idx: int
        ) -> Dict[str, float]:
        
        x, y, label, path = batch
        batch_size = len(x)
        reg_logit, clf_logit = self.forward(x)
        loss1 = F.l1_loss(reg_logit, y)
        loss2 = F.cross_entropy(clf_logit, label)
        loss3 = self._edge_loss(reg_logit, y)
        loss = loss1 + loss2 + self.alpha * loss3
        f1 = multiclass_f1_score(clf_logit, label, num_classes=self.num_classes, average="macro")
        self.mae_all += F.l1_loss(reg_logit, y, reduction="sum")
        self.num_data += len(x)

        self.clf_targets.append(label.cpu())
        self.clf_preds.append(clf_logit.argmax(dim=1).cpu())
        
        self.log("test_loss1", loss1, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("test_loss2", loss2, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("test_loss3", loss3, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        self.log("test_f1", f1, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        
        # if batch_idx == 0:
        #     print(path[0])
        #     _, axs = plt.subplots(1, 8, figsize=(32, 8))
        #     for i in range(5):
        #         im0 = axs[i].imshow(x.float()[0, i].cpu(), aspect="auto")
        #         plt.colorbar(im0, ax=axs[i])
        #     im1 = axs[5].imshow(reg_logit.float()[0, 0].cpu(), aspect="auto")
        #     im2 = axs[6].imshow(y.float()[0, 0].cpu(), aspect="auto")
        #     im3 = axs[7].imshow(y.float()[0, 0].cpu()-reg_logit.float()[0, 0].cpu(), aspect="auto")
        #     plt.colorbar(im1, ax=axs[5])
        #     plt.colorbar(im2, ax=axs[6])
        #     plt.colorbar(im3, ax=axs[7])
        #     plt.suptitle(path[0])
        #     plt.tight_layout()
        #     plt.show()

        return {"loss": loss}
    
    def on_test_epoch_end(self):
        clf_targets = torch.cat(self.clf_targets)
        clf_preds = torch.cat(self.clf_preds)
        cm = multiclass_confusion_matrix(clf_preds, clf_targets, num_classes=self.num_classes)
        print(cm)
        print(f"Test MAE: {self.mae_all / self.num_data / 70 / 70:.4f}")
        print(f"# of test data: {self.num_data}")
        del self.clf_targets
        del self.clf_preds
    
    def predict_step(self, batch: List[Union[Tuple[str], torch.Tensor]]) -> Tuple[str, torch.Tensor, torch.Tensor]:
        file_names, x = batch
        reg_logit, clf_logit = self.forward(x)
        return file_names, reg_logit, clf_logit

    def configure_optimizers(self) -> Dict[str, object]:
        self.optimizer = timm.optim.create_optimizer_v2(self, opt="adamw", lr=self.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=self.optimizer,
            max_lr=self.learning_rate,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1,
            div_factor=25,
            final_div_factor=1e+04,
        )
        scheduler_config = {
            "scheduler": self.scheduler,
            "interval": "step",
            "frequency": 1,
            "monitor": "val_loss",
            "strict": False,
        }
        return (
            {
                "optimizer": self.optimizer,
                "lr_scheduler": scheduler_config,
            },
        )
    
    def _total_variation_loss(self, img):
        loss = (
            torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]))
            + torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]))
        )
        return loss
    
    def _get_sobel_edges(self, x):
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
        
        if x.shape[1] > 1:
            x = x.mean(dim=1, keepdim=True)

        grad_x = F.conv2d(x, sobel_x, padding=1)
        grad_y = F.conv2d(x, sobel_y, padding=1)
        edge = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-6)
        return edge
    
    def _edge_loss(self, pred, target):
        pred_edge = self._get_sobel_edges(pred)
        target_edge = self._get_sobel_edges(target)
        return F.l1_loss(pred_edge, target_edge)

### Define Configurations

In [None]:
now_time = datetime.datetime.now()
output_dir = Path(f"../output/multitask_using_AB_v2_{now_time.date()}-{now_time.hour:02}-{now_time.minute:02}")

config = CFG(
    output_dir=output_dir,
    # model_name="convnextv2_base",
    model_name="focalnet_base_lrf",
    pretrained=True,
    debag=False,
    train_ratio=0.8,
    seed=42,
    height=288,
    width=288,
    batch_size=32,
    epochs=100,
    learning_rate=4e-04,
    patience=10,
    accumulation_steps=4,
    do_transform=True,
)
config.seed_everything()

logger = get_logger(output_dir.joinpath('output.log'))
config_log = [
    f'{k} = {config.__dict__[k]}'
    for k, _ in config.__dict__.items()
    if not k.startswith('__')
]
logger.info('\n'.join(config_log))
logger.info('\n')

### Load Data Paths

In [None]:
dir_path = Path("../data")
print([p.stem for p in dir_path.glob("*")])

In [None]:
families = {
    "CurveFault_A": 0,
    "CurveFault_B": 1,
    "CurveVel_A": 2,
    "CurveVel_B": 3,
    "FlatFault_A": 4,
    "FlatFault_B": 5,
    "FlatVel_A": 6,
    "FlatVel_B": 7,
    "Style_A": 8,
    "Style_B": 9,
}

paths = []
for family, label in families.items():
    for i, p in enumerate(dir_path.joinpath(family).glob("*.npz")):
        paths.append((family, label, p))
paths = pd.DataFrame(paths, columns=["family", "label", "path"])
if config.debag:
    paths = paths.sample(n=5_000, replace=False)
display(paths)

### Split Paths into training, validation, and test.

In [None]:
train_valid_paths, test_paths = train_test_split(
    paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=42,
    stratify=paths["family"]
)
train_paths, valid_paths = train_test_split(
    train_valid_paths,
    train_size=config.train_ratio,
    shuffle=True,
    random_state=config.seed,
    stratify=train_valid_paths["family"]
)
display(train_paths)
display(valid_paths)
display(test_paths)

In [None]:
mean_x = []
std_x = []
mean_y = []
std_y = []
with open("../output/statistics.pkl", "rb") as f:
    statistics = pickle.load(f)
    mean_x.append(list(statistics["All"]["mean_log_x"]))
    std_x.append(list(statistics["All"]["std_log_x"]))
    for f in families.keys():
        mean_y.append(statistics[f]["mean_y"])
        std_y.append(statistics[f]["std_y"])
mean_x = torch.tensor(mean_x).reshape(-1, 1, 1)
std_x = torch.tensor(std_x).reshape(-1, 1, 1)
mean_y = torch.tensor(mean_y).reshape(-1, 1, 1, 1)
std_y = torch.tensor(std_y).reshape(-1, 1, 1, 1)
print(mean_x)
print(std_x)
print(mean_y)
print(std_y)

In [None]:
family_pairs = {
    "All": [
        "CurveFault_A",
        "CurveVel_A",
        "FlatFault_A",
        "FlatVel_A",
        "Style_A",
        "CurveFault_B",
        "CurveVel_B",
        "FlatFault_B",
        "FlatVel_B",
        "Style_B",
    ],
}

In [None]:
display(pd.crosstab(train_paths["family"], train_paths["label"]))
display(pd.crosstab(valid_paths["family"], valid_paths["label"]))
display(pd.crosstab(test_paths["family"], test_paths["label"]))

In [None]:
%%time


model = LitUNetModel(
    model_name=config.model_name,
    pretrained=config.pretrained,
    num_classes=10,
    height=config.height,
    width=config.width,
    learning_rate=config.learning_rate,
    mean_y=mean_y.to(config.device),
    std_y=std_y.to(config.device),
)

# model = torch.compile(model)

datamodule = MyDataModule(
    train_paths=train_paths,
    valid_paths=valid_paths,
    test_paths=test_paths,
    seed=config.seed,
    batch_size=config.batch_size,
    height=config.height,
    width=config.width,
    mean_x=mean_x,
    std_x=std_x,
    do_transform=config.do_transform,
)

ema_callback = EMACallback(decay=0.99, ema_device="cuda" if torch.cuda.is_available() else "cpu")

callbacks=[
    EarlyStopping(monitor="val_loss1", patience=config.patience, mode='min'),
    # LearningRateMonitor(logging_interval="epoch"),
    TQDMProgressBar(),
    ModelCheckpoint(
        monitor="val_loss1",
        dirpath=config.output_dir,
        filename="model-{epoch:02d}-{val_loss1:.2f}",
        save_top_k=1,
        mode="min",
    ),
    ema_callback,
]

trainer = L.Trainer(
    default_root_dir=config.output_dir.joinpath("All"),
    enable_checkpointing=True,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    max_epochs=config.epochs,
    precision="bf16-mixed",
    callbacks=callbacks,
    logger=CSVLogger(config.output_dir, name="All"),
    log_every_n_steps=1_000,
    val_check_interval=None,
    check_val_every_n_epoch=1,
    accumulate_grad_batches=config.accumulation_steps,
    gradient_clip_val=0,
    # benchmark=True,
)

# training
trainer.fit(model, datamodule=datamodule)

# save last model
checkpoint_path = config.output_dir.joinpath(f"All/multitask_v2_All_{config.seed}.ckpt")
ema_checkpoint_path = config.output_dir.joinpath("All/ema_state.pth")

# save check points
trainer.save_checkpoint(checkpoint_path)
torch.save(ema_callback.ema_state, ema_checkpoint_path)

# load the ema state dict
ema_state_dict = torch.load(ema_checkpoint_path, map_location="cuda", weights_only=False)
model.load_state_dict(ema_state_dict, strict=False)

# testing
trainer.test(model, datamodule=datamodule)


del datamodule
del callbacks, ema_callback
del trainer

In [None]:
metrics0 = pd.read_csv(config.output_dir.joinpath(f"All/version_0/metrics.csv"))
metrics0 = metrics0.sort_values(["step", "epoch"]).reset_index(drop=True)
train_metric = metrics0[["epoch", "lr"]+metrics0.columns[metrics0.columns.str.contains("train")].to_list()].dropna().reset_index(drop=True)
valid_metric = metrics0[["epoch"]+metrics0.columns[metrics0.columns.str.contains("val")].to_list()].dropna().reset_index(drop=True)
metrics0 = train_metric.merge(valid_metric, on="epoch")
display(metrics0.head())

_, axs = plt.subplots(6, 1, figsize=(8, 8))
metrics0[["epoch", "lr"]].dropna().plot(x="epoch", y="lr", kind="line", marker=".", ax=axs[0])
metrics0[["epoch", "train_loss", "val_loss"]].dropna().plot(x="epoch", y=["train_loss", "val_loss"], kind="line", marker=".", ax=axs[1])
metrics0[["epoch", "val_loss1"]].dropna().plot(x="epoch", y="val_loss1", kind="line", marker=".", ax=axs[2])
metrics0[["epoch", "val_loss2"]].dropna().plot(x="epoch", y="val_loss2", kind="line", marker=".", ax=axs[3])
metrics0[["epoch", "val_loss3"]].dropna().plot(x="epoch", y="val_loss3", kind="line", marker=".", ax=axs[4])
metrics0[["epoch", "val_f1"]].dropna().plot(x="epoch", y="val_f1", kind="line", marker=".", ax=axs[5])
axs[0].set_xlabel("step")
axs[0].set_ylabel("learning rate")

axs[1].set_xlabel("epoch")
axs[1].set_ylabel("Loss")

axs[2].set_xlabel("epoch")
axs[2].set_ylabel("Loss1")

axs[3].set_xlabel("epoch")
axs[3].set_ylabel("Loss2")

axs[4].set_xlabel("epoch")
axs[4].set_ylabel("Loss3")

axs[5].set_xlabel("epoch")
axs[5].set_ylabel("F1")

plt.tight_layout()
plt.show()

### Test

In [None]:
from util.log_transform import log_transform_torch


class InferenceDataset(Dataset):
    def __init__(
            self,
            paths: List[Path],
            height: int,
            width: int,
            mean_x: torch.Tensor,
            std_x: torch.Tensor,
        ) -> None:
        
        self.paths = paths
        self.height = height
        self.width = width
        self.mean_x = mean_x
        self.std_x = std_x
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[str, torch.Tensor]:
        path = self.paths[index]
        file_names = path.stem
        x = np.load(path)["x"] # (5, 1000, 70)
        x = torch.from_numpy(x)
        # x = F.pad(x, pad=(1, 1, 76, 76), mode="constant")
        x = log_transform_torch(x)
        x = (x - self.mean_x) / self.std_x
        x = x.unsqueeze(dim=0) # (1, 5, 1000, 70)
        x = F.interpolate(x, size=(self.height, self.width), mode="bicubic")
        x = x.squeeze(dim=0) # (5, 1000, 70)
        x = x.float()
        return file_names, x

In [None]:
display(pd.crosstab(test_paths["family"], test_paths["label"]))

In [None]:
checkpoint_path = config.output_dir.joinpath(f"All/multitask_v2_All_{config.seed}.ckpt")
ema_checkpoint_path = config.output_dir.joinpath("All/ema_state.pth")

model = LitUNetModel.load_from_checkpoint(checkpoint_path)
model.eval()

ema_state_dict = torch.load(ema_checkpoint_path, map_location="cuda", weights_only=False)
model.load_state_dict(ema_state_dict, strict=False)

test_dataset = InferenceDataset(
    paths=list(test_paths["path"]),
    height=config.height,
    width=config.width,
    mean_x=mean_x,
    std_x=std_x,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=os.cpu_count()//6,
    pin_memory=False,
)

trainer = L.Trainer(
    default_root_dir=".",
    enable_checkpointing=False,
    logger=False,
    callbacks=[TQDMProgressBar()],
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    devices="auto",
)
predictions = trainer.predict(model, test_dataloader)
test_file_names, test_reg_logits, test_clf_logits = zip(*predictions)
test_reg_logits = torch.cat(test_reg_logits)
test_clf_logits = torch.cat(test_clf_logits)
print(test_reg_logits.shape, test_reg_logits.min(), test_reg_logits.max())

random_index = np.random.choice(range(len(test_dataset)), size=5, replace=False)
print(random_index)

_, axs = plt.subplots(1, 5, figsize=(12, 4))
for e, i in enumerate(random_index):
    img = axs[e].imshow(test_reg_logits[i, 0], aspect="auto")
    plt.colorbar(img, ax=axs[e])
plt.tight_layout()
plt.show()

### Compute the confusion matrix and F1 Score

In [None]:
test_paths["pred_label"] = torch.argmax(test_clf_logits, dim=1)
display(test_paths)

cm = multiclass_confusion_matrix(
    torch.Tensor(test_paths["pred_label"].values),
    torch.Tensor(test_paths["label"].values),
    num_classes=10,
)
display(cm)

f1 = multiclass_f1_score(
    torch.Tensor(test_paths["pred_label"].values),
    torch.Tensor(test_paths["label"].values),
    num_classes=10,
    average="macro",
)
display(f1)

### Compute the MAE for each class

In [None]:
test_paths.reset_index(drop=True)

In [None]:
count = 0
mae_all = 0
for g_name, g in test_paths.reset_index(drop=True).groupby("family"):
    print(g_name)
    idx = g.index.to_list()
    g_reg_images = test_reg_logits[idx]
    g_true_images = [np.load(p)["y"] for p in g["path"]]
    g_true_images = np.stack(g_true_images, axis=0)
    g_true_images = torch.Tensor(g_true_images)
    # g_true_images = (g_true_images - statistics["All"]["mean_y"]) / statistics["All"]["std_y"]
    # g_true_images = g_true_images * statistics["All"]["std_y"] + statistics["All"]["mean_y"]
    
    _, axs = plt.subplots(1, 3, figsize=(12, 4))
    img0 = axs[0].imshow(g_reg_images[0, 0], aspect="auto")
    img1 = axs[1].imshow(g_true_images[0, 0], aspect="auto")
    img2 = axs[2].imshow(g_true_images[0, 0]-g_reg_images[0, 0], aspect="auto")
    plt.colorbar(img0, ax=axs[0])
    plt.colorbar(img1, ax=axs[1])
    plt.colorbar(img2, ax=axs[2])
    plt.tight_layout()
    plt.show()

    mae = F.l1_loss(
        g_reg_images,
        g_true_images,
        reduction="sum",
    )
    mae_all += mae.item()
    print(f"MAE: {mae.item() / len(g) / 70 / 70:.4f}")
    print("="*50)

mae_all = mae_all / len(test_paths) / 70 / 70
print(f"All MAE: {mae_all:.4f}")

### Export logs as HTML file

In [None]:
os.system(f"jupyter nbconvert --to html --output-dir {config.output_dir} 3_train_multitask_using_AB_v2.ipynb")

In [None]:
test_paths.drop(columns=["pred_label"], inplace=True)

In [None]:
checkpoint_path = config.output_dir.joinpath(f"All/multitask_v2_All_{config.seed}.ckpt")
ema_checkpoint_path = config.output_dir.joinpath("All/ema_state.pth")

model = LitUNetModel.load_from_checkpoint(checkpoint_path)
model.eval()

ema_state_dict = torch.load(ema_checkpoint_path, map_location="cuda", weights_only=False)
model.load_state_dict(ema_state_dict, strict=False)

datamodule = MyDataModule(
    train_paths=train_paths,
    valid_paths=valid_paths,
    test_paths=test_paths,
    seed=config.seed,
    batch_size=2*config.batch_size,
    height=config.height,
    width=config.width,
    mean_x=mean_x,
    std_x=std_x,
    do_transform=False,
)

trainer = L.Trainer(
    default_root_dir=".",
    enable_checkpointing=False,
    logger=False,
    callbacks=[TQDMProgressBar()],
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    devices="auto",
)

trainer.test(model, datamodule)

In [None]:
image = np.load("../data/FlatFault_A/seis2_1_4_vel2_1_4_364.npz")
x = torch.from_numpy(image["x"]).float()
# x = F.pad(x, pad=(1, 1, 76, 76), mode="constant")
x = log_transform_torch(x)
x = (x - mean_x) / std_x
print(x.shape, mean_x.shape, std_x.shape)

x = x.unsqueeze(dim=0)
x = F.interpolate(x, size=(config.height, config.width), mode="bicubic")
x = x.float()

y = image["y"].astype(np.float32)

x = x.to(config.device)

model = model.to(config.device)
model.eval()
reg, clf = model(x)

# reg = reg * statistics["All"]["std_y"] + statistics["All"]["mean_y"]
reg = reg.cpu().detach().numpy()
clf = clf.cpu().detach().numpy()
print(reg.shape, clf.shape)

In [None]:
_, axs = plt.subplots(1, 3, figsize=(12, 4))
img0 = axs[0].imshow(reg[0, 0], aspect="auto")
img1 = axs[1].imshow(y[0], aspect="auto")
img2 = axs[2].imshow(y[0]-reg[0, 0], aspect="auto")
plt.colorbar(img0, ax=axs[0])
plt.colorbar(img1, ax=axs[1])
plt.colorbar(img2, ax=axs[2])
plt.tight_layout()
plt.show()