In [1]:
%cd ..

C:\Users\Wight\PycharmProjects\ThesisMain


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

from PIL import Image
import os
import models.DeiT.DeiTModels  # Need to register the models!
from timm.models import create_model
from datasets.dataset_utils import img_equal_unsplit
import importlib
import time

# Settings and Parameters
Here we define the DeiT model that we wish to evaluate and the corresponding parameters to evaluate it.

In [3]:
model_name = 'deit_small_distilled_patch16_224'  # Must be something like 'deit_small_distilled_patch16_224'.
trained_model_path = 'notebooks\\TL\\save_state_ep_20_new_best_MAE_5.397.pth'  # The path to trained model file (something like XYZ.pth)
label_factor = 3000  # The label factor used to train this specific model.
dataset = 'WE_DeiT'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

# Prepare for evaluation
Use the settings to load the DeiT model and dataloader for the test set. Also loads the transform with which we can restore the original images. Cuda is required!
If save_results is True, also create the directory in which the predictions are saved.

In [4]:
model = create_model(
        model_name,
        init_path=None,
        num_classes=1000,  # Not yet used anyway. Must match pretrained model!
        drop_rate=0.,
        drop_path_rate=0.,  
        drop_block_rate=None,
    )

model.cuda()

resume_state = torch.load(trained_model_path)
model.load_state_dict(resume_state['net'])

model.eval()

DistilledRegressionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attenti

In [5]:
dataloader = importlib.import_module(f'datasets.standard.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.standard.{dataset}.settings').cfg_data

train_loader, val_loader, test_loader, restore_transform = dataloader(model.crop_size)
if set_to_eval == 'val':
    my_dataloader = val_loader
elif set_to_eval == 'test':
    my_dataloader = test_loader
else:
    print(f'Error: invalid set --> {set_to_eval}')

2731 train images found.
636 val images found.
599 test images found.


In [6]:
save_path = None
if save_results:
    save_folder = 'DeiT' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

# Evaluation loop and save funtion

In [7]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [8]:
def eval_model(model, my_dataloader, show_predictions, restore_transform, label_factor, cfg_data):
    with torch.no_grad():
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors

        for idx, (img, img_patches, gt_patches) in enumerate(my_dataloader):
            img_patches = img_patches.squeeze().cuda()
            gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
            img = img.squeeze()  # Remove batch dimension
            _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction
            
            img = restore_transform(img)

            pred_den = model(img_patches)  # Precicted density crops
            pred_den = pred_den.cpu()

            # Restore GT and Prediction
            gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            
            
            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            relative_error = AEs[-1] / gt_cnt * 100
            print(f'IMG {idx:<3} '
                  f'Prediction: {pred_cnt:<9.3f} '
                  f'GT: {gt_cnt:<9.3f} '
                  f'Absolute Error: {AEs[-1]:<9.3f} '
                  f'Relative Error: {relative_error:.1f}%')
            
            if save_path:
                plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return MAE, MSE

In [9]:
MAE, MSE = eval_model(model, my_dataloader, save_path, restore_transform, label_factor, cfg_data)
print(f'MAE/MSE: {MAE:.3f}/{MSE:.3f}')

IMG 0   Prediction: 6.224     GT: 4.000     Absolute Error: 2.224     Relative Error: 55.6%
IMG 1   Prediction: 3.825     GT: 2.000     Absolute Error: 1.825     Relative Error: 91.2%
IMG 2   Prediction: 21.614    GT: 23.000    Absolute Error: 1.386     Relative Error: 6.0%
IMG 3   Prediction: 20.575    GT: 23.000    Absolute Error: 2.425     Relative Error: 10.5%
IMG 4   Prediction: 12.226    GT: 11.000    Absolute Error: 1.226     Relative Error: 11.1%
IMG 5   Prediction: 2.733     GT: 1.000     Absolute Error: 1.733     Relative Error: 173.3%
IMG 6   Prediction: 3.964     GT: 2.000     Absolute Error: 1.964     Relative Error: 98.2%
IMG 7   Prediction: 3.486     GT: 3.000     Absolute Error: 0.486     Relative Error: 16.2%
IMG 8   Prediction: 7.600     GT: 7.000     Absolute Error: 0.600     Relative Error: 8.6%
IMG 9   Prediction: 6.503     GT: 6.000     Absolute Error: 0.503     Relative Error: 8.4%
IMG 10  Prediction: 6.854     GT: 4.000     Absolute Error: 2.854     Relative Err

IMG 91  Prediction: 15.461    GT: 14.000    Absolute Error: 1.461     Relative Error: 10.4%
IMG 92  Prediction: 10.809    GT: 9.000     Absolute Error: 1.809     Relative Error: 20.1%
IMG 93  Prediction: 18.037    GT: 16.000    Absolute Error: 2.037     Relative Error: 12.7%
IMG 94  Prediction: 7.755     GT: 5.000     Absolute Error: 2.755     Relative Error: 55.1%
IMG 95  Prediction: 6.607     GT: 6.000     Absolute Error: 0.607     Relative Error: 10.1%
IMG 96  Prediction: 12.564    GT: 18.000    Absolute Error: 5.436     Relative Error: 30.2%
IMG 97  Prediction: 11.489    GT: 11.000    Absolute Error: 0.489     Relative Error: 4.4%
IMG 98  Prediction: 19.285    GT: 14.000    Absolute Error: 5.285     Relative Error: 37.7%
IMG 99  Prediction: 6.630     GT: 5.000     Absolute Error: 1.630     Relative Error: 32.6%
IMG 100 Prediction: 4.493     GT: 2.000     Absolute Error: 2.493     Relative Error: 124.7%
IMG 101 Prediction: 11.734    GT: 11.000    Absolute Error: 0.734     Relative E

IMG 182 Prediction: 180.565   GT: 162.999   Absolute Error: 17.565    Relative Error: 10.8%
IMG 183 Prediction: 150.415   GT: 143.000   Absolute Error: 7.415     Relative Error: 5.2%
IMG 184 Prediction: 147.581   GT: 140.000   Absolute Error: 7.581     Relative Error: 5.4%
IMG 185 Prediction: 169.293   GT: 161.999   Absolute Error: 7.294     Relative Error: 4.5%
IMG 186 Prediction: 164.080   GT: 178.000   Absolute Error: 13.920    Relative Error: 7.8%
IMG 187 Prediction: 153.314   GT: 169.999   Absolute Error: 16.685    Relative Error: 9.8%
IMG 188 Prediction: 133.401   GT: 160.000   Absolute Error: 26.599    Relative Error: 16.6%
IMG 189 Prediction: 132.877   GT: 157.000   Absolute Error: 24.122    Relative Error: 15.4%
IMG 190 Prediction: 131.113   GT: 178.000   Absolute Error: 46.886    Relative Error: 26.3%
IMG 191 Prediction: 149.438   GT: 185.000   Absolute Error: 35.562    Relative Error: 19.2%
IMG 192 Prediction: 167.692   GT: 214.000   Absolute Error: 46.307    Relative Error:

IMG 272 Prediction: 21.429    GT: 22.000    Absolute Error: 0.571     Relative Error: 2.6%
IMG 273 Prediction: 31.671    GT: 33.000    Absolute Error: 1.329     Relative Error: 4.0%
IMG 274 Prediction: 37.328    GT: 28.000    Absolute Error: 9.328     Relative Error: 33.3%
IMG 275 Prediction: 28.799    GT: 25.000    Absolute Error: 3.799     Relative Error: 15.2%
IMG 276 Prediction: 26.833    GT: 26.000    Absolute Error: 0.834     Relative Error: 3.2%
IMG 277 Prediction: 38.293    GT: 37.000    Absolute Error: 1.293     Relative Error: 3.5%
IMG 278 Prediction: 54.525    GT: 61.000    Absolute Error: 6.474     Relative Error: 10.6%
IMG 279 Prediction: 73.100    GT: 81.000    Absolute Error: 7.900     Relative Error: 9.8%
IMG 280 Prediction: 76.965    GT: 95.000    Absolute Error: 18.035    Relative Error: 19.0%
IMG 281 Prediction: 98.029    GT: 104.000   Absolute Error: 5.970     Relative Error: 5.7%
IMG 282 Prediction: 96.715    GT: 99.000    Absolute Error: 2.284     Relative Error: 

IMG 362 Prediction: 90.694    GT: 69.000    Absolute Error: 21.694    Relative Error: 31.4%
IMG 363 Prediction: 59.917    GT: 45.000    Absolute Error: 14.917    Relative Error: 33.1%
IMG 364 Prediction: 82.530    GT: 54.000    Absolute Error: 28.530    Relative Error: 52.8%
IMG 365 Prediction: 66.775    GT: 51.000    Absolute Error: 15.775    Relative Error: 30.9%
IMG 366 Prediction: 107.817   GT: 94.000    Absolute Error: 13.817    Relative Error: 14.7%
IMG 367 Prediction: 105.427   GT: 83.000    Absolute Error: 22.427    Relative Error: 27.0%
IMG 368 Prediction: 133.943   GT: 77.000    Absolute Error: 56.943    Relative Error: 74.0%
IMG 369 Prediction: 139.734   GT: 95.000    Absolute Error: 44.734    Relative Error: 47.1%
IMG 370 Prediction: 144.229   GT: 103.000   Absolute Error: 41.229    Relative Error: 40.0%
IMG 371 Prediction: 170.541   GT: 106.000   Absolute Error: 64.541    Relative Error: 60.9%
IMG 372 Prediction: 186.638   GT: 111.000   Absolute Error: 75.638    Relative E

IMG 451 Prediction: 120.783   GT: 85.000    Absolute Error: 35.783    Relative Error: 42.1%
IMG 452 Prediction: 72.448    GT: 60.000    Absolute Error: 12.449    Relative Error: 20.7%
IMG 453 Prediction: 100.669   GT: 74.000    Absolute Error: 26.669    Relative Error: 36.0%
IMG 454 Prediction: 126.633   GT: 89.000    Absolute Error: 37.633    Relative Error: 42.3%
IMG 455 Prediction: 63.942    GT: 50.000    Absolute Error: 13.942    Relative Error: 27.9%
IMG 456 Prediction: 104.532   GT: 80.000    Absolute Error: 24.533    Relative Error: 30.7%
IMG 457 Prediction: 77.524    GT: 63.000    Absolute Error: 14.525    Relative Error: 23.1%
IMG 458 Prediction: 101.662   GT: 70.000    Absolute Error: 31.662    Relative Error: 45.2%
IMG 459 Prediction: 115.714   GT: 77.000    Absolute Error: 38.714    Relative Error: 50.3%
IMG 460 Prediction: 98.126    GT: 80.000    Absolute Error: 18.126    Relative Error: 22.7%
IMG 461 Prediction: 144.278   GT: 104.000   Absolute Error: 40.279    Relative E

IMG 541 Prediction: 26.266    GT: 22.000    Absolute Error: 4.266     Relative Error: 19.4%
IMG 542 Prediction: 21.354    GT: 15.000    Absolute Error: 6.354     Relative Error: 42.4%
IMG 543 Prediction: 19.902    GT: 17.000    Absolute Error: 2.902     Relative Error: 17.1%
IMG 544 Prediction: 32.282    GT: 36.000    Absolute Error: 3.718     Relative Error: 10.3%
IMG 545 Prediction: 70.276    GT: 82.000    Absolute Error: 11.724    Relative Error: 14.3%
IMG 546 Prediction: 90.870    GT: 86.000    Absolute Error: 4.871     Relative Error: 5.7%
IMG 547 Prediction: 78.709    GT: 62.000    Absolute Error: 16.710    Relative Error: 27.0%
IMG 548 Prediction: 31.745    GT: 34.000    Absolute Error: 2.255     Relative Error: 6.6%
IMG 549 Prediction: 22.843    GT: 18.000    Absolute Error: 4.843     Relative Error: 26.9%
IMG 550 Prediction: 13.265    GT: 13.000    Absolute Error: 0.265     Relative Error: 2.0%
IMG 551 Prediction: 15.499    GT: 16.000    Absolute Error: 0.501     Relative Erro