In [2]:
# %% [markdown]
# # Linear Probing Experiment: Box Embeddings vs Vanilla VAE
# 
# This notebook loads two pre-trained VAE models (one with box-regularization/inclusion losses and one vanilla) and performs a linear probing experiment on CIFAR10.
#
# **Goal:** Evaluate if the Box Embedding constraints yield more linearly separable representations than a standard VAE.

# %%
import os
import sys
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

# --- USER SETUP ---
# Ensure the project root is in the path so imports work
PROJECT_ROOT = "/home/ubuntu/workspace/code/compositional-representation-learning"
sys.path.append(PROJECT_ROOT)

# Set Data Root for the Dataset class
os.environ["DATA_ROOT_DIR"] = "/home/ubuntu/workspace/data" # Adjust if your data is elsewhere

# Import project modules
from pl_modules import PatchBoxEmbeddingsVAE
from datasets import get_dataset

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# %% [markdown]
# ## 1. Configuration and Checkpoints
# 
# We define the configurations directly here (based on your provided files) and point to the specific checkpoints.

# %%
# --- Experiment 1: Box Embeddings (Inclusion Loss) ---
box_checkpoint_path = "/home/ubuntu/workspace/experiment_root_dir/PatchBoxEmbeddingsVAE/PatchBoxEmbeddingsVAE_CIFAR10_exp_1/PatchBoxEmbeddingsVAE_CIFAR10_exp_1___2025-11-19__10-16-25/checkpoints/model-epoch=099.ckpt"

box_config = {
    "model": {
        "type": "PatchBoxEmbeddingsVAE",
        "config": {
            "embed_dim": 64,
            "hidden_dims": [32, 64, 128, 256],
            "grid_size": [4, 4],
            "gumbel_temp": 1.0,
            "min_side_length": 0.1,
            "crop_objects": False,
            "loss_weights": {
                "reconstruction_loss": 10.0,
                "inclusion_loss": 0.5,
                "box_volume_regularization_loss": 0.0,
                "min_side_regularization_loss": 0.5,
                "lpips_loss": 0.0,
                "ssim_loss": 0.5,
                "full_image_weight": 10.0
            }
        }
    },
    "data": {
        "train": {
            "type": "CIFAR10Dataset",
            "config": {
                "image_size": [32, 32],
                "train": True
            },
            "dataloader_config": {
                "batch_size": 64,
                "shuffle": True # Shuffle for probe training
            }
        }
    },
    "trainer": {
        "optimizer": {"type": "Adam", "config": {"lr": 0.0001}}
    }
}

# --- Experiment 2: Vanilla VAE ---
vanilla_checkpoint_path = "/home/ubuntu/workspace/experiment_root_dir/PatchBoxEmbeddingsVAE/PatchBoxEmbeddingsVAE_CIFAR10_vanilla_exp_0/PatchBoxEmbeddingsVAE_CIFAR10_vanilla_exp_0___2025-11-19__11-24-11/checkpoints/model-epoch=099.ckpt"

vanilla_config = {
    "model": {
        "type": "PatchBoxEmbeddingsVAE",
        "config": {
            "embed_dim": 64,
            "hidden_dims": [32, 64, 128, 256],
            "grid_size": [4, 4],
            "gumbel_temp": 1.0,
            "min_side_length": 0.1,
            "crop_objects": False,
            "loss_weights": {
                "reconstruction_loss": 10.0,
                "inclusion_loss": 0.0, # Zero inclusion
                "box_volume_regularization_loss": 0.0,
                "min_side_regularization_loss": 0.0,
                "lpips_loss": 0.0,
                "ssim_loss": 0.5,
                "full_image_weight": 10.0
            }
        }
    },
    "data": {
        "train": {
            "type": "CIFAR10Dataset",
            "config": {
                "image_size": [32, 32],
                "train": True
            },
            "dataloader_config": {
                "batch_size": 64,
                "shuffle": True
            }
        }
    },
    "trainer": {
        "optimizer": {"type": "Adam", "config": {"lr": 0.0001}}
    }
}

# %% [markdown]
# ## 2. Helper Functions
# 
# We need functions to:
# 1. Load the VAE model.
# 2. Extract features (dataset -> frozen VAE -> tensors).
# 3. Train the Linear Probe.

# %%
def load_model(checkpoint_path, config):
    print(f"Loading model from {checkpoint_path}...")
    model = PatchBoxEmbeddingsVAE.load_from_checkpoint(checkpoint_path, config=config)
    model.eval()
    model.to(device)
    return model

def extract_features(model, dataloader, include_patch_embeddings=False):
    """
    Passes data through the frozen VAE and extracts the 'z' embedding.
    We specifically extract the LAST element of the patch sequence, 
    which corresponds to the Full Image embedding.
    """
    features = []
    labels = []
    
    print("Extracting features...")
    with torch.no_grad():
        for batch in tqdm(dataloader):
            # Move inputs to device
            batch["images"] = batch["images"].to(device)
            batch["object_masks"] = batch["object_masks"].to(device)
            
            # Forward pass
            outputs = model(batch)
            
            # Output 'z' shape: (batch_size, num_patches + 1, box_embed_dim)
            # Index -1 corresponds to the Full Image embedding
            if include_patch_embeddings:
                embeddings = outputs["z"].reshape(outputs["z"].shape[0], -1)
            else:
                embeddings = outputs["z"][:, -1, :]
            
            features.append(embeddings.cpu())
            labels.append(batch["metadata"]["label"])
            
    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0)
    
    print(f"Extracted Feature Shape: {features.shape}")
    print(f"Extracted Label Shape: {labels.shape}")
    
    return features, labels

class LinearProbe(nn.Module):
    def __init__(self, input_dim, num_classes=10):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.linear(x)

def train_linear_probe(train_features, train_labels, test_features, test_labels, lr=3e-4, epochs=100):
    input_dim = train_features.shape[1]
    probe = LinearProbe(input_dim=input_dim).to(device)
    
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Create simple dataloaders for the features
    train_ds = torch.utils.data.TensorDataset(train_features, train_labels)
    train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)
    
    test_ds = torch.utils.data.TensorDataset(test_features, test_labels)
    test_dl = DataLoader(test_ds, batch_size=128, shuffle=False)
    
    print(f"\nTraining Linear Probe (Input Dim: {input_dim})...")
    
    best_acc = 0.0
    
    for epoch in range(epochs):
        probe.train()
        total_loss = 0
        for x, y in train_dl:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = probe(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        # Evaluation
        probe.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in test_dl:
                x, y = x.to(device), y.to(device)
                logits = probe(x)
                preds = torch.argmax(logits, dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        
        acc = 100 * correct / total
        if acc > best_acc:
            best_acc = acc
            
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_dl):.4f} | Test Acc: {acc:.2f}%")
            
    return best_acc

# %% [markdown]
# ## 3. Prepare Datasets
# We need both Train (to train the linear probe) and Test (to evaluate it) sets for CIFAR10.

# %%
# Train Dataset
train_dataset = get_dataset(box_config) # Config is same for data in both cases
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=4)

# Test Dataset
test_config_dict = box_config.copy()
test_config_dict['data']['train']['config']['train'] = False # Switch to test set
test_dataset = get_dataset(test_config_dict)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# %% [markdown]
# ## 4. Evaluate Box Embeddings Model
# 1. Load Box Model
# 2. Extract Features
# 3. Train Probe

# %%
# 1. Load
box_model = load_model(box_checkpoint_path, box_config)

# 2. Extract
print("\n--- Extracting Features for BOX Model ---")
box_train_feats, box_train_labels = extract_features(box_model, train_loader, include_patch_embeddings=True)
box_test_feats, box_test_labels = extract_features(box_model, test_loader, include_patch_embeddings=True)

# Clear GPU memory of the VAE
del box_model
torch.cuda.empty_cache()

# 3. Train Probe
print("\n--- Training Probe on BOX Features ---")
box_acc = train_linear_probe(box_train_feats, box_train_labels, box_test_feats, box_test_labels)
print(f"Final Box Embedding Accuracy: {box_acc:.2f}%")

# %% [markdown]
# ## 5. Evaluate Vanilla VAE Model
# 1. Load Vanilla Model
# 2. Extract Features
# 3. Train Probe

# %%
# 1. Load
vanilla_model = load_model(vanilla_checkpoint_path, vanilla_config)

# 2. Extract
print("\n--- Extracting Features for VANILLA Model ---")
van_train_feats, van_train_labels = extract_features(vanilla_model, train_loader, include_patch_embeddings=True)
van_test_feats, van_test_labels = extract_features(vanilla_model, test_loader, include_patch_embeddings=True)

# Clear GPU memory
del vanilla_model
torch.cuda.empty_cache()

# 3. Train Probe
print("\n--- Training Probe on VANILLA Features ---")
van_acc = train_linear_probe(van_train_feats, van_train_labels, van_test_feats, van_test_labels)
print(f"Final Vanilla Embedding Accuracy: {van_acc:.2f}%")

# %% [markdown]
# ## 6. Results Comparison

# %%
print("="*40)
print("LINEAR PROBING RESULTS (CIFAR10)")
print("="*40)
print(f"{'Model Type':<20} | {'Test Accuracy':<15}")
print("-" * 38)
print(f"{'Box Embeddings':<20} | {box_acc:.2f}%")
print(f"{'Vanilla VAE':<20} | {van_acc:.2f}%")
print("-" * 38)

diff = box_acc - van_acc
print(f"Delta (Box - Vanilla): {diff:+.2f}%")
if diff > 0:
    print("Box Embeddings learned a better linearly separable representation.")
else:
    print("Vanilla VAE learned a better linearly separable representation.")

Running on device: cuda


Train samples: 50000
Test samples: 10000
Loading model from /home/ubuntu/workspace/experiment_root_dir/PatchBoxEmbeddingsVAE/PatchBoxEmbeddingsVAE_CIFAR10_exp_1/PatchBoxEmbeddingsVAE_CIFAR10_exp_1___2025-11-19__10-16-25/checkpoints/model-epoch=099.ckpt...

--- Extracting Features for BOX Model ---
Extracting features...


100%|██████████| 391/391 [00:12<00:00, 30.55it/s]


Extracted Feature Shape: torch.Size([50000, 2176])
Extracted Label Shape: torch.Size([50000])
Extracting features...


100%|██████████| 79/79 [00:02<00:00, 27.88it/s]


Extracted Feature Shape: torch.Size([10000, 2176])
Extracted Label Shape: torch.Size([10000])

--- Training Probe on BOX Features ---

Training Linear Probe (Input Dim: 2176)...
Epoch 5/100 | Loss: 2.1040 | Test Acc: 28.95%
Epoch 10/100 | Loss: 2.0620 | Test Acc: 28.74%
Epoch 15/100 | Loss: 2.0609 | Test Acc: 25.94%
Epoch 20/100 | Loss: 2.0317 | Test Acc: 37.31%
Epoch 25/100 | Loss: 1.9830 | Test Acc: 33.34%
Epoch 30/100 | Loss: 2.0158 | Test Acc: 34.72%
Epoch 35/100 | Loss: 1.9885 | Test Acc: 32.10%
Epoch 40/100 | Loss: 1.9847 | Test Acc: 36.35%
Epoch 45/100 | Loss: 1.9772 | Test Acc: 35.80%
Epoch 50/100 | Loss: 2.0290 | Test Acc: 32.03%
Epoch 55/100 | Loss: 1.9221 | Test Acc: 38.22%
Epoch 60/100 | Loss: 1.9289 | Test Acc: 32.45%
Epoch 65/100 | Loss: 1.9570 | Test Acc: 31.66%
Epoch 70/100 | Loss: 1.9240 | Test Acc: 35.88%
Epoch 75/100 | Loss: 1.9276 | Test Acc: 37.51%
Epoch 80/100 | Loss: 1.9177 | Test Acc: 38.23%
Epoch 85/100 | Loss: 2.0109 | Test Acc: 38.54%
Epoch 90/100 | Loss: 1.9

100%|██████████| 391/391 [00:12<00:00, 30.19it/s]


Extracted Feature Shape: torch.Size([50000, 2176])
Extracted Label Shape: torch.Size([50000])
Extracting features...


100%|██████████| 79/79 [00:02<00:00, 26.86it/s]


Extracted Feature Shape: torch.Size([10000, 2176])
Extracted Label Shape: torch.Size([10000])

--- Training Probe on VANILLA Features ---

Training Linear Probe (Input Dim: 2176)...
Epoch 5/100 | Loss: 1.7224 | Test Acc: 40.38%
Epoch 10/100 | Loss: 1.6826 | Test Acc: 40.95%
Epoch 15/100 | Loss: 1.6601 | Test Acc: 40.88%
Epoch 20/100 | Loss: 1.6439 | Test Acc: 40.88%
Epoch 25/100 | Loss: 1.6301 | Test Acc: 41.08%
Epoch 30/100 | Loss: 1.6204 | Test Acc: 41.50%
Epoch 35/100 | Loss: 1.6099 | Test Acc: 41.43%
Epoch 40/100 | Loss: 1.6019 | Test Acc: 41.56%
Epoch 45/100 | Loss: 1.5950 | Test Acc: 41.79%
Epoch 50/100 | Loss: 1.5870 | Test Acc: 41.80%
Epoch 55/100 | Loss: 1.5814 | Test Acc: 41.52%
Epoch 60/100 | Loss: 1.5746 | Test Acc: 41.77%
Epoch 65/100 | Loss: 1.5700 | Test Acc: 41.11%
Epoch 70/100 | Loss: 1.5647 | Test Acc: 41.63%
Epoch 75/100 | Loss: 1.5605 | Test Acc: 41.58%
Epoch 80/100 | Loss: 1.5555 | Test Acc: 41.99%
Epoch 85/100 | Loss: 1.5513 | Test Acc: 41.89%
Epoch 90/100 | Loss: