In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple


In [8]:
def getBCEmask(shape: Tuple[int, int, int, int], dim):
    B, C, H, W = shape


    mask = torch.zeros((B, C, H, W), dtype=torch.uint8)

    mask[:, :, :dim, :] = 1    
    mask[:, :, -dim:, :] = 1    
    mask[:, :, :, :dim] = 1      
    mask[:, :, :, -dim:] = 1  

    return mask


In [24]:
class BoundaryAwareBCE(nn.Module):
    def __init__(self, lambda_w=1.0):
        super(BoundaryAwareBCE, self).__init__()

        self.lambda_w = lambda_w
    
    def forward(self, pred, target, b_mask):
        assert b_mask.shape == pred.shape == target.shape

        BCE_loss = F.binary_cross_entropy(pred, target, reduction='none')
        BCE_loss = BCE_loss.mean()

        mask_ones = torch.sum(b_mask == 1)
        print(mask_ones)

        boundary_aware = (BCE_loss*b_mask).sum()
        
        boundary_aware = (self.lambda_w/mask_ones)*boundary_aware
        

        return BCE_loss + boundary_aware
    

In [25]:


pred = torch.rand(2, 21, 10, 10)
target = torch.rand(2, 21, 10, 10)
m = getBCEmask((2, 21, 10, 10), 2) 

print(pred.shape)
print(target.shape)
print(m.shape)

torch.Size([2, 21, 10, 10])
torch.Size([2, 21, 10, 10])
torch.Size([2, 21, 10, 10])


In [26]:
loss = BoundaryAwareBCE()

l = loss(pred, target, m)
l

tensor(2688)


tensor(2.0283)

In [18]:
loss = nn.CrossEntropyLoss()

input = torch.rand(2, 21, 10, 10)
target = torch.rand(2, 21, 10, 10)
output = loss(input, target)

print(input.shape)
print(output)

torch.Size([2, 21, 10, 10])
tensor(32.5631)


In [13]:
from utils import ToBMask
import globals as glob
from data import CSDataset
from torchvision.transforms import v2

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
MASK_FN = ToBMask(glob.CS_COLOR2LABEL, fill=0)
DATASET_NAME = "Fine512x192"

data_train = CSDataset(f"{DATASET_NAME}/train",
                        transform_x=v2.Compose(
                            [v2.ToImage(), v2.ToDtype(dtype=torch.float32, scale=True)]
                        ),
                        transform_y=v2.Compose([MASK_FN]),
                       )

x, y, _ = data_train[0]

for i in range(255):
    for j in range(192):
        p = y[:, i, j]

        if p.sum() == 0:
            print(i,j, y[:, i, j])



(16, 192, 512)
90086
0 1 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 2 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 3 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 4 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 5 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 6 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 7 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 8 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 9 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 10 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 11 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 12 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
0 13 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

IndexError: index 192 is out of bounds for dimension 1 with size 192