In [1]:
%cd ..

C:\Users\Wight\PycharmProjects\ThesisMain


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

from PIL import Image
import os
import models.DeiTModels  # Need to register the models!
from timm.models import create_model
from datasets.dataset_utils import img_equal_unsplit
import importlib
import time

# Settings and Parameters
Here we define the DeiT model that we wish to evaluate and the corresponding parameters to evaluate it.

In [3]:
model_name = 'deit_small_distilled_patch16_224'  # Must be something like 'deit_small_distilled_patch16_224'.
trained_model_path = 'notebooks/added_relu/LeakyReLU/state_dicts/save_state_ep_800_new_best_MAE_7.909.pth'  # The path to trained model file (something like XYZ.pth)
label_factor = 10000  # The label factor used to train this specific model.
dataset = 'SHTB_DeiT'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

# Prepare for evaluation
Use the settings to load the DeiT model and dataloader for the test set. Also loads the transform with which we can restore the original images. Cuda is required!
If save_results is True, also create the directory in which the predictions are saved.

In [4]:
model = create_model(
        model_name,
        init_path=None,
        num_classes=1000,  # Not yet used anyway. Must match pretrained model!
        drop_rate=0.,
        drop_path_rate=0.1,  # TODO: What does this do?
        drop_block_rate=None,
    )

model.cuda()

resume_state = torch.load(trained_model_path)
model.load_state_dict(resume_state['net'])

model.eval()

DistilledRegressionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attenti

In [5]:
dataloader = importlib.import_module(f'datasets.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.{dataset}.settings').cfg_data

train_loader, val_loader, test_loader, restore_transform = dataloader(model.crop_size)
if set_to_eval == 'val':
    my_dataloader = val_loader
elif set_to_eval == 'test':
    my_dataloader = test_loader
else:
    print(f'Error: invalid set --> {set_to_eval}')

320 train images found.
80 val images found.
316 test images found.


In [6]:
save_path = None
if save_results:
    save_folder = 'DeiT' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

# Evaluation loop and save funtion

In [7]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [8]:
def eval_model(model, my_dataloader, show_predictions, restore_transform, label_factor, cfg_data):
    with torch.no_grad():
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors

        for idx, (img, img_patches, gt_patches) in enumerate(my_dataloader):
            img_patches = img_patches.squeeze().cuda()
            gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
            img = img.squeeze()  # Remove batch dimension
            _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction
            
            img = restore_transform(img)

            pred_den = model(img_patches)  # Precicted density crops
            pred_den = pred_den.cpu()

            # Restore GT and Prediction
            gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            

            
            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            relative_error = AEs[-1] / gt_cnt * 100
            print(f'IMG {idx:<3} '
                  f'Prediction: {pred_cnt:<9.3f} '
                  f'GT: {gt_cnt:<9.3f} '
                  f'Absolute Error: {AEs[-1]:<9.3f} '
                  f'Relative Error: {relative_error:.1f}%')
            
            if save_path:
                plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return MAE, MSE

In [9]:
MAE, MSE = eval_model(model, my_dataloader, save_path, restore_transform, label_factor, cfg_data)
print(f'MAE: {MAE:<9.3f} Root MSE: {MSE:.3f}')

IMG 0   Prediction: 22.336    GT: 22.743    Absolute Error: 0.407     Relative Error: 1.8%
IMG 1   Prediction: 166.549   GT: 175.638   Absolute Error: 9.089     Relative Error: 5.2%
IMG 2   Prediction: 130.153   GT: 156.493   Absolute Error: 26.341    Relative Error: 16.8%
IMG 3   Prediction: 32.463    GT: 36.219    Absolute Error: 3.756     Relative Error: 10.4%
IMG 4   Prediction: 62.371    GT: 68.812    Absolute Error: 6.441     Relative Error: 9.4%
IMG 5   Prediction: 57.267    GT: 56.467    Absolute Error: 0.800     Relative Error: 1.4%
IMG 6   Prediction: 39.601    GT: 43.457    Absolute Error: 3.856     Relative Error: 8.9%
IMG 7   Prediction: 220.054   GT: 220.488   Absolute Error: 0.434     Relative Error: 0.2%
IMG 8   Prediction: 174.925   GT: 163.917   Absolute Error: 11.008    Relative Error: 6.7%
IMG 9   Prediction: 480.522   GT: 471.458   Absolute Error: 9.064     Relative Error: 1.9%
IMG 10  Prediction: 134.937   GT: 137.556   Absolute Error: 2.619     Relative Error: 1.

IMG 90  Prediction: 106.891   GT: 99.229    Absolute Error: 7.662     Relative Error: 7.7%
IMG 91  Prediction: 58.941    GT: 59.737    Absolute Error: 0.795     Relative Error: 1.3%
IMG 92  Prediction: 244.716   GT: 245.455   Absolute Error: 0.738     Relative Error: 0.3%
IMG 93  Prediction: 169.411   GT: 185.273   Absolute Error: 15.862    Relative Error: 8.6%
IMG 94  Prediction: 182.916   GT: 182.204   Absolute Error: 0.711     Relative Error: 0.4%
IMG 95  Prediction: 106.556   GT: 104.747   Absolute Error: 1.809     Relative Error: 1.7%
IMG 96  Prediction: 99.095    GT: 102.004   Absolute Error: 2.908     Relative Error: 2.9%
IMG 97  Prediction: 155.536   GT: 158.964   Absolute Error: 3.428     Relative Error: 2.2%
IMG 98  Prediction: 55.305    GT: 55.964    Absolute Error: 0.659     Relative Error: 1.2%
IMG 99  Prediction: 38.713    GT: 38.688    Absolute Error: 0.024     Relative Error: 0.1%
IMG 100 Prediction: 276.704   GT: 276.026   Absolute Error: 0.678     Relative Error: 0.2%

IMG 180 Prediction: 190.762   GT: 182.070   Absolute Error: 8.691     Relative Error: 4.8%
IMG 181 Prediction: 128.435   GT: 137.000   Absolute Error: 8.565     Relative Error: 6.3%
IMG 182 Prediction: 87.748    GT: 94.777    Absolute Error: 7.029     Relative Error: 7.4%
IMG 183 Prediction: 31.150    GT: 35.590    Absolute Error: 4.440     Relative Error: 12.5%
IMG 184 Prediction: 24.371    GT: 26.776    Absolute Error: 2.405     Relative Error: 9.0%
IMG 185 Prediction: 52.672    GT: 57.832    Absolute Error: 5.160     Relative Error: 8.9%
IMG 186 Prediction: 60.050    GT: 60.757    Absolute Error: 0.706     Relative Error: 1.2%
IMG 187 Prediction: 74.871    GT: 78.197    Absolute Error: 3.326     Relative Error: 4.3%
IMG 188 Prediction: 316.110   GT: 300.800   Absolute Error: 15.310    Relative Error: 5.1%
IMG 189 Prediction: 215.713   GT: 206.282   Absolute Error: 9.431     Relative Error: 4.6%
IMG 190 Prediction: 131.373   GT: 140.479   Absolute Error: 9.106     Relative Error: 6.5

IMG 270 Prediction: 101.338   GT: 94.167    Absolute Error: 7.171     Relative Error: 7.6%
IMG 271 Prediction: 40.181    GT: 47.370    Absolute Error: 7.189     Relative Error: 15.2%
IMG 272 Prediction: 195.119   GT: 193.818   Absolute Error: 1.301     Relative Error: 0.7%
IMG 273 Prediction: 28.274    GT: 32.698    Absolute Error: 4.424     Relative Error: 13.5%
IMG 274 Prediction: 452.635   GT: 403.789   Absolute Error: 48.846    Relative Error: 12.1%
IMG 275 Prediction: 164.682   GT: 168.180   Absolute Error: 3.498     Relative Error: 2.1%
IMG 276 Prediction: 58.009    GT: 59.014    Absolute Error: 1.005     Relative Error: 1.7%
IMG 277 Prediction: 44.164    GT: 48.715    Absolute Error: 4.551     Relative Error: 9.3%
IMG 278 Prediction: 230.516   GT: 228.937   Absolute Error: 1.579     Relative Error: 0.7%
IMG 279 Prediction: 158.611   GT: 171.959   Absolute Error: 13.348    Relative Error: 7.8%
IMG 280 Prediction: 116.368   GT: 107.538   Absolute Error: 8.831     Relative Error: 8