In [None]:
%cd ../..

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from itertools import product


from PIL import Image
import os
import models.ViCCT.ViCCTModels  # Need to register the models!
from timm.models import create_model
from datasets.dataset_utils import img_equal_unsplit
import importlib
import time

In [None]:
model_name = 'ViCCT_small'  # Must be something like 'ViCCT_small'.
trained_model_path = 'notebooks\\TL\\save_state_ep_250_new_best_MAE_4.971.pth'  # The path to trained model file (something like XYZ.pth)
label_factor = 3000  # The label factor used to train this specific model.
dataset = 'WE_ViCCT_Meta'  # Must be the exact name of the dataset
save_results = False  # When true, save the images, GTs and predictions. A folder for this is created automatically.
set_to_eval = 'test'  # val', 'test'. Which split to test the model on. 'train' does not work!

all_adapt_lrs = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]  # The learning rates which to use for 
# all_adapt_imgs = [
#     [
#         ['104207_1-04-S20100821071000000E20100821120000000_007550.jpg'], 
#         ['200608_C08-02-S20100626083000000E20100626233000000_clip1_021050.jpg'], 
#         ['200702_C09-01-S20100717083000000E20100717233000000_007550.jpg'],
#         ['202201_1-01-S20100922060000000E20100922235959000_clip1_004550.jpg'],
#         ['500717_D11-03-S20100717083000000E20100717233000000_016550.jpg']
#     ],
#     [
#         ['104207_1-04-S20100821071000000E20100821120000000_172550.jpg'],
#         ['200608_C08-02-S20100626083000000E20100626233000000_clip1_061550.jpg'],
#         ['200702_C09-01-S20100717083000000E20100717233000000_141050.jpg'],
#         ['202201_1-01-S20100922060000000E20100922235959000_clip1_023450.jpg'],
#         ['500717_D11-03-S20100717083000000E20100717233000000_142550.jpg']
#     ],
#     [
#         ['104207_1-04-S20100821071000000E20100821120000000_078050.jpg'],
#         ['200608_C08-02-S20100626083000000E20100626233000000_clip1_159050.jpg'],
#         ['200702_C09-01-S20100717083000000E20100717233000000_091550.jpg'],
#         ['202201_1-01-S20100922060000000E20100922235959000_clip1_064850.jpg'],
#         ['500717_D11-03-S20100717083000000E20100717233000000_144050.jpg']
#     ]
# ]

adapt_imgs = [
    [
        '104207_1-04-S20100821071000000E20100821120000000_007550.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_090050.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_172550.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_069050.jpg',
        '104207_1-04-S20100821071000000E20100821120000000_078050.jpg'
        
    ],
    [
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_021050.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_169550.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_061550.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_091550.jpg',
        '200608_C08-02-S20100626083000000E20100626233000000_clip1_159050.jpg'
        
    ],
    [
        '200702_C09-01-S20100717083000000E20100717233000000_007550.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_141050.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_003050.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_076550.jpg',
        '200702_C09-01-S20100717083000000E20100717233000000_091550.jpg'
        
    ],
    [
        '202201_1-01-S20100922060000000E20100922235959000_clip1_004550.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_041450.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_023450.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_011750.jpg',
        '202201_1-01-S20100922060000000E20100922235959000_clip1_064850.jpg'
    ],
    [
        '500717_D11-03-S20100717083000000E20100717233000000_016550.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_057050.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_142550.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_102050.jpg',
        '500717_D11-03-S20100717083000000E20100717233000000_144050.jpg'
    ]
]

In [None]:
loss_fn = torch.nn.MSELoss()

In [None]:
dataloader = importlib.import_module(f'datasets.meta.{dataset}.loading_data').loading_data
cfg_data = importlib.import_module(f'datasets.meta.{dataset}.settings').cfg_data

train_loaders, val_loaders, test_loaders, restore_transform = dataloader(adapt_imgs)
if set_to_eval == 'val':
    my_dataloaders = val_loaders
elif set_to_eval == 'test':
    my_dataloaders = test_loaders
else:
    print(f'Error: invalid set --> {set_to_eval}')

In [None]:
save_path = None
if save_results:
    save_folder = 'ViCCT' + '_' + dataset + '_' + set_to_eval + '_' + time.strftime("%m-%d_%H-%M", time.localtime())
    save_path = os.path.join('notebooks', save_folder)  # Manually change here is you want to save somewhere else
    os.mkdir(save_path)

In [None]:
def plot_and_save_results(save_path, img, img_idx, gt, prediction, pred_cnt, gt_cnt):
    img_save_path = os.path.join(save_path, f'IMG_{img_idx}_AE_{abs(pred_cnt - gt_cnt):.3f}.jpg')
    
    plt.figure()
    f, axarr = plt.subplots(1, 3, figsize=(13, 13))
    axarr[0].imshow(img)
    axarr[1].imshow(gt, cmap=cm.jet)
    axarr[1].title.set_text(f'GT count: {gt_cnt:.3f}')
    axarr[2].imshow(prediction, cmap=cm.jet)
    axarr[2].title.set_text(f'predicted count: {pred_cnt:.3f}')
    plt.tight_layout()
    plt.savefig(img_save_path)
    plt.close('all')

In [None]:
def load_model_and_optim(adapt_lr):
    model = create_model(
            model_name,
            init_path=None,
            num_classes=1000,  # Not yet used anyway. Must match pretrained model!
            drop_rate=0.,
            drop_path_rate=0.,  
            drop_block_rate=None,
        )

    model.cuda()

    resume_state = torch.load(trained_model_path)
    model.load_state_dict(resume_state['net'])
    
    optim = torch.optim.SGD(model.parameters(), lr=adapt_lr)
    
    return model, optim

In [None]:
def eval_on_scene(model, scene_dataloader):
    model.eval()
    
    with torch.no_grad():
        preds = []
        gts = []
        AEs = []  # Absolute Errors
        SEs = []  # Squared Errors

        for idx, (img, img_patches, gt_patches) in enumerate(scene_dataloader):
            img_patches = img_patches.squeeze().cuda()
            gt_patches = gt_patches.squeeze().unsqueeze(1)  # Remove batch dim, insert channel dim
            img = img.squeeze()  # Remove batch dimension
            _, img_h, img_w = img.shape  # Obtain image dimensions. Used to reconstruct GT and Prediction
            
#             img = restore_transform(img)

            pred_den = model(img_patches)  # Precicted density crops
            pred_den = pred_den.cpu()

            # Restore GT and Prediction
            gt = img_equal_unsplit(gt_patches, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            den = img_equal_unsplit(pred_den, cfg_data.OVERLAP, cfg_data.IGNORE_BUFFER, img_h, img_w, 1)
            gt = gt.squeeze()  # Remove channel dim
            den = den.squeeze()  # Remove channel dim
            
            
            pred_cnt = den.sum() / label_factor
            gt_cnt = gt.sum() / cfg_data.LABEL_FACTOR
            
            preds.append(pred_cnt)
            gts.append(gt_cnt)
            AEs.append(torch.abs(pred_cnt - gt_cnt).item())
            SEs.append(torch.square(pred_cnt - gt_cnt).item())
            relative_error = AEs[-1] / gt_cnt * 100
#             print(f'IMG {idx:<3} '
#                   f'Prediction: {pred_cnt:<9.3f} '
#                   f'GT: {gt_cnt:<9.3f} '
#                   f'Absolute Error: {AEs[-1]:<9.3f} '
#                   f'Relative Error: {relative_error:.1f}%')
            
#             if save_path:
#                 plot_and_save_results(save_path, img, idx, gt, den, pred_cnt, gt_cnt)
            
        MAE = np.mean(AEs)
        MSE = np.sqrt(np.mean(SEs))

    return preds, gts, MAE, MSE

In [None]:
def adapt_to_scene(model, scene_dataloader, optim):
    model.train()
    
    img_stack, gt_stack = scene_dataloader.dataset.get_adapt_batch()
    img_stack, gt_stack = img_stack.squeeze(0).cuda(), gt_stack.squeeze(0).cuda()

    optim.zero_grad()
    pred_stack = model.forward(img_stack)
    loss = loss_fn(pred_stack, gt_stack)
    loss.backward()
    optim.step()
    
    return model

In [None]:
model, optim = load_model_and_optim(1.)  # Learning rate is not used when not adapting
scene_dataloaders, restore_transform, cfg_data = get_dataloaders(None)
for idx, scene_dataloader in enumerate(scene_dataloaders):
    print(f'scene {idx + 1}')
    preds_before, gts, MAE_before, MSE_before = eval_on_scene(model, scene_dataloader)
    print(f'  No adapt MAE: {MAE_before:.3f}, MSE: {MSE_before:.3f}')
  

In [None]:
# all_adapt_params = product(all_adapt_imgs, all_adapt_lrs)
# for adapt_imgs, adapt_lr in all_adapt_params:
#     print(f'lr={adapt_lr}')
    
#     scene_dataloaders, restore_transform, cfg_data = get_dataloaders(adapt_imgs)
#     for idx, scene_dataloader in enumerate(scene_dataloaders):
#         print(f'  scene {idx + 1}')
#         model, optim = load_model_and_optim(adapt_lr)

#         model = adapt_to_scene(model, scene_dataloader, optim)

#         preds_after, gts, MAE_after, MSE_after = eval_on_scene(model, scene_dataloader)
#         print(f'    After adapt -->  MAE: {MAE_after:.3f}, MSE: {MSE_after:.3f}')    

In [None]:

for scene_idx in range(5):
    print(f'Scene {scene_idx + 1}')
    for idx, adapt_lr in enumerate(all_adapt_lrs):
        print(f'  lr={adapt_lr}')

        scene_dataloader = my_dataloaders[scene_idx]
        model, optim = load_model_and_optim(adapt_lr)

        model = adapt_to_scene(model, scene_dataloader, optim)

        preds_after, gts, MAE_after, MSE_after = eval_on_scene(model, scene_dataloader)
        print(f'    After adapt MAE/MSE: {MAE_after:.3f}/{MSE_after:.3f}')  

In [None]:
# all_adapt_params = product(all_adapt_imgs, all_adapt_lrs)
# for idx, (adapt_imgs, adapt_lr) in enumerate(all_adapt_params):
#     print(f'adapt lr: {adapt_lr}')
#     for adapt_img in adapt_imgs:
#         print(adapt_img)
#     print()