In [9]:
import os
import sys
sys.path.append("/home/ubuntu/workspace/code/compositional-representation-learning")

In [10]:
import torchvision
from torchvision import transforms

from pl_modules.BoxEmbeddings.PatchBoxEmbeddingsVAE import VanillaVAE

In [12]:
data_root_dir = "/home/ubuntu/workspace/data_root_dir"

ds = torchvision.datasets.CIFAR10(
    root=data_root_dir, train=False, download=True, transform=transforms.ToTensor()
)

In [13]:
len(ds)

10000

In [8]:
ds[0][0]

tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

        [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
         [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
         [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
         ...,
         [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
         [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
         [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

        [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
         [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
         [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.

In [11]:
import torch
import torch.nn as nn

# Example configuration - you can modify these
latent_dim = 128
hidden_dims = [32, 64, 128, 256]  # Example hidden dimensions

# Create model
model = VanillaVAE(latent_dim=latent_dim, hidden_dims=hidden_dims)
model.eval()

# Get a CIFAR image
sample_image, _ = ds[0]
sample_image = sample_image.unsqueeze(0)  # Add batch dimension: (1, 3, 32, 32)

print("=" * 80)
print("INPUT SHAPE:")
print(f"  Input image: {sample_image.shape}")
print("=" * 80)

# Hook function to capture intermediate shapes
shapes_dict = {}

def get_shape_hook(name):
    def hook(module, input, output):
        if isinstance(input, tuple):
            input_shapes = [inp.shape if hasattr(inp, 'shape') else str(inp) for inp in input]
            shapes_dict[f"{name}_input"] = input_shapes
        else:
            shapes_dict[f"{name}_input"] = input.shape if hasattr(input, 'shape') else str(input)
        
        if isinstance(output, (list, tuple)):
            output_shapes = [out.shape if hasattr(out, 'shape') else str(out) for out in output]
            shapes_dict[f"{name}_output"] = output_shapes
        else:
            shapes_dict[f"{name}_output"] = output.shape if hasattr(output, 'shape') else str(output)
    return hook

# Register hooks on encoder layers
for i, layer in enumerate(model.encoder):
    layer.register_forward_hook(get_shape_hook(f"encoder_layer_{i}"))

# Register hooks on fc layers
model.fc_mu_min.register_forward_hook(get_shape_hook("fc_mu_min"))
model.fc_mu_max.register_forward_hook(get_shape_hook("fc_mu_max"))
model.decoder_input.register_forward_hook(get_shape_hook("decoder_input"))

# Register hooks on decoder layers
for i, layer in enumerate(model.decoder):
    layer.register_forward_hook(get_shape_hook(f"decoder_layer_{i}"))

# Register hooks on final layer components
for i, layer in enumerate(model.final_layer):
    layer.register_forward_hook(get_shape_hook(f"final_layer_{i}"))

# Forward pass
print("\nENCODER:")
print("-" * 80)
with torch.no_grad():
    # Encode
    encoded = model.encode(sample_image)
    mu_min, mu_max = encoded
    
    print(f"  After encoder (flattened): {model.encoder(sample_image).flatten(start_dim=1).shape}")
    print(f"  mu_min: {mu_min.shape}")
    print(f"  mu_max: {mu_max.shape}")
    print(f"  z (concatenated): {torch.cat([mu_min, mu_max], dim=-1).shape}")
    
    # Decode
    z = torch.cat([mu_min, mu_max], dim=-1)
    print("\nDECODER:")
    print("-" * 80)
    
    decoder_input = model.decoder_input(z)
    print(f"  After decoder_input: {decoder_input.shape}")
    
    # Reshape for decoder
    decoder_input_reshaped = decoder_input.view(-1, hidden_dims[0], 2, 2)
    print(f"  After reshape for decoder: {decoder_input_reshaped.shape}")
    
    decoder_output = model.decoder(decoder_input_reshaped)
    print(f"  After decoder: {decoder_output.shape}")
    
    final_output = model.final_layer(decoder_output)
    print(f"  Final output: {final_output.shape}")

print("\n" + "=" * 80)
print("DETAILED LAYER-BY-LAYER SHAPES:")
print("=" * 80)

# Print all captured shapes
for key in sorted(shapes_dict.keys()):
    print(f"\n{key}:")
    print(f"  {shapes_dict[key]}")

# Also print the full forward pass result
print("\n" + "=" * 80)
print("FULL FORWARD PASS RESULT:")
print("=" * 80)
with torch.no_grad():
    result = model(sample_image)
    reconstructed, original, mu_min, mu_max = result
    print(f"  Reconstructed: {reconstructed.shape}")
    print(f"  Original: {original.shape}")
    print(f"  mu_min: {mu_min.shape}")
    print(f"  mu_max: {mu_max.shape}")

# Calculate spatial dimensions at each stage
print("\n" + "=" * 80)
print("SPATIAL DIMENSION PROGRESSION:")
print("=" * 80)
print(f"  Input: {sample_image.shape[2]}x{sample_image.shape[3]}")

# Calculate encoder spatial dimensions
current_h, current_w = sample_image.shape[2], sample_image.shape[3]
print(f"  After encoder layers (stride=2 each):")
for i, h_dim in enumerate(hidden_dims):
    current_h = current_h // 2
    current_w = current_w // 2
    print(f"    Layer {i} (hidden_dim={h_dim}): {current_h}x{current_w}")

print(f"  Flattened size: {hidden_dims[-1] * current_h * current_w}")
print(f"  Latent dim: {latent_dim} (each of mu_min and mu_max)")
print(f"  Concatenated z: {latent_dim * 2}")

# Calculate decoder spatial dimensions
print(f"  After decoder layers (stride=2 each):")
decoder_h, decoder_w = 2, 2  # Starting from decoder_input reshape
print(f"    Initial decoder input: {decoder_h}x{decoder_w}")
for i in range(len(hidden_dims) - 1):
    decoder_h = decoder_h * 2
    decoder_w = decoder_w * 2
    print(f"    After decoder layer {i}: {decoder_h}x{decoder_w}")

print(f"  Final output: {decoder_h}x{decoder_w}")


INPUT SHAPE:
  Input image: torch.Size([1, 3, 32, 32])

ENCODER:
--------------------------------------------------------------------------------
  After encoder (flattened): torch.Size([1, 1024])
  mu_min: torch.Size([1, 128])
  mu_max: torch.Size([1, 128])
  z (concatenated): torch.Size([1, 256])

DECODER:
--------------------------------------------------------------------------------
  After decoder_input: torch.Size([1, 1024])
  After reshape for decoder: torch.Size([1, 256, 2, 2])
  After decoder: torch.Size([1, 32, 16, 16])
  Final output: torch.Size([1, 3, 32, 32])

DETAILED LAYER-BY-LAYER SHAPES:

decoder_input_input:
  [torch.Size([1, 256])]

decoder_input_output:
  torch.Size([1, 1024])

decoder_layer_0_input:
  [torch.Size([1, 256, 2, 2])]

decoder_layer_0_output:
  torch.Size([1, 128, 4, 4])

decoder_layer_1_input:
  [torch.Size([1, 128, 4, 4])]

decoder_layer_1_output:
  torch.Size([1, 64, 8, 8])

decoder_layer_2_input:
  [torch.Size([1, 64, 8, 8])]

decoder_layer_2_outpu