In [None]:
import yaml
base_dir = Path(__file__).resolve().parent.parent
config_path = base_dir / "notebooks" / "00_config.yaml"
print(base_dir)

with open(config_path, "r") as file:
    config = yaml.safe_load(file)


In [None]:
# Logging und Speicherorte aus der YAML-Datei laden
tensorboard_logs = config["foldernames"]["tensorboard_logs"]
checkpoints = config["foldernames"]["checkpoints"]
logfiles = config["foldernames"]["logfiles"]
class_counts_file = config["foldernames"]["class_counts_file"]
sample_weights_file = config["foldernames"]["sample_weights_file"]

# Definiere Pfade für Experiment-Verzeichnis (kann auch je nach Inferenz-Setup angepasst werden)
experiment_group = config["data"]["experiment_group"]
experiment_id = config["data"]["experiment_id"]

# Beispiel: Pfad für TensorBoard Logs
tensorboard_dir = Path(f"/dss/dsshome1/08/{USER}/{experiment_group}/tensorboard_logs/{experiment_id}")
tensorboard_dir.mkdir(parents=True, exist_ok=True)

# Beispiel: Pfad für Checkpoints
checkpoint_dir = Path(f"/dss/dsshome1/08/{USER}/{experiment_group}/checkpoints")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Beispiel: Lade Checkpoint (falls notwendig)
checkpoint_path = checkpoint_dir / f"{experiment_group}_{experiment_id}_best_siamese_unet_state.pth"
if checkpoint_path.exists():
    model.load_state_dict(torch.load(checkpoint_path))
    print(f"Checkpoint geladen: {checkpoint_path}")
else:
    print(f"Kein Checkpoint unter {checkpoint_path} gefunden.")
    

In [None]:
from sklearn.metrics import accuracy_score, f1_score

In [1]:
import torch
import numpy as np
import yaml
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from tqdm import tqdm

from utils.helperfunctions import get_data_folder
from utils.dataset import xView2Dataset, collate_fn, transform, image_transform
from model.siameseNetwork import SiameseUnet
from utils.helperfunctions import find_best_checkpoint, load_checkpoint

import os
base_dir = Path(os.getcwd()).parent  # Gehe einen Ordner zurück vom aktuellen Arbeitsverzeichnis
config_path = base_dir / "notebooks" / "00_config.yaml"
with open(config_path, "r") as file:
    config = yaml.safe_load(file)

# Pfade einrichten
USER = config["data"]["user"]
USER_HOME_PATH = Path(f"/dss/dsshome1/08/{USER}")
EXPERIMENT_GROUP = config["data"]["experiment_group"]
EXPERIMENT_ID = config["data"]["experiment_id"]
CHECKPOINTS_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "checkpoints"

# Test-Daten Ordner abrufen
DATA_ROOT, EVAL_ROOT, EVAL_IMG, EVAL_LABEL, EVAL_TARGET, EVAL_PNG_IMAGES = get_data_folder(
    config["data"]["validation_name"], 
    main_dataset=config["data"]["use_main_dataset"]
)

# Test-Dataset erstellen
test_dataset = xView2Dataset(
    png_path=EVAL_PNG_IMAGES, 
    target_path=EVAL_TARGET, 
    transform=transform(), 
    image_transform=image_transform()
)

# Gerät festlegen
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Verwende Gerät: {device}")

# Modell erstellen
model = SiameseUnet(num_pre_classes=2, num_post_classes=6)
if torch.cuda.device_count() > 1:
    print(f"Verwende {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(model)
model = model.to(device)

best_checkpoint_path = find_best_checkpoint(CHECKPOINTS_DIR, EXPERIMENT_ID)
# Besten Checkpoint laden
model = load_checkpoint(model, best_checkpoint_path)
#model.eval()

# Dataloader erstellen
test_dataloader = DataLoader(
    test_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn,
    pin_memory=True
)

def evaluate_model(model, dataloader, device):
    model.eval()
    all_pre_preds = []
    all_pre_true = []
    all_post_preds = []
    all_post_true = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluiere"):
            # Batch auspacken
            pre_images, post_images, pre_labels, post_labels = batch

            # Auf Gerät verschieben
            pre_images = pre_images.to(device)
            post_images = post_images.to(device)
            pre_labels = pre_labels.to(device)
            post_labels = post_labels.to(device)

            # Forward pass
            outputs = model(pre_images, post_images)

            # Annahme: Das Modell gibt das fusionierte Ergebnis zurück
            # → Splitte die Vorhersagen in pre/post
            pre_output = outputs[:, :2, :, :]   # erste 2 Kanäle → pre
            post_output = outputs[:, 2:, :, :]  # restliche 4 Kanäle → post

            # Vorhersagen abrufen
            pre_preds = torch.argmax(pre_output, dim=1)      # (B, H, W)
            post_preds = torch.argmax(post_output, dim=1)

            # Labels auf richtige Form bringen
            pre_labels = pre_labels.squeeze(1).long()        # (B, 1, H, W) → (B, H, W)
            post_labels = post_labels.squeeze(1).long()

            # Flatten und sammeln
            all_pre_preds.extend(pre_preds.view(-1).cpu().numpy())
            all_pre_true.extend(pre_labels.view(-1).cpu().numpy())
            all_post_preds.extend(post_preds.view(-1).cpu().numpy())
            all_post_true.extend(post_labels.view(-1).cpu().numpy())

    return {
        'pre_preds': np.array(all_pre_preds),
        'pre_true': np.array(all_pre_true),
        'post_preds': np.array(all_post_preds),
        'post_true': np.array(all_post_true)
    }

# Evaluierung durchführen
print("Starte Modell-Evaluierung...")
results = evaluate_model(model, test_dataloader, device)
print("pre_preds shape:", results['pre_preds'].shape)
print("pre_true shape:", results['pre_true'].shape)
print("unique pre_preds:", np.unique(results['pre_preds']))
print("unique pre_true:", np.unique(results['pre_true']))
# Kennzahlen berechnen
pre_accuracy = accuracy_score(results['pre_true'], results['pre_preds'])
pre_f1_weighted = f1_score(results['pre_true'], results['pre_preds'], average='weighted')
pre_f1_macro = f1_score(results['pre_true'], results['pre_preds'], average='macro')
pre_cm = confusion_matrix(results['pre_true'], results['pre_preds'])

post_accuracy = accuracy_score(results['post_true'], results['post_preds'])
post_f1_weighted = f1_score(results['post_true'], results['post_preds'], average='weighted')
post_f1_macro = f1_score(results['post_true'], results['post_preds'], average='macro')
post_cm = confusion_matrix(results['post_true'], results['post_preds'])

# Ergebnisse anzeigen
print("\nPre-Disaster Performance:")
print(f"Accuracy: {pre_accuracy:.4f}")
print(f"F1 Score (weighted): {pre_f1_weighted:.4f}")
print(f"F1 Score (macro): {pre_f1_macro:.4f}")
print("Confusion Matrix:")
print(pre_cm)

print("\nPost-Disaster Performance:")
print(f"Accuracy: {post_accuracy:.4f}")
print(f"F1 Score (weighted): {post_f1_weighted:.4f}")
print(f"F1 Score (macro): {post_f1_macro:.4f}")
print("Confusion Matrix:")
print(post_cm)

# Ergebnisse speichern
results_dir = USER_HOME_PATH / EXPERIMENT_GROUP / "evaluation_results"
results_dir.mkdir(parents=True, exist_ok=True)
result_file = results_dir / f"{EXPERIMENT_ID}_evaluation_results.txt"

with open(result_file, 'w') as f:
    f.write("MODEL EVALUATION RESULTS\n")
    f.write("=======================\n\n")
    
    f.write("Pre-Disaster Performance:\n")
    f.write(f"Accuracy: {pre_accuracy:.4f}\n")
    f.write(f"F1 Score (weighted): {pre_f1_weighted:.4f}\n")
    f.write(f"F1 Score (macro): {pre_f1_macro:.4f}\n")
    f.write("Confusion Matrix:\n")
    f.write(str(pre_cm) + "\n\n")
    
    f.write("Post-Disaster Performance:\n")
    f.write(f"Accuracy: {post_accuracy:.4f}\n")
    f.write(f"F1 Score (weighted): {post_f1_weighted:.4f}\n")
    f.write(f"F1 Score (macro): {post_f1_macro:.4f}\n")
    f.write("Confusion Matrix:\n")
    f.write(str(post_cm) + "\n")

print(f"\nEvaluierungsergebnisse gespeichert in: {result_file}")

Verwende Gerät: cuda
Verwende 3 GPUs!
Loaded raw state_dict from /dss/dsshome1/08/di97ren/xView2_Subset/checkpoints/003_best_siamese_unet_state.pth
Checkpoint erfolgreich in DataParallel-Modell geladen.
Starte Modell-Evaluierung...


Evaluiere:  79%|███████▉  | 93/117 [02:26<00:34,  1.43s/it]

: 

: 

In [1]:
import gc
import json
import torch
import torch.nn as nn
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall, MulticlassF1Score, MulticlassJaccardIndex

from utils.helperfunctions import get_data_folder
from utils.dataset import xView2Dataset, collate_fn, transform, image_transform
from model.siameseNetwork import SiameseUnet
from model.loss import FocalLoss, combined_loss_function

def evaluate_model():
    # Lade Konfiguration
    base_dir = Path(os.getcwd()).parent  # Gehe einen Ordner zurück vom aktuellen Arbeitsverzeichnis
    config_path = base_dir / "notebooks" / "00_config.yaml"
    print(f"Lade Konfiguration von: {config_path}")

    with open(config_path, "r") as file:
        config = yaml.safe_load(file)

    # Datenpfade für Testset
    DATA_ROOT, TEST_ROOT, TEST_IMG, TEST_LABEL, TEST_TARGET, TEST_PNG_IMAGES = get_data_folder(
        config["data"]["test_name"], 
        main_dataset=config["data"]["use_main_dataset"]
    )

    USER = config["data"]["user"]
    USER_HOME_PATH = Path(f"/dss/dsshome1/08/{USER}")

    # Verzeichnisse einrichten
    EXPERIMENT_GROUP = config["data"]["experiment_group"]
    EXPERIMENT_ID = config["data"]["experiment_id"]
    CHECKPOINTS_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "checkpoints"
    EVALUATION_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "evaluation" / EXPERIMENT_ID
    EVALUATION_DIR.mkdir(parents=True, exist_ok=True)

    print(f"Evaluationsergebnisse werden gespeichert in: {EVALUATION_DIR}")

    # Device Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Verwende Device: {device}")

    # Lade Klassenzählungen für die Loss-Berechnung
    class_counts_path = os.path.join(base_dir, "precalculations", "class_counts.json")
    with open(class_counts_path, 'r') as f:
        class_counts = json.load(f)
        # Anpassen an das tatsächliche Format der JSON-Datei
        pre_counts = class_counts["pre"]
        post_counts = class_counts["post"]

    # Berechne Klassengewichte wie im Training
    pre_weights = {}
    post_weights = {}

    # Konvertiere Strings zu Integers für die Zählungen
    pre_counts = {int(k): int(v) for k, v in pre_counts.items()}
    post_counts = {int(k): int(v) for k, v in post_counts.items()}

    total_pre = sum(pre_counts.values())
    for cls, count in pre_counts.items():
        pre_weights[cls] = 1.0 / (count / total_pre) if count > 0 else 1.0

    total_post = sum(post_counts.values())
    for cls, count in post_counts.items():
        post_weights[cls] = 1.0 / (count / total_post) if count > 0 else 1.0

    # Konvertiere zu Tensoren
    class_weights_pre = torch.tensor([
        pre_weights.get(0, 1.0), 
        pre_weights.get(1, 10.0)
    ], device=device)

    class_weights_post = torch.tensor([
        post_weights.get(0, 1.0), 
        post_weights.get(1, 10.0),
        post_weights.get(2, 30.0),
        post_weights.get(3, 20.0),
        post_weights.get(4, 50.0),
        post_weights.get(5, 100.0)
    ], device=device)

    # Initialisiere Loss-Funktionen
    focal_loss_pre = FocalLoss(alpha=class_weights_pre, gamma=config["focal_loss"]["gamma"])
    focal_loss_post = FocalLoss(alpha=class_weights_post, gamma=config["focal_loss"]["gamma"])

    # Erstelle Test-Dataset
    test_dataset = xView2Dataset(
        png_path=TEST_PNG_IMAGES, 
        target_path=TEST_TARGET, 
        transform=transform(), 
        image_transform=image_transform()
    )

    # Konfiguriere Dataloader
    if device.type == "cuda":
        num_workers = torch.cuda.device_count() * config["dataloader"]["num_workers_multiplier"]
    else:
        num_workers = 3

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=config["training"]["batch_size"],
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=config["dataloader"]["pin_memory"]
    )

    # Lade Modell
    model = SiameseUnet(num_pre_classes=2, num_post_classes=6)
    if torch.cuda.device_count() > 1:
        print(f"Verwende {torch.cuda.device_count()} GPUs!")
        model = torch.nn.DataParallel(model)

    # Lade besten Checkpoint
    checkpoint_path = CHECKPOINTS_DIR / f"{EXPERIMENT_ID}_best_siamese_unet_state.pth"
    print(f"Lade Modell von: {checkpoint_path}")
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model = model.to(device)
    model.eval()

    # Definiere Evaluationsmetriken
    precision_pre = MulticlassPrecision(num_classes=2).to(device)
    recall_pre = MulticlassRecall(num_classes=2).to(device)
    f1_pre = MulticlassF1Score(num_classes=2).to(device)
    iou_pre = MulticlassJaccardIndex(num_classes=2).to(device)

    precision_post = MulticlassPrecision(num_classes=6).to(device)
    recall_post = MulticlassRecall(num_classes=6).to(device)
    f1_post = MulticlassF1Score(num_classes=6).to(device)
    iou_post = MulticlassJaccardIndex(num_classes=6).to(device)

    # Klassennamen für bessere Lesbarkeit
    pre_class_names = ["No Building", "Building"]
    post_class_names = ["No Building", "Minor Damage", "Major Damage", "Destroyed", "Flooded", "Other Damage"]

    # Evaluiere das Modell
    model.eval()
    test_loss = 0.0
    
    print("Starte Evaluation...")
    with torch.no_grad():
        for i, (pre_imgs, post_imgs, pre_masks, post_masks) in enumerate(tqdm(test_dataloader)):
            X_pre = pre_imgs.to(device).float()
            X_post = post_imgs.to(device).float()
            y_pre = pre_masks.to(device)
            y_post = post_masks.to(device)
            
            # Vorbereitung der Masken für die Metriken
            y_pre_metric = y_pre.squeeze(1).long()
            y_post_metric = y_post.squeeze(1).long()
            
            # Forward pass
            pred = model(X_pre, X_post)
            
            # Berechne Loss
            loss = combined_loss_function(pred, y_pre_metric, y_post_metric, focal_loss_pre, focal_loss_post)
            test_loss += loss.item()
            
            # Aktualisiere Metriken
            precision_pre.update(pred[:, :2], y_pre_metric)
            recall_pre.update(pred[:, :2], y_pre_metric)
            f1_pre.update(pred[:, :2], y_pre_metric)
            iou_pre.update(pred[:, :2], y_pre_metric)
            
            precision_post.update(pred[:, 2:], y_post_metric)
            recall_post.update(pred[:, 2:], y_post_metric)
            f1_post.update(pred[:, 2:], y_post_metric)
            iou_post.update(pred[:, 2:], y_post_metric)
            
            # Speichere ein paar Beispielbilder für Visualisierung (optional)
            if i == 0:
                # Hier können Beispielbilder gespeichert werden (wird aus Gründen der Einfachheit weggelassen)
                pass
    
    # Berechne durchschnittlichen Loss
    avg_test_loss = test_loss / len(test_dataloader)
    
    # Berechne finale Metriken
    precision_pre_value = precision_pre.compute()
    recall_pre_value = recall_pre.compute()
    f1_pre_value = f1_pre.compute()
    iou_pre_value = iou_pre.compute()
    
    precision_post_value = precision_post.compute()
    recall_post_value = recall_post.compute()
    f1_post_value = f1_post.compute()
    iou_post_value = iou_post.compute()
    
    # Speichere Ergebnisse
    with open(EVALUATION_DIR / "evaluation_results.txt", "w") as f:
        f.write(f"Test Loss: {avg_test_loss:.4f}\n\n")
        
        f.write("Pre-Disaster Metriken:\n")
        f.write(f"Durchschnittliche Precision: {precision_pre_value.mean().item():.4f}\n")
        f.write(f"Durchschnittliche Recall: {recall_pre_value.mean().item():.4f}\n")
        f.write(f"Durchschnittlicher F1-Score: {f1_pre_value.mean().item():.4f}\n")
        f.write(f"Durchschnittlicher IoU: {iou_pre_value.mean().item():.4f}\n\n")
        
        f.write("Post-Disaster Metriken:\n")
        f.write(f"Durchschnittliche Precision: {precision_post_value.mean().item():.4f}\n")
        f.write(f"Durchschnittliche Recall: {recall_post_value.mean().item():.4f}\n")
        f.write(f"Durchschnittlicher F1-Score: {f1_post_value.mean().item():.4f}\n")
        f.write(f"Durchschnittlicher IoU: {iou_post_value.mean().item():.4f}\n\n")
        
        # Klassenspezifische Metriken (Pre-Disaster)
        f.write("Klassenspezifische Metriken (Pre-Disaster):\n")
        for i, class_name in enumerate(pre_class_names):
            f.write(f"Klasse {i} ({class_name}):\n")
            f.write(f"  Precision: {precision_pre_value[i].item():.4f}\n")
            f.write(f"  Recall: {recall_pre_value[i].item():.4f}\n")
            f.write(f"  F1-Score: {f1_pre_value[i].item():.4f}\n")
            f.write(f"  IoU: {iou_pre_value[i].item():.4f}\n\n")
        
        # Klassenspezifische Metriken (Post-Disaster)
        f.write("Klassenspezifische Metriken (Post-Disaster):\n")
        for i, class_name in enumerate(post_class_names):
            f.write(f"Klasse {i} ({class_name}):\n")
            f.write(f"  Precision: {precision_post_value[i].item():.4f}\n")
            f.write(f"  Recall: {recall_post_value[i].item():.4f}\n")
            f.write(f"  F1-Score: {f1_post_value[i].item():.4f}\n")
            f.write(f"  IoU: {iou_post_value[i].item():.4f}\n\n")
    
    # Drucke zusammenfassende Metriken
    print("\nEvaluationsergebnisse:")
    print(f"Test Loss: {avg_test_loss:.4f}")
    print("\nPre-Disaster Metriken:")
    print(f"Durchschnittliche Precision: {precision_pre_value.mean().item():.4f}")
    print(f"Durchschnittliche Recall: {recall_pre_value.mean().item():.4f}")
    print(f"Durchschnittlicher F1-Score: {f1_pre_value.mean().item():.4f}")
    print(f"Durchschnittlicher IoU: {iou_pre_value.mean().item():.4f}")
    
    print("\nPost-Disaster Metriken:")
    print(f"Durchschnittliche Precision: {precision_post_value.mean().item():.4f}")
    print(f"Durchschnittliche Recall: {recall_post_value.mean().item():.4f}")
    print(f"Durchschnittlicher F1-Score: {f1_post_value.mean().item():.4f}")
    print(f"Durchschnittlicher IoU: {iou_post_value.mean().item():.4f}")
    
    print(f"\nAusführliche Ergebnisse wurden gespeichert in: {EVALUATION_DIR / 'evaluation_results.txt'}")
    
    # Optional: Erstelle einen einfachen Plot für die Visualisierung (kann erweitert werden)
    plt.figure(figsize=(10, 6))
    
    metrics = ['Precision', 'Recall', 'F1-Score', 'IoU']
    pre_values = [precision_pre_value.mean().item(), recall_pre_value.mean().item(), 
                 f1_pre_value.mean().item(), iou_pre_value.mean().item()]
    post_values = [precision_post_value.mean().item(), recall_post_value.mean().item(), 
                  f1_post_value.mean().item(), iou_post_value.mean().item()]
    
    x = np.arange(len(metrics))
    width = 0.35
    
    plt.bar(x - width/2, pre_values, width, label='Pre-Disaster')
    plt.bar(x + width/2, post_values, width, label='Post-Disaster')
    
    plt.xlabel('Metrik')
    plt.ylabel('Wert')
    plt.title('Zusammenfassung der Evaluationsmetriken')
    plt.xticks(x, metrics)
    plt.legend()
    
    plt.savefig(EVALUATION_DIR / "metrics_summary.png")
    print(f"Metriken-Zusammenfassung gespeichert als: {EVALUATION_DIR / 'metrics_summary.png'}")

if __name__ == "__main__":
    evaluate_model()

Lade Konfiguration von: /dss/dsshome1/08/di97ren/04-geo-oma24/xView2SiameseUNet/notebooks/00_config.yaml
Evaluationsergebnisse werden gespeichert in: /dss/dsshome1/08/di97ren/xView2_all_data/evaluation/001
Verwende Device: cpu




Lade Modell von: /dss/dsshome1/08/di97ren/xView2_all_data/checkpoints/001_best_siamese_unet_state.pth
Starte Evaluation...


  0%|          | 1/234 [02:03<7:57:43, 123.02s/it]


KeyboardInterrupt: 

In [None]:

# Load class weights for loss calculation
class_counts_path = os.path.join(base_dir, "precalculations", "class_counts.json")
with open(class_counts_path, 'r') as f:
    class_counts = json.load(f)
    pre_counts = class_counts["pre_counts"]
    post_counts = class_counts["post_counts"]

# Create class weights as in training
pre_weights = {}
post_weights = {}

# Calculate class weights (same logic as in training)
total_pre = sum(pre_counts.values())
for cls, count in pre_counts.items():
    pre_weights[int(cls)] = 1.0 / (count / total_pre) if count > 0 else 1.0

total_post = sum(post_counts.values())
for cls, count in post_counts.items():
    post_weights[int(cls)] = 1.0 / (count / total_post) if count > 0 else 1.0

# Convert to tensors
class_weights_pre = torch.tensor([
    pre_weights.get(0, 1.0), 
    pre_weights.get(1, 10.0)
], device=device)

class_weights_post = torch.tensor([
    post_weights.get(0, 1.0), 
    post_weights.get(1, 10.0),
    post_weights.get(2, 30.0),
    post_weights.get(3, 20.0),
    post_weights.get(4, 50.0),
    post_weights.get(5, 100.0)
], device=device)

# Initialize loss functions
focal_loss_pre = FocalLoss(alpha=class_weights_pre, gamma=config["focal_loss"]["gamma"])
focal_loss_post = FocalLoss(alpha=class_weights_post, gamma=config["focal_loss"]["gamma"])

# Create test dataset
test_dataset = xView2Dataset(
    png_path=TEST_PNG_IMAGES, 
    target_path=TEST_TARGET, 
    transform=transform(), 
    image_transform=image_transform()
)

# Configure dataloader
if device == "cuda":
    num_workers = torch.cuda.device_count() * config["dataloader"]["num_workers_multiplier"]
else:
    num_workers = 4

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
    pin_memory=config["dataloader"]["pin_memory"]
)

# Load model
model = SiameseUnet(num_pre_classes=2, num_post_classes=6)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(model)

# Load best checkpoint
checkpoint_path = CHECKPOINTS_DIR / f"{EXPERIMENT_ID}_best_siamese_unet_state.pth"
print(f"Loading model from: {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model = model.to(device)
model.eval()

# Define evaluation metrics
precision_pre = MulticlassPrecision(num_classes=2).to(device)
recall_pre = MulticlassRecall(num_classes=2).to(device)
f1_pre = MulticlassF1Score(num_classes=2).to(device)
iou_pre = MulticlassJaccardIndex(num_classes=2).to(device)

precision_post = MulticlassPrecision(num_classes=6).to(device)
recall_post = MulticlassRecall(num_classes=6).to(device)
f1_post = MulticlassF1Score(num_classes=6).to(device)
iou_post = MulticlassJaccardIndex(num_classes=6).to(device)

# Class names for better readability
pre_class_names = ["No Building", "Building"]
post_class_names = ["No Building", "Minor Damage", "Major Damage", "Destroyed", "Flooded", "Other Damage"]

def evaluate_model():
    """Evaluate the model on the test set"""
    model.eval()
    test_loss = 0.0
    all_pre_preds = []
    all_pre_targets = []
    all_post_preds = []
    all_post_targets = []
    sample_images = []
    sample_count = 0
    
    print("Starting evaluation...")
    with torch.no_grad():
        for i, (pre_imgs, post_imgs, pre_masks, post_masks) in enumerate(tqdm(test_dataloader)):
            X_pre = pre_imgs.to(device).float()
            X_post = post_imgs.to(device).float()
            y_pre = pre_masks.to(device)
            y_post = post_masks.to(device)
            
            # Prepare masks for metrics
            y_pre_metric = y_pre.squeeze(1).long()
            y_post_metric = y_post.squeeze(1).long()
            
            # Forward pass
            pred = model(X_pre, X_post)
            
            # Calculate loss
            loss = combined_loss_function(pred, y_pre_metric, y_post_metric, focal_loss_pre, focal_loss_post)
            test_loss += loss.item()
            
            # Get predictions
            pre_pred = torch.argmax(pred[:, :2], dim=1)
            post_pred = torch.argmax(pred[:, 2:], dim=1)
            
            # Store batch predictions and targets for overall metrics
            all_pre_preds.append(pre_pred.cpu())
            all_pre_targets.append(y_pre_metric.cpu())
            all_post_preds.append(post_pred.cpu())
            all_post_targets.append(y_post_metric.cpu())
            
            # Update metrics
            precision_pre.update(pred[:, :2], y_pre_metric)
            recall_pre.update(pred[:, :2], y_pre_metric)
            f1_pre.update(pred[:, :2], y_pre_metric)
            iou_pre.update(pred[:, :2], y_pre_metric)
            
            precision_post.update(pred[:, 2:], y_post_metric)
            recall_post.update(pred[:, 2:], y_post_metric)
            f1_post.update(pred[:, 2:], y_post_metric)
            iou_post.update(pred[:, 2:], y_post_metric)
            
            # Save some sample images for visualization (first 5 batches)
            if i < 5:
                for j in range(min(2, len(pre_imgs))):  # Take 2 samples from each batch
                    sample_images.append({
                        'pre_img': pre_imgs[j].cpu().numpy(),
                        'post_img': post_imgs[j].cpu().numpy(),
                        'pre_mask': pre_masks[j].cpu().numpy(),
                        'post_mask': post_masks[j].cpu().numpy(),
                        'pre_pred': pre_pred[j].cpu().numpy(),
                        'post_pred': post_pred[j].cpu().numpy()
                    })
    
    # Calculate average loss
    avg_test_loss = test_loss / len(test_dataloader)
    
    # Compute final metrics
    precision_pre_value = precision_pre.compute()
    recall_pre_value = recall_pre.compute()
    f1_pre_value = f1_pre.compute()
    iou_pre_value = iou_pre.compute()
    
    precision_post_value = precision_post.compute()
    recall_post_value = recall_post.compute()
    f1_post_value = f1_post.compute()
    iou_post_value = iou_post.compute()
    
    # Prepare metrics for reporting
    metrics = {
        'test_loss': avg_test_loss,
        'pre_disaster': {
            'class_names': pre_class_names,
            'precision': precision_pre_value.cpu().numpy(),
            'recall': recall_pre_value.cpu().numpy(),
            'f1_score': f1_pre_value.cpu().numpy(),
            'iou': iou_pre_value.cpu().numpy(),
            'avg_precision': precision_pre_value.mean().item(),
            'avg_recall': recall_pre_value.mean().item(),
            'avg_f1': f1_pre_value.mean().item(),
            'avg_iou': iou_pre_value.mean().item()
        },
        'post_disaster': {
            'class_names': post_class_names,
            'precision': precision_post_value.cpu().numpy(),
            'recall': recall_post_value.cpu().numpy(),
            'f1_score': f1_post_value.cpu().numpy(),
            'iou': iou_post_value.cpu().numpy(),
            'avg_precision': precision_post_value.mean().item(),
            'avg_recall': recall_post_value.mean().item(),
            'avg_f1': f1_post_value.mean().item(),
            'avg_iou': iou_post_value.mean().item()
        }
    }
    
    # Create confusion matrices
    all_pre_preds = torch.cat(all_pre_preds).cpu().numpy()
    all_pre_targets = torch.cat(all_pre_targets).cpu().numpy()
    all_post_preds = torch.cat(all_post_preds).cpu().numpy()
    all_post_targets = torch.cat(all_post_targets).cpu().numpy()
    
    return metrics, sample_images, (all_pre_preds, all_pre_targets, all_post_preds, all_post_targets)

def save_metrics(metrics):
    """Save metrics to file"""
    # Overall metrics
    with open(RESULTS_DIR / "metrics_summary.txt", "w") as f:
        f.write(f"Test Loss: {metrics['test_loss']:.4f}\n\n")
        
        f.write("Pre-Disaster Metrics:\n")
        f.write(f"Average Precision: {metrics['pre_disaster']['avg_precision']:.4f}\n")
        f.write(f"Average Recall: {metrics['pre_disaster']['avg_recall']:.4f}\n")
        f.write(f"Average F1-Score: {metrics['pre_disaster']['avg_f1']:.4f}\n")
        f.write(f"Average IoU: {metrics['pre_disaster']['avg_iou']:.4f}\n\n")
        
        f.write("Post-Disaster Metrics:\n")
        f.write(f"Average Precision: {metrics['post_disaster']['avg_precision']:.4f}\n")
        f.write(f"Average Recall: {metrics['post_disaster']['avg_recall']:.4f}\n")
        f.write(f"Average F1-Score: {metrics['post_disaster']['avg_f1']:.4f}\n")
        f.write(f"Average IoU: {metrics['post_disaster']['avg_iou']:.4f}\n")
    
    # Class-specific metrics (pre-disaster)
    pre_metrics_df = pd.DataFrame({
        'Class': metrics['pre_disaster']['class_names'],
        'Precision': metrics['pre_disaster']['precision'],
        'Recall': metrics['pre_disaster']['recall'],
        'F1-Score': metrics['pre_disaster']['f1_score'],
        'IoU': metrics['pre_disaster']['iou']
    })
    pre_metrics_df.to_csv(RESULTS_DIR / "pre_disaster_metrics.csv", index=False)
    
    # Class-specific metrics (post-disaster)
    post_metrics_df = pd.DataFrame({
        'Class': metrics['post_disaster']['class_names'],
        'Precision': metrics['post_disaster']['precision'],
        'Recall': metrics['post_disaster']['recall'],
        'F1-Score': metrics['post_disaster']['f1_score'],
        'IoU': metrics['post_disaster']['iou']
    })
    post_metrics_df.to_csv(RESULTS_DIR / "post_disaster_metrics.csv", index=False)
    
    print(f"Metrics saved to {RESULTS_DIR}")

def plot_confusion_matrices(prediction_data):
    """Plot and save confusion matrices"""
    pre_preds, pre_targets, post_preds, post_targets = prediction_data
    
    # Compute confusion matrices
    cm_pre = confusion_matrix(pre_targets.flatten(), pre_preds.flatten())
    cm_post = confusion_matrix(post_targets.flatten(), post_preds.flatten())
    
    # Plot pre-disaster confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_pre, annot=True, fmt='d', cmap='Blues', 
                xticklabels=pre_class_names, yticklabels=pre_class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Pre-Disaster Confusion Matrix')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / "pre_disaster_confusion_matrix.png")
    
    # Plot post-disaster confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm_post, annot=True, fmt='d', cmap='Blues',
                xticklabels=post_class_names, yticklabels=post_class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Post-Disaster Confusion Matrix')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / "post_disaster_confusion_matrix.png")
    
    print(f"Confusion matrices saved to {RESULTS_DIR}")

def visualize_samples(sample_images):
    """Visualize sample predictions"""
    for i, sample in enumerate(sample_images):
        fig, axs = plt.subplots(2, 3, figsize=(18, 12))
        
        # Display pre-disaster image
        axs[0, 0].imshow(np.transpose(sample['pre_img'], (1, 2, 0)))
        axs[0, 0].set_title('Pre-Disaster Image')
        axs[0, 0].axis('off')
        
        # Display pre-disaster ground truth
        axs[0, 1].imshow(sample['pre_mask'].squeeze(), cmap='tab10', vmin=0, vmax=1)
        axs[0, 1].set_title('Pre-Disaster Ground Truth')
        axs[0, 1].axis('off')
        
        # Display pre-disaster prediction
        axs[0, 2].imshow(sample['pre_pred'], cmap='tab10', vmin=0, vmax=1)
        axs[0, 2].set_title('Pre-Disaster Prediction')
        axs[0, 2].axis('off')
        
        # Display post-disaster image
        axs[1, 0].imshow(np.transpose(sample['post_img'], (1, 2, 0)))
        axs[1, 0].set_title('Post-Disaster Image')
        axs[1, 0].axis('off')
        
        # Display post-disaster ground truth
        axs[1, 1].imshow(sample['post_mask'].squeeze(), cmap='tab10', vmin=0, vmax=5)
        axs[1, 1].set_title('Post-Disaster Ground Truth')
        axs[1, 1].axis('off')
        
        # Display post-disaster prediction
        axs[1, 2].imshow(sample['post_pred'], cmap='tab10', vmin=0, vmax=5)
        axs[1, 2].set_title('Post-Disaster Prediction')
        axs[1, 2].axis('off')
        
        plt.tight_layout()
        plt.savefig(VISUALIZATION_DIR / f"sample_{i+1}.png")
        plt.close()
    
    print(f"Sample visualizations saved to {VISUALIZATION_DIR}")

def create_class_distribution_plots(prediction_data):
    """Create and save class distribution plots"""
    pre_preds, pre_targets, post_preds, post_targets = prediction_data
    
    # Plot pre-disaster class distribution
    plt.figure(figsize=(10, 6))
    pre_target_counts = np.bincount(pre_targets.flatten(), minlength=2)
    pre_pred_counts = np.bincount(pre_preds.flatten(), minlength=2)
    
    x = np.arange(len(pre_class_names))
    width = 0.35
    
    plt.bar(x - width/2, pre_target_counts, width, label='Ground Truth')
    plt.bar(x + width/2, pre_pred_counts, width, label='Predicted')
    
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title('Pre-Disaster Class Distribution')
    plt.xticks(x, pre_class_names)
    plt.legend()
    plt.savefig(RESULTS_DIR / "pre_disaster_class_distribution.png")
    
    # Plot post-disaster class distribution
    plt.figure(figsize=(12, 6))
    post_target_counts = np.bincount(post_targets.flatten(), minlength=6)
    post_pred_counts = np.bincount(post_preds.flatten(), minlength=6)
    
    x = np.arange(len(post_class_names))
    
    plt.bar(x - width/2, post_target_counts, width, label='Ground Truth')
    plt.bar(x + width/2, post_pred_counts, width, label='Predicted')
    
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title('Post-Disaster Class Distribution')
    plt.xticks(x, post_class_names, rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / "post_disaster_class_distribution.png")
    
    print(f"Class distribution plots saved to {RESULTS_DIR}")

def main():
    print("Starting model evaluation...")
    
    # Evaluate model
    metrics, sample_images, prediction_data = evaluate_model()
    
    # Print summary metrics
    print("\nEvaluation Results:")
    print(f"Test Loss: {metrics['test_loss']:.4f}")
    print("\nPre-Disaster Metrics:")
    print(f"Average Precision: {metrics['pre_disaster']['avg_precision']:.4f}")
    print(f"Average Recall: {metrics['pre_disaster']['avg_recall']:.4f}")
    print(f"Average F1-Score: {metrics['pre_disaster']['avg_f1']:.4f}")
    print(f"Average IoU: {metrics['pre_disaster']['avg_iou']:.4f}")
    
    print("\nPost-Disaster Metrics:")
    print(f"Average Precision: {metrics['post_disaster']['avg_precision']:.4f}")
    print(f"Average Recall: {metrics['post_disaster']['avg_recall']:.4f}")
    print(f"Average F1-Score: {metrics['post_disaster']['avg_f1']:.4f}")
    print(f"Average IoU: {metrics['post_disaster']['avg_iou']:.4f}")
    
    # Save metrics
    save_metrics(metrics)
    
    # Create visualizations
    plot_confusion_matrices(prediction_data)
    visualize_samples(sample_images)
    create_class_distribution_plots(prediction_data)
    
    print("\nEvaluation completed successfully!")

if __name__ == "__main__":
    main()