In [1]:
import torch
import numpy as np

In [2]:
def generate_class_mask(label, classes):
    label, classes = torch.broadcast_tensors(label,
                                             classes.unsqueeze(1).unsqueeze(2))
    class_mask = label.eq(classes).sum(0, keepdims=True)
    return class_mask

def get_class_masks(labels):
    class_masks = []
    for label in labels:
        classes = torch.unique(labels)
        nclasses = classes.shape[0]
        class_choice = np.random.choice(
            nclasses, int((nclasses + nclasses % 2) / 2), replace=False)
        classes = classes[torch.Tensor(class_choice).long()]
        # print(classes)
        class_masks.append(generate_class_mask(label, classes).unsqueeze(0))
    return class_masks

def one_mix(mask, data=None, target=None):
    # (1,1,H,W) (2,3,H,W) (2,H,W)
    if mask is None:
        return data, target
    if not (data is None):
        # (3,H,W)
        stackedMask0, _ = torch.broadcast_tensors(mask[0], data[0])
        # (1,3,H,W)
        data = (stackedMask0 * data[0] +
                (1 - stackedMask0) * data[1]).unsqueeze(0)
    if not (target is None):
        stackedMask0, _ = torch.broadcast_tensors(mask[0], target[0])
        # (1,1,H,W)
        target = (stackedMask0 * target[0] +
                  (1 - stackedMask0) * target[1]).unsqueeze(0)
    return data, target

def lbl_retain(mask, lbl):
    # (1,H,W)
    stackedMask0, _ = torch.broadcast_tensors(mask[0], lbl[0])
    # (1,1,H,W)
    src_lbl = lbl[0].masked_fill(stackedMask0==0, 255).unsqueeze(0)
    tgt_lbl = lbl[1].masked_fill(stackedMask0==1, 255).unsqueeze(0)
    return src_lbl, tgt_lbl

def get_one_hot(label, N):
    b,_,h,w = label.shape
    label = torch.where(label==255, N, label)
    label = label.squeeze(1).view(-1)
    ones = torch.sparse.torch.eye(N)
    ones = torch.cat((ones, torch.zeros(1, N)), dim=0)
    ones = ones.index_select(0, label)
    return ones.view(b, h, w, N).permute(0, 3, 1, 2)

def get_one_hot_cls(label, N):
    assert label.dim() == 2
    b,c = label.shape
    label_one_hot = [None] * b
    ones = torch.sparse.torch.eye(N)
    for i in range(b):
        label_one_hot[i] = (label[i].unsqueeze(-1)*ones).unsqueeze(0)
    label_one_hot = torch.cat(label_one_hot)
    return label_one_hot

In [3]:
src_labels = torch.tensor([[
    [0,2,3],
    [0,2,2],
    [2,255,4]
]])
tgt_labels = torch.tensor([[
    [1,3,3],
    [1,3,5],
    [3,0,255]
]])

print(src_labels)
print(tgt_labels)

mix_masks = get_class_masks(src_labels)
# print(mix_masks[0].shape)

mix_lbl = [None] * 1
src_lbl_retain, tgt_lbl_retain = [None] * 1, [None] * 1

for i in range(1):
    print(mix_masks[i])
    _, mix_lbl[i] = one_mix(
        mix_masks[i], 
        target=torch.stack((src_labels[i], tgt_labels[i]))
    )
    src_lbl_retain[i], tgt_lbl_retain[i] = lbl_retain(
        mix_masks[i], 
        torch.stack((src_labels[i], tgt_labels[i]))
    )
    
mix_lbl = torch.cat(mix_lbl)
src_lbl_retain = torch.cat(src_lbl_retain)
tgt_lbl_retain = torch.cat(tgt_lbl_retain)

print(src_lbl_retain)
print(tgt_lbl_retain)
print(mix_lbl)

src_lbl_retain = get_one_hot(src_lbl_retain, 6)
tgt_lbl_retain = get_one_hot(tgt_lbl_retain, 6)
mix_lbl_onehot = get_one_hot(mix_lbl, 6)

# print(src_lbl_retain)
# print(tgt_lbl_retain)
# print(mix_lbl_onehot)

b,c,_,_ = mix_lbl_onehot.shape

src_sum = src_lbl_retain.view(b,c,-1).sum(dim=2).float()
tgt_sum = tgt_lbl_retain.view(b,c,-1).sum(dim=2).float()
mix_sum = mix_lbl_onehot.view(b,c,-1).sum(dim=2).float()

print(src_sum)
print(tgt_sum)
print(mix_sum)
print("+++")
src_cls = (src_sum / mix_sum).nan_to_num(nan=0.0)
tgt_cls = (tgt_sum / mix_sum).nan_to_num(nan=0.0)
mix_cls = src_cls + tgt_cls

print(src_cls.shape)
print(tgt_cls)
print(mix_cls)
print("+++")
# lbl = torch.cat((src_cls, tgt_cls), dim=1)
# print(lbl)

st = get_one_hot_cls(src_cls, src_cls.shape[1])
tt = get_one_hot_cls(tgt_cls, tgt_cls.shape[1])
mm = torch.cat((st, tt), dim=2)

print(st)
print(tt)
print(mm)

# for i in range(1):
    # st = get_one_hot_cls(src_cls[0], src_cls[0].shape[0])
    # tt = get_one_hot_cls(tgt_cls[0], tgt_cls[0].shape[0])
    # t = torch.cat((st, tt), dim=1)
    # print(st)
    # print(tt)

tensor([[[  0,   2,   3],
         [  0,   2,   2],
         [  2, 255,   4]]])
tensor([[[  1,   3,   3],
         [  1,   3,   5],
         [  3,   0, 255]]])
tensor([[[[1, 0, 0],
          [1, 0, 0],
          [0, 1, 1]]]])
tensor([[[[  0, 255, 255],
          [  0, 255, 255],
          [255, 255,   4]]]])
tensor([[[[255,   3,   3],
          [255,   3,   5],
          [  3, 255, 255]]]])
tensor([[[[  0,   3,   3],
          [  0,   3,   5],
          [  3, 255,   4]]]])
tensor([[2., 0., 0., 0., 1., 0.]])
tensor([[0., 0., 0., 4., 0., 1.]])
tensor([[2., 0., 0., 4., 1., 1.]])
+++
torch.Size([1, 6])
tensor([[0., 0., 0., 1., 0., 1.]])
tensor([[1., 0., 0., 1., 1., 1.]])
+++
tensor([[[1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0.]]])
tensor([[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
   

In [4]:
x = [None] * 2
for i in range(2):
    x[i] = torch.rand(3,3)
print(len(x))
x = torch.cat(x)
print(type(x))
print(x.shape)

2
<class 'torch.Tensor'>
torch.Size([6, 3])


In [5]:
conda activate daformer


CommandNotFoundError: Your shell has not been properly configured to use 'conda activate'.
To initialize your shell, run

    $ conda init <SHELL_NAME>

Currently supported shells are:
  - bash
  - fish
  - tcsh
  - xonsh
  - zsh
  - powershell

See 'conda init --help' for more information and options.

IMPORTANT: You may need to close and restart your shell after running 'conda init'.



Note: you may need to restart the kernel to use updated packages.


In [6]:
import torch
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

img = torch.ones(2,3,4,4)
print(img)

mean = torch.tensor(img_norm_cfg['mean']).reshape(1,3,1,1)
std = torch.tensor(img_norm_cfg['std']).reshape(1,3,1,1)
stdinv = 1.0 / std
img = img - mean
print(img)

img = torch.mul(img, stdinv)
print(img)




tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]],


        [[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
tensor([[[[-122.6750, -122.6750, -122.6750, -122.6750],
          [-122.6750, -122.6750, -122.6750, -122.6750],
          [-122.6750, -122.6750, -122.6750, -122.6750],
          [-122.6750, -122.6750, -122.6750, -122.6750]],

         [[-115.2800, -115.2800, -115.2800, -115.2800],
          [-115.2800, -115.2800

In [19]:
import torch
import torch.nn.functional as F

x = torch.tensor([1.,2.,3.])
y = torch.tensor([1.,0.,1.])
z = x / y
print(z)

z = z.nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)
print(z)

tensor([1., inf, 3.])
tensor([1., 0., 3.])
