In [1]:
import skimage.io as skio
import skimage
import glob
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
from filters import *
from tools import *
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import random

import argparse
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask

from UNet.predict import predict_img

In [2]:
DATASET = ['Fluo-N2DL-HeLa','PhC-C2DH-U373']
idx = 0

IM_TEST_PATH = sorted(glob.glob(f'data/{DATASET[idx]}/IMG_TEST/*.tif'))
TG_TEST_PATH = sorted(glob.glob(f'data/{DATASET[idx]}/TARGET_TEST/*.tif'))

X_test, y_test = load_img_tg(IM_TEST_PATH, TG_TEST_PATH)

In [None]:
model = f'./UNet/checkpoints/{DATASET[idx]}/checkpoint_epoch5.pth'

net = UNet(n_channels=1, n_classes=1, bilinear=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')

net.to(device=device)
net.load_state_dict(torch.load(model, map_location=device))

logging.info('Model loaded!')

In [None]:
results = []
for i, img in enumerate(zip(X_test, y_test)):
    logging.info(f'\nPredicting image {IM_TEST_PATH[i]} ...')

    pred = predict_img(net=net,
                       full_img=img,
                       device=device)
    
    results.append(jaccard_score(y_test[i],pred,average='micro'))
    plt.hist(results,bins=100,range=(0,1))
    plt.title(f'IoU of UNet on {DATASET[idx]}')
    plt.xlabel('IoU score')
    plt.ylabel('Frequency')
    plt.savefig(f'output/UNet_performance_dataset_{idx}.png')

In [None]:
n_im = random.randrange(len(y_test))
target = y_test[n_im]
seg = predict_img(net=net,
                       full_img=X_test[n_im],
                       device=device)

score = jaccard_score(target,seg,average='micro')

fig = plot_pred_with_target(target,seg,score)
fig.savefig(f'output/inference_UNet_dataset_{idx}.png')