## 0. Imports

In [None]:
#!pip install hydra-core
import torch
import matplotlib.pyplot as plt
import numpy as np

from hydra.core.global_hydra import GlobalHydra
from hydra import compose, initialize

import sys

sys.path.append("../")

from DCNN.datasets.base_dataset import BaseDataset
from DCNN.model import DCNN
from DCNN.trainer import DCNNLightniningModule

GlobalHydra.instance().clear()
initialize(config_path="../config")
config = compose("config")

MODEL_CHECKPOINT_PATH = "./weights-epoch=19-validation_loss=-17.90.ckpt"
NOISY_DATASET_PATH = "/Users/vtokala/Documents/Research/di_nn/Dataset/noisy_testset_1f"
CLEAN_DATASET_PATH = '/Users/vtokala/Documents/Research/di_nn/Dataset/clean_testset_1f'

## 1. Load model and dataset

In [None]:
dataset = BaseDataset(NOISY_DATASET_PATH,CLEAN_DATASET_PATH)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    drop_last=False,
    num_workers=1
)

dataloader = iter(dataloader)

model = DCNNLightniningModule(config)
model.eval()
torch.set_grad_enabled(False)
checkpoint = torch.load(MODEL_CHECKPOINT_PATH, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])


## 2. Evaluate dataset samples on model

In [None]:
while True:
    try:
        batch = next(dataloader)
    except StopIteration:
        break
    model_output = model(batch[0])[0].numpy()

    true_coords = batch[1]["source_coordinates"][0].numpy()
    
    print("True vs estimated coordinates:", true_coords, model_output)
    #print("Predicted coordinates:", model_output)
    print("Error (meters):", np.linalg.norm(true_coords - model_output))
    print("\n")