In [3]:
import torch
import psutil
import os
import sys
sys.path.append(os.path.abspath(".."))
import torch
from models.change_classifier import ChangeClassifier

In [1]:
import torch
from collections import defaultdict

def precise_memory_calculation(batch_size=24, image_size=256):
    """
    PRECISE memory calculation based on actual layer dimensions
    """
    print("=== PRECISE MEMORY CALCULATION ===")
    
    # Based on your state_dict, here are the critical layer dimensions:
    # This shows the feature map progression through the network
    
    # Feature map dimensions at each stage (estimated from layer shapes)
    feature_map_progression = [
        # (height, width, channels, layer_type)
        (256, 256, 3, "input"),
        (128, 128, 48, "backbone.0"),      # After first conv
        (64, 64, 24, "backbone.1"),        # After MBConv blocks
        (32, 32, 32, "backbone.2"),        # More MBConv blocks  
        (16, 16, 56, "backbone.3"),        # Final backbone output
        (16, 16, 56, "mixing_mask.2"),     # After mixing
        (32, 32, 64, "up.0"),              # After first upsample
        (64, 64, 64, "up.1"),              # After second upsample
        (128, 128, 32, "up.2"),            # After third upsample
        (256, 256, 1, "output")            # Final output
    ]
    
    bytes_per_float = 4
    total_activation_memory = 0
    
    print("--- Feature Map Memory Breakdown ---")
    for i, (h, w, c, layer_name) in enumerate(feature_map_progression):
        layer_memory = batch_size * h * w * c * bytes_per_float / (1024 * 1024)
        total_activation_memory += layer_memory
        print(f"{layer_name}: {h}x{w}x{c} = {layer_memory:.1f} MB")
    
    # Parameter memory (from your state_dict - I can count them)
    total_params = 285803  # From your earlier output
    param_memory_mb = total_params * bytes_per_float / (1024 * 1024)
    
    # Input data memory
    input_memory_mb = batch_size * (2 * 256 * 256 * 3 + 256 * 256 * 1) * bytes_per_float / (1024 * 1024)
    
    # Gradients and optimizer
    gradients_memory_mb = param_memory_mb
    optimizer_memory_mb = 2 * param_memory_mb
    
    # PyTorch overhead (CUDA context, fragmentation, etc.)
    pytorch_overhead = 3.0  # Higher factor for complex models
    
    total_memory_mb = (
        param_memory_mb + 
        input_memory_mb + 
        total_activation_memory + 
        gradients_memory_mb + 
        optimizer_memory_mb
    ) * pytorch_overhead
    
    print(f"\n--- PRECISE Memory Breakdown ---")
    print(f"Parameters: {param_memory_mb:.1f} MB")
    print(f"Input data: {input_memory_mb:.1f} MB")
    print(f"Activations (TOTAL): {total_activation_memory:.1f} MB")
    print(f"Gradients: {gradients_memory_mb:.1f} MB")
    print(f"Optimizer: {optimizer_memory_mb:.1f} MB")
    print(f"PyTorch overhead (x{pytorch_overhead}): {total_memory_mb:.1f} MB")
    print(f"TOTAL: {total_memory_mb:.1f} MB ({total_memory_mb/1024:.1f} GB)")
    
    # Compare with your actual usage
    print(f"\n--- Comparison with Reality ---")
    print(f"Your actual usage: 13.6 GB")
    print(f"My calculation: {total_memory_mb/1024:.1f} GB")
    print(f"Difference: {abs(13.6 - total_memory_mb/1024):.1f} GB")
    
    return total_memory_mb

# Run the precise calculation
if __name__ == "__main__":
    precise_memory_calculation(batch_size=24, image_size=256)

=== PRECISE MEMORY CALCULATION ===
--- Feature Map Memory Breakdown ---
input: 256x256x3 = 18.0 MB
backbone.0: 128x128x48 = 72.0 MB
backbone.1: 64x64x24 = 9.0 MB
backbone.2: 32x32x32 = 3.0 MB
backbone.3: 16x16x56 = 1.3 MB
mixing_mask.2: 16x16x56 = 1.3 MB
up.0: 32x32x64 = 6.0 MB
up.1: 64x64x64 = 24.0 MB
up.2: 128x128x32 = 48.0 MB
output: 256x256x1 = 6.0 MB

--- PRECISE Memory Breakdown ---
Parameters: 1.1 MB
Input data: 42.0 MB
Activations (TOTAL): 188.6 MB
Gradients: 1.1 MB
Optimizer: 2.2 MB
PyTorch overhead (x3.0): 705.0 MB
TOTAL: 705.0 MB (0.7 GB)

--- Comparison with Reality ---
Your actual usage: 13.6 GB
My calculation: 0.7 GB
Difference: 12.9 GB
