In [4]:
import stage1
import stage2
from PIL import Image
import os
from metrics import mse, ssim_score, psnr, ncc, dice_score
import torch
import numpy as np
from tqdm import tqdm 
import cv2
import matplotlib.pyplot as plt

In [5]:
def get_metrics(image_root, gt_root):
    if not os.path.exists(image_root) or not os.path.exists(gt_root):
        print("Incorrect Path!")
        return

    images = sorted([f for f in os.listdir(image_root)])
    gts = sorted([f for f in os.listdir(gt_root)])

    if len(images) != len(gts):
        print("Number of Images mismatch with GT")
        return
    
    if not os.path.exists("sample_data/preds"):
        os.mkdir("sample_data/preds")
    
    model1 = stage1.Stage_1()
    model2 = stage2.Stage2()

    model1.model.eval()
    model2.unet_model.eval()

    r_mse = 0.0
    r_ssim = 0.0
    r_psnr = 0.0
    r_ncc = 0.0
    r_dice = 0.0

    for idx in tqdm(range(len(images)), desc="Computing metrics"):
        img_name = images[idx]
        img_path = os.path.join(image_root, img_name)
        img = Image.open(img_path).convert('RGB')
        size_og = img.size

        with torch.no_grad():
            masked_img = model1.forward(img)
            pred = model2.forward(masked_img, size_og)

        gt_name = gts[idx]
        gt_path = os.path.join(gt_root, gt_name)
        gt = np.array(Image.open(gt_path).convert('1'))

        gt = gt.astype(np.float32)
        pred = torch.from_numpy(pred).float().cpu().numpy()

        pred = (pred > 0.4).astype(float)

        r_mse += mse(pred, gt)
        r_ssim += ssim_score(pred, gt)
        r_psnr += psnr(pred, gt)
        r_ncc += ncc(pred, gt)
        r_dice += dice_score(pred, gt)

    r_ncc /= len(images)
    r_ssim /= len(images)
    r_psnr /= len(images)
    r_mse /= len(images)
    r_dice /= len(images)
    
    return {
        "ssim" : r_ssim,
        "psnr" : r_psnr,
        "mse" : r_mse,
        "ncc" : r_ncc,
        "dice": r_dice,
    }


In [6]:
metrics = get_metrics("sample_data/train", "sample_data/train_gt")
print(metrics)

Computing metrics: 100%|██████████| 128/128 [03:53<00:00,  1.82s/it]

{'ssim': 0.9779788919104669, 'psnr': 70.22066102478281, 'mse': 0.007259247933862208, 'ncc': 0.534853588555341, 'dice': 0.533907647464024}



