# 1. bce2d和boundary_loss

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [20]:
def bce2d(input, target):
        """To calculate weighted cross-entropy of two binary boundary map,
        input and target are both binary boundary map.

        Args:
        input(Tensor): pred binary boundary map.
        target(Tensor): ground truth binary boundary map generated by using Sobel
            filter and threshold on ground truth segmentation mask.
        """
        print(f"input.shape: {input.shape}")
        log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)
        print(f"log_p.shape: {log_p.shape}")
        target_t = target.transpose(1,2).transpose(2,3).contiguous().view(1, -1)
        target_trans = target_t.clone()

        pos_index = (target_t == 1)
        neg_index = (target_t == 0)
        ignore_index = (target_t > 1)

        target_trans[pos_index] = 1
        target_trans[neg_index] = 0

        pos_index = pos_index.data.cpu().numpy().astype(bool)
        neg_index = neg_index.data.cpu().numpy().astype(bool)
        ignore_index = ignore_index.data.cpu().numpy().astype(bool)

        weight = torch.Tensor(log_p.size()).fill_(0)
        weight = weight.numpy()
        pos_num = pos_index.sum()
        neg_num = neg_index.sum()
        sum_num = pos_num + neg_num
        weight[pos_index] = neg_num * 1.0 / sum_num
        weight[neg_index] = pos_num * 1.0 / sum_num

        weight[ignore_index] = 0

        weight = torch.from_numpy(weight)
        weight = weight
        loss = F.binary_cross_entropy_with_logits(
            log_p, target_t, weight, size_average=True)
        return loss

In [16]:
# https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122
def weighted_bce(bd_pre, target):
    n, c, h, w = bd_pre.size()
    log_p = bd_pre.permute(0,2,3,1).contiguous().view(1, -1)
    target_t = target.view(1, -1)

    pos_index = (target_t == 1)
    neg_index = (target_t == 0)

    weight = torch.zeros_like(log_p)
    pos_num = pos_index.sum()
    neg_num = neg_index.sum()
    sum_num = pos_num + neg_num
    weight[pos_index] = neg_num * 1.0 / sum_num
    weight[neg_index] = pos_num * 1.0 / sum_num

    loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, reduction='mean')

    return loss

In [26]:
a = torch.zeros(2,1,16,16)
a[:,:,5,:] = 1
a[:,:,1,:] = 3
# print(a.shape)
# print(a)
pre = torch.randn(2,1,16,16)
print(pre.shape)

print(bce2d(pre, a))
print(weighted_bce(pre, a))

torch.Size([2, 1, 16, 16])
input.shape: torch.Size([2, 1, 16, 16])
log_p.shape: torch.Size([1, 512])
tensor(0.0991)
tensor(0.0991)


确实结果是一样的。。

In [27]:
# Boundary Loss 原理与代码解析 https://blog.csdn.net/ooooocj/article/details/126560722
# 和这个好像不是一个东西
import torch
import numpy as np
from torch import einsum
from torch import Tensor
from scipy.ndimage import distance_transform_edt as distance
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union
 
 
# switch between representations
def probs2class(probs: Tensor) -> Tensor:
    b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
    assert simplex(probs)
 
    res = probs.argmax(dim=1)
    assert res.shape == (b, w, h)
 
    return res
 
def probs2one_hot(probs: Tensor) -> Tensor:
    _, C, _, _ = probs.shape
    assert simplex(probs)
 
    res = class2one_hot(probs2class(probs), C)
    assert res.shape == probs.shape
    assert one_hot(res)
 
    return res
 
def class2one_hot(seg: Tensor, C: int) -> Tensor:
    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
        seg = seg.unsqueeze(dim=0)
    assert sset(seg, list(range(C)))
 
    b, w, h = seg.shape  # type: Tuple[int, int, int]
 
    res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
    assert res.shape == (b, C, w, h)
    assert one_hot(res)
 
    return res
 
def one_hot2dist(seg: np.ndarray) -> np.ndarray:
    assert one_hot(torch.Tensor(seg), axis=0)
    C: int = len(seg)
 
    res = np.zeros_like(seg)
    # res = res.astype(np.float64)
    for c in range(C):
        posmask = seg[c].astype(np.bool)
 
        if posmask.any():
            negmask = ~posmask
            res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
    return res
 
def simplex(t: Tensor, axis=1) -> bool:
    _sum = t.sum(axis).type(torch.float32)
    _ones = torch.ones_like(_sum, dtype=torch.float32)
    return torch.allclose(_sum, _ones)
 
def one_hot(t: Tensor, axis=1) -> bool:
    return simplex(t, axis) and sset(t, [0, 1])
 
    # Assert utils
 
 
def uniq(a: Tensor) -> Set:
    return set(torch.unique(a.cpu()).numpy())
 
def sset(a: Tensor, sub: Iterable) -> bool:
    return uniq(a).issubset(sub)
 
class SurfaceLoss():
    def __init__(self):
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc: List[int] = [1]  # 这里忽略背景类  https://github.com/LIVIAETS/surface-loss/issues/3
 
    # probs: bcwh, dist_maps: bcwh
    def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor:
        assert simplex(probs)
        assert not one_hot(dist_maps)
 
        pc = probs[:, self.idc, ...].type(torch.float32)
        dc = dist_maps[:, self.idc, ...].type(torch.float32)
 
        multiplied = einsum("bcwh,bcwh->bcwh", pc, dc)
 
        loss = multiplied.mean()
 
        return loss
 
def test(data,logits):
    # data = torch.tensor([[[0, 0, 0, 0, 0, 0, 0],
    #                       [0, 1, 1, 0, 0, 0, 0],
    #                       [0, 1, 1, 0, 0, 0, 0],
    #                       [0, 0, 0, 0, 0, 0, 0]]])  # (b, h, w)->(1,4,7)
    data2 = class2one_hot(data, 2)  # (b, num_class, h, w): (1,2,4,7)
    data2 = data2[0].numpy()  # (2,4,7)
    data3 = one_hot2dist(data2)  # bcwh
    # logits = torch.tensor([[[0, 0, 0, 0, 0, 0, 0],
    #                         [0, 1, 1, 1, 1, 1, 0],
    #                         [0, 1, 1, 0, 0, 0, 0],
    #                         [0, 0, 0, 0, 0, 0, 0]]])  # (b, h, w)

    logits = class2one_hot(logits, 2)
    Loss = SurfaceLoss()
    data3 = torch.tensor(data3).unsqueeze(0)
    res = Loss(logits, data3, None)
    print('loss:', res)

In [31]:
print(pre[1].shape)
test(pre[1],a[1])

torch.Size([1, 16, 16])


AssertionError: 

# 2. InverseNet

In [1]:
import torch
pred = torch.rand((1, 10, 4, 4))
target = torch.randint(0, 10, (1, 4, 4))
print(pred.shape)
print(target)

torch.Size([1, 10, 4, 4])
tensor([[[2, 3, 3, 2],
         [8, 4, 3, 4],
         [7, 9, 7, 7],
         [1, 9, 9, 7]]])


In [2]:
print(target.shape)

torch.Size([1, 4, 4])


In [3]:
logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long()
print(logits.shape)
print(labels.shape)

torch.Size([8, 3, 4, 4])
torch.Size([8, 4, 4])


我的跟其他的似乎不是很一样，其他都是对Segmentation的mask进行loss的计算，

这个是对两个boundary map进行计算的。。。不一样，这个只有一类前景，不是多类的前景。

In [13]:
pred = torch.zeros(3, 1, 4, 4)
target = torch.zeros(3, 1, 4, 4)

for i in range(3):
    pred[i, 0, i, :] = 1
    target[i, 0, :, i] = 1
    print(f"pred is:{pred[i]}, \n target is {target[i]}")

pred is:tensor([[[1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]]), 
 target is tensor([[[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]]])
pred is:tensor([[[0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]]), 
 target is tensor([[[0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.]]])
pred is:tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]]]), 
 target is tensor([[[0., 0., 1., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.]]])


In [8]:
pred.shape

torch.Size([3, 1, 4, 4])

In [11]:
pred[0,0,0,:]

tensor([0., 0., 0., 0.])

In [15]:
import urllib
pretraind_model_url = 'https://github.com/Qualcomm-AI-research/InverseForm/releases/download/v1.0/distance_measures_regressor.pth'
inverseNet_path = 'distance_measures_regressor.pth'
urllib.request.urlretrieve(pretraind_model_url, inverseNet_path)

('distance_measures_regressor.pth',
 <http.client.HTTPMessage at 0x7fb073291210>)

In [17]:
from tqdm import tqdm
import requests

repo = "https://github.com/Qualcomm-AI-research/InverseForm/"
download_folder="releases/download/v1.0/distance_measures_regressor.pth"
pretraind_model_url=repo+download_folder
print(pretraind_model_url)

response = requests.get(pretraind_model_url, stream=True)
total_size_in_bytes= int(response.headers.get('content-length', 0))
block_size = 1024 #1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(inverseNet_path, 'wb') as file:
    for data in response.iter_content(block_size):
        progress_bar.update(len(data))
        file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
    print("ERROR, something went wrong")

https://github.com/Qualcomm-AI-research/InverseForm/releases/download/v1.0/distance_measures_regressor.pth


100%|███████████████████████████████████████| 402M/402M [00:17<00:00, 22.7MiB/s]


参考：https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests

In [19]:
import os
os.makedirs("./checkpoints", exist_ok=True)

In [20]:
import sys
sys.path

['/Users/huangshan/Documents/DailyStudy/openMMLabCampusLearn/PR/PR3',
 '/Users/huangshan/Documents/software/miniconda3/miniconda3/envs/py37/lib/python37.zip',
 '/Users/huangshan/Documents/software/miniconda3/miniconda3/envs/py37/lib/python3.7',
 '/Users/huangshan/Documents/software/miniconda3/miniconda3/envs/py37/lib/python3.7/lib-dynload',
 '',
 '/Users/huangshan/Documents/software/miniconda3/miniconda3/envs/py37/lib/python3.7/site-packages',
 '/Users/huangshan/Documents/DailyStudy/openMMLabCampusLearn/selfExercise/mmsegmentation',
 '/Users/huangshan/Documents/DailyStudy/openMMLabCampusLearn/selfExercise/mmdetection',
 '/Users/huangshan/Documents/software/miniconda3/miniconda3/envs/py37/lib/python3.7/site-packages/IPython/extensions',
 '/Users/huangshan/.ipython']

In [None]:
sys.path.append('/Users/huangshan/Documents/DailyStudy/mmsegmentation')