In [1]:
import os

import cv2 as cv
import pandas as pd
import torch
import torchvision.transforms as transforms
from PIL import Image
from catalyst.dl.utils import UtilsFactory
from tqdm import tqdm

import argparse

from models.utils import get_model

In [4]:
test_df = '../data/test_df.csv'
dataset_path = '../data'
network = 'linknet'
model_weights_path = '../data/best.pth'

In [5]:
def predict(dataset_path, model_weights_path, network, test_df_path):

    model = get_model(network)
    checkpoint = torch.load(model_weights_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    test_df = pd.read_csv(test_df_path)
    
    predictions_path = os.path.join(dataset_path, "predictions")

    if not os.path.exists(predictions_path):
        os.mkdir(predictions_path)
        print("Prediction directory created.")

    for _, data_info in tqdm(test_df.iterrows()):
        data = get_image(data_info)

        prediction = model(data['features'].view(1, 3, 224, 224))

        result = torch.sigmoid(prediction.view(224, 224)).detach().numpy()

        cv.imwrite(os.path.join(predictions_path, filename['image_name'] + ".png"), result * 255)


In [None]:
predict(dataset_path, model_weights_path, network, test_df)

In [7]:
def get_image(data_info):
    filename = '_'.join([
        data_info['name'],
        data_info['channel'],
        data_info['position']])
    image_path = os.path.join(
        data_info['image_path'],
        '{}.{}'.format(
            filename,
            data_info['image_type']))
    mask_path = os.path.join(
        data_info['mask_path'],
        '{}.{}'.format(
            filename,
            data_info['mask_type']))

    img = Image.open(image_path)
    mask = Image.open(mask_path)

    img_tensor = transforms.ToTensor()(img)
    mask_tensor = transforms.ToTensor()(mask)

    return {'features': img_tensor, 'targets': mask_tensor}