In [25]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F

In [75]:
# GENERATE CLASS PROTOTYPES

interp_up = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True)

def generate_random_orthogonal_matrix(feature_dim, num_classes):
    a = np.random.random(size=(feature_dim, num_classes))
    P, _ = np.linalg.qr(a)
    P = torch.tensor(P).float()
    assert torch.allclose(torch.matmul(P.T, P), torch.eye(num_classes), atol=1e-07), torch.max(torch.abs(torch.matmul(P.T, P) - torch.eye(num_classes)))
    return P

def generate_etf_class_prototypes(feature_dim, num_classes):
    print(f"Generating ETF class prototypes for K={num_classes} and d={feature_dim}.")
    d = feature_dim
    K = num_classes
    P = generate_random_orthogonal_matrix(feature_dim=d, num_classes=K)
    I = torch.eye(K)
    one = torch.ones(K, K)
    M_star = np.sqrt(K / (K-1)) * torch.matmul(P, I-((1/K) * one))
    M_star = M_star.cuda()
    return M_star

In [76]:
class_prototypes = generate_etf_class_prototypes(feature_dim=2048, num_classes=5)
class_prototypes

Generating ETF class prototypes for K=5 and d=2048.


tensor([[-0.0201, -0.0195,  0.0518,  0.0056, -0.0178],
        [-0.0207,  0.0010,  0.0187,  0.0126, -0.0117],
        [-0.0154, -0.0023, -0.0076,  0.0098,  0.0155],
        ...,
        [-0.0236, -0.0041,  0.0160,  0.0285, -0.0168],
        [-0.0194,  0.0179, -0.0243,  0.0290, -0.0032],
        [-0.0434,  0.0126, -0.0100,  0.0093,  0.0315]], device='cuda:0')

In [77]:
'''
    B: batch size
    feat_dim: feature dimension
    H: Height
    W: Width
    mode: Choose out of three options - ["thresholding", "thresh_feat_consistency", "pixel_self_labeling_OT"]
    
    target_features: B*feat_dim*H*W
    domain_agnostic_prototypes: C*feat_dim

    domain_agnostic_prototypes are already normalized.
'''

print(class_prototypes.shape) # (feat_dim, C)

torch.Size([2048, 5])


In [78]:
B = 4
feat_dim = 2048
H = 33
W = 33

target_features = torch.randn(B, feat_dim, H, W)
print(target_features.shape)
target_features = interp_up(target_features)
print(target_features.shape)

torch.Size([4, 2048, 33, 33])
torch.Size([4, 2048, 256, 256])


In [79]:
# single pixel representation

print(target_features[0,:,0,0].shape)
print(torch.norm(target_features[0,:,0,0], p=2))

torch.Size([2048])
tensor(44.7465)


In [80]:
target_features = F.normalize(target_features, p=2, dim=1)
print(target_features[0,:,0,0].shape)
print(torch.norm(target_features[0,:,0,0], p=2))

torch.Size([2048])
tensor(1.)


In [81]:
target_features = target_features.permute(0, 2, 3, 1)
print(target_features.shape)

torch.Size([4, 256, 256, 2048])


In [82]:
batch_pixel_cosine_sim = torch.matmul(target_features.cuda(), class_prototypes.cuda())
print(batch_pixel_cosine_sim.shape)

torch.Size([4, 256, 256, 5])


In [83]:
print(batch_pixel_cosine_sim[0,0,0,:])

tensor([-0.0152,  0.0082,  0.0056,  0.0132, -0.0118], device='cuda:0')


In [84]:
threshold = 0.6

batch_sort_cosine, _ = torch.sort(batch_pixel_cosine_sim, dim=-1)
print(batch_sort_cosine.shape)

torch.Size([4, 256, 256, 5])


In [85]:
print(batch_sort_cosine[0,0,0,:])

tensor([-0.0152, -0.0118,  0.0056,  0.0082,  0.0132], device='cuda:0')


In [86]:
batch_sort_cosine[:,:,:,-1]

tensor([[[0.0132, 0.0167, 0.0205,  ..., 0.0328, 0.0334, 0.0332],
         [0.0127, 0.0160, 0.0205,  ..., 0.0326, 0.0332, 0.0329],
         [0.0116, 0.0173, 0.0247,  ..., 0.0314, 0.0327, 0.0344],
         ...,
         [0.0291, 0.0282, 0.0261,  ..., 0.0265, 0.0262, 0.0254],
         [0.0306, 0.0300, 0.0285,  ..., 0.0235, 0.0241, 0.0241],
         [0.0331, 0.0329, 0.0318,  ..., 0.0234, 0.0241, 0.0241]],

        [[0.0384, 0.0400, 0.0407,  ..., 0.0269, 0.0319, 0.0352],
         [0.0374, 0.0389, 0.0396,  ..., 0.0266, 0.0321, 0.0358],
         [0.0349, 0.0363, 0.0368,  ..., 0.0255, 0.0315, 0.0355],
         ...,
         [0.0138, 0.0168, 0.0229,  ..., 0.0301, 0.0291, 0.0277],
         [0.0136, 0.0191, 0.0255,  ..., 0.0329, 0.0321, 0.0307],
         [0.0162, 0.0212, 0.0269,  ..., 0.0343, 0.0337, 0.0324]],

        [[0.0115, 0.0104, 0.0141,  ..., 0.0486, 0.0460, 0.0429],
         [0.0083, 0.0122, 0.0176,  ..., 0.0511, 0.0476, 0.0438],
         [0.0121, 0.0165, 0.0216,  ..., 0.0528, 0.0483, 0.

In [87]:
pixel_sub_cosine = batch_sort_cosine[:,:,:,-1] - batch_sort_cosine[:,:,:,-2]
print(pixel_sub_cosine.shape)

torch.Size([4, 256, 256])


In [88]:
pixel_mask = pixel_sub_cosine > threshold

In [89]:
pixel_mask

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [

In [90]:
pixel_mask.shape

torch.Size([4, 256, 256])

In [91]:
hard_pixel_label = torch.argmax(batch_pixel_cosine_sim, dim=-1)
hard_pixel_label.shape

torch.Size([4, 256, 256])

In [92]:
hard_pixel_label

tensor([[[3, 3, 3,  ..., 1, 1, 1],
         [3, 3, 2,  ..., 1, 1, 1],
         [3, 2, 2,  ..., 1, 4, 4],
         ...,
         [2, 2, 2,  ..., 4, 4, 4],
         [3, 3, 3,  ..., 4, 4, 4],
         [3, 3, 3,  ..., 2, 2, 2]],

        [[2, 2, 2,  ..., 3, 3, 3],
         [2, 2, 2,  ..., 3, 3, 3],
         [2, 2, 2,  ..., 3, 3, 3],
         ...,
         [0, 0, 3,  ..., 3, 3, 3],
         [3, 3, 3,  ..., 3, 3, 3],
         [3, 3, 3,  ..., 3, 3, 3]],

        [[3, 3, 0,  ..., 4, 4, 4],
         [3, 0, 0,  ..., 4, 4, 4],
         [0, 0, 0,  ..., 4, 4, 4],
         ...,
         [4, 4, 4,  ..., 1, 1, 3],
         [4, 4, 4,  ..., 1, 1, 3],
         [4, 4, 4,  ..., 1, 1, 3]],

        [[2, 2, 2,  ..., 2, 2, 2],
         [2, 2, 2,  ..., 2, 2, 2],
         [4, 4, 0,  ..., 2, 2, 2],
         ...,
         [2, 2, 2,  ..., 2, 2, 2],
         [2, 2, 2,  ..., 2, 2, 2],
         [2, 2, 0,  ..., 2, 2, 2]]], device='cuda:0')

In [93]:
hard_pixel_label[pixel_mask]

tensor([], device='cuda:0', dtype=torch.int64)

In [94]:
hard_pixel_label_flat = hard_pixel_label.flatten()
hard_pixel_label_flat.shape

torch.Size([262144])

In [95]:
pixel_mask_flat = pixel_mask.flatten()
pixel_mask_flat.shape

torch.Size([262144])

In [96]:
hard_pixel_label_flat[pixel_mask_flat]

tensor([], device='cuda:0', dtype=torch.int64)

In [102]:
test_mask = torch.ones(262144)
test_mask = test_mask.bool()
test_mask.shape

torch.Size([262144])

In [103]:
hard_pixel_label_flat[test_mask]

tensor([3, 3, 3,  ..., 2, 2, 2], device='cuda:0')