In [1]:
import torch
from pprint import pprint
import importlib
import hydra
import sys
from pathlib import Path
from typing import Any, Callable

import dotenv
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import torchmetrics.classification
import torchmetrics.segmentation
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from simplecv.module import fpn
from torch import Tensor
import torchvision.transforms as T

from inz.models.base_pl_module import BasePLModule
from inz.data.data_module import XBDDataModule
from inz.data.event import Event, Hold, Test, Tier1, Tier3
from inz.util import get_wandb_logger, show_masks_comparison

In [2]:
from inz.models.farseg_module import DoubleBranchFarSegModule

In [3]:
dotenv.load_dotenv()
RANDOM_SEED = 123
pl.seed_everything(RANDOM_SEED)
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")

Seed set to 123


In [4]:
CKPT_PATH ="/home/tomek/inz/inz/saved_checkpoints/farseg_doublebranch-epoch-39-step-39000-f1-0.660326-best-f1.ckpt"

In [5]:
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base="1.3", config_path="../outputs/farseg_tier1_tier3/2024-10-21_09-20-16/.hydra"):
    cfg = compose(config_name="config", overrides=[])

model_class_str = cfg["module"]["module"]["_target_"]
model_class_name = model_class_str.split(".")[-1]
module_path = ".".join(model_class_str.split(".")[:-1])
imported_module = importlib.import_module(module_path)
model_class = getattr(imported_module, model_class_name)
model_partial = hydra.utils.instantiate(cfg["module"]["module"])


INFO:simplecv.util.logger:ResNetEncoder: pretrained = True


scene_relation: on
loss type: cosine


In [6]:
model = DoubleBranchFarSegModule.load_from_checkpoint(CKPT_PATH, *model_partial.args, **model_partial.keywords).to(device)
model

DoubleBranchFarSegModule(
  (model): DoubleBranchFarSeg(
    (module): FarSeg(
      (en): ResNetEncoder(
        (resnet): ResNet(
          (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          (layer1): Sequential(
            (0): Bottleneck(
              (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       

In [7]:
BATCH_SIZE = 32

dm = XBDDataModule(
    path=Path("data/xBD_processed_512"),
    drop_unclassified_channel=True,
    events={
        Hold: [
            Event.guatemala_volcano,
            Event.hurricane_florence,
            Event.hurricane_harvey,
            Event.hurricane_matthew,
            Event.hurricane_michael,
            Event.mexico_earthquake,
            Event.midwest_flooding,
            Event.palu_tsunami,
            Event.santa_rosa_wildfire,
            Event.socal_fire,
        ],
    },
    val_fraction=0.0,
    test_fraction=1.0,
    train_batch_size=BATCH_SIZE,
    val_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
    # transform=T.Compose(
    #     transforms=[
    #         T.RandomHorizontalFlip(p=0.5),
    #         T.RandomApply(
    #             p=0.6, transforms=[T.RandomAffine(degrees=(-10, 10), scale=(0.9, 1.1), translate=(0.1, 0.1))]
    #         ),
    #     ]
    # ),
)
dm.prepare_data()
dm.setup("test")

print(f"{len(dm.test_dataloader())} test batches")

117 test batches


In [8]:
trainer = pl.Trainer(
    max_epochs=1,
    callbacks=[
        RichProgressBar()
    ],
    precision="bf16",
    # TODO logger?
)
trainer.test(model, datamodule=dm)

/home/tomek/inz/inz/.venv/lib/python3.11/site-packages/lightning_fabric/connector.py:563: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'acc_loc': 0.9842587113380432,
  'iou_loc': 0.6999634504318237,
  'f1': 0.6866701245307922,
  'precision': 0.7344251275062561,
  'recall': 0.6704961061477661,
  'iou': 0.28329113125801086,
  'f1_0': 0.9916307926177979,
  'f1_1': 0.8104465007781982,
  'f1_2': 0.4231886863708496,
  'f1_3': 0.5836737751960754,
  'f1_4': 0.6244109869003296,
  'precision_0': 0.9892022013664246,
  'precision_1': 0.8475274443626404,
  'precision_2': 0.48553794622421265,
  'precision_3': 0.60467928647995,
  'precision_4': 0.7451770901679993,
  'recall_0': 0.9940763711929321,
  'recall_1': 0.7799804210662842,
  'recall_2': 0.4207552373409271,
  'recall_3': 0.5933434963226318,
  'recall_4': 0.5643253922462463,
  'iou_0': 0.9802424311637878,
  'iou_1': 0.2626893222332001,
  'iou_2': 0.041807446628808975,
  'iou_3': 0.06591569632291794,
  'iou_4': 0.06580130010843277,
  'val_loss_0': 0.0,
  'val_loss_1': 0.0,
  'val_loss_2': 0.0,
  'val_loss_3': 0.0,
  'val_loss_4': 0.0,
  'val_loss': 0.062313638627529144,
  'f1