In [None]:
%cd ../..

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from itertools import product
from collections import OrderedDict


from PIL import Image
import os
import models.ViCCT.ViCCTModels  # Need to register the models!
import models.ViCCT.ViCCTModelsFunctional  # Need to register the models!

from timm.models import create_model
from datasets.dataset_utils import img_equal_unsplit
import importlib
import time

In [None]:
model_name = 'ViCCT_small'  # Must be something like 'deit_small_distilled_patch16_224'.
trained_model_path = 'D:\\Bureaublad\\save_state_ep_250_new_best_MAE_4.971.pth'  # The path to trained model file (something like XYZ.pth)
label_factor = 3000  # The label factor used to train this specific model.
dataset = 'WE_ViCCT_Meta'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

In [None]:
loss_fn = torch.nn.MSELoss()

In [None]:
dataloader = importlib.import_module(f'datasets.meta.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.meta.{dataset}.settings').cfg_data

train_loaders, val_loaders, test_loaders, restore_transform = dataloader()
if set_to_eval == 'val':
    my_dataloaders = val_loaders
elif set_to_eval == 'test':
    my_dataloaders = test_loaders
else:
    print(f'Error: invalid set --> {set_to_eval}')

In [None]:
save_dir = None
if save_results:
    save_folder = 'DeiT_meta' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_dir = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_dir)

In [None]:
model = create_model(
        model_name,
        init_path=None,
        num_classes=1000,  # Not yet used anyway. Must match pretrained model!
        drop_rate=0.,
        drop_path_rate=0.,  
        drop_block_rate=None,
    )


model.cuda()

resume_state = torch.load(trained_model_path)
model.load_state_dict(resume_state['net'])


model.eval()


In [None]:
def save_scene_graph(preds, gts, save_name):
    MAE = np.mean(np.abs(np.array(preds) - np.array(gts)))
    
#     save_path = os.path.join(save_dir, save_name)
    xs = np.arange(len(gts))
    plt.figure(figsize=(20,10))
    plt.title(f'MAE: {MAE:.3f}')
    plt.plot(xs, gts, color='green', label='GT')
    plt.plot(xs, preds, color='blue', label='Predictions')
    plt.legend()
#     plt.savefig(save_path)
    plt.show()

In [None]:
def eval_on_scene(model, scene_dataloader):
    model.eval()
    
    with torch.no_grad():
        preds = []
        gts = []
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors

        for idx, (img, img_patches, gt_patches) in enumerate(scene_dataloader):
            img_patches = img_patches.squeeze().cuda()
            gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
            img = img.squeeze()  # Remove batch dimension
            _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction
            
#             img = restore_transform(img)

            pred_den = model.forward(img_patches)  # Precicted density crops
            pred_den = pred_den.cpu()

            # Restore GT and Prediction
            gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            
            
            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            preds.append(pred_cnt)
            gts.append(gt_cnt)
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            relative_error = AEs[-1] / gt_cnt * 100
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return preds, gts, MAE, MSE

In [None]:
MAEs = []
MSEs = []
for idx, scene_dataloader in enumerate(my_dataloaders):
    print(f'Scene {idx + 1}')

    preds, gts, MAE, MSE = eval_on_scene(model, scene_dataloader)
    print(f'    MAE/MSE: {MAE:.3f}/{MSE:.3f}')
    MAEs.append(MAE)
    MSEs.append(MSE)

    save_scene_graph(preds, gts, f'scene_{idx + 1}.jpg')
    
overal_MAE = np.mean(MAEs)
overal_MSE = np.mean(MSEs)
print(f'avg MAE/MSE: {overal_MAE:.3f}/{overal_MSE:.3f}')
        