In [1]:
import gc 
from typing import Tuple, Dict, List
from pathlib import Path
import numpy as np
import os
import sys

from torch.utils.data import DataLoader  # Falls du mit einem DataLoader arbeitest

from utils.dataset import xView2Dataset, collate_fn_test, image_transform, transform

In [2]:
USER = "di97ren"
# keep the following unchanged
ROOT = Path("/dss/dsstbyfs02/pn49ci/pn49ci-dss-0022")
USER_HOME_PATH = Path(f"/dss/dsshome1/08/{USER}")
DATA_PATH = ROOT / "data"


# Configure the path to the xview2 dataset for your environment
DATASET_ROOT = DATA_PATH / "xview2-subset"



TEST_ROOT = DATASET_ROOT / "test"
TEST_IMG = TEST_ROOT / "png_images"



# Pathes to store the experiment information in:
EXPERIMENT_GROUP = "xView2_Subset"
EXPERIMENT_ID = "003"

EXPERIMENT_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "tensorboard_logs" / EXPERIMENT_ID
EXPERIMENT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINTS_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "checkpoints"
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)


In [3]:
test_dataset = xView2Dataset(png_path= TEST_IMG,
                 image_transform = image_transform(),
                 inference = True)


test_dataloader = DataLoader(
    test_dataset,
    batch_size=64,  # Kann größer sein als beim Training
    collate_fn=collate_fn_test,
    shuffle=False,  # Bei Inference nicht shuffeln
    num_workers=5
)

In [8]:
from utils.helperfunctions import find_best_checkpoint, load_checkpoint
import torch
from model.siameseNetwork import SiameseUnet
device = "cuda" if torch.cuda.is_available() else "cpu"

# Modell initialisieren
model = SiameseUnet(num_pre_classes=2, num_post_classes=6)
model.to(device)

best_checkpoint_path = find_best_checkpoint(CHECKPOINTS_DIR, EXPERIMENT_ID)

# Besten Checkpoint laden
model = load_checkpoint(model, best_checkpoint_path)

Loaded raw state_dict from /dss/dsshome1/08/di97ren/xView2_Subset/checkpoints/003_best_siamese_unet_state.pth
Checkpoint erfolgreich in DataParallel-Modell geladen.


In [12]:
# Inferenz durchführen
from utils.inference_step import inference
from utils.viz import visualize_predictions
results = inference(model, test_dataloader)
visualize_predictions(results, num_samples = 5, random_seed = None)

Outputs shape: torch.Size([50, 8, 1024, 1024])
Pre-outputs stats: min=-8.2386, max=150.5160, mean=4.7182
Post-outputs stats: min=-181.8403, max=124.4734, mean=-1.9284
Bild 0:
  Pre-Klasse 0: min=-0.7193, max=35.4962, mean=7.2498
  Pre-Klasse 1: min=-8.2386, max=13.9006, mean=0.8933
  Post-Klasse 0: min=-3.5084, max=32.9576, mean=8.3953
  Post-Klasse 1: min=-67.8368, max=0.5039, mean=-13.5436
  Post-Klasse 2: min=-35.5794, max=2.2669, mean=-5.8434
  Post-Klasse 3: min=-35.2704, max=0.2501, mean=-8.1684
  Post-Klasse 4: min=-16.9713, max=8.7846, mean=-1.4143
  Post-Klasse 5: min=-0.0734, max=31.7061, mean=6.0811
Sample at position (100,100):
  Pre-logits: tensor([15.9075,  6.2959], device='cuda:0')
  Post-logits: tensor([ 12.2762, -19.4789,  -6.7663, -10.3962,  -7.4509,   9.0879],
       device='cuda:0')
  Pre-probs: tensor([9.9993e-01, 6.6944e-05], device='cuda:0')
  Post-probs: tensor([9.6039e-01, 1.5537e-14, 5.1572e-09, 1.3676e-10, 2.6007e-09, 3.9610e-02],
       device='cuda:0')
Pre-