In [None]:
%cd ../..

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from itertools import product

from PIL import Image
import os
from models.CSRNet.CSRNet import CSRNet
from models.CSRNet.CSRNet_functional import CSRNet_functional

import importlib
import time

In [None]:
trained_model_path = 'notebooks\\TL\\save_state_ep_560_new_best_MAE_7.454.pth'  # The path to trained model file (something like XYZ.pth)
label_factor = 100  # The label factor used to train this specific model.
dataset = 'WE_CSRNet_Meta'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

all_adapt_lrs = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]  # The learning rates which to use for 
# all_adapt_imgs = [
#     [
#         ['104207_1-04-S20100821071000000E20100821120000000_007550.jpg'], 
#         ['200608_C08-02-S20100626083000000E20100626233000000_clip1_021050.jpg'], 
#         ['200702_C09-01-S20100717083000000E20100717233000000_007550.jpg'],
#         ['202201_1-01-S20100922060000000E20100922235959000_clip1_004550.jpg'],
#         ['500717_D11-03-S20100717083000000E20100717233000000_016550.jpg']
#     ],
#     [
#         ['104207_1-04-S20100821071000000E20100821120000000_172550.jpg'],
#         ['200608_C08-02-S20100626083000000E20100626233000000_clip1_061550.jpg'],
#         ['200702_C09-01-S20100717083000000E20100717233000000_141050.jpg'],
#         ['202201_1-01-S20100922060000000E20100922235959000_clip1_023450.jpg'],
#         ['500717_D11-03-S20100717083000000E20100717233000000_142550.jpg']
#     ],
#     [
#         ['104207_1-04-S20100821071000000E20100821120000000_078050.jpg'],
#         ['200608_C08-02-S20100626083000000E20100626233000000_clip1_159050.jpg'],
#         ['200702_C09-01-S20100717083000000E20100717233000000_091550.jpg'],
#         ['202201_1-01-S20100922060000000E20100922235959000_clip1_064850.jpg'],
#         ['500717_D11-03-S20100717083000000E20100717233000000_144050.jpg']
#     ]
# ]

adapt_imgs = [
    [
        '104207_1-04-S20100821071000000E20100821120000000_007550.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_090050.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_172550.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_069050.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_078050.jpg'
        
    ],
    [
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_021050.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_169550.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_061550.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_091550.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_159050.jpg'
        
    ],
    [
        '200702_C09-01-S20100717083000000E20100717233000000_007550.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_141050.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_003050.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_076550.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_091550.jpg'
        
    ],
    [
        '202201_1-01-S20100922060000000E20100922235959000_clip1_004550.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_041450.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_023450.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_011750.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_064850.jpg'
    ],
    [
        '500717_D11-03-S20100717083000000E20100717233000000_016550.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_057050.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_142550.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_102050.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_144050.jpg'
    ]
]

In [None]:
loss_fn = torch.nn.MSELoss()

In [None]:
dataloader = importlib.import_module(f'datasets.meta.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.meta.{dataset}.settings').cfg_data

train_loaders, val_loaders, test_loaders, restore_transform = dataloader(adapt_imgs)
if set_to_eval == 'val':
    my_dataloaders = val_loaders
elif set_to_eval == 'test':
    my_dataloaders = test_loaders
else:
    print(f'Error: invalid set --> {set_to_eval}')

In [None]:
def load_model_and_optim(adapt_lr):
    model = CSRNet()

    resume_state = torch.load(trained_model_path)

    # new_dict = {}
    # for k, v in resume_state.items():
    #     k = k[4:]
    #     new_dict[k] = v
    # model.load_state_dict(new_dict)

    model.load_state_dict(resume_state['net'])

    model.cuda()
    model.eval()
    
    optim = torch.optim.SGD(model.parameters(), lr=adapt_lr)
    
    return model, optim

In [None]:
save_path = None
if save_results:
    save_folder = 'CSRNet' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

In [None]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [None]:
def eval_on_scene(model, scene_dataloader):
    model.eval()
    with torch.no_grad():
        preds = []
        gts = []
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors

        for idx, (img, gt) in enumerate(scene_dataloader):
            img = img.cuda()
           
            den = model(img)  # Precicted density crops
            den = den.cpu()

            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            
#             img = restore_transform(img.squeeze())  # Original image
            pred_cnt = den.sum() / cfg_data.LABEL_FACTOR
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            preds.append(pred_cnt.item())
            gts.append(gt_cnt.item())
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            relative_error = AEs[-1] / gt_cnt * 100
#             print(f'IMG {idx:<3} '
#                   f'Prediction: {pred_cnt:<9.3f} '
#                   f'GT: {gt_cnt:<9.3f} '
#                   f'Absolute Error: {AEs[-1]:<9.3f} '
#                   f'Relative Error: {relative_error:.1f}%')
            
#             if save_path:
#                 plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return preds, gts, MAE, MSE

In [None]:
def adapt_to_scene(model, scene_dataloader, optim):
    model.train()
    
    imgs, gts = scene_dataloader.dataset.get_adapt_batch()
    imgs, gts = imgs.cuda(), gts.cuda()

    optim.zero_grad()
    preds = model.forward(imgs)
    preds = preds.squeeze(1) # remove channel dim
    loss = loss_fn(preds, gts)
    loss.backward()
    optim.step()
    
    return model

In [None]:
model, optim = load_model_and_optim(1.)  # Learning rate is not used when not adapting
for idx, scene_dataloader in enumerate(my_dataloaders):
    print(f'scene {idx + 1}')
    preds_before, gts, MAE_before, MSE_before = eval_on_scene(model, scene_dataloader)
    print(f'  No adapt MAE: {MAE_before:.3f}, MSE: {MSE_before:.3f}')
  

In [None]:
for scene_idx in range(5):
    print(f'Scene {scene_idx + 1}')
    for idx, adapt_lr in enumerate(all_adapt_lrs):
        print(f'  lr={adapt_lr}')

        scene_dataloader = my_dataloaders[scene_idx]
        model, optim = load_model_and_optim(adapt_lr)

        model = adapt_to_scene(model, scene_dataloader, optim)

        preds_after, gts, MAE_after, MSE_after = eval_on_scene(model, scene_dataloader)
        print(f'    After adapt MAE/MSE: {MAE_after:.3f}/{MSE_after:.3f}')  