In [1]:
import torch

In [11]:
t_mask = t_mask = torch.rand(1, 8400) > 0.5 
teacher_preds = [
    torch.randn(1, 144, 80, 80, dtype=torch.float32),
    torch.randn(1, 144, 40, 40, dtype=torch.float32),
    torch.randn(1, 144, 20, 20, dtype=torch.float32)
]
student_preds = [
    torch.randn(1, 144, 80, 80, dtype=torch.float32),
    torch.randn(1, 144, 40, 40, dtype=torch.float32),
    torch.randn(1, 144, 20, 20, dtype=torch.float32)
]

In [18]:
def generate_masks_from_teacher_tal(t_mask, teacher_preds):
    """
    Generate masks for each level of teacher_preds with single channel dimension
    
    Args:
        t_mask: Tensor of shape [4, 8400] containing boolean values
        teacher_preds: List of three tensors with shapes:
            [4, 144, 80, 80], [4, 144, 40, 40], [4, 144, 20, 20]
    
    Returns:
        masks: List of three masks with 0/1 values with shapes:
               [4, 1, 80, 80], [4, 1, 40, 40], [4, 1, 20, 20]
    """
    batch_size = t_mask.shape[0]
    
    # Define the spatial dimensions for each level
    # spatial_dims = [(80, 80), (40, 40), (20, 20)]
    # Extract spatial dims dynamically from teacher_preds
    spatial_dims = [(pred.shape[2], pred.shape[3]) for pred in teacher_preds]
    
    masks = []
    
    for i, (h, w) in enumerate(spatial_dims):
        # Calculate the number of elements for this spatial dimension
        num_elements = h * w
        
        # Select the appropriate portion of the 8400 elements
        start_idx = sum([dim[0] * dim[1] for dim in spatial_dims[:i]])
        end_idx = start_idx + num_elements
        
        # Extract the relevant portion and convert to 0/1
        level_mask = t_mask[:, start_idx:end_idx].float()  # Convert to float (0.0/1.0)
        
        # Reshape to [batch_size, 1, h, w] - single channel
        level_mask_reshaped = level_mask.reshape(batch_size, 1, h, w)
        
        masks.append(level_mask_reshaped)
    
    return masks

In [20]:
mask = generate_masks_from_teacher_tal(t_mask, teacher_preds)

In [33]:
import torch.nn.functional as F
cls_fg_mask = False
if  True:
    # --- class Distillation setup ---
    class_channels = 80
    per_scale_weights = [1., 1., 1.]                                    
    alpha = 1.0 
    T = 1.0 
    
    distill_cls_loss = 0.0

    for i, (s_pred, t_pred) in enumerate(zip(student_preds, teacher_preds)):
        # --- Extract class logits ---
        s_logits = s_pred[:, -class_channels:, :, :]  # [B, 80, H, W]
        t_logits = t_pred[:, -class_channels:, :, :]  # [B, 80, H, W]
        
        # --- Teacher probabilities with temperature ---
        with torch.no_grad():
            t_probs = torch.sigmoid(t_logits / T)  # [B, 80, H, W]
        
        # --- Compute distillation loss ---
        bce = F.binary_cross_entropy_with_logits(s_logits, t_probs, reduction='none')  # [B, 80, H, W]
        
        if False:
            # Apply foreground mask: broadcast over 80 classes
            # mask[i]: [B, 1, H, W] will broadcast to [B, 80, H, W]
            fg_mask = mask[i]  # Pre-computed foreground mask
            masked_bce = bce * fg_mask  # [B, 80, H, W]
            # Calculate active elements: mask sum * number of classes
            # fg_mask.sum() gives total spatial foreground elements per batch
            # Multiply by class_channels to get total masked elements across all classes
            active_elems = fg_mask.sum() * class_channels + 1e-6
            
            loss_ = masked_bce.sum() / active_elems
        else:
            # No mask: average over all elements
            loss_ = bce.mean()
        print(loss_)
        # Weighted accumulation
        distill_cls_loss += per_scale_weights[i] * loss_


tensor(0.8063)
tensor(0.8056)
tensor(0.8046)


In [39]:
if True:
    # ----- DFL distillation starts here -----
    reg_max = 16
    T = 1.0 
    lambda_d = 0.5

    dfl_distill_loss = 0.0
    count = 0

    for scale_idx in range(len(student_preds)):
        sp = student_preds[scale_idx]      # [B, Ctot, H, W]
        tp = teacher_preds[scale_idx]
        
        B, _, H, W = sp.shape
        dfl_channels = 4 * reg_max

        # Extract DFL logits (first 64 channels)
        sp_dfl = sp[:, :dfl_channels, :, :]   # [B, 64, H, W]
        tp_dfl = tp[:, :dfl_channels, :, :]   # [B, 64, H, W]
        
        # Reshape to [B, 4, reg_max, H, W]
        sp_reshaped = sp_dfl.view(B, 4, reg_max, H, W)
        tp_reshaped = tp_dfl.view(B, 4, reg_max, H, W)
        print(sp_reshaped.shape, tp_reshaped.shape, sep= "\n")
        # Compute KL divergence per location and coordinate
        with torch.no_grad():
            tp_soft = torch.softmax(tp_reshaped / T, dim=2)  # [B, 4, reg_max, H, W]
        sp_logsoft = torch.log_softmax(sp_reshaped / T, dim=2)  # [B, 4, reg_max, H, W]

        # KL divergence: [B, 4, H, W] (sum over reg_max)
        kl_per_pixel = torch.sum(tp_soft * (torch.log(tp_soft + 1e-8) - sp_logsoft), dim=2)  # [B, 4, H, W]

        # Reduce over the 4 coordinates (mean or sum)
        kl_spatial = kl_per_pixel.mean(dim=1)  # [B, H, W]  (you can also use .sum(dim=1))

        # --- Apply foreground mask if enabled ---
        if True:
            fg_mask = mask[scale_idx]  # [B, H, W], from your precomputed list
            masked_kl = kl_spatial * fg_mask.squeeze(1)
            # Normalize by number of active elements (not just batchmean)
            active = fg_mask.squeeze(1).sum() + 1e-6
            loss_kl = masked_kl.sum() / active
        else:
            # Original: mean over all pixels and batch
            loss_kl = kl_spatial.mean()

        # Scale by T^2 (standard in distillation)
        loss_kl = loss_kl * (T ** 2)

        dfl_distill_loss += loss_kl
        count += 1

    # Average across scales
    if count > 0:
        dfl_distill_loss = dfl_distill_loss / count

    # Final weight
    dfl_distill_loss = lambda_d * dfl_distill_loss


torch.Size([1, 4, 16, 80, 80])
torch.Size([1, 4, 16, 80, 80])
torch.Size([1, 4, 16, 40, 40])
torch.Size([1, 4, 16, 40, 40])
torch.Size([1, 4, 16, 20, 20])
torch.Size([1, 4, 16, 20, 20])


In [45]:
if True:
    # ----- L2 box regression distillation starts here -----
    reg_max = 16
    λ_box_reg = 1.0
    
    box_reg_loss = 0.0
    count = 0

    # Precompute bins once (shape: [1, reg_max])
    bins = torch.arange(reg_max, dtype=torch.float32).view(1, 1, reg_max, 1, 1)  # [1,1,16,1,1]  # [1, reg_max]
    print(bins)
    for scale_idx in range(len(student_preds)):
        sp = student_preds[scale_idx]      # [B, Ctot, H, W]
        tp = teacher_preds[scale_idx]

        B, Ctot, H, W = sp.shape
        dfl_channels = 4 * reg_max

        # Extract DFL logits (first 64 channels)
        sp_dfl = sp[:, :dfl_channels, :, :]   # [B, 64, H, W]
        tp_dfl = tp[:, :dfl_channels, :, :]   # [B, 64, H, W]

        # Reshape to [B, 4, reg_max, H, W]
        sp_reshaped = sp_dfl.view(B, 4, reg_max, H, W)
        tp_reshaped = tp_dfl.view(B, 4, reg_max, H, W)

        # Convert to probabilities
        sp_prob = torch.softmax(sp_reshaped, dim=2)  # [B, 4, reg_max, H, W]
        tp_prob = torch.softmax(tp_reshaped, dim=2)  # [B, 4, reg_max, H, W]

        # Compute expected values (continuous offsets)
        
        sp_val = (torch.sum(sp_prob * bins, dim=2) / reg_max)
        tp_val = (torch.sum(tp_prob * bins, dim=2) / reg_max)

        # Compute squared error per coordinate → [B, 4, H, W]
        sq_error = (sp_val - tp_val) ** 2

        # Reduce over the 4 box sides (mean or sum); we use mean
        l2_spatial = sq_error.mean(dim=1)  # [B, H, W]

        # --- Apply foreground mask if enabled ---
        if True:
            fg_mask = mask[scale_idx].squeeze(1)  # [B, H, W]
            masked_l2 = l2_spatial * fg_mask
            active = fg_mask.sum() + 1e-6
            loss_reg = masked_l2.sum() / active
        else:
            loss_reg = l2_spatial.mean()

        box_reg_loss += loss_reg
        count += 1

    # Average across scales
    if count > 0:
        box_reg_loss = box_reg_loss / count

    # Apply final weight
    box_reg_loss = λ_box_reg * box_reg_loss

tensor([[[[[ 0.]],

          [[ 1.]],

          [[ 2.]],

          [[ 3.]],

          [[ 4.]],

          [[ 5.]],

          [[ 6.]],

          [[ 7.]],

          [[ 8.]],

          [[ 9.]],

          [[10.]],

          [[11.]],

          [[12.]],

          [[13.]],

          [[14.]],

          [[15.]]]]])
