In [29]:
import torch
import torch.nn.functional as F

In [39]:
def reco_loss_func(rep, label, mask, prob=None, strong_threshold=1.0,temp=0.5, num_queries=256, num_negatives=256):
    B, C, H, W = rep.shape
    rep = rep.reshape(B, C, -1)  
    label = label.reshape(B, -1)  
    mask = mask.reshape(B, -1)  
    
    if prob is not None:
        prob = prob.reshape(B, prob.size(1), -1)
    
    # Get unique classes in the batch
    classes = torch.unique(label)
    classes = classes[classes != 255]  
    
    if len(classes) == 0:
        return torch.tensor(0.0, device=rep.device, requires_grad=True)
    
    # mean class representations (positive keys)
    r_c_plus_k = {} 
    class_pixels_dict = {}
    
    for c in classes:
        c_mask = (label == c) & mask 
        if c_mask.sum() == 0:
            continue
        
        class_pixels = []
        class_probs = []
        for b in range(B):
            b_mask = c_mask[b] 
            if b_mask.sum() > 0:
                b_rep = rep[b, :, b_mask] 
                class_pixels.append(b_rep)
                if prob is not None:
                    b_prob = prob[b, c.item(), b_mask]
                    class_probs.append(b_prob)
        
        if len(class_pixels) > 0:
            all_pixels = torch.cat(class_pixels, dim=1)  
            class_pixels_dict[c.item()] = {
                'pixels': all_pixels,
                'probs': torch.cat(class_probs) if len(class_probs) > 0 else None
            }
            # mean representation for class c (positive key)
            r_c_plus_k[c.item()] = all_pixels.mean(dim=1)
    
    # class relationship graph G (for active key sampling)
    classes_list = list(r_c_plus_k.keys())
    if len(classes_list) <= 1:  
        return torch.tensor(0.0, device=rep.device, requires_grad=True)
    
    # pairwise relationships between classes as defined in Eq. 3
    G = torch.zeros((len(classes_list), len(classes_list)), device=rep.device)
    for i, c_i in enumerate(classes_list):
        for j, c_j in enumerate(classes_list):
            if i != j:
                G[i, j] = torch.dot(
                    F.normalize(r_c_plus_k[c_i], p=2, dim=0),
                    F.normalize(r_c_plus_k[c_j], p=2, dim=0)
                )
    
    # ReCo loss for all valid classes
    total_loss = 0.0
    valid_classes = 0
    
    for idx, c in enumerate(classes_list):
        c_data = class_pixels_dict[c]
        c_pixels = c_data['pixels']
        
        # active query sampling based on prediction confidence (Eq. 4)
        if c_data['probs'] is not None:
            c_probs = c_data['probs']
            hard_mask = (c_probs < strong_threshold)
            hard_pixels = c_pixels[:, hard_mask]
            hard_probs = c_probs[hard_mask]
            
            if hard_pixels.size(1) > num_queries:
                _, indices = torch.topk(hard_probs, num_queries, largest=False)
                queries = hard_pixels[:, indices]
            else:
                queries = hard_pixels

        else:
            if c_pixels.size(1) > num_queries:
                indices = torch.randperm(c_pixels.size(1), device=c_pixels.device)[:num_queries]
                queries = c_pixels[:, indices]
            else:
                queries = c_pixels
        
        if queries.size(1) == 0:
            continue
            
        # positive key (class mean representation)
        positive_key = r_c_plus_k[c]
        
        # negative key sampling 
        with torch.no_grad():
            distribution = F.softmax(G[idx], dim=0)
            neg_indices = [j for j in range(len(classes_list)) if j != idx]
            
            neg_distribution = distribution[neg_indices]
            neg_distribution = neg_distribution / neg_distribution.sum()
            
            negative_keys = []
            neg_classes = [classes_list[j] for j in neg_indices]
            remaining_samples = num_negatives
            
            for j, neg_c in enumerate(neg_classes):
                if remaining_samples <= 0:
                    break
                    
                num_samples = min(
                    max(1, int(neg_distribution[j].item() * num_negatives)),
                    class_pixels_dict[neg_c]['pixels'].size(1),
                    remaining_samples
                )
                
                if num_samples > 0:
                    indices = torch.randperm(class_pixels_dict[neg_c]['pixels'].size(1), device=rep.device)[:num_samples]
                    sampled_neg = class_pixels_dict[neg_c]['pixels'][:, indices]
                    negative_keys.append(sampled_neg)
                    remaining_samples -= num_samples
            
            if len(negative_keys) == 0:
                continue
                
            all_negative_keys = torch.cat(negative_keys, dim=1)  # [C, N]
            
            # Combine positive and negative keys: keys = [positive key | negative keys]
            positive_feat = positive_key.unsqueeze(0).unsqueeze(0).repeat(queries.size(1), 1, 1)
            all_feat = torch.cat((positive_feat, all_negative_keys.t().unsqueeze(0).repeat(queries.size(1), 1, 1)), dim=1)
        
        #  contrastive loss as in Eq. 1
        queries = queries.t()  # [Q, C]
        
        # similarity between queries and all keys
        seg_logits = F.cosine_similarity(queries.unsqueeze(1), all_feat, dim=2) / temp
        
        # cross entropy loss with positive key as target (index 0)
        class_loss = F.cross_entropy(seg_logits, torch.zeros(queries.size(0), dtype=torch.long, device=rep.device))
        
        total_loss += class_loss
        valid_classes += 1
    
    if valid_classes > 0:
        return total_loss / valid_classes
    else:
        return torch.tensor(0.0, device=rep.device, requires_grad=True)

In [40]:
batch_size = 2
feat_dim = 16  
height, width = 4, 4  
num_classes = 3

torch.manual_seed(1823)
rep = torch.randn(batch_size, feat_dim, height, width)

label = torch.zeros(batch_size, height, width, dtype=torch.long)
label[0, :2, :] = 0
label[0, 2:, :] = 1
label[1, :, :2] = 1
label[1, :, 2:] = 2

mask = torch.ones(batch_size, height, width, dtype=torch.bool)
mask[0, 0, 0] = False  

prob = torch.rand(batch_size, num_classes, height, width)
prob = F.softmax(prob, dim=1)  

strong_threshold = 0.9
temp = 0.5
num_queries = 8
num_negatives = 8

In [38]:
loss = reco_loss_func(rep, label, mask, prob, strong_threshold, temp, num_queries, num_negatives)
print(f"ReCo Loss: {loss.item()}")

1 tensor([[-1.0016,  0.3446,  0.2901, -1.3508, -0.5025, -0.3367, -0.9438, -0.1312],
        [ 0.2963,  0.2302,  1.3373,  0.7256,  1.1815,  0.0697,  1.3870,  0.5442],
        [ 0.5031,  0.1988, -0.4222,  0.4480, -0.2400,  0.6908, -0.4035, -1.1064],
        [-0.2810, -0.7745,  0.8408, -0.1935, -0.8418,  0.4919,  0.2604,  1.4582],
        [ 0.5528, -0.2664, -1.4722, -0.0469,  0.3454,  1.1731,  0.8137, -1.1522],
        [ 0.2722, -0.4692,  0.6343,  0.5207, -0.6887,  0.4594,  0.5794,  0.5612],
        [-0.9507,  0.9871,  0.4305, -0.3972, -0.2546, -0.9206,  0.3315,  1.1926],
        [ 0.5082, -0.5687, -0.9777,  1.0867,  0.5834,  0.3881, -1.1428, -0.0261],
        [ 0.6318,  0.3515,  0.2428,  0.4017, -1.2224,  0.2660, -1.1095, -0.8849],
        [-1.7641,  1.0241, -2.1486, -2.7698, -0.2632,  0.1761, -0.5551, -0.2191],
        [-0.8859,  1.1656, -1.2972,  0.1934, -0.9513,  1.2077, -1.7894, -0.0391],
        [-0.4229,  0.6063, -0.2001,  1.0872, -0.9963,  0.2359,  0.7518,  0.4993],
        [-0.76