In [1]:
import torch
import torch.nn.functional as F

def generate_masks_from_teacher_tal(t_mask, teacher_preds, mask_type="both"):
    """
    Generate:
      1) Original TAL masks (hard binary per level)
      2) Mask Pyramid (multi-scale OR fusion mask)

    Args:
        t_mask: [B, 8400] boolean mask from TAL.
        teacher_preds: list of three prediction tensors:
            [B, C, 80, 80], [B, C, 40, 40], [B, C, 20, 20]
        mask_type: 
            "original" -> return only original masks
            "pyramid"  -> return only pyramid masks
            "both"     -> return both (default)

    Returns:
        Based on mask_type:
            original_masks: [B,1,80,80], [B,1,40,40], [B,1,20,20]
            pyramid_masks:  [B,1,80,80], [B,1,40,40], [B,1,20,20]
    """

    batch = t_mask.shape[0]

    # Extract spatial sizes dynamically
    spatial_dims = [(p.shape[2], p.shape[3]) for p in teacher_preds]

    # ---------------------------------------------------------
    # 1. ORIGINAL HARD (BINARY) MASKS
    # ---------------------------------------------------------
    original_masks = []
    start = 0

    for h, w in spatial_dims:
        N = h * w
        end = start + N

        # boolean -> float 0/1
        mask = t_mask[:, start:end].float()
        mask = mask.reshape(batch, 1, h, w)

        original_masks.append(mask)
        start = end

    m3, m4, m5 = original_masks  # 80x80, 40x40, 20x20

    # If user wants only original masks
    if mask_type == "original":
        return original_masks

    # ---------------------------------------------------------
    # 2. MASK PYRAMID (MULTI-SCALE OR)
    # ---------------------------------------------------------

    # Step A: Upsample 40→80 and 20→80
    m4_up = F.interpolate(m4, size=(80, 80), mode="nearest")
    m5_up = F.interpolate(m5, size=(80, 80), mode="nearest")

    # Ensure float domain (binary mix safety)
    m3f = m3.float()
    m4f = m4_up.float()
    m5f = m5_up.float()

    # Step B: OR fusion across scales
    # OR rule: if any mask has 1 → output must be 1
    pyramid_80 = torch.maximum(torch.maximum(m3f, m4f), m5f)
    # Equivalent to OR: pyramid_80 = (m3f > 0) | (m4f > 0) | (m5f > 0)

    # Step C: Downscale back to 40 and 20
    pyramid_40 = F.interpolate(pyramid_80, size=(40, 40), mode="nearest")
    pyramid_20 = F.interpolate(pyramid_80, size=(20, 20), mode="nearest")

    pyramid_masks = [pyramid_80, pyramid_40, pyramid_20]

    # If user wants only pyramid
    if mask_type == "pyramid":
        return pyramid_masks

    # Else return both
    return original_masks, pyramid_masks

In [2]:
t_mask = t_mask = torch.rand(1, 8400) > 0.5 
teacher_preds = [
    torch.randn(1, 144, 80, 80, dtype=torch.float32),
    torch.randn(1, 144, 40, 40, dtype=torch.float32),
    torch.randn(1, 144, 20, 20, dtype=torch.float32)
]
student_preds = [
    torch.randn(1, 144, 80, 80, dtype=torch.float32),
    torch.randn(1, 144, 40, 40, dtype=torch.float32),
    torch.randn(1, 144, 20, 20, dtype=torch.float32)
]

In [3]:
org, pyr = generate_masks_from_teacher_tal(t_mask, teacher_preds, mask_type="both")

In [4]:
# import matplotlib.pyplot as plt

# def visualize_masks(original_masks, pyramid_masks, batch_idx=0):
#     """
#     Plots:
#       - Original masks: 80×80, 40×40, 20×20
#       - Pyramid masks: 80×80, 40×40, 20×20

#     Args:
#         original_masks: list of 3 tensors [B,1,H,W]
#         pyramid_masks:  list of 3 tensors [B,1,H,W]
#         batch_idx: which sample to visualize
#     """

#     titles = [
#         "Original Mask 80×80", "Original Mask 40×40", "Original Mask 20×20",
#         "Pyramid Mask 80×80", "Pyramid Mask 40×40", "Pyramid Mask 20×20"
#     ]

#     masks_to_show = [
#         original_masks[0][batch_idx, 0].cpu().numpy(),
#         original_masks[1][batch_idx, 0].cpu().numpy(),
#         original_masks[2][batch_idx, 0].cpu().numpy(),
#         pyramid_masks[0][batch_idx, 0].cpu().numpy(),
#         pyramid_masks[1][batch_idx, 0].cpu().numpy(),
#         pyramid_masks[2][batch_idx, 0].cpu().numpy(),
#     ]

#     plt.figure(figsize=(12, 8))

#     for i, (mask, title) in enumerate(zip(masks_to_show, titles)):
#         plt.subplot(2, 3, i + 1)
#         plt.imshow(mask, cmap="gray", interpolation="nearest")
#         plt.title(title)
#         plt.axis("off")

#     plt.tight_layout()
#     plt.show()

In [5]:
# visualize_masks(org, pyr, batch_idx=0)

In [6]:
# import cv2
# import numpy as np

# def visualize_tensor_values_opencv(tensor, window_name="Tensor Values", scale=40):
#     """
#     Visualize a 2D tensor with OpenCV, showing each cell value as text.
#     - tensor: torch.Tensor or numpy array, shape [H, W]
#     - scale: pixel size for each cell
#     """

#     # Convert PyTorch → numpy
#     if hasattr(tensor, "cpu"):
#         tensor = tensor.cpu().numpy()

#     H, W = tensor.shape

#     # Create blank canvas
#     img = np.zeros((H * scale, W * scale, 3), dtype=np.uint8)

#     # Normalize tensor for coloring
#     t_min, t_max = tensor.min(), tensor.max()
#     if t_max - t_min == 0:
#         norm = np.zeros_like(tensor)
#     else:
#         norm = (tensor - t_min) / (t_max - t_min)

#     # Draw each cell
#     for i in range(H):
#         for j in range(W):
#             # Pick color based on value
#             intensity = int(norm[i, j] * 255)
#             color = (intensity, intensity, 255)  # slight red tint

#             # Draw rectangle
#             cv2.rectangle(
#                 img,
#                 (j * scale, i * scale),
#                 ((j + 1) * scale, (i + 1) * scale),
#                 color,
#                 thickness=-1
#             )

#             # Put value text
#             val_str = f"{tensor[i,j]:.2f}" if tensor.dtype != bool else str(int(tensor[i,j]))
#             cv2.putText(
#                 img,
#                 val_str,
#                 (j * scale + 5, i * scale + scale - 5),
#                 cv2.FONT_HERSHEY_SIMPLEX,
#                 0.5,
#                 (0, 0, 0),   # black text
#                 1,
#                 cv2.LINE_AA
#             )

#     cv2.imshow(window_name, img)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()


In [8]:
for mask in pyr:
    has_non_binary = torch.any((mask != 0) & (mask != 1))
    print("Mask contains non-binary values:", has_non_binary.item())

Mask contains non-binary values: False
Mask contains non-binary values: False
Mask contains non-binary values: False


In [18]:
# First install: pip install rich numpy

from rich.console import Console
from rich.table import Table
from rich.text import Text
import numpy as np

def plot_binary_tensors_rich(tensors):
    """
    Plot binary tensors using the Rich library for beautiful terminal output.
    """
    console = Console()
    
    for i, tensor in enumerate(tensors):
        # Create a table for each tensor
        table = Table(
            title=f"Binary Tensor {i} (Shape: {tensor.shape})",
            show_header=False,
            box=None,
            padding=0
        )
        
        # Add columns
        for _ in range(tensor.shape[1]):
            table.add_column(style="bold", justify="center")
        
        # Add rows
        for row in tensor:
            row_cells = []
            for val in row:
                val_int = int(val)
                if val_int == 1:
                    # Green background for 1s
                    cell = Text(" 1 ", style="black on green")
                else:
                    # Red background for 0s
                    cell = Text(" 0 ", style="black on red")
                row_cells.append(cell)
            table.add_row(*row_cells)
        
        console.print(table)
        console.print()  # Empty line between tensors

# Example usage
if __name__ == "__main__":
    # Create sample tensors
    tensors = org[0].squeeze()
    
    plot_binary_tensors_rich(tensors)

IndexError: tuple index out of range

In [19]:
org[0].shape

torch.Size([1, 1, 80, 80])

In [22]:
def print_tensor_ascii(tensor, threshold=0.5):
    """
    Print tensor as ASCII art with '█' for 1 and ' ' for 0
    """
    # Assuming tensor shape is [1, 1, H, W]
    if len(tensor.shape) == 4:
        data = tensor[0, 0]  # Extract the 2D data
    else:
        data = tensor
    
    for i in range(data.shape[0]):
        row = ''
        for j in range(data.shape[1]):
            if data[i, j] > threshold:
                row += '██'
            else:
                row += '  '
        print(row)

# Example usage
tensor_80x80 = pyr[0]  # Your 1x1x80x80 tensor
print("80x80 Tensor:")
print_tensor_ascii(tensor_80x80)
print("\n" + "="*50 + "\n")

80x80 Tensor:
██████████████████████████      ██████████████████████████████████████  ██████████  ████████████  ████████████████████████████  ████████    ██  ████████████████
██████████████████████████  ██████████████████████████████████████████  ████████████████████████████████████████████████        ████████      ██████████████████
████████████████████████████  ██████████████████  ██████████████████████████████████████████████████████████████      ██████████████████████████████████████████
████████████████████████████    ██████████████████  ██  ██████████████  ████████    ██████████████████████████████      ████████████████████████████████  ██████
████████  ██████████████████████████████████████████████████████    ████  ██████      ██████████████  ██████████████████  ██████████████    ██████████████  ████
████████████  ████  ████████████████████████████████████████████    ████████████      ██████████████  ██████    ████████  ██████████████  ██████████████████████
████████████  ██  ██