In [1]:
%cd ..

C:\Users\Wight\PycharmProjects\ThesisMain


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

from PIL import Image
import os
import models.DeiTModels  # Need to register the models!
from timm.models import create_model
from datasets.dataset_utils import img_equal_unsplit
import importlib
import time

# Settings and Parameters
Here we define the DeiT model that we wish to evaluate and the corresponding parameters to evaluate it.

In [3]:
model_name = 'deit_small_distilled_patch16_224'  # Must be something like 'deit_small_distilled_patch16_224'.
trained_model_path = 'notebooks/save_state_ep_140_new_best_MAE_6.442.pth'  # The path to trained model file (something like XYZ.pth)
label_factor = 10000  # The label factor used to train this specific model.
dataset = 'SHTB_DeiT'  # Must be the exact name of the dataset
save_results = True  # When true, save the images, GTs and predictions. A folder for this is created automatically.

# Prepare for evaluation
Use the settings to load the DeiT model and dataloader for the test set. Also loads the transform with which we can restore the original images. Cuda is required!
If save_results is True, also create the directory in which the predictions are saved.

In [4]:
model = create_model(
        model_name,
        init_path=None,
        num_classes=1000,  # Not yet used anyway. Must match pretrained model!
        drop_rate=0.,
        drop_path_rate=0.1,  # TODO: What does this do?
        drop_block_rate=None,
    )

model.cuda()

resume_state = torch.load(trained_model_path)
model.load_state_dict(resume_state['net'])

model.eval()

DistilledRegressionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attenti

In [5]:
dataloader = importlib.import_module(f'datasets.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.{dataset}.settings').cfg_data

_, _, test_loader, restore_transform = dataloader(model.crop_size)

316 test images found.


In [6]:
save_path = None
if save_results:
    save_folder = 'DeiT' + '_' + dataset + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

# Evaluation loop and save funtion

In [7]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [8]:
def eval_model(model, test_loader, show_predictions, restore_transform, label_factor, cfg_data):
    with torch.no_grad():
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors

        for idx, (img, img_patches, gt_patches) in enumerate(test_loader):
            img_patches = img_patches.squeeze().cuda()
            gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
            img = img.squeeze()  # Remove batch dimension
            _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction
            
            img = restore_transform(img)

            pred_den = model(img_patches)  # Precicted density crops
            pred_den = pred_den.cpu()

            # Restore GT and Prediction
            gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim

            
            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            relative_error = AEs[-1] / gt_cnt * 100
            print(f'IMG {idx:<3} '
                  f'Prediction: {pred_cnt:<9.3f} '
                  f'GT: {gt_cnt:<9.3f} '
                  f'Absolute Error: {AEs[-1]:<9.3f} '
                  f'Relative Error: {relative_error:.1f}%')
            
            if save_path:
                plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return MAE, MSE

In [9]:
MAE, MSE = eval_model(model, test_loader, save_path, restore_transform, label_factor, cfg_data)
print(f'MAE: {MAE:<9.3f} Root MSE: {MSE:.3f}')

IMG 0   Prediction: 25.830    GT: 22.743    Absolute Error: 3.088     Relative Error: 13.6%
IMG 1   Prediction: 163.276   GT: 175.638   Absolute Error: 12.362    Relative Error: 7.0%
IMG 2   Prediction: 126.008   GT: 157.493   Absolute Error: 31.485    Relative Error: 20.0%
IMG 3   Prediction: 36.522    GT: 36.219    Absolute Error: 0.302     Relative Error: 0.8%
IMG 4   Prediction: 67.306    GT: 68.812    Absolute Error: 1.506     Relative Error: 2.2%
IMG 5   Prediction: 59.498    GT: 56.467    Absolute Error: 3.031     Relative Error: 5.4%
IMG 6   Prediction: 44.254    GT: 43.457    Absolute Error: 0.796     Relative Error: 1.8%
IMG 7   Prediction: 219.760   GT: 221.371   Absolute Error: 1.611     Relative Error: 0.7%
IMG 8   Prediction: 179.915   GT: 163.917   Absolute Error: 15.998    Relative Error: 9.8%
IMG 9   Prediction: 485.413   GT: 471.458   Absolute Error: 13.955    Relative Error: 3.0%
IMG 10  Prediction: 138.967   GT: 137.556   Absolute Error: 1.412     Relative Error: 1.

IMG 90  Prediction: 105.931   GT: 99.229    Absolute Error: 6.701     Relative Error: 6.8%
IMG 91  Prediction: 61.236    GT: 59.737    Absolute Error: 1.499     Relative Error: 2.5%
IMG 92  Prediction: 239.140   GT: 245.455   Absolute Error: 6.314     Relative Error: 2.6%
IMG 93  Prediction: 178.922   GT: 185.273   Absolute Error: 6.351     Relative Error: 3.4%
IMG 94  Prediction: 188.254   GT: 182.204   Absolute Error: 6.049     Relative Error: 3.3%
IMG 95  Prediction: 107.120   GT: 104.747   Absolute Error: 2.373     Relative Error: 2.3%
IMG 96  Prediction: 101.125   GT: 102.004   Absolute Error: 0.879     Relative Error: 0.9%
IMG 97  Prediction: 157.237   GT: 158.964   Absolute Error: 1.726     Relative Error: 1.1%
IMG 98  Prediction: 55.012    GT: 55.964    Absolute Error: 0.951     Relative Error: 1.7%
IMG 99  Prediction: 43.992    GT: 38.688    Absolute Error: 5.304     Relative Error: 13.7%
IMG 100 Prediction: 281.537   GT: 276.026   Absolute Error: 5.511     Relative Error: 2.0

IMG 180 Prediction: 180.122   GT: 182.070   Absolute Error: 1.949     Relative Error: 1.1%
IMG 181 Prediction: 141.193   GT: 137.000   Absolute Error: 4.193     Relative Error: 3.1%
IMG 182 Prediction: 93.908    GT: 94.777    Absolute Error: 0.869     Relative Error: 0.9%
IMG 183 Prediction: 31.781    GT: 35.590    Absolute Error: 3.809     Relative Error: 10.7%
IMG 184 Prediction: 23.386    GT: 26.776    Absolute Error: 3.390     Relative Error: 12.7%
IMG 185 Prediction: 56.298    GT: 57.832    Absolute Error: 1.534     Relative Error: 2.7%
IMG 186 Prediction: 61.856    GT: 60.757    Absolute Error: 1.099     Relative Error: 1.8%
IMG 187 Prediction: 75.006    GT: 78.197    Absolute Error: 3.191     Relative Error: 4.1%
IMG 188 Prediction: 325.767   GT: 304.351   Absolute Error: 21.416    Relative Error: 7.0%
IMG 189 Prediction: 212.353   GT: 206.282   Absolute Error: 6.071     Relative Error: 2.9%
IMG 190 Prediction: 135.144   GT: 140.479   Absolute Error: 5.335     Relative Error: 3.

IMG 270 Prediction: 102.142   GT: 94.167    Absolute Error: 7.975     Relative Error: 8.5%
IMG 271 Prediction: 46.490    GT: 47.370    Absolute Error: 0.880     Relative Error: 1.9%
IMG 272 Prediction: 211.408   GT: 194.818   Absolute Error: 16.591    Relative Error: 8.5%
IMG 273 Prediction: 29.569    GT: 33.250    Absolute Error: 3.681     Relative Error: 11.1%
IMG 274 Prediction: 445.736   GT: 403.789   Absolute Error: 41.947    Relative Error: 10.4%
IMG 275 Prediction: 173.345   GT: 169.180   Absolute Error: 4.165     Relative Error: 2.5%
IMG 276 Prediction: 57.660    GT: 59.014    Absolute Error: 1.354     Relative Error: 2.3%
IMG 277 Prediction: 49.157    GT: 49.715    Absolute Error: 0.558     Relative Error: 1.1%
IMG 278 Prediction: 242.868   GT: 228.937   Absolute Error: 13.931    Relative Error: 6.1%
IMG 279 Prediction: 165.320   GT: 172.959   Absolute Error: 7.639     Relative Error: 4.4%
IMG 280 Prediction: 114.133   GT: 107.538   Absolute Error: 6.595     Relative Error: 6.