In [7]:
import torch
import os
from glob import glob
from PIL import Image
import numpy as np
import cv2
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

SIGMA_X2 = 4
SIGMA_I2 = 10
RADIUS = 5
N_CLASSES = 1

region_size = 2*[2*RADIUS+1]

In [8]:
def get_weights_tensor(image, d, regions_mold):

    # Create a region matrix (height, width) for each pixel in the input image I and stack them into a tensor.
    # Each pixel in the image corresponds to a single matrix with given pixel in the center.
    # Tensor shape: (1, # of pixels in I, R_height, R_width)
    regions = torch.nn.Unflatten(1, REGION_SIZE)(torch.nn.Unfold(REGION_SIZE, padding=RADIUS)(
        image[(None,) * 2])).permute(0, 3, 1, 2)

    i_f = image.flatten()
    i_f = torch.unsqueeze(i_f, dim=1)
    i_f = torch.unsqueeze(i_f, dim=1)
    i_f = torch.unsqueeze(i_f, dim=0)
    d = torch.unsqueeze(d, dim=0)
    return torch.squeeze(((torch.exp(-torch.pow((regions - i_f), 2) / SIGMA_I2) * d) * regions_mold), dim=0)

def get_distance_sq(p1, p2):
    dif = p1 - p2
    return np.dot(dif.T, dif)

def generate_distance_tensor(r):
    region = torch.zeros((2*r+1, 2*r+1))
    for x in range(region.shape[0]):
        for y in range(region.shape[1]):
            region[x,y] = -get_distance_sq(np.array([x,y]), np.array([r,r]))
    return torch.exp(torch.unsqueeze(region, dim=0)/SIGMA_X2)

### Weight tensor generation

In [9]:
d = generate_distance_tensor(RADIUS)
mold = torch.ones(256, 256)
unfold = torch.nn.Unfold(region_size, padding=RADIUS)
unflatten = torch.nn.Unflatten(1, region_size)
regions_mold = unflatten(unfold(mold[(None,)*2])).permute(0,3,1,2)

with tqdm(total=len(glob(r"C:\Users\clohk\Desktop\Projects\WNet\wnet_pytorch\patches_tries\*"))) as pbar:
    for patch_path in glob(r"C:\Users\clohk\Desktop\Projects\WNet\wnet_pytorch\patches_tries\*"):
        filename = os.path.basename(os.path.splitext(os.path.normpath(patch_path))[0])
        img = np.array(Image.open(patch_path))
        img = cv2.resize(img, (256, 256), interpolation = cv2.INTER_AREA)
        torch.save(get_weights_tensor(torch.Tensor(img), d, regions_mold), r"C:\Users\clohk\Desktop\Projects\WNet\wnet_pytorch\weights\{}.pt".format(filename))
        pbar.update(1)

100%|██████████| 6254/6254 [57:47<00:00,  1.80it/s]  


### Loss function

In [145]:
unfold = torch.nn.Unfold(region_size, padding=RADIUS)
unflatten = torch.nn.Unflatten(1, region_size)
image = torch.Tensor(img)
CLASS_PROBABILITIES = torch.rand_like(image)


In [146]:
W = torch.load(r"C:\Users\clohk\Desktop\Projects\WNet\wnet_pytorch\weights\180_s10_10_2.pt")
P = torch.squeeze(unflatten(unfold(CLASS_PROBABILITIES[(None,)*2])).permute(0,3,1,2), dim=0)
L = 1 - torch.matmul(CLASS_PROBABILITIES.flatten(), torch.sum(W * P, dim=(1,2)))/ \
        torch.matmul(CLASS_PROBABILITIES.flatten(), torch.sum(W, dim=(1,2)))
L

tensor(0.4934)

### Testing

In [None]:
d = generate_distance_tensor(RADIUS)
mold = torch.ones(256, 256)
unfold = torch.nn.Unfold(region_size, padding=RADIUS)
unflatten = torch.nn.Unflatten(1, region_size)
regions_mold = unflatten(unfold(mold[(None,)*2])).permute(0,3,1,2)