In [4]:
# Imports
import init
import torch
import os
from src.core.config import Conf
from src.statistics.statistics import TestSetValidation
from src.core.inference import RCNNInferenceHandler
from src.models.CNNPatternDetection import CNNPatternNetwork
from src.models.CNNRegionClassification import CNNRegionClassifier

DATA_ROOT = "../kaggle-processed"
assert os.path.exists(DATA_ROOT)

In [5]:
PATTERN_MODEL_PATH = "./checkpoints/best_model_pattern.h5"
REGION_MODEL_PATH = "./checkpoints/best_model_region.h5"
config = Conf(
    inference_overlap=32,
    vote_fraction_needed=0.5,
    region_classification_confidence=0.75
)

device = "cuda"

In [6]:

region_model = CNNRegionClassifier(config).to(device)
region_model.load_state_dict(torch.load(REGION_MODEL_PATH))
pattern_model = CNNPatternNetwork(config).to(device)
pattern_model.load_state_dict(torch.load(PATTERN_MODEL_PATH))


inference = RCNNInferenceHandler(
    config,
    region_model,
    pattern_model,
    device
)

validation_helper = TestSetValidation(
    config,
    os.path.join(DATA_ROOT, "valid"),
    inference
)

validation_helper.run()


2024-06-14 11:50:37.434 | INFO     | src.statistics.statistics:prep_data:177 - Preparing validation data
Preparing data: 100%|██████████| 644/644 [02:07<00:00,  5.05it/s]
Preparing statistics: 322it [00:00, 70400.64it/s]


{'TP': 8.507875409169289,
 'TN': 579.2517239338579,
 'FP': 10.375688066141509,
 'FN': 45.86471259083142,
 'avg_iou': 0.13139925764939675,
 'accuracy': 0.9126701853152595,
 'precision': 0.450543956933386,
 'recall': 0.15647361514535926,
 'specificity': 0.9824029754129859,
 'f1_score': 0.2322774330299505}