In [11]:
import torch
import copy
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline

# Add the parent directory to the Python path - bad practice, but it's just for the example
import sys
sys.path.append("/Users/heydari/Desktop/test/FHHI-XAI-PIDNET/")

device = "cuda:0" if torch.cuda.is_available() else "cpu"

from src.glocal_analysis import run_analysis 
from src.datasets.flood_dataset import FloodDataset
from src.datasets.DLR_dataset import DatasetDLR
from src.plot_crp_explanations import plot_explanations, plot_one_image_explanation
from src.minio_client import MinIOClient
from LCRP.models import get_model 
from LCRP.utils.crp_configs import ATTRIBUTORS, CANONIZERS, VISUALIZATIONS, COMPOSITES

import logging
# Suppress specific noisy libraries if needed
logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("numba").setLevel(logging.WARNING)



In [12]:
# --- imports
import torch
from torchvision import transforms

# --- data / transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])

root_dir = "../data/flood_segmentation/"
dataset = FloodDataset(root_dir=root_dir, split="train", transform=transform)
print('Loaded dataset:', type(dataset))
print('len(dataset) =', len(dataset))

# --- device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# --- model / checkpoint
model_name = "pidnet"
ckpt_path = "../models/flood_model.pt"

# ========= HACK: prevent get_model()/get_pidnet() from trying to strictly load its OWN checkpoint =========
# Some internal path calls `model.load_state_dict(torch.load(cfg["ckpt_path"]))` with strict=True,
# which raises due to "model." prefix or different head. We temporarily force:
#   - torch.load -> always map to CPU (safe on CPU-only)
#   - nn.Module.load_state_dict -> always use strict=False
_orig_torch_load = torch.load
_orig_load_state_dict = torch.nn.Module.load_state_dict

def _cpu_load(*args, **kwargs):
    kwargs.setdefault("map_location", "cpu")
    return _orig_torch_load(*args, **kwargs)

def _lenient_load_state_dict(self, state_dict, strict=True):
    # Force non-strict to avoid internal RuntimeError during model construction
    return _orig_load_state_dict(self, state_dict, strict=False)

torch.load = _cpu_load
torch.nn.Module.load_state_dict = _lenient_load_state_dict
try:
    # Build the model; any internal checkpoint load will be lenient and won't crash
    model = get_model(model_name=model_name, device=device, classes=2)
finally:
    # Restore patched functions
    torch.load = _orig_torch_load
    torch.nn.Module.load_state_dict = _orig_load_state_dict

# ========= Now load YOUR checkpoint properly (handle "model." / "module." prefixes) =========
def _extract_state_dict(obj):
    if isinstance(obj, dict):
        for k in ("state_dict", "model_state", "model", "net", "module"):
            if k in obj and isinstance(obj[k], dict):
                return obj[k]
    return obj

def _strip_prefix(sd, prefix):
    if any(k.startswith(prefix) for k in sd.keys()):
        return {k[len(prefix):]: v for k, v in sd.items()}
    return sd

raw_sd = torch.load(ckpt_path, map_location="cpu")
sd = _extract_state_dict(raw_sd)
sd = _strip_prefix(sd, "model.")
sd = _strip_prefix(sd, "module.")

missing, unexpected = model.load_state_dict(sd, strict=False)
if missing:
    print("[load_state_dict] Missing keys:", len(missing))
    # print(missing)  # uncomment for full list
if unexpected:
    print("[load_state_dict] Unexpected keys:", len(unexpected))
    # print(unexpected)  # uncomment for full list

# finalize
model.to(device)
model.eval()
model.augment = False

output_dir = "../src/output/crp/pidnet_flood_new"
print(f"Model ready on {device}. Output dir: {output_dir}")


Loaded dataset: <class 'src.datasets.flood_dataset.FloodDataset'>
len(dataset) = 1321
Loaded checkpoint /Users/heydari/Desktop/test/FHHI-XAI-PIDNET/models/flood_model.pt
Model ready on cpu. Output dir: ../src/output/crp/pidnet_flood_new


In [13]:
import torch

ckpt_path = "../models/flood_model.pt"
ckpt = torch.load(ckpt_path, map_location="cpu")

# Try to find state-dict inside checkpoint
if isinstance(ckpt, dict) and any(k in ckpt for k in ("model", "state_dict", "state_dicts", "net")):
    # prefer common names
    for candidate in ("model", "state_dict", "state_dicts", "net"):
        if candidate in ckpt:
            state = ckpt[candidate]
            break
else:
    state = ckpt  # maybe it's a raw state_dict

print("Loaded checkpoint type:", type(ckpt))
print("State is type:", type(state))
print("Number of keys in state:", len(state.keys()))
print("Sample keys (first 20):")
for i, k in enumerate(list(state.keys())[:20]):
    print(i, k)

Loaded checkpoint type: <class 'collections.OrderedDict'>
State is type: <class 'collections.OrderedDict'>
Number of keys in state: 479
Sample keys (first 20):
0 model.conv1.0.weight
1 model.conv1.0.bias
2 model.conv1.1.weight
3 model.conv1.1.bias
4 model.conv1.1.running_mean
5 model.conv1.1.running_var
6 model.conv1.1.num_batches_tracked
7 model.conv1.3.weight
8 model.conv1.3.bias
9 model.conv1.4.weight
10 model.conv1.4.bias
11 model.conv1.4.running_mean
12 model.conv1.4.running_var
13 model.conv1.4.num_batches_tracked
14 model.layer1.0.conv1.weight
15 model.layer1.0.bn1.weight
16 model.layer1.0.bn1.bias
17 model.layer1.0.bn1.running_mean
18 model.layer1.0.bn1.running_var
19 model.layer1.0.bn1.num_batches_tracked


In [15]:
run_analysis(model_name, model, dataset, output_dir=output_dir, device=device)

Running Analysis...


  0%|          | 0/166 [00:00<?, ?it/s][A

KeyboardInterrupt: 

In [None]:
run_analysis(model_name, model, dataset, output_dir, device)


In [None]:

len(dataset)