In [None]:
%cd ../..

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

from PIL import Image
import os
import models.ViCCT_models  # Need to register the models!
from models.Swin_ViCCT_models import Swin_ViCCT_large_22k
from timm.models import create_model
from datasets.dataset_utils import img_equal_unsplit
import importlib
import time

In [None]:
model_name = 'ViCCT_base'  # Must be something like 'deit_small_distilled_patch16_224'.
# model_path = 'D:\\OneDrive\\OneDrive - UvA\\ThesisData\\trained_models\\SWIN generic\\save_state_ep_1600.pth'
model_path = 'D:\\OneDrive\\OneDrive - UvA\\ThesisData\\trained_models\\ViCCT base most public\\save_state_ep_1300.pth'
label_factor = 3000  # The label factor used to train this specific model.
dataset = 'Generic_ViCCT'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

In [None]:
model = create_model(  # From the timm library. This function created the model specific architecture.
    model_name,
    init_path=model_path,
    pretrained_cc=True,
    drop_rate=None if 'Swin' in model_name else 0.,  # Dropout

    # Bamboozled by Facebook. This isn't drop_path_rate, but rather 'drop_connect'.
    # Not yet sure what it is for the Swin version
    drop_path_rate=None if 'Swin' in model_name else 0.,
    drop_block_rate=None,  # Drops our entire Transformer blocks I think? Not used for ViCCT.
)
model = model.eval()
model = model.cuda()

In [None]:
dataloader = importlib.import_module(f'datasets.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.{dataset}.settings').cfg_data

train_loader, val_loader, test_loader, restore_transform = dataloader(model.crop_size)
if set_to_eval == 'val':
    my_dataloader = val_loader
elif set_to_eval == 'test':
    my_dataloader = test_loader
else:
    print(f'Error: invalid set --> {set_to_eval}')

In [None]:
save_path = None
if save_results:
    save_folder = 'DeiT' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

In [None]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [None]:
def eval_model(model, my_dataloader, show_predictions, restore_transform, label_factor, cfg_data):
    loss_fn = torch.nn.MSELoss(reduction='none')
    
    AEs = []  # Absolute Errors
    SEs = []  # Squared Errors
    GTs = []
    preds = []
    
    crop_losses = []  # The loss of just the crops before recombining
    whole_img_losses = []  # The loss of the image after the crops are combined
    with torch.no_grad():
        for idx, (img, img_patches, gt_patches) in enumerate(my_dataloader):

            img_patches = img_patches.squeeze().cuda()
            gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
            img = img.squeeze()  # Remove batch dimension
            _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction
            
            img = restore_transform(img)

            pred_stack = torch.zeros(img_patches.shape[0], 1, 224, 224)

            
            for idx2, img_crop in enumerate(img_patches):
                pred_stack[idx2] = model.forward(img_crop.unsqueeze(0))
#             pred_den = model(img_patches)  # Precicted density crops
            pred_den = pred_stack.cpu()
            
            crop_loss = loss_fn(pred_den, gt_patches)
            crop_loss = crop_loss.mean((-2, -1))
            crop_losses.extend(crop_loss.tolist())

            # Restore GT and Prediction
            gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            
            whole_img_loss = loss_fn(den, gt)
            whole_img_loss = whole_img_loss.mean((-2, -1))  # Only has 2 dims though
            whole_img_losses.append(whole_img_loss.item())
            

            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            GTs.append(gt_cnt.item())
            preds.append(pred_cnt.item())
            relative_error = AEs[-1] / gt_cnt * 100
            print(f'IMG {idx:<3} '
                  f'Prediction: {pred_cnt:<9.3f} '
                  f'GT: {gt_cnt:<9.3f} '
                  f'Absolute Error: {AEs[-1]:<9.3f} '
                  f'Relative Error: {relative_error:.1f}%')
            
            if save_path:
                plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))
        Mean_crop_loss = np.mean(crop_losses)
        Mean_whole_img_loss = np.mean(whole_img_losses)

    return MAE, MSE, Mean_crop_loss, Mean_whole_img_loss, GTs, preds

In [None]:
MAE, MSE, Mean_crop_loss, Mean_whole_img_loss, GTs, preds = eval_model(model, my_dataloader, save_path, restore_transform, label_factor, cfg_data)
print(f'MAE/MSE: {MAE}/{MSE}, Mean crop loss: {Mean_crop_loss:.3f}, Mean whole image loss: {Mean_whole_img_loss:.3f}.')

In [None]:
img_nrs = np.arange(len(GTs))
sorted_idxs = np.argsort(GTs)
GTs = np.array(GTs)
preds = np.array(preds)

plt.rcParams.update({'font.size': 14})

plt.figure()
plt.plot(img_nrs, GTs[sorted_idxs], label='Ground truths')
plt.plot(img_nrs, preds[sorted_idxs], label='Predictions')
plt.ylabel('Crowd count')
plt.xlabel('Sorted image')
plt.legend()
plt.tight_layout()
# plt.savefig(f'DeiT_{dataset}_pred_vs_gt.jpg')
plt.show()

In [None]:
sorted_error_idxs = np.flip(np.argsort(np.abs(GTs - preds)))

with torch.no_grad():
    for idx in sorted_error_idxs[:10]:
        img, img_patches, gt_patches = my_dataloader.dataset.__getitem__(idx)

        img_patches = img_patches.cuda()
        gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
        img = img.squeeze()  # Remove batch dimension
        _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction

        img = restore_transform(img)

        pred_stack = torch.zeros(img_patches.shape[0], 1, 224, 224)


        for idx2, img_crop in enumerate(img_patches):
            pred_stack[idx2] = model.forward(img_crop.unsqueeze(0))
#             pred_den = model(img_patches)  # Precicted density crops
#             pred_den = pred_den.cpu()
        pred_den = pred_stack.cpu()
        
        # Restore GT and Prediction
        gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
        den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
        gt = gt.squeeze()  # Remove channel dim
        den = den.squeeze()  # Remove channel dim


        pred_cnt = den.sum() / label_factor
        gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR

        print(f'IMG {idx}, pred: {pred_cnt:.3f}, gt: {gt_cnt:.3f}. Error: {pred_cnt - gt_cnt:.3f}')
        
        plt.figure(figsize=(15, 15))
        plt.imshow(np.asarray(img))
#         plt.savefig(f'DeiT_IMG_{idx + 1}_{dataset}.jpg')
        plt.show()
        
        plt.figure(figsize=(15, 15))
        plt.imshow(gt.numpy(), cmap=cm.jet)
        plt.title(f'GT count: {gt_cnt:.3f}')
#         plt.savefig(f'DeiT_IMG_{idx + 1}_{dataset}_prediction.jpg')
        plt.show()
        
        plt.figure(figsize=(15, 15))
        plt.imshow(den.numpy(), cmap=cm.jet)
        plt.title(f'Predicted count: {pred_cnt:.3f}')
#         plt.savefig(f'DeiT_IMG_{idx + 1}_{dataset}_prediction.jpg')
        plt.show()

In [None]:
sorted_good_idxs = np.argsort(np.abs(GTs - preds))

with torch.no_grad():
    for idx in sorted_good_idxs[:20]:
        img, img_patches, gt_patches = my_dataloader.dataset.__getitem__(idx)

        img_patches = img_patches.cuda()
        gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
        img = img.squeeze()  # Remove batch dimension
        _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction

        img = restore_transform(img)

        pred_stack = torch.zeros(img_patches.shape[0], 1, 224, 224)


        for idx2, img_crop in enumerate(img_patches):
            pred_stack[idx2] = model.forward(img_crop.unsqueeze(0))
#             pred_den = model(img_patches)  # Precicted density crops
#             pred_den = pred_den.cpu()
        pred_den = pred_stack.cpu()
        
        # Restore GT and Prediction
        gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
        den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
        gt = gt.squeeze()  # Remove channel dim
        den = den.squeeze()  # Remove channel dim


        pred_cnt = den.sum() / label_factor
        gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR

        print(f'IMG {idx}, pred: {pred_cnt:.3f}, gt: {gt_cnt:.3f}. Error: {pred_cnt - gt_cnt:.3f}')
        
        plt.figure(figsize=(15, 15))
        plt.imshow(np.asarray(img))
#         plt.savefig(f'DeiT_IMG_{idx + 1}_{dataset}.jpg')
        plt.show()
        
        plt.figure(figsize=(15, 15))
        plt.imshow(gt.numpy(), cmap=cm.jet)
        plt.title(f'GT count: {gt_cnt:.3f}')
#         plt.savefig(f'DeiT_IMG_{idx + 1}_{dataset}_prediction.jpg')
        plt.show()
        
        plt.figure(figsize=(15, 15))
        plt.imshow(den.numpy(), cmap=cm.jet)
        plt.title(f'Predicted count: {pred_cnt:.3f}')
#         plt.savefig(f'DeiT_IMG_{idx + 1}_{dataset}_prediction.jpg')
        plt.show()

In [None]:
# model.cpu()
# sd = model.state_dict()
# save_d = {'state_dict': sd}
# torch.save(save_d, '40_45_adapted.pth')