In [1]:
import importlib
from pathlib import Path

import dotenv
import hydra
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Hold
from inz.util import show_masks_comparison

In [2]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [3]:
CKPT_PATH = "/home/tomek/inz/inz/saved_checkpoints/runs/farseg-double/checkpoints/experiment_name-0-epoch-39-step-39000-f1-0.660326-best-f1.ckpt"
CONFIG_PATH = "../saved_checkpoints/runs/farseg-double/.hydra"
BATCH_SIZE = 1

# region misc

from hydra import compose, initialize

with initialize(version_base="1.3", config_path=CONFIG_PATH):
    cfg = compose(config_name="config", overrides=[])

model_class_str = cfg["module"]["module"]["_target_"]
model_class_name = model_class_str.split(".")[-1]
module_path = ".".join(model_class_str.split(".")[:-1])
imported_module = importlib.import_module(module_path)
model_class = getattr(imported_module, model_class_name)
model_partial = hydra.utils.instantiate(cfg["module"]["module"])

model = model_class.load_from_checkpoint(CKPT_PATH, *model_partial.args, **model_partial.keywords).to(device)


dm = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events={
        Hold: [
            Event.guatemala_volcano,
            Event.hurricane_florence,
            Event.hurricane_harvey,
            Event.hurricane_matthew,
            Event.hurricane_michael,
            Event.mexico_earthquake,
            Event.midwest_flooding,
            Event.palu_tsunami,
            Event.santa_rosa_wildfire,
            Event.socal_fire,
        ],
    },
    val_fraction=0.0,
    test_fraction=1.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
)
dm.prepare_data()
dm.setup("test")

print(f"{len(dm.test_dataloader())} test batches")

# endregion

PREDS_DIR = Path(f"preds/{cfg['experiment_name']}")
Path.mkdir(Path("preds/"), exist_ok=True)
Path.mkdir(PREDS_DIR, exist_ok=True)

INFO:simplecv.util.logger:ResNetEncoder: pretrained = True


scene_relation: on
loss type: cosine
3732 test batches


In [4]:

for i, fname in enumerate(tqdm(dm._test_dataset.dataset._image_paths_post)):
    batch = dm._test_dataset[i]
    images_pre, masks_pre, images_post, masks_post = batch
    images_pre = torch.unsqueeze(images_pre, 0)
    masks_pre = torch.unsqueeze(masks_pre, 0)
    images_post = torch.unsqueeze(images_post, 0)
    masks_post = torch.unsqueeze(masks_post, 0)
    with torch.no_grad():
        m = model.to(device)
        preds = m(torch.cat([images_pre, images_post], dim=1).to(device))

    plt.subplots_adjust(top=0.99)
    stem = Path(fname).stem
    show_masks_comparison(
        images_pre=images_pre,
        images_post=images_post,
        masks_pre=masks_pre,
        masks_post=masks_post,
        preds=preds,
        opacity=0.3,
        compact=True,
        title=stem,
    )
    plt.savefig(PREDS_DIR / f"{stem}.jpg", dpi=150, bbox_inches='tight')
    plt.close()

 53%|█████▎    | 1962/3732 [11:35<10:27,  2.82it/s]


KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>