# Inference
This notebook is intended to depict the results of inference for the xView2 Dataset. 

In [3]:
from utils.inference_step import inference
from utils.helperfunctions import load_checkpoint, find_best_checkpoint, get_data_folder
from utils.dataset import xView2Dataset, collate_fn_test, image_transform
from pathlib import Path

## Get all pathes and create Dataset and DataLoader

In [4]:

DATA_ROOT, TEST_ROOT, VAL_IMG, TEST_LABEL, TEST_TARGET, TEST_PNG_IMAGES = get_data_folder("test", main_dataset = False)

test_dataset = xView2Dataset(png_path=TEST_PNG_IMAGES,
image_transform = image_transform(), inference = True)

test_dataloader = DataLoader(
    test_dataset, batch_size = 32,
    collate_fn = collate_fn_test,
    shuffle = False,
    num_workers = 5
)

In [None]:
USER = "di97ren"
#USER_PATH = Path(f"/dss/dsstbyfs02/pn49ci/pn49ci-dss-0022/users/{USER}")
USER_HOME_PATH = Path(f"/dss/dsshome1/08/{USER}")

# Pathes to store experiment informations in:
EXPERIMENT_GROUP = "xView2_Subset"
EXPERIMENT_ID = "003"
EXPERIMENT_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "tensorboard_logs" / EXPERIMENT_ID
EXPERIMENT_DIR.mkdir(parents=True, exist_ok=True)

print(EXPERIMENT_DIR)

# Auch Checkpoints-Verzeichnis erstellen
CHECKPOINTS_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "checkpoints"
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)

# Logfiles-Verzeichnis erstellen
LOGFILES_DIR = USER_HOME_PATH / EXPERIMENT_GROUP / "logfiles" / EXPERIMENT_ID
LOGFILES_DIR.mkdir(parents=True, exist_ok=True)
print(f"Logfiles werden gespeichert in: {LOGFILES_DIR}")


In [None]:


# Modell initialisieren
model = SiameseUnet(num_pre_classes=2, num_post_classes=6)
model.to(device)
device = "cuda" if torch.cuda.is_available() else "cpu"

best_checkpoint_path = find_best_checkpoint(CHECKPOINTS_DIR, EXPERIMENT_ID)
# Besten Checkpoint laden
model = load_checkpoint(model, best_checkpoint_path)
results = inference(model, test_dataloader)


# Plot the results:

In [None]:
from utils.viz import vizualize_predictions

In [None]:
vizualize_predictions(results)