In [2]:
import numpy as np
import torch
from scipy import stats
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import monoSimDataset
from model.quality_model import EmbeddingNet, MobileNetV3_wtA, MobileNetV3_wA, MobileNetV3_Lite, MobileNetV2_Lite

In [10]:
dataset_path = '/home/dl/wangleyuan/dataset/SD-Eyes'
cp_path = "checkpoints/1203_202301_MobileNetV2_Lite/421_1.3395e-03.pth"
model_name = 'MobileNetV2_Lite'
device = "cuda:2"

In [6]:
if torch.cuda.is_available() and device is not None:
        device = torch.device(device)
else:
    if not torch.cuda.is_available():
        print("hey man, buy a GPU!")
    device = torch.device("cpu")

In [7]:
test_data = monoSimDataset(path=dataset_path, mode='test', seed=3141, debug_data=False, upsample=True)
test_data_loader = DataLoader(test_data, 1, shuffle=False, drop_last=True)

In [11]:
if model_name == 'MobileNetV3_wtA':
    model = MobileNetV3_wtA()
elif model_name == 'MobileNetV3_wA':
    model = MobileNetV3_wA(0.5)
elif model_name == 'MobileNetV3_Lite':
    model = MobileNetV3_Lite(True)
elif model_name == 'MobileNetV2_Lite':
    model = MobileNetV2_Lite(True, True, 0.5)
elif model_name == 'EmbeddingNet':
    model = EmbeddingNet(False)
else:
    model = None
assert model is not None
model.to(device)
if cp_path:
    cp_data = torch.load(cp_path, map_location=device)
    try:
        model.load_state_dict(cp_data['model'])
    except Exception as e:
        model.load_state_dict(cp_data['model'], strict=False)
        print(e)

In [None]:
if model_name.split('_')[0] == 'MobileNetV3':
    model.train()
else:
    model.eval()
with torch.no_grad():
    test_pred_epoch_loss = 0
    test_mask_epoch_loss = 0
    scores = np.zeros((1))
    prediction = np.zeros((1))
    for img, mask, target in tqdm(test_data_loader, desc='[{}] test_batch'.format(epoch),
                                    bar_format='{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'):
        if model_name == 'MobileNetV3_Lite':
            mask = mask.to(torch.long)
        img = img.to(device)
        mask = mask.to(device)
        target = target.to(device)
        pred, heatmap = model(img)
        test_pred_loss = pred_criterion(pred, target)
        test_mask_loss = mask_criterion(heatmap, mask)
        test_pred_epoch_loss += test_pred_loss.item()
        test_mask_epoch_loss += test_mask_loss.item()
        scores = np.append(scores, target.cpu().numpy().reshape((-1)))
        prediction = np.append(prediction, pred.cpu().numpy().reshape((-1)))
    test_pred_loss = test_pred_epoch_loss / len(test_data_loader)
    test_mask_loss = test_mask_epoch_loss / len(test_data_loader)
    prediction = np.nan_to_num(prediction)
    srocc = stats.spearmanr(prediction[1:], scores[1:])[0]
    lcc = stats.pearsonr(prediction[1:], scores[1:])[0]

    print("[{}] Test - prediction loss: {:.4e}".format(epoch, test_pred_loss))
    print("[{}] Test - mask loss: {:.4e}".format(epoch, test_mask_loss))
    print("[{}] Test - LCC: {:.4f}, SROCC: {:.4f}".format(epoch, lcc, srocc))