In [104]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F

In [105]:
# GENERATE CLASS PROTOTYPES

interp_up = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True)

def generate_random_orthogonal_matrix(feature_dim, num_classes):
    a = np.random.random(size=(feature_dim, num_classes))
    P, _ = np.linalg.qr(a)
    P = torch.tensor(P).float()
    assert torch.allclose(torch.matmul(P.T, P), torch.eye(num_classes), atol=1e-07), torch.max(torch.abs(torch.matmul(P.T, P) - torch.eye(num_classes)))
    return P

def generate_etf_class_prototypes(feature_dim, num_classes):
    print(f"Generating ETF class prototypes for K={num_classes} and d={feature_dim}.")
    d = feature_dim
    K = num_classes
    P = generate_random_orthogonal_matrix(feature_dim=d, num_classes=K)
    I = torch.eye(K)
    one = torch.ones(K, K)
    M_star = np.sqrt(K / (K-1)) * torch.matmul(P, I-((1/K) * one))
    M_star = M_star.cuda()
    return M_star

In [106]:
class_prototypes = generate_etf_class_prototypes(feature_dim=2048, num_classes=5)
class_prototypes

Generating ETF class prototypes for K=5 and d=2048.


tensor([[-0.0238,  0.0183,  0.0116, -0.0147,  0.0085],
        [-0.0238, -0.0238, -0.0087,  0.0048,  0.0515],
        [-0.0276,  0.0288, -0.0199,  0.0072,  0.0115],
        ...,
        [-0.0239, -0.0309,  0.0081,  0.0428,  0.0040],
        [-0.0178,  0.0310, -0.0032,  0.0147, -0.0248],
        [-0.0192, -0.0163,  0.0194, -0.0169,  0.0331]], device='cuda:0')

In [107]:
'''
    B: batch size
    feat_dim: feature dimension
    H: Height
    W: Width
    mode: Choose out of three options - ["thresholding", "thresh_feat_consistency", "pixel_self_labeling_OT"]
    
    target_features: B*feat_dim*H*W
    domain_agnostic_prototypes: C*feat_dim

    domain_agnostic_prototypes are already normalized.
'''

print(class_prototypes.shape) # (feat_dim, C)

torch.Size([2048, 5])


In [108]:
B = 4
feat_dim = 2048
H = 33
W = 33

target_features = torch.randn(B, feat_dim, H, W)
print(target_features.shape)
target_features = interp_up(target_features)
print(target_features.shape)

torch.Size([4, 2048, 33, 33])
torch.Size([4, 2048, 256, 256])


In [109]:
# single pixel representation

print(target_features[0,:,0,0].shape)
print(torch.norm(target_features[0,:,0,0], p=2))

torch.Size([2048])
tensor(45.3399)


In [110]:
target_features = F.normalize(target_features, p=2, dim=1)
print(target_features[0,:,0,0].shape)
print(torch.norm(target_features[0,:,0,0], p=2))

torch.Size([2048])
tensor(1.0000)


In [111]:
target_features = target_features.permute(0, 2, 3, 1)
print(target_features.shape)

torch.Size([4, 256, 256, 2048])


In [112]:
batch_pixel_cosine_sim = torch.matmul(target_features.cuda(), class_prototypes.cuda())
print(batch_pixel_cosine_sim.shape)

torch.Size([4, 256, 256, 5])


In [113]:
print(batch_pixel_cosine_sim[0,0,0,:])

tensor([-0.0030, -0.0254,  0.0085,  0.0188,  0.0011], device='cuda:0')


In [114]:
threshold = 0.6

batch_sort_cosine, _ = torch.sort(batch_pixel_cosine_sim, dim=-1)
print(batch_sort_cosine.shape)

torch.Size([4, 256, 256, 5])


In [115]:
print(batch_sort_cosine[0,0,0,:])

tensor([-0.0254, -0.0030,  0.0011,  0.0085,  0.0188], device='cuda:0')


In [116]:
batch_sort_cosine[:,:,:,-1]

tensor([[[0.0188, 0.0170, 0.0143,  ..., 0.0167, 0.0188, 0.0200],
         [0.0205, 0.0185, 0.0154,  ..., 0.0151, 0.0187, 0.0211],
         [0.0220, 0.0197, 0.0164,  ..., 0.0127, 0.0181, 0.0218],
         ...,
         [0.0316, 0.0364, 0.0412,  ..., 0.0243, 0.0291, 0.0323],
         [0.0227, 0.0280, 0.0338,  ..., 0.0278, 0.0314, 0.0336],
         [0.0170, 0.0206, 0.0272,  ..., 0.0299, 0.0325, 0.0338]],

        [[0.0240, 0.0248, 0.0251,  ..., 0.0358, 0.0362, 0.0357],
         [0.0220, 0.0230, 0.0235,  ..., 0.0375, 0.0386, 0.0387],
         [0.0253, 0.0258, 0.0256,  ..., 0.0385, 0.0405, 0.0412],
         ...,
         [0.0324, 0.0328, 0.0322,  ..., 0.0343, 0.0367, 0.0377],
         [0.0338, 0.0337, 0.0325,  ..., 0.0251, 0.0267, 0.0273],
         [0.0342, 0.0337, 0.0320,  ..., 0.0237, 0.0182, 0.0184]],

        [[0.0204, 0.0216, 0.0224,  ..., 0.0168, 0.0138, 0.0190],
         [0.0167, 0.0187, 0.0205,  ..., 0.0176, 0.0148, 0.0213],
         [0.0166, 0.0154, 0.0175,  ..., 0.0182, 0.0158, 0.

In [117]:
pixel_sub_cosine = batch_sort_cosine[:,:,:,-1] - batch_sort_cosine[:,:,:,-2]
print(pixel_sub_cosine.shape)

torch.Size([4, 256, 256])


In [118]:
pixel_mask = pixel_sub_cosine > threshold

In [119]:
pixel_mask

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [

In [120]:
pixel_mask.shape

torch.Size([4, 256, 256])

In [121]:
hard_pixel_label = torch.argmax(batch_pixel_cosine_sim, dim=-1)
hard_pixel_label.shape

torch.Size([4, 256, 256])

In [122]:
hard_pixel_label

tensor([[[3, 3, 3,  ..., 1, 1, 1],
         [3, 3, 3,  ..., 1, 1, 1],
         [3, 3, 3,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 2, 2, 2],
         [1, 1, 1,  ..., 2, 2, 2],
         [2, 1, 1,  ..., 2, 2, 2]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [3, 3, 3,  ..., 1, 1, 1],
         ...,
         [3, 3, 3,  ..., 4, 4, 4],
         [3, 3, 3,  ..., 4, 4, 4],
         [3, 3, 3,  ..., 2, 4, 4]],

        [[4, 4, 4,  ..., 0, 2, 4],
         [4, 4, 4,  ..., 0, 0, 4],
         [3, 3, 4,  ..., 0, 0, 4],
         ...,
         [4, 4, 4,  ..., 3, 1, 1],
         [4, 4, 4,  ..., 3, 1, 1],
         [4, 4, 4,  ..., 3, 1, 1]],

        [[3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2],
         [3, 3, 3,  ..., 2, 2, 2],
         ...,
         [2, 2, 2,  ..., 0, 0, 0],
         [2, 2, 2,  ..., 0, 0, 0],
         [2, 2, 2,  ..., 0, 0, 0]]], device='cuda:0')

In [123]:
hard_pixel_label[pixel_mask]

tensor([], device='cuda:0', dtype=torch.int64)

In [124]:
hard_pixel_label_flat = hard_pixel_label.flatten()
hard_pixel_label_flat.shape

torch.Size([262144])

In [125]:
pixel_mask_flat = pixel_mask.flatten()
pixel_mask_flat.shape

torch.Size([262144])

In [126]:
hard_pixel_label_flat[pixel_mask_flat]

tensor([], device='cuda:0', dtype=torch.int64)

In [127]:
test_mask = torch.ones(262144)
test_mask = test_mask.bool()
test_mask.shape

torch.Size([262144])

In [128]:
hard_pixel_label_flat[test_mask]

tensor([3, 3, 3,  ..., 0, 0, 0], device='cuda:0')

In [130]:
target_features[pixel_mask.cpu()]

tensor([], size=(0, 2048))

In [131]:
target_features.shape

torch.Size([4, 256, 256, 2048])

In [134]:
target_features_flat = target_features.reshape(-1, 2048)
print(target_features_flat.shape)

torch.Size([262144, 2048])


In [136]:
pixel_mask_flat.shape

torch.Size([262144])

In [141]:
target_features_flat[test_mask.cpu()]

tensor([[-0.0224,  0.0190, -0.0069,  ...,  0.0020, -0.0003,  0.0406],
        [-0.0224,  0.0180, -0.0102,  ..., -0.0042,  0.0019,  0.0412],
        [-0.0219,  0.0162, -0.0140,  ..., -0.0120,  0.0046,  0.0407],
        ...,
        [ 0.0121,  0.0048,  0.0329,  ...,  0.0055,  0.0105, -0.0013],
        [ 0.0152,  0.0036,  0.0269,  ...,  0.0048,  0.0092, -0.0003],
        [ 0.0173,  0.0025,  0.0215,  ...,  0.0041,  0.0080,  0.0005]])

In [142]:
len(target_features_flat)

262144

In [2]:
# testing 
import torch

preflattened_label = torch.randn(4*33*33)
preflattened_label.shape

torch.Size([4356])

In [3]:
label = torch.randn(4,33,33)
label.shape

torch.Size([4, 33, 33])

In [4]:
preflattened_label_flat = preflattened_label.view(-1)
preflattened_label_flat.shape

torch.Size([4356])

In [5]:
label_flat = label.view(-1)
label_flat.shape

torch.Size([4356])

In [None]:
s