In [16]:
import os, time
from operator import add
import numpy as np
from glob import glob
import cv2
from tqdm import tqdm
import imageio
import torch
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

from model import build_unet
from utils import create_dir, seeding

In [17]:
def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)    ## (512, 512, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)  ## (512, 512, 3)
    return mask

In [19]:
""" Seeding """
seeding(42)

""" Folders """
create_dir("results")

""" Load dataset """
test_x = sorted(glob(r"C:\Users\shiqi\PycharmProjects\DSCM_fundus\deepscm\assets\data\fundus\test2_image\*"))
test_y = sorted(glob("../new_data/test/mask/*"))
test_x =test_x [:100]
""" Hyperparameters """
H = 512
W = 512
size = (W, H)
checkpoint_path = "files/checkpoint.pth"

""" Load the checkpoint """
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = build_unet()
model = model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

time_taken = []

for i, x in enumerate(test_x):
    """ Extract the name """
    name = x.split("\\")[-1].split(".")[0]

    """ Reading image """
    image = cv2.imread(x, cv2.IMREAD_COLOR) ## (512, 512, 3)
    ## image = cv2.resize(image, size)
    x = np.transpose(image, (2, 0, 1))      ## (3, 512, 512)
    x = x/255.0
    x = np.expand_dims(x, axis=0)           ## (1, 3, 512, 512)
    x = x.astype(np.float32)
    x = torch.from_numpy(x)
    x = x.to(device)

    # """ Reading mask """
    # mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
    # ## mask = cv2.resize(mask, size)
    # y = np.expand_dims(mask, axis=0)            ## (1, 512, 512)
    # y = y/255.0
    # y = np.expand_dims(y, axis=0)               ## (1, 1, 512, 512)
    # y = y.astype(np.float32)
    # y = torch.from_numpy(y)
    # y = y.to(device)

    with torch.no_grad():
        """ Prediction and Calculating FPS """
        start_time = time.time()
        pred_y = model(x)
        pred_y = torch.sigmoid(pred_y)
        total_time = time.time() - start_time
        time_taken.append(total_time)


        # score = calculate_metrics(y, pred_y)
        # metrics_score = list(map(add, metrics_score, score))
        pred_y = pred_y[0].cpu().numpy()        ## (1, 512, 512)
        pred_y = np.squeeze(pred_y, axis=0)     ## (512, 512)
        pred_y = pred_y > 0.5
        pred_y = np.array(pred_y, dtype=np.uint8)

    """ Saving masks """
    # ori_mask = mask_parse(mask)
    pred_y = mask_parse(pred_y)
    line = np.ones((size[1], 10, 3)) * 128

    cat_images = np.concatenate(
        [image, line, line, pred_y * 255], axis=1
    )
    cv2.imwrite(f"results/{name}.png", cat_images)