In [None]:
%cd ../..

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

from PIL import Image
import os
from models.CSRNet.CSRNet import CSRNet

import importlib
import time

In [None]:
# trained_model_path = 'D:\\OneDrive\\OneDrive - UvA\\ThesisData\\trained_models\\CSRNet TL SHTA\\save_state_ep_121_new_best_MAE_74.482.pth'  # The path to trained model file (something like XYZ.pth)
trained_model_path = 'D:\\OneDrive\\OneDrive - UvA\\ThesisData\\trained_models\\CSRNet TL SHTA new\\save_state_ep_140_new_best_MAE_84.146.pth'  # The path to trained model file (something like XYZ.pth)
# trained_model_path = 'D:\\Downloads\\PartAmodel_best.pth.tar'  # The path to trained model file (something like XYZ.pth)


label_factor = 100  # The label factor used to train this specific model.
dataset = 'SHTB_CSRNet'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

In [None]:
model = CSRNet()

resume_state = torch.load(trained_model_path)

# new_dict = {}
# for k, v in resume_state.items():
#     k = k[4:]
#     new_dict[k] = v
# model.load_state_dict(new_dict)

model.load_state_dict(resume_state['net'])
# model.load_state_dict(resume_state['state_dict'])


model.cuda()
model.eval()

In [None]:
dataloader = importlib.import_module(f'datasets.standard.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.standard.{dataset}.settings').cfg_data

train_loader, val_loader, test_loader, restore_transform = dataloader()
if set_to_eval == 'val' or set_to_eval == 'eval':
    my_dataloader = val_loader
elif set_to_eval == 'test':
    my_dataloader = test_loader
else:
    print(f'Error: invalid set --> {set_to_eval}')

In [None]:
save_path = None
if save_results:
    save_folder = 'CSRNet' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

In [None]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [None]:
def eval_model(model, my_dataloader, show_predictions, restore_transform, label_factor, cfg_data):
    with torch.no_grad():
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors
        GTs = []
        preds = []

        for idx, (img, gt) in enumerate(my_dataloader):
            img = img.cuda()
           
            den = model(img)  # Precicted density crops
            den = den.cpu()

            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            
            img = restore_transform(img.squeeze())  # Original image
            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            GTs.append(gt_cnt.item())
            preds.append(pred_cnt.item())
            relative_error = AEs[-1] / gt_cnt * 100
            print(f'IMG {idx:<3} '
                  f'Prediction: {pred_cnt:<9.3f} '
                  f'GT: {gt_cnt:<9.3f} '
                  f'Absolute Error: {AEs[-1]:<9.3f} '
                  f'Relative Error: {relative_error:.1f}%')
            
            if save_path:
                plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return MAE, MSE, GTs, preds

In [None]:
MAE, MSE, GTs, preds = eval_model(model, my_dataloader, save_path, restore_transform, label_factor, cfg_data)
print(f'MAE/MSE: {MAE:.3f}/{MSE:.3f}')

In [None]:
img_nrs = np.arange(len(GTs))
sorted_idxs = np.argsort(GTs)
GTs = np.array(GTs)
preds = np.array(preds)

plt.rcParams.update({'font.size': 14})

plt.figure()
plt.plot(img_nrs, GTs[sorted_idxs], label='Ground truths')
plt.plot(img_nrs, preds[sorted_idxs], label='Predictions')
plt.ylabel('Crowd count')
plt.xlabel('Sorted image')
plt.legend(loc=2, frameon=False)
plt.tight_layout()
plt.savefig(f'CSRNet_{dataset}_pred_vs_gt.jpg')
plt.show()


In [None]:
sorted_error_idxs = np.flip(np.argsort(np.abs(GTs - preds)))

with torch.no_grad():
    for idx in sorted_error_idxs[:10]:
        img, gt = my_dataloader.dataset.__getitem__(idx)
        img = img.unsqueeze(0)
        gt = gt.unsqueeze(0)
        img = img.cuda()

        den = model(img)  # Precicted density crops
        den = den.cpu()

        gt = gt.squeeze()  # Remove channel dim
        den = den.squeeze()  # Remove channel dim

        img = restore_transform(img.squeeze())  # Original image
        pred_cnt = den.sum() / label_factor
        gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR


        print(f'IMG {idx}, pred: {pred_cnt:.3f}, gt: {gt_cnt:.3f}. Error: {pred_cnt - gt_cnt:.3f}')
        
        plt.figure()
        plt.imshow(np.asarray(img))
        plt.title(f'GT count: {gt_cnt:.3f}')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(f'CSRNet_IMG_{idx + 1}_{dataset}.jpg')
        plt.show()
        
        
        plt.figure()
        plt.imshow(den.numpy(), cmap=cm.jet)
        plt.title(f'Predicted count: {pred_cnt:.3f}')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(f'CSRNet_IMG_{idx + 1}_{dataset}_prediction.jpg')
        plt.show()

In [None]:
np.mean(np.abs(GTs - preds)[sorted_error_idxs[-300:]])