In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
import typing as T

plt.set_cmap('viridis')

# Old code

In [None]:
class SquareFilter(nn.Module):
    def __init__(self, ksize, padding_mode='reflect'):
        super().__init__()
        self.padding = [ksize // 2] * 4
        self.padding_mode = padding_mode

        with torch.no_grad():
            kernel = torch.ones(ksize)
            self.register_buffer('kernel', kernel.div_(torch.sum(kernel)))  # normalize
            self.kernel.requires_grad_(False)

    def forward(self, x):
        ker1 = self.kernel.expand(x.shape[1], 1, 1, *self.kernel.shape)
        ker2 = ker1.view(x.shape[1], 1, *self.kernel.shape, 1)
        x = F.pad(x, self.padding, mode=self.padding_mode)
        for ker in [ker1, ker2]:
            x = F.conv2d(x, weight=ker, groups=x.shape[1], padding=0)
        return x

In [None]:
def get_inside_positions(segmap, trigger_shape, num_classes=-1):
    filter = SquareFilter(trigger_shape // 2 * 2 + 1)
    segmap_oh = F.one_hot(segmap.long(), num_classes=num_classes).float()  # NHWC
    segmap_oh_filtered = filter(segmap_oh.permute(0, 3, 1, 2))  # NCHW
    return torch.isclose(segmap_oh_filtered.max(1).values, torch.ones(1))

def get_mask_distance_map(mask):
    return (cv2.distanceTransform(mask, cv2.DIST_C, 5) + 0.5).astype(np.uint8)

def get_closest_valid_trigger_centers(segmap, victim_class, trigger_shape, valid_mask):
    fg_map = (segmap != victim_class).astype(np.uint8)
    dist_map = get_mask_distance_map(fg_map)  # TODO: check 3
    masked_dist_map = dist_map * valid_mask
    min_dist = np.min(masked_dist_map[masked_dist_map > 0])
    return masked_dist_map == min_dist

In [None]:
num_classes = 4
victim_class = 0
trigger_shape = 5

def make_example_segmap():
    segmap = np.zeros((50, 50), dtype=np.uint8)
    segmap[20:30, 20:30] = 1
    segmap[30:40, 0:20] = 2
    segmap[30:50, 20:50] = 3
    segmap[44, 41] = 0
    return torch.tensor(segmap).unsqueeze(0)

segmap = make_example_segmap()
print(f'{segmap.shape=}')
plt.imshow(segmap.numpy().squeeze())
plt.show()

In [None]:
segmap_oh_filtered = SquareFilter(trigger_shape)(F.one_hot(segmap.long(), num_classes=num_classes).float().permute(0, 3, 1, 2))[:,1:].permute(0, 2, 3, 1)
print(f'{segmap_oh_filtered.shape=}')
plt.imshow(segmap_oh_filtered.numpy().squeeze())
plt.show()

victim_class_map = segmap == victim_class
print(f'{victim_class_map.shape=}')
plt.imshow(victim_class_map.numpy().squeeze())
plt.show()

inside_positions = get_inside_positions(segmap, trigger_shape)
non_victim_inside_positions = inside_positions & ~victim_class_map
print(f'{non_victim_inside_positions.shape=}')
plt.imshow((non_victim_inside_positions).numpy().squeeze())
plt.show()

In [None]:
valid_position_map = get_closest_valid_trigger_centers(segmap.numpy().squeeze(), victim_class, trigger_shape, valid_mask=non_victim_inside_positions.numpy().squeeze())
plt.imshow(valid_position_map)
plt.show()

In [None]:
fg_map = (segmap.numpy().squeeze() != victim_class).astype(np.uint8)
dist_map = get_mask_distance_map(fg_map)  # TODO: check 3

print(np.unique(dist_map))
plt.imshow(dist_map)
plt.show()

In [None]:
valid_mask = non_victim_inside_positions.numpy().squeeze()
masked_dist_map = dist_map * valid_mask
plt.imshow(masked_dist_map)
plt.show()

min_distance = np.min(masked_dist_map[masked_dist_map > 0])
plt.imshow(masked_dist_map == min_distance)
plt.show()

# Version 2 - Arbitrary kernel shape and position constraint options

In [None]:
def channelwise_conv2d(x, kernel, padding, padding_mode='reflect'):
    """Applies a 2D channel-wise convolution on input tensor `x` with the given `kernel`.

    Args:
        x (torch.Tensor): Input tensor with shape (N, C, H, W).
        kernel (torch.Tensor): Convolution kernel with shape (H, W).
        padding (tuple): Tuple specifying the padding on each side [left, right, top, bottom].
        padding_mode (str, optional): Padding mode. Default is 'reflect'.
    """
    ker = kernel.expand(x.shape[1], 1, *kernel.shape)  # C, C/groups, H, W
    x = F.pad(x, padding, mode=padding_mode)  # padding: L, R, T, B
    return F.conv2d(x, weight=ker, groups=x.shape[1], padding=0)


def get_overlaps(masks, trigger_shape, anchor='top left'):
    assert anchor == 'top left'
    kernel = torch.ones(trigger_shape) / np.prod(trigger_shape)
    masks_filtered = channelwise_conv2d(
        masks, kernel, padding=[0, kernel.shape[1] - 1, 0, kernel.shape[0] - 1])
    return masks_filtered


def get_inside_positions(segmap, trigger_shape, num_classes=-1, anchor='top left'):
    masks = F.one_hot(segmap.long(), num_classes=num_classes).permute(0, 3, 1, 2)  # NHWC
    masks_filtered = get_overlaps(masks.float(), trigger_shape, anchor=anchor)
    return torch.isclose(masks_filtered.max(1).values, torch.ones(1))


def get_outer_border(mask, anchor='top left'):
    assert anchor == 'top left'
    kernel = torch.ones((3, 3)) / 8
    kernel[1, 1] = -1
    borderness = channelwise_conv2d(mask.unsqueeze(1), kernel, padding=[1] * 4)
    return borderness > 0


def get_mask_distance_map(mask):
    return (cv2.distanceTransform(mask, cv2.DIST_C, 5) + 0.5).astype(np.uint8)


def get_closest_valid_trigger_centers_np(fg_mask, valid_mask):  # NumPy, single
    dist_map = get_mask_distance_map((1 - fg_mask).astype(np.uint8))  # TODO: check 3
    masked_dist_map = dist_map * valid_mask
    min_dist = np.min(masked_dist_map[masked_dist_map > 0])
    return masked_dist_map == min_dist


def get_closest_valid_trigger_centers(fg_mask, valid_mask):  # NumPy, single
    fg_mask = fg_mask.cpu().numpy()
    valid_mask = valid_mask.cpu().numpy()
    return torch.from_numpy(get_closest_valid_trigger_centers_np(fg_mask, valid_mask)).to(torch.bool)


def get_valid_trigger_centers(segmap, victim_class, trigger_shape, num_classes=-1, constraint='closest'):
    victim_mask = (segmap == victim_class).float()
    inside_positions = get_inside_positions(segmap, trigger_shape, num_classes=num_classes)
    valid_positions = inside_positions * (1 - victim_mask)  # TODO
    if constraint in ('closest', 'border'):
        victim_overlaps = get_overlaps(victim_mask.unsqueeze(1), trigger_shape).squeeze(1)
        overlap_mask = (victim_overlaps > 0).float()
        if constraint == 'border':
            return valid_positions * get_outer_border(overlap_mask)
        else:
            assert overlap_mask.shape[0] == 1
            return get_closest_valid_trigger_centers(overlap_mask.squeeze(0), valid_positions.squeeze(0))
    return valid_positions

In [None]:
num_classes = 4
victim_class = 0
trigger_shape = [2, 5]

def make_example_segmap():
    segmap = np.zeros((50, 50), dtype=np.uint8)
    segmap[20:30, 20:30] = 1
    segmap[30:40, 0:20] = 2
    segmap[30:50, 20:50] = 3
    segmap[44, 41] = 0
    return torch.tensor(segmap).unsqueeze(0)

segmap = make_example_segmap()
print(f'{segmap.shape=}')
plt.imshow(segmap.numpy().squeeze())
plt.show()

In [None]:
inside_positions = get_inside_positions(segmap, trigger_shape, num_classes=num_classes)
plt.imshow(inside_positions.numpy().squeeze())
plt.show()

In [None]:
victim_mask = segmap == victim_class
valid_positions = inside_positions * ~victim_mask
plt.imshow((segmap + 4 * valid_positions).numpy().squeeze())
plt.show()

In [None]:
victim_overlaps = get_overlaps(victim_mask.float().unsqueeze(1), trigger_shape).squeeze(1)
outer_border = get_outer_border((victim_overlaps > 0).float())
plt.imshow((segmap + 4 * outer_border).numpy().squeeze())
plt.show()

valid_positions_border = valid_positions * outer_border
plt.imshow((segmap + 4 * valid_positions_border).numpy().squeeze())
plt.show()

In [None]:
valid_trigger_centers = get_valid_trigger_centers(segmap, victim_class, trigger_shape, num_classes=num_classes, constraint=None)
plt.imshow((segmap + 4 * valid_trigger_centers).numpy().squeeze())
plt.show()

In [None]:
valid_trigger_centers = get_valid_trigger_centers(segmap, victim_class, trigger_shape, num_classes=num_classes, constraint='border')
plt.imshow((segmap + 4 * valid_trigger_centers).numpy().squeeze())
plt.show()

In [None]:
valid_trigger_centers = get_valid_trigger_centers(segmap, victim_class, trigger_shape, num_classes=num_classes, constraint='closest')
plt.imshow((segmap + 4 * valid_trigger_centers).numpy().squeeze())
plt.show()