In [None]:
import pickle
import torch
file = 'results/attn_maps/seed-4/blur_regions-begin___mid)/seg_applied_layers-down___mid/seg_blur_sigma-10_seg_scale-3_guidance_scale-0.pkl'

with open(file, 'rb') as f:
    data = pickle.load(f)
    
# Compute total size in bytes
total_size_bytes = sum(arr.nbytes for arr in data['up'][0])

# Convert to MB
total_size_mb = total_size_bytes / (1024 * 1024)

print(f"Total size of data['up'][0]: {total_size_bytes} bytes ({total_size_mb:.2f} MB)")
data['up'][0].dtype, data['up'][30].shape
data['up'][0].shape

Total size of data['up'][0]: 67108864 bytes (64.00 MB)


(dtype('uint16'), (2, 1024, 1024))

In [3]:
import torch
import numpy as np

def reduce_precision(attention_map: torch.Tensor) -> np.ndarray:
    """
    Reduces the precision of the input tensor to unsigned 16-bit integer (uint16) and converts it to a NumPy array.

    Args:
        attention_map (Tensor): Input tensor of shape [B, S, S], expected to be in range [0, 1].

    Returns:
        np.ndarray: Reduced-precision NumPy array of shape [B, S, S] in uint16 format.
    """
    if torch.any(torch.isnan(attention_map)):
        print("Warning: Attention map contains NaN values. Clipping to [0, 1] before scaling.")
        attention_map = attention_map.clamp(0, 1)

    # Scale and convert to uint16
    scaled_attn_maps = (attention_map * 10000).clamp(0, 65535).to(torch.uint16)

    # Convert to NumPy and return
    return scaled_attn_maps.cpu().numpy().astype(np.uint16)


# Create a sample tensor in range [0,1] with shape (2, 1024, 1024)
torch.manual_seed(42)  # For reproducibility
sample_tensor = torch.rand((2, 1024, 1024), dtype=torch.float32)  # Values in [0,1]

# Apply the function
compressed_numpy = reduce_precision(sample_tensor)

# Check memory size
memory_size_mb = compressed_numpy.nbytes / (1024 * 1024)  # Convert bytes to MB

# Verify range and dtype
sample_min, sample_max = compressed_numpy.min(), compressed_numpy.max()

sample_tensor.shape, compressed_numpy.shape, compressed_numpy.dtype, memory_size_mb, (sample_min, sample_max)


(torch.Size([2, 1024, 1024]),
 (2, 1024, 1024),
 dtype('uint16'),
 4.0,
 (np.uint16(0), np.uint16(9999)))