In [1]:
%cd ..

C:\Users\Wight\PycharmProjects\ThesisMain


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

from PIL import Image
import os
import models.DeiTModels  # Need to register the models!
from timm.models import create_model
from datasets.dataset_utils import img_equal_unsplit
import importlib
import time

# Settings and Parameters
Here we define the DeiT model that we wish to evaluate and the corresponding parameters to evaluate it.

In [10]:
model_name = 'deit_small_distilled_patch16_224'  # Must be something like 'deit_small_distilled_patch16_224'.
trained_model_path = 'notebooks/other splits/03-22_15-15 SPLIT3/state_dicts/save_state_ep_870_new_best_MAE_8.186.pth'  # The path to trained model file (something like XYZ.pth)
label_factor = 10000  # The label factor used to train this specific model.
dataset = 'SHTB_DeiT'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

# Prepare for evaluation
Use the settings to load the DeiT model and dataloader for the test set. Also loads the transform with which we can restore the original images. Cuda is required!
If save_results is True, also create the directory in which the predictions are saved.

In [11]:
model = create_model(
        model_name,
        init_path=None,
        num_classes=1000,  # Not yet used anyway. Must match pretrained model!
        drop_rate=0.,
        drop_path_rate=0.1,  # TODO: What does this do?
        drop_block_rate=None,
    )

model.cuda()

resume_state = torch.load(trained_model_path)
model.load_state_dict(resume_state['net'])

model.eval()

DistilledRegressionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attenti

In [12]:
dataloader = importlib.import_module(f'datasets.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.{dataset}.settings').cfg_data

train_loader, val_loader, test_loader, restore_transform = dataloader(model.crop_size)
if set_to_eval == 'val':
    my_dataloader = val_loader
elif set_to_eval == 'test':
    my_dataloader = test_loader
else:
    print(f'Error: invalid set --> {set_to_eval}')

320 train images found.
80 val images found.
316 test images found.


In [13]:
save_path = None
if save_results:
    save_folder = 'DeiT' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

# Evaluation loop and save funtion

In [14]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [15]:
def eval_model(model, my_dataloader, show_predictions, restore_transform, label_factor, cfg_data):
    with torch.no_grad():
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors

        for idx, (img, img_patches, gt_patches) in enumerate(my_dataloader):
            img_patches = img_patches.squeeze().cuda()
            gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
            img = img.squeeze()  # Remove batch dimension
            _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction
            
            img = restore_transform(img)

            pred_den = model(img_patches)  # Precicted density crops
            pred_den = pred_den.cpu()

            # Restore GT and Prediction
            gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            

            
            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            relative_error = AEs[-1] / gt_cnt * 100
            print(f'IMG {idx:<3} '
                  f'Prediction: {pred_cnt:<9.3f} '
                  f'GT: {gt_cnt:<9.3f} '
                  f'Absolute Error: {AEs[-1]:<9.3f} '
                  f'Relative Error: {relative_error:.1f}%')
            
            if save_path:
                plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return MAE, MSE

In [16]:
MAE, MSE = eval_model(model, my_dataloader, save_path, restore_transform, label_factor, cfg_data)
print(f'MAE: {MAE:<9.3f} Root MSE: {MSE:.3f}')

IMG 0   Prediction: 22.875    GT: 22.743    Absolute Error: 0.133     Relative Error: 0.6%
IMG 1   Prediction: 166.042   GT: 175.638   Absolute Error: 9.596     Relative Error: 5.5%
IMG 2   Prediction: 136.603   GT: 156.493   Absolute Error: 19.890    Relative Error: 12.7%
IMG 3   Prediction: 34.461    GT: 36.219    Absolute Error: 1.758     Relative Error: 4.9%
IMG 4   Prediction: 63.833    GT: 68.812    Absolute Error: 4.979     Relative Error: 7.2%
IMG 5   Prediction: 59.964    GT: 56.467    Absolute Error: 3.497     Relative Error: 6.2%
IMG 6   Prediction: 41.985    GT: 43.457    Absolute Error: 1.472     Relative Error: 3.4%
IMG 7   Prediction: 224.546   GT: 220.488   Absolute Error: 4.057     Relative Error: 1.8%
IMG 8   Prediction: 176.396   GT: 163.917   Absolute Error: 12.478    Relative Error: 7.6%
IMG 9   Prediction: 506.964   GT: 471.458   Absolute Error: 35.506    Relative Error: 7.5%
IMG 10  Prediction: 141.303   GT: 137.556   Absolute Error: 3.748     Relative Error: 2.7

IMG 90  Prediction: 106.979   GT: 99.229    Absolute Error: 7.749     Relative Error: 7.8%
IMG 91  Prediction: 62.846    GT: 59.737    Absolute Error: 3.109     Relative Error: 5.2%
IMG 92  Prediction: 262.048   GT: 245.455   Absolute Error: 16.593    Relative Error: 6.8%
IMG 93  Prediction: 168.525   GT: 185.273   Absolute Error: 16.748    Relative Error: 9.0%
IMG 94  Prediction: 191.934   GT: 182.204   Absolute Error: 9.730     Relative Error: 5.3%
IMG 95  Prediction: 111.167   GT: 104.747   Absolute Error: 6.420     Relative Error: 6.1%
IMG 96  Prediction: 104.371   GT: 102.004   Absolute Error: 2.367     Relative Error: 2.3%
IMG 97  Prediction: 151.202   GT: 158.964   Absolute Error: 7.761     Relative Error: 4.9%
IMG 98  Prediction: 56.097    GT: 55.964    Absolute Error: 0.134     Relative Error: 0.2%
IMG 99  Prediction: 40.629    GT: 38.688    Absolute Error: 1.940     Relative Error: 5.0%
IMG 100 Prediction: 280.343   GT: 276.026   Absolute Error: 4.317     Relative Error: 1.6%

IMG 180 Prediction: 179.742   GT: 182.070   Absolute Error: 2.329     Relative Error: 1.3%
IMG 181 Prediction: 137.343   GT: 137.000   Absolute Error: 0.343     Relative Error: 0.3%
IMG 182 Prediction: 92.611    GT: 94.777    Absolute Error: 2.166     Relative Error: 2.3%
IMG 183 Prediction: 31.899    GT: 35.590    Absolute Error: 3.691     Relative Error: 10.4%
IMG 184 Prediction: 24.260    GT: 26.776    Absolute Error: 2.516     Relative Error: 9.4%
IMG 185 Prediction: 56.727    GT: 57.832    Absolute Error: 1.105     Relative Error: 1.9%
IMG 186 Prediction: 59.067    GT: 60.757    Absolute Error: 1.690     Relative Error: 2.8%
IMG 187 Prediction: 76.311    GT: 78.197    Absolute Error: 1.886     Relative Error: 2.4%
IMG 188 Prediction: 319.111   GT: 300.800   Absolute Error: 18.312    Relative Error: 6.1%
IMG 189 Prediction: 212.464   GT: 206.282   Absolute Error: 6.181     Relative Error: 3.0%
IMG 190 Prediction: 159.777   GT: 140.479   Absolute Error: 19.298    Relative Error: 13.

IMG 270 Prediction: 99.904    GT: 94.167    Absolute Error: 5.737     Relative Error: 6.1%
IMG 271 Prediction: 42.026    GT: 47.370    Absolute Error: 5.343     Relative Error: 11.3%
IMG 272 Prediction: 211.533   GT: 193.818   Absolute Error: 17.716    Relative Error: 9.1%
IMG 273 Prediction: 31.522    GT: 32.698    Absolute Error: 1.176     Relative Error: 3.6%
IMG 274 Prediction: 470.219   GT: 403.789   Absolute Error: 66.430    Relative Error: 16.5%
IMG 275 Prediction: 167.649   GT: 168.180   Absolute Error: 0.531     Relative Error: 0.3%
IMG 276 Prediction: 61.597    GT: 59.014    Absolute Error: 2.584     Relative Error: 4.4%
IMG 277 Prediction: 45.290    GT: 48.715    Absolute Error: 3.425     Relative Error: 7.0%
IMG 278 Prediction: 233.214   GT: 228.937   Absolute Error: 4.277     Relative Error: 1.9%
IMG 279 Prediction: 170.147   GT: 171.959   Absolute Error: 1.812     Relative Error: 1.1%
IMG 280 Prediction: 112.346   GT: 107.538   Absolute Error: 4.808     Relative Error: 4.