In [1]:
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np
import torch
from basic_unet import UNet
from testnet import load_dataset
from torchvision.transforms import ToTensor
from scipy.spatial.distance import directed_hausdorff


class HausdorffMask:
    def __init__(self, width, height):
        self.width = width
        self.height = height

    def generate_masks(self, circle_size, offset):
        self.x_count = int(self.width / offset)
        self.y_count = int(self.height / offset)

        self.masks = []
        for y_offset in range(self.y_count):
            row = []
            for x_offset in range(self.x_count):
                x = (x_offset * offset)
                y = (y_offset * offset)
                image = Image.new('L', (self.width, self.height), 255)
                draw = ImageDraw.Draw(image)
                draw.ellipse([(x, y), (x + circle_size, y + circle_size)], fill=0)
                tensor = ToTensor()(image) / 255
                tensor = tensor.squeeze()
                row.append(tensor)
            self.masks.append(row)

    def evaluate(self, image, segment, model, device):
        distances = np.zeros((self.y_count, self.x_count))

        for y_offset in range(self.y_count):
            for x_offset in range(self.x_count):
                mask = self.masks[x_offset][y_offset]
                mask = mask.to(device)
                masked_image = image * mask
                output = model(masked_image)
                output = output.detach().cpu().numpy()[0]

                hd1 = directed_hausdorff(output, segment)
                hd2 = directed_hausdorff(segment, output)
                distances[x_offset][y_offset] = np.max([hd1, hd2])
        return distances


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = load_dataset(1)
model = UNet(in_channels=1, out_channels=1)
state_dict = torch.load('models/2_testnet_0490.pth')
model.load_state_dict(state_dict)
model = model.to(device)

hdm = HausdorffMask(240, 240)
hdm.generate_masks(circle_size=20, offset=5)

for sample in test_loader:
    segment = sample['segment']
    segment = segment.squeeze()
    plt.imshow(segment)
    plt.show()

    image = sample['input'].to(device)
    plt.imshow(sample['input'].squeeze())
    plt.show()

    distances = hdm.evaluate(image, segment, model, device)
    plt.imshow(distances)
    plt.show()
    break

<matplotlib.figure.Figure at 0x7fd791f52e80>

<matplotlib.figure.Figure at 0x7fd78865e080>

<matplotlib.figure.Figure at 0x7fd78862d3c8>