In [6]:
import numpy as np 
import pandas as pd 
import os 
import torch 
import matplotlib.pyplot as plt
from Model.UNetResNet import UNetResNet

In [7]:
models = {}
device = 'cuda'

for model_path in os.listdir(r'./SavedModel/'): 
    if model_path[:3] != 'res': 
        continue
    splitted = model_path.split('_')
    arch, weight = splitted[0], splitted[4]
    models[f'{arch}_{weight}'] = UNetResNet(device=device, backbone=arch)
    models[f'{arch}_{weight}'].load_state_dict(
        torch.load(f'./SavedModel/{model_path}')
    )

In [8]:
from diodetools.TrainTest import test
from diodetools.DiodeLoader import DiodeDataLoader
from torch.utils.data import DataLoader

In [9]:
def getData(path):
    filelist = []

    for root, dirs, files in os.walk(path):
        for file in files:
            filelist.append(os.path.join(root, file))

    filelist.sort()
    data = {
        "image": [x for x in filelist if x.endswith(".png")],
        "depth": [x for x in filelist if x.endswith("_depth.npy")],
        "mask": [x for x in filelist if x.endswith("_depth_mask.npy")],
    }

    df = pd.DataFrame(data)
    return df

TEST_PATH = r'../datasets/diode/val/indoors'
df_test = getData(TEST_PATH)
max_depth = 50.
height, width = 192, 256
loader = DataLoader(
    DiodeDataLoader(
        data_frame=df_test, 
        max_depth=max_depth, 
        img_dim=(height, width), 
        depth_dim=(height, width)
        ),
    batch_size=64,
    shuffle=False, 
    num_workers=4, 
    pin_memory=True
)

In [10]:
device = 'cuda'
outputs = {}
for key in sorted(models.keys()): 
    print(key)
    arch, weight = key.split('_')
    out_loss, out_rmse, out_rel, out_acc1, out_acc2, out_acc3 = test(
        model=models[key].to(device), 
        l1_weight=float(weight), 
        loader=loader, 
        max_depth=max_depth, 
        device=device)
    outputs[key] = {
        'loss': out_loss, 
        'rmse': out_rmse.item(), 
        'rel': out_rel.item(), 
        'acc1': out_acc1.item(), 
        'acc2': out_acc2.item(), 
        'acc3': out_acc3.item()
    }
    print()

resnet101_0.1
Testing Phase [Total batch : 6]
  Batch[6/6]		Loss : 0.2660 RMSE : 1.8356 REL : 0.3385 ACC^1 : 0.5326 ACC^2 : 0.7681 ACC^3 : 0.8691

resnet101_0.3
Testing Phase [Total batch : 6]
  Batch[6/6]		Loss : 0.5319 RMSE : 1.8724 REL : 0.3516 ACC^1 : 0.5247 ACC^2 : 0.7565 ACC^3 : 0.8633

resnet101_0.5
Testing Phase [Total batch : 6]
  Batch[6/6]		Loss : 0.7968 RMSE : 1.9145 REL : 0.3444 ACC^1 : 0.5175 ACC^2 : 0.7569 ACC^3 : 0.8631

resnet101_0.7
Testing Phase [Total batch : 6]
  Batch[6/6]		Loss : 1.0314 RMSE : 1.8210 REL : 0.3246 ACC^1 : 0.5451 ACC^2 : 0.7726 ACC^3 : 0.8673

resnet101_1.0
Testing Phase [Total batch : 6]
  Batch[6/6]		Loss : 1.4269 RMSE : 1.8717 REL : 0.3351 ACC^1 : 0.5258 ACC^2 : 0.7646 ACC^3 : 0.8662

resnet152_0.1
Testing Phase [Total batch : 6]
  Batch[6/6]		Loss : 0.2617 RMSE : 1.7832 REL : 0.3245 ACC^1 : 0.5638 ACC^2 : 0.7754 ACC^3 : 0.8679

resnet152_0.3
Testing Phase [Total batch : 6]
  Batch[6/6]		Loss : 0.5009 RMSE : 1.7777 REL : 0.3286 ACC^1 : 0.5646 AC

In [11]:
print(f"ModelName \t Loss \t RMSE \t REL \t ACC1 \t ACC2 \t ACC3")
for output in sorted(outputs.keys()): 
    output_eval = outputs[output]
    str_out = f"{output}\t"
    for output_eval_key in output_eval.keys(): 
        str_out += f"{round(output_eval[output_eval_key], 4)} \t"
    
    print(str_out)

ModelName 	 Loss 	 RMSE 	 REL 	 ACC1 	 ACC2 	 ACC3
resnet101_0.1	0.266 	1.8356 	0.3385 	0.5326 	0.7681 	0.8691 	
resnet101_0.3	0.5319 	1.8724 	0.3516 	0.5247 	0.7565 	0.8633 	
resnet101_0.5	0.7968 	1.9145 	0.3444 	0.5175 	0.7569 	0.8631 	
resnet101_0.7	1.0314 	1.821 	0.3246 	0.5451 	0.7726 	0.8673 	
resnet101_1.0	1.4269 	1.8717 	0.3351 	0.5258 	0.7646 	0.8662 	
resnet152_0.1	0.2617 	1.7832 	0.3245 	0.5638 	0.7754 	0.8679 	
resnet152_0.3	0.5009 	1.7777 	0.3286 	0.5646 	0.774 	0.8712 	
resnet152_0.5	0.7536 	1.7821 	0.3269 	0.5733 	0.7738 	0.8665 	
resnet152_0.7	0.9919 	1.7673 	0.3281 	0.5738 	0.7718 	0.8671 	
resnet152_1.0	1.355 	1.7393 	0.3281 	0.5716 	0.7712 	0.8678 	
resnet18_0.1	0.2841 	2.0689 	0.3721 	0.4887 	0.728 	0.8322 	
resnet18_0.3	0.5661 	2.0057 	0.3903 	0.4819 	0.7253 	0.8344 	
resnet18_0.5	0.8447 	1.996 	0.3766 	0.4923 	0.7246 	0.8331 	
resnet18_0.7	1.1393 	2.0343 	0.3798 	0.4831 	0.7206 	0.8285 	
resnet18_1.0	1.5823 	2.036 	0.3821 	0.4744 	0.72 	0.8313 	
resnet34_0.1	0.292