## 1. Setup & Imports


In [1]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import torchvision.models as models
import timm

# Import existing utils
from utils_criterion import compute_errors
from models.unetbaseline_model import define_G  # ‚Üê U-Net ÏßÄÏõê!

# GPU setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
GPU: NVIDIA RTX A5000


## 2. Locations


In [2]:
locations = [
    "2ndFloorLuxembourg",
    "3rd_Floor_Luxembourg",
    "Attic",
    "Outdoor_Cobblestone_Path",
    "Salle_Chevalier",
    "Salle_des_Colonnes",
    "V119_Cake_Corridors"
]


## 3. Configuration


In [3]:
# ========== Configuration ==========

# Data settings
ROOT_DIR = '/root/dev/data/dataset/Batvision/BatvisionV2/'
USE_ALL_LOCATIONS = True  # True: Ï†ÑÏ≤¥ ÏãúÌÄÄÏä§, False: Îã®Ïùº ÏãúÌÄÄÏä§
LOCATION = 'Salle_des_Colonnes'
MAX_DEPTH = 30.0
IMG_SIZE = 256

# Train/Eval split
USE_TRAIN_AS_EVAL = False  # True: trainÏúºÎ°ú ÌèâÍ∞Ä, False: val ÏÇ¨Ïö©

# Model settings
MODEL_TYPE = 'resnet'  # 'resnet' or 'unet_256'
PRETRAINED = True

# Training settings
BATCH_SIZE = 8
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4

# Logging
PRINT_EVERY = 5
VIS_SAMPLES = 3

print("Configuration:")
if USE_ALL_LOCATIONS:
    print(f"  Dataset: ALL LOCATIONS ({len(locations)} sequences)")
else:
    print(f"  Dataset: {LOCATION}")
print(f"  Model: {MODEL_TYPE}")
print(f"  Epochs: {NUM_EPOCHS}, LR: {LEARNING_RATE}")


Configuration:
  Dataset: ALL LOCATIONS (7 sequences)
  Model: resnet
  Epochs: 50, LR: 0.0001


# üéµ Depth Any Audio: Cross-Modal Distillation

Inspired by **Depth AnyEvent** (ICCV 2025), we implement cross-modal distillation:
- **Teacher**: Depth Anything V2 (RGB ‚Üí Depth) - Vision Foundation Model
- **Student**: Audio U-Net (Binaural Audio ‚Üí Depth) - Event-like sensor
- **Strategy**: Teacher generates proxy depth labels from RGB images to supervise the audio-based student

This approach eliminates the need for expensive depth annotations!


## Configuration for Depth Any Audio


In [4]:
# ========== Depth Any Audio Configuration ==========

# Training mode
# OPTIONS:
#   - USE_DISTILLATION=True: Cross-modal distillation (Teacher RGB + Student Audio)
#   - USE_DISTILLATION=False: Supervised learning (Student Audio + GT only)
USE_DISTILLATION = True  # True: Depth Any Audio (distillation), False: Standard supervised

# Teacher model settings
# OPTIONS:
#   - 'depthanything_v2_vitl': Depth Anything V2 (depth prediction level KD)
#   - 'vitl_feature': ViT-L (feature level KD from spectrogram)
TEACHER_MODEL = 'vitl_feature'  # 'depthanything_v2_vitl' or 'vitl_feature'
TEACHER_ENCODER = 'vitl'  # vits, vitb, vitl
FREEZE_TEACHER = True  # Always freeze teacher

# Feature-level KD settings (for ViT-L teacher)
USE_FEATURE_KD = True  # True: Feature-level KD, False: Prediction-level KD
FEATURE_KD_LAYERS = [6, 12, 18, 24]  # Which transformer blocks to use for KD (0-indexed, total 24 for ViT-L)
FEATURE_KD_LAMBDA = 2.0  # Weight for feature-level KD loss (increased for better KD effect)
FEATURE_KD_LOSS_TYPE = 'cosine'  # 'mse' or 'cosine' for feature matching (cosine is better for normalized features)

# Student model settings (Audio U-Net)
STUDENT_MODEL = 'unet_256'  # Audio-based depth estimator
STUDENT_BASE_CHANNELS = 64  # 64 for standard, 32 for lightweight

# Distillation loss settings
DISTILLATION_LOSS = 'combined'  # 'l1', 'silog', 'combined'
LAMBDA_L1 = 0.5
LAMBDA_SILOG = 0.5
SILOG_LAMBDA = 0.85

# Ground truth supervision settings
USE_GT_SUPERVISION = True  # Use GT depth for supervision (always True for fair comparison)
if USE_DISTILLATION:
    if USE_FEATURE_KD and TEACHER_MODEL == 'vitl_feature':
        # Feature-level KD: Only feature KD + GT supervision
        LAMBDA_DISTILL = 0.0  # No prediction-level distillation
        LAMBDA_GT = 1.0  # Weight for GT supervision loss
    else:
        # Prediction-level KD: Teacher pseudo-label + GT supervision
        LAMBDA_DISTILL = 0.5  # Weight for distillation loss (teacher pseudo-label)
        LAMBDA_GT = 0.5  # Weight for GT supervision loss
else:
    # Supervised learning mode: Only GT supervision
    LAMBDA_DISTILL = 0.0
    LAMBDA_GT = 1.0

# Data settings for distillation
USE_RGB_TEACHER = True  # Use RGB images for teacher (only for depthanything_v2)
USE_AUDIO_STUDENT = True  # Use audio for student

# Training settings for distillation
DISTILL_BATCH_SIZE = 2  # Further reduced batch size for ViT-L (memory intensive)
DISTILL_GRAD_ACCUM = 8  # Gradient accumulation steps (effective batch = 2 * 8 = 16)
DISTILL_EPOCHS = 100
DISTILL_LR = 1e-4
DISTILL_WEIGHT_DECAY = 1e-4
USE_MIXED_PRECISION = True  # Use mixed precision to save memory

print("=" * 80)
print("Depth Any Audio Configuration")
print("=" * 80)
if USE_DISTILLATION:
    if USE_FEATURE_KD and TEACHER_MODEL == 'vitl_feature':
        print("Mode: üéµ Feature-Level Knowledge Distillation")
        print(f"  Teacher: {TEACHER_MODEL} (ViT-L, frozen: {FREEZE_TEACHER})")
        print(f"    Feature layers: {FEATURE_KD_LAYERS}")
        print(f"    Feature KD loss: {FEATURE_KD_LOSS_TYPE} (Œª={FEATURE_KD_LAMBDA})")
        print(f"  Student: {STUDENT_MODEL} (base_channels: {STUDENT_BASE_CHANNELS})")
        print(f"  Loss: {DISTILLATION_LOSS}")
        print(f"    Œª_L1: {LAMBDA_L1}, Œª_SIlog: {LAMBDA_SILOG}, SIlog_Œª: {SILOG_LAMBDA}")
        print(f"  Supervision: Feature KD + GT (Œª_feature={FEATURE_KD_LAMBDA}, Œª_GT={LAMBDA_GT})")
    else:
        print("Mode: üéµ Cross-Modal Distillation (Prediction-Level KD)")
        print(f"  Teacher: {TEACHER_MODEL} (frozen: {FREEZE_TEACHER})")
        print(f"  Student: {STUDENT_MODEL} (base_channels: {STUDENT_BASE_CHANNELS})")
        print(f"  Loss: {DISTILLATION_LOSS}")
        print(f"    Œª_L1: {LAMBDA_L1}, Œª_SIlog: {LAMBDA_SILOG}, SIlog_Œª: {SILOG_LAMBDA}")
        print(f"  Supervision: Œª_distill={LAMBDA_DISTILL}, Œª_GT={LAMBDA_GT}")
else:
    print("Mode: üìö Supervised Learning (Student Audio + GT only)")
    print(f"  Student: {STUDENT_MODEL} (base_channels: {STUDENT_BASE_CHANNELS})")
    print(f"  Loss: {DISTILLATION_LOSS}")
    print(f"    Œª_L1: {LAMBDA_L1}, Œª_SIlog: {LAMBDA_SILOG}, SIlog_Œª: {SILOG_LAMBDA}")
    print(f"  Supervision: GT only (Œª_GT={LAMBDA_GT})")
    print("  ‚ö†Ô∏è  Teacher model will NOT be loaded or used")
print("=" * 80)


Depth Any Audio Configuration
Mode: üéµ Feature-Level Knowledge Distillation
  Teacher: vitl_feature (ViT-L, frozen: True)
    Feature layers: [6, 12, 18, 24]
    Feature KD loss: cosine (Œª=2.0)
  Student: unet_256 (base_channels: 64)
  Loss: combined
    Œª_L1: 0.5, Œª_SIlog: 0.5, SIlog_Œª: 0.85
  Supervision: Feature KD + GT (Œª_feature=2.0, Œª_GT=1.0)


In [5]:
import torchaudio
import torchaudio.transforms as T

class DepthAnyAudioDataset(Dataset):
    """
    Dataset for Depth Any Audio: Returns RGB image, Audio spectrogram, and Depth
    - RGB: For teacher model (Depth Anything V2)
    - Audio: For student model (Audio U-Net)
    - Depth: Ground truth (optional, for validation)
    
    Supports multiple locations (similar to BatvisionV2_Dataset.py)
    """
    
    def __init__(self, root_dir, locations=None, split='train', max_depth=30.0, img_size=256, 
                 audio_format='spectrogram', use_gt_depth=True, location_blacklist=None):
        """
        Args:
            root_dir: Root directory of the dataset
            locations: List of locations or single location string. If None, uses all valid locations.
            split: 'train' or 'val'
            max_depth: Maximum depth in meters
            img_size: Image size for resizing
            audio_format: 'spectrogram' or 'waveform'
            use_gt_depth: Whether to load ground truth depth
            location_blacklist: List of locations to exclude
        """
        self.root_dir = root_dir
        self.max_depth = max_depth
        self.img_size = img_size
        self.audio_format = audio_format
        self.use_gt_depth = use_gt_depth
        
        # Handle locations
        if locations is None:
            # Use all valid directories in root_dir
            location_list = [item for item in os.listdir(root_dir) 
                           if os.path.isdir(os.path.join(root_dir, item))
                           and not item.startswith('.')
                           and not item.startswith('__')
                           and not item.endswith('_unzipped')]
        elif isinstance(locations, str):
            # Single location
            location_list = [locations]
        else:
            # Multiple locations
            location_list = locations
        
        # Apply blacklist
        if location_blacklist:
            location_list = [loc for loc in location_list if loc not in location_blacklist]
        
        # Load CSVs from all valid locations
        location_csv_paths = []
        for location in location_list:
            csv_path = os.path.join(root_dir, location, f'{split}.csv')
            if os.path.exists(csv_path):
                location_csv_paths.append(csv_path)
            else:
                print(f"‚ö†Ô∏è  Warning: {csv_path} not found, skipping location {location}")
        
        if len(location_csv_paths) == 0:
            raise ValueError(f"No valid locations found with {split}.csv in {root_dir}. "
                           f"Checked {len(location_list)} directories: {location_list[:5]}...")
        
        # Load and concatenate all CSVs
        self.instances = []
        for csv_path in location_csv_paths:
            df = pd.read_csv(csv_path)
            location_name = os.path.basename(os.path.dirname(csv_path))
            print(f"Loaded {len(df)} samples from {location_name} ({split})")
            self.instances.append(df)
        
        self.data = pd.concat(self.instances, ignore_index=True)
        print(f"‚úÖ Total: {len(self.data)} samples from {len(location_csv_paths)} location(s)")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # === Load RGB Image ===
        # Fix: Use 'camera path' and 'camera file name' instead of 'image path' and 'image file name'
        img_path = os.path.join(self.root_dir, row['camera path'], row['camera file name'])
        image = cv2.imread(img_path)
        if image is None:
            raise RuntimeError(f"Could not load image file {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.img_size, self.img_size))
        image = image.astype(np.float32) / 255.0
        image = torch.from_numpy(image).permute(2, 0, 1)  # [3, H, W]
        
        # === Load Audio (Binaural) ===
        # Fix: Use torchaudio.load() with fallback methods instead of np.load()
        audio_path = os.path.join(self.root_dir, row['audio path'], row['audio file name'])
        waveform, sr = self._load_audio(audio_path)
        
        # Cut audio to fit max depth (same as BatvisionV2_Dataset.py)
        if self.max_depth:
            cut = int((2 * self.max_depth / 340) * sr)
            waveform = waveform[:, :cut]
        
        # Convert to spectrogram
        if self.audio_format == 'spectrogram':
            # Use parameters from BatvisionV2_Dataset.py for cut audio
            win_length = 64
            n_fft = 512
            hop_length = 64 // 4
            
            spec_transform = T.Spectrogram(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                power=1.0
            )
            audio = spec_transform(waveform)  # [2, freq, time]
            
            # Apply log scale (same as BatvisionV2_Dataset.py)
            audio = torch.log(audio + 1e-8)
            
            # Min-max normalize each channel independently to [0, 1]
            for c in range(audio.shape[0]):
                spec_min = audio[c].min()
                spec_max = audio[c].max()
                if spec_max > spec_min:
                    audio[c] = (audio[c] - spec_min) / (spec_max - spec_min)
                else:
                    audio[c] = torch.zeros_like(audio[c])
            
            # Resize to match image size
            audio = F.interpolate(audio.unsqueeze(0), size=(self.img_size, self.img_size), 
                                mode='bilinear', align_corners=False).squeeze(0)
        else:
            audio = waveform
        
        # === Load Depth (Ground Truth) ===
        if self.use_gt_depth:
            depth_path = os.path.join(self.root_dir, row['depth path'], row['depth file name'])
            depth = np.load(depth_path).astype(np.float32)
            depth = depth / 1000.0  # mm to m
            if self.max_depth:
                depth[depth > self.max_depth] = self.max_depth
            depth[depth < 0] = 0
            depth = cv2.resize(depth, (self.img_size, self.img_size), 
                             interpolation=cv2.INTER_NEAREST)
            depth = torch.from_numpy(depth).unsqueeze(0)  # [1, H, W]
        else:
            depth = torch.zeros(1, self.img_size, self.img_size)
        
        return {
            'image': image,      # RGB for teacher
            'audio': audio,      # Audio for student
            'depth_gt': depth,   # Ground truth (optional)
            'filename': row['camera file name']
        }
    
    def _load_audio(self, audio_path):
        """Load audio with multiple fallback methods (from BatvisionV2_Dataset.py)"""
        try:
            waveform, sr = torchaudio.load(audio_path)
        except (RuntimeError, ValueError) as e:
            try:
                waveform, sr = torchaudio.load(audio_path, backend="soundfile")
            except:
                try:
                    from scipy.io import wavfile
                    sr, audio_data = wavfile.read(audio_path)
                    if audio_data.ndim == 1:
                        waveform = torch.from_numpy(audio_data).float().unsqueeze(0)
                    else:
                        waveform = torch.from_numpy(audio_data.T).float()
                    if waveform.dtype == torch.int16:
                        waveform = waveform / 32768.0
                    elif waveform.dtype == torch.int32:
                        waveform = waveform / 2147483648.0
                except Exception as e2:
                    try:
                        import soundfile as sf
                        audio_data, sr = sf.read(audio_path)
                        if audio_data.ndim == 1:
                            waveform = torch.from_numpy(audio_data).float().unsqueeze(0)
                        else:
                            waveform = torch.from_numpy(audio_data.T).float()
                    except Exception as e3:
                        raise RuntimeError(
                            f"Could not load audio file {audio_path} with any method. "
                            f"Tried: torchaudio (error: {e}), scipy (error: {e2}), soundfile (error: {e3})"
                        )
        return waveform, sr

# Test dataset
print("\n" + "="*80)
print("Testing DepthAnyAudioDataset")
print("="*80)

# Test with single location or multiple locations based on USE_ALL_LOCATIONS
if USE_ALL_LOCATIONS:
    test_locations = locations  # Use predefined locations list
else:
    test_locations = LOCATION  # Single location

test_dataset = DepthAnyAudioDataset(
    root_dir=ROOT_DIR,
    locations=test_locations,
    split='train',
    max_depth=MAX_DEPTH,
    img_size=IMG_SIZE
)

sample = test_dataset[0]
print(f"\nSample 0:")
print(f"  Image shape: {sample['image'].shape}, range: [{sample['image'].min():.3f}, {sample['image'].max():.3f}]")
print(f"  Audio shape: {sample['audio'].shape}, range: [{sample['audio'].min():.3f}, {sample['audio'].max():.3f}]")
print(f"  Depth shape: {sample['depth_gt'].shape}, range: [{sample['depth_gt'].min():.3f}, {sample['depth_gt'].max():.3f}]")
print(f"  Filename: {sample['filename']}")
print("‚úÖ Dataset test passed!")



Testing DepthAnyAudioDataset
Loaded 431 samples from 2ndFloorLuxembourg (train)
Loaded 290 samples from 3rd_Floor_Luxembourg (train)
Loaded 37 samples from Attic (train)
Loaded 377 samples from Outdoor_Cobblestone_Path (train)
Loaded 116 samples from Salle_Chevalier (train)
Loaded 240 samples from Salle_des_Colonnes (train)
Loaded 420 samples from V119_Cake_Corridors (train)
‚úÖ Total: 1911 samples from 7 location(s)

Sample 0:
  Image shape: torch.Size([3, 256, 256]), range: [0.000, 1.000]
  Audio shape: torch.Size([2, 256, 256]), range: [0.063, 0.998]
  Depth shape: torch.Size([1, 256, 256]), range: [0.000, 11.929]
  Filename: camera_0.jpeg
‚úÖ Dataset test passed!


In [6]:
if USE_DISTILLATION:
    # Only load teacher model if distillation is enabled
    print("\n" + "="*80)
    print(f"Loading Teacher Model: {TEACHER_MODEL}")
    print("="*80)
    
    if TEACHER_MODEL == 'vitl_feature':
        # Load ViT-L for feature-level KD
        # Teacher: RGB Ïù¥ÎØ∏ÏßÄ ÏûÖÎ†•Ïùò ÏÇ¨Ï†ÑÌïôÏäµ ViT-L
        try:
            import timm
            print(f"Loading ViT-L Teacher (RGB input) from timm...")
            
            # ViT-L feature extractor wrapper for RGB images
            class ViTLFeatureExtractor(nn.Module):
                """ViT-L model that extracts features at specified transformer blocks from RGB images"""
                def __init__(self, model_name=None, feature_layers=[6, 12, 18, 24], input_channels=3):
                    super().__init__()
                    # Try different ViT-L model names if not specified
                    if model_name is None:
                        # Try common ViT-L model names in order of preference
                        possible_models = [
                            'vit_large_patch16_224',  # Most common, ImageNet-1k pretrained
                            'vit_large_patch16_224_in21k',  # ImageNet-21k pretrained
                            'vit_large_patch14_224_in21k',  # ImageNet-21k pretrained, patch14
                            'vit_large_patch14_224',  # Less common
                        ]
                        model_name = None
                        for m in possible_models:
                            try:
                                # Check if model exists in timm registry
                                if m in timm.list_models('*vit_large*'):
                                    model_name = m
                                    break
                            except:
                                # Fallback: try to create model
                                try:
                                    test_model = timm.create_model(m, pretrained=False, num_classes=0, check_cfg=False)
                                    model_name = m
                                    del test_model
                                    break
                                except:
                                    continue
                        
                        if model_name is None:
                            # Last resort: try to find any vit_large model
                            try:
                                available_models = timm.list_models('*vit_large*')
                                if available_models:
                                    model_name = available_models[0]
                                    print(f"‚ö†Ô∏è  Using first available ViT-L model: {model_name}")
                                else:
                                    raise ValueError("No ViT-L models found in timm. Please install timm or specify model_name.")
                            except:
                                raise ValueError("Could not find a valid ViT-L model. Please install timm or specify model_name.")
                    
                    # Load pretrained ViT-L
                    try:
                        self.vit = timm.create_model(
                            model_name,
                            pretrained=True,
                            num_classes=0,  # Remove classification head
                            img_size=224
                        )
                    except RuntimeError as e:
                        # If pretrained weights don't exist, try without pretrained
                        print(f"‚ö†Ô∏è  Warning: Pretrained weights not found for {model_name}. Using random initialization.")
                        self.vit = timm.create_model(
                            model_name,
                            pretrained=False,
                            num_classes=0,
                            img_size=224
                        )
                    
                    self.feature_layers = feature_layers
                    self.model_name = model_name  # Store model name for reference
                    self.input_channels = input_channels
                    # Determine patch size from model
                    if 'patch14' in model_name:
                        self.patch_size = 14
                    elif 'patch16' in model_name:
                        self.patch_size = 16
                    else:
                        self.patch_size = 16  # Default
                    self.embed_dim = 1024  # ViT-L embedding dimension
                    
                def forward(self, x, return_features=True):
                    """
                    Args:
                        x: [B, 3, H, W] - RGB image input
                        return_features: If True, return intermediate features
                    
                    Returns:
                        features: Dict of features at specified layers
                        or output: Final output if return_features=False
                    """
                    # Ensure RGB input (3 channels)
                    if x.shape[1] != 3:
                        if x.shape[1] == 1:
                            x = x.repeat(1, 3, 1, 1)
                        elif x.shape[1] == 2:
                            # Convert 2-channel to 3-channel
                            x = torch.cat([x, (x[:, 0:1] + x[:, 1:2]) / 2], dim=1)
                    
                    # Resize to ViT input size (224x224)
                    if x.shape[-1] != 224:
                        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
                    
                    # Normalize to ImageNet stats
                    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
                    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
                    x = (x - mean) / std
                    
                    if return_features:
                        # Extract features at specified layers
                        features = {}
                        
                        # Patch embedding
                        x = self.vit.patch_embed(x)  # [B, N, embed_dim]
                        
                        # Add cls token if exists
                        if hasattr(self.vit, 'cls_token') and self.vit.cls_token is not None:
                            cls_tokens = self.vit.cls_token.expand(x.shape[0], -1, -1)  # [B, 1, embed_dim]
                            x = torch.cat([cls_tokens, x], dim=1)  # [B, N+1, embed_dim]
                        
                        # Add positional embedding
                        if hasattr(self.vit, 'pos_embed') and self.vit.pos_embed is not None:
                            # pos_embed shape: [1, N+1, embed_dim] (includes cls token)
                            if x.shape[1] == self.vit.pos_embed.shape[1]:
                                x = x + self.vit.pos_embed
                            else:
                                # Handle dynamic pos_embed (slice to match)
                                x = x + self.vit.pos_embed[:, :x.shape[1], :]
                        
                        # Apply pos_drop if exists
                        if hasattr(self.vit, 'pos_drop'):
                            x = self.vit.pos_drop(x)
                        
                        # Pre-norm if exists
                        if hasattr(self.vit, 'norm_pre'):
                            x = self.vit.norm_pre(x)
                        
                        # Extract features at specified transformer blocks
                        # blocks can be nn.ModuleList or nn.Sequential
                        if isinstance(self.vit.blocks, nn.ModuleList):
                            for i, block in enumerate(self.vit.blocks):
                                x = block(x)
                                if (i + 1) in self.feature_layers:
                                    # Remove cls token if exists before reshaping
                                    if hasattr(self.vit, 'cls_token') and self.vit.cls_token is not None:
                                        # x shape: [B, N+1, C], remove first token
                                        x_patches = x[:, 1:, :]  # [B, N, C]
                                    else:
                                        x_patches = x  # [B, N, C]
                                    
                                    # Reshape to spatial format: [B, N, C] -> [B, C, H, W]
                                    B, N, C = x_patches.shape
                                    # Calculate spatial dimensions from patch embedding
                                    # For patch_size=14 and img_size=224: N = (224/14)^2 = 16^2 = 256
                                    H = W = int(N ** 0.5)  # Assuming square patches
                                    if H * W == N:
                                        feat = x_patches.permute(0, 2, 1).view(B, C, H, W)
                                        features[f'layer_{i+1}'] = feat
                        else:
                            # If blocks is Sequential, we need to hook into it
                            # For now, just process all blocks
                            x = self.vit.blocks(x)
                            # Note: Sequential doesn't allow per-layer extraction easily
                            # This is a fallback - ideally blocks should be ModuleList
                            if hasattr(self.vit, 'cls_token') and self.vit.cls_token is not None:
                                x_patches = x[:, 1:, :]
                            else:
                                x_patches = x
                            B, N, C = x_patches.shape
                            H = W = int(N ** 0.5)
                            if H * W == N:
                                feat = x_patches.permute(0, 2, 1).view(B, C, H, W)
                                # Store as last layer if no specific layers found
                                features['layer_final'] = feat
                        
                        return features
                    else:
                        # Use forward_features if available (more reliable)
                        if hasattr(self.vit, 'forward_features'):
                            x = self.vit.forward_features(x)
                        else:
                            # Manual forward pass
                            x = self.vit.patch_embed(x)
                            if hasattr(self.vit, 'pos_embed'):
                                if x.shape[1] == self.vit.pos_embed.shape[1]:
                                    x = x + self.vit.pos_embed
                                else:
                                    x = x + self.vit.pos_embed[:, :x.shape[1], :]
                            if hasattr(self.vit, 'pos_drop'):
                                x = self.vit.pos_drop(x)
                            if hasattr(self.vit, 'norm_pre'):
                                x = self.vit.norm_pre(x)
                            x = self.vit.blocks(x)
                            if hasattr(self.vit, 'norm'):
                                x = self.vit.norm(x)
                        return x
            
            # Create ViT-L teacher (RGB input)
            teacher_model = ViTLFeatureExtractor(
                model_name=None,  # Auto-detect available ViT-L model
                feature_layers=FEATURE_KD_LAYERS,
                input_channels=3  # RGB input
            ).to(device)
            
            # Print which model was loaded
            print(f"‚úÖ Loaded ViT-L Teacher model: {teacher_model.model_name}")
            print(f"   Input: RGB images (3 channels)")
            print(f"   Patch size: {teacher_model.patch_size}x{teacher_model.patch_size}")
            print(f"   Embedding dim: {teacher_model.embed_dim}")
            
            # Freeze teacher
            if FREEZE_TEACHER:
                for param in teacher_model.parameters():
                    param.requires_grad = False
                teacher_model.eval()
                print("‚úÖ Teacher model (ViT-L) frozen")
            
            # Test teacher with RGB image
            with torch.no_grad():
                test_img = torch.randn(1, 3, 256, 256).to(device)
                teacher_features = teacher_model(test_img, return_features=True)
                print(f"Teacher features extracted at layers: {list(teacher_features.keys())}")
                for k, v in teacher_features.items():
                    print(f"  {k}: shape {v.shape}")
            
            print("‚úÖ ViT-L teacher model (RGB input) loaded successfully!")
            
        except ImportError:
            print("‚ö†Ô∏è  timm not found. Installing...")
            print("Run: pip install timm")
            raise ImportError("timm is required for ViT-L feature extraction")
    
    elif TEACHER_MODEL == 'depthanything_v2_vitl':
        # Load Depth Anything V2 (original prediction-level KD)
        try:
            from depth_anything_v2.dpt import DepthAnythingV2
        except ImportError:
            print("‚ö†Ô∏è  Depth Anything V2 not found. Installing...")
            print("Run: pip install depth-anything-v2")
            print("\nFor now, we'll use a placeholder teacher model.")
            
            # Placeholder teacher (ResNet-based)
            class PlaceholderTeacher(nn.Module):
                def __init__(self):
                    super().__init__()
                    resnet = models.resnet18(pretrained=True)
                    self.encoder = nn.Sequential(*list(resnet.children())[:-2])
                    self.decoder = nn.Sequential(
                        nn.Conv2d(512, 256, 3, padding=1),
                        nn.ReLU(),
                        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                        nn.Conv2d(256, 128, 3, padding=1),
                        nn.ReLU(),
                        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                        nn.Conv2d(128, 64, 3, padding=1),
                        nn.ReLU(),
                        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                        nn.Conv2d(64, 32, 3, padding=1),
                        nn.ReLU(),
                        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                        nn.Conv2d(32, 1, 3, padding=1),
                    )
                
                def forward(self, x):
                    x = self.encoder(x)
                    x = self.decoder(x)
                    # Resize to match input
                    x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
                    return x * 30.0  # Scale to depth range
            
            DepthAnythingV2 = PlaceholderTeacher
            USING_PLACEHOLDER_TEACHER = True
        else:
            USING_PLACEHOLDER_TEACHER = False

        if USING_PLACEHOLDER_TEACHER:
            print("‚ö†Ô∏è  Using placeholder teacher (ResNet-based)")
            teacher_model = PlaceholderTeacher().to(device)
        else:
            print(f"Loading Depth Anything V2: {TEACHER_ENCODER}")
            model_configs = {
                'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
                'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
                'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
            }
            
            teacher_model = DepthAnythingV2(**model_configs[TEACHER_ENCODER])
            teacher_model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{TEACHER_ENCODER}.pth', map_location='cpu'))
            teacher_model = teacher_model.to(device)

        # Freeze teacher
        if FREEZE_TEACHER:
            for param in teacher_model.parameters():
                param.requires_grad = False
            teacher_model.eval()
            print("‚úÖ Teacher model frozen")

        # Test teacher
        with torch.no_grad():
            test_img = torch.randn(1, 3, 256, 256).to(device)
            teacher_output = teacher_model(test_img)
            print(f"Teacher output shape: {teacher_output.shape}")
            print(f"Teacher output range: [{teacher_output.min():.3f}, {teacher_output.max():.3f}]")

        print("‚úÖ Teacher model loaded successfully!")
    else:
        raise ValueError(f"Unknown teacher model: {TEACHER_MODEL}")
else:
    # Supervised learning mode: No teacher model needed
    print("\n" + "="*80)
    print("Teacher Model: SKIPPED")
    print("="*80)
    print("üìö Supervised learning mode: Training with GT depth only")
    print("‚ö†Ô∏è  Teacher model will not be loaded or used")
    teacher_model = None
    print("‚úÖ Ready for supervised training!")



Loading Teacher Model: vitl_feature
Loading ViT-L Teacher (RGB input) from timm...
‚úÖ Loaded ViT-L Teacher model: vit_large_patch16_224
   Input: RGB images (3 channels)
   Patch size: 16x16
   Embedding dim: 1024
‚úÖ Teacher model (ViT-L) frozen
Teacher features extracted at layers: ['layer_final']
  layer_final: shape torch.Size([1, 1024, 14, 14])
‚úÖ ViT-L teacher model (RGB input) loaded successfully!


## Student Model: Audio U-Net


In [7]:
print("\n" + "="*80)
print("Loading Student Model (Audio ViT-L with Spectrogram-Specialized Architecture)")
print("="*80)

if USE_FEATURE_KD and TEACHER_MODEL == 'vitl_feature':
    # Student: Audio Ïä§ÌéôÌä∏Î°úÍ∑∏Îû® ÏûÖÎ†•Ïùò ViT-L (TeacherÎ°úÎ∂ÄÌÑ∞ Ï¥àÍ∏∞Ìôî)
    import timm
    
    # Audio ViT-L Student model with spectrogram-specialized architecture
    class AudioViTLStudent(nn.Module):
        """
        ViT-L Student model for audio spectrogram input, initialized from teacher.
        Spectrogram-specialized architecture:
        1. Spectrogram-aware patch embedding (2-channel input with multi-layer conv)
        2. Spatial depth decoder (patch tokens -> depth map via multi-scale fusion)
        3. Multi-scale feature fusion for better depth prediction
        """
        def __init__(self, teacher_model, feature_layers=[6, 12, 18, 24], output_size=256):
            super().__init__()
            self.feature_layers = feature_layers
            self.output_size = output_size
            
            # Create ViT-L backbone (same architecture as teacher)
            # Use teacher's model name and structure
            teacher_vit = teacher_model.vit
            
            # Create new ViT-L model with same architecture
            self.vit = timm.create_model(
                teacher_model.model_name,
                pretrained=False,  # Will initialize from teacher
                num_classes=0,
                img_size=224
            )
            
            # Initialize from teacher weights
            self._init_from_teacher(teacher_vit)
            
            # Modify patch embedding to accept 2-channel audio input
            # Original patch_embed expects 3 channels, we need to adapt it
            original_patch_embed = self.vit.patch_embed
            # timm's patch_embed is typically a Conv2d layer
            if hasattr(original_patch_embed, 'proj'):
                # PatchEmbed with proj attribute
                original_proj = original_patch_embed.proj
                embed_dim = original_proj.out_channels
                kernel_size = original_proj.kernel_size[0] if isinstance(original_proj.kernel_size, tuple) else original_proj.kernel_size
                stride = original_proj.stride[0] if isinstance(original_proj.stride, tuple) else original_proj.stride
                padding = original_proj.padding[0] if isinstance(original_proj.padding, tuple) else original_proj.padding
            else:
                # Direct Conv2d
                original_proj = original_patch_embed
                embed_dim = original_proj.out_channels
                kernel_size = original_proj.kernel_size[0] if isinstance(original_proj.kernel_size, tuple) else original_proj.kernel_size
                stride = original_proj.stride[0] if isinstance(original_proj.stride, tuple) else original_proj.stride
                padding = original_proj.padding[0] if isinstance(original_proj.padding, tuple) else original_proj.padding
            
            # ========== Spectrogram-Specialized Patch Embedding ==========
            # Use multi-layer conv to better capture frequency-time patterns in spectrogram
            self.audio_patch_embed = nn.Sequential(
                nn.Conv2d(2, embed_dim // 2, kernel_size=3, stride=1, padding=1),  # Initial projection
                nn.BatchNorm2d(embed_dim // 2),
                nn.GELU(),
                nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)  # Patch embedding
            )
            
            # Initialize from RGB patch embedding (smart initialization)
            with torch.no_grad():
                if hasattr(original_patch_embed, 'proj'):
                    rgb_weight = original_patch_embed.proj.weight.data  # [out_dim, 3, kernel, kernel]
                    rgb_bias = original_patch_embed.proj.bias.data if original_patch_embed.proj.bias is not None else None
                else:
                    rgb_weight = original_patch_embed.weight.data
                    rgb_bias = original_patch_embed.bias.data if original_patch_embed.bias is not None else None
                
                # First conv: 2 channels -> embed_dim//2 (use first 2 RGB channels)
                audio_weight_1 = rgb_weight[:, :2, :, :].mean(dim=0, keepdim=True).repeat(embed_dim // 2, 1, 1, 1)
                self.audio_patch_embed[0].weight.data = audio_weight_1
                
                # Second conv: embed_dim//2 -> embed_dim (initialize from RGB mean)
                rgb_mean = rgb_weight.mean(dim=1, keepdim=True)  # [out_dim, 1, kernel, kernel]
                audio_weight_2 = rgb_mean.repeat(1, embed_dim // 2, 1, 1)
                self.audio_patch_embed[3].weight.data = audio_weight_2
                if rgb_bias is not None and self.audio_patch_embed[3].bias is not None:
                    self.audio_patch_embed[3].bias.data = rgb_bias.clone()
            
            # ========== Spatial Depth Decoder (Spectrogram-Specialized) ==========
            # Instead of global pooling, decode patch tokens spatially for better depth prediction
            self.embed_dim = 1024  # ViT-L embedding dimension
            patch_size = teacher_model.patch_size  # 16 for vit_large_patch16_224
            vit_h = vit_w = 224 // patch_size  # 14 for patch16_224
            
            # Multi-scale feature fusion decoder
            # Fuse features from different transformer layers for better depth prediction
            self.depth_decoder = nn.ModuleDict({
                'projection': nn.ModuleList([
                    nn.Conv2d(self.embed_dim, self.embed_dim // 2, 1) for _ in range(len(feature_layers))
                ]),
                'fusion': nn.Sequential(
                    nn.Conv2d(self.embed_dim // 2 * len(feature_layers), self.embed_dim // 2, 3, padding=1),
                    nn.BatchNorm2d(self.embed_dim // 2),
                    nn.GELU(),
                ),
                'upsample': nn.Sequential(
                    # vit_h x vit_w (14x14) -> output_size x output_size
                    nn.ConvTranspose2d(self.embed_dim // 2, self.embed_dim // 4, kernel_size=4, stride=2, padding=1),  # 14->28
                    nn.BatchNorm2d(self.embed_dim // 4),
                    nn.GELU(),
                    nn.ConvTranspose2d(self.embed_dim // 4, self.embed_dim // 8, kernel_size=4, stride=2, padding=1),  # 28->56
                    nn.BatchNorm2d(self.embed_dim // 8),
                    nn.GELU(),
                    nn.ConvTranspose2d(self.embed_dim // 8, self.embed_dim // 16, kernel_size=4, stride=2, padding=1),  # 56->112
                    nn.BatchNorm2d(self.embed_dim // 16),
                    nn.GELU(),
                    nn.ConvTranspose2d(self.embed_dim // 16, self.embed_dim // 32, kernel_size=4, stride=2, padding=1),  # 112->224
                    nn.BatchNorm2d(self.embed_dim // 32),
                    nn.GELU(),
                    nn.Conv2d(self.embed_dim // 32, 64, 3, padding=1),
                    nn.BatchNorm2d(64),
                    nn.GELU(),
                    nn.Conv2d(64, 32, 3, padding=1),
                    nn.BatchNorm2d(32),
                    nn.GELU(),
                    nn.Conv2d(32, 1, 3, padding=1),  # Depth output
                )
            })
            
            # Final resize to exact output size if needed
            if output_size != 224:
                self.final_resize = True
            else:
                self.final_resize = False
            
        def _init_from_teacher(self, teacher_vit):
            """Initialize student ViT from teacher ViT weights"""
            student_state = self.vit.state_dict()
            teacher_state = teacher_vit.state_dict()
            
            # Copy matching weights (excluding patch_embed.proj which we'll adapt)
            for key in student_state.keys():
                if key in teacher_state:
                    if 'patch_embed.proj' not in key:  # Skip patch embedding (different input channels)
                        student_state[key] = teacher_state[key].clone()
            
            self.vit.load_state_dict(student_state)
            print("‚úÖ Student ViT initialized from teacher weights")
        
        def forward(self, x, return_features=False):
            """
            Args:
                x: [B, 2, H, W] - Audio spectrogram (2 channels: left, right)
                return_features: If True, return intermediate features
            
            Returns:
                depth: [B, 1, H, W] - Predicted depth map
                features: Dict of intermediate features (if return_features=True)
            """
            B = x.shape[0]
            original_size = x.shape[-1]
            
            # Resize to ViT input size (224x224) for processing
            if x.shape[-1] != 224:
                x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
            
            # Normalize spectrogram (already in [0, 1] from dataset)
            # Use spectrogram-appropriate normalization
            mean = torch.tensor([0.5, 0.5]).view(1, 2, 1, 1).to(x.device)
            std = torch.tensor([0.5, 0.5]).view(1, 2, 1, 1).to(x.device)
            x = (x - mean) / std
            
            # ========== Spectrogram Patch Embedding ==========
            x_patches = self.audio_patch_embed(x)  # [B, embed_dim, H_patch, W_patch]
            B, C, H_patch, W_patch = x_patches.shape
            
            # Flatten spatial dimensions: [B, embed_dim, H_patch, W_patch] -> [B, H_patch*W_patch, embed_dim]
            x = x_patches.flatten(2).transpose(1, 2)  # [B, N, embed_dim]
            
            # ========== ViT Transformer Blocks ==========
            # Add cls token if exists
            if hasattr(self.vit, 'cls_token') and self.vit.cls_token is not None:
                cls_tokens = self.vit.cls_token.expand(B, -1, -1)
                x = torch.cat([cls_tokens, x], dim=1)
            
            # Add positional embedding
            if hasattr(self.vit, 'pos_embed') and self.vit.pos_embed is not None:
                if x.shape[1] == self.vit.pos_embed.shape[1]:
                    x = x + self.vit.pos_embed
                else:
                    x = x + self.vit.pos_embed[:, :x.shape[1], :]
            
            # Apply pos_drop if exists
            if hasattr(self.vit, 'pos_drop'):
                x = self.vit.pos_drop(x)
            
            # Pre-norm if exists
            if hasattr(self.vit, 'norm_pre'):
                x = self.vit.norm_pre(x)
            
            # Extract features at specified transformer blocks
            multi_scale_features = []
            features = {} if return_features else None
            
            if isinstance(self.vit.blocks, nn.ModuleList):
                for i, block in enumerate(self.vit.blocks):
                    x = block(x)
                    if (i + 1) in self.feature_layers:
                        # Remove cls token if exists before reshaping
                        if hasattr(self.vit, 'cls_token') and self.vit.cls_token is not None:
                            x_patches = x[:, 1:, :]  # [B, N, C] (remove cls token)
                        else:
                            x_patches = x
                        
                        # Reshape to spatial format: [B, N, C] -> [B, C, H_patch, W_patch]
                        B, N, C = x_patches.shape
                        # Calculate spatial dimensions from patch grid
                        # For ViT-L patch16_224: N should be (224/16)^2 = 14^2 = 196
                        H_feat = W_feat = int(N ** 0.5)
                        
                        # Handle cases where N is not a perfect square (shouldn't happen, but just in case)
                        if H_feat * W_feat == N:
                            feat = x_patches.permute(0, 2, 1).view(B, C, H_feat, W_feat)
                            
                            # Store for feature KD
                            if return_features:
                                features[f'layer_{i+1}'] = feat
                            
                            # Project and store for depth decoder
                            layer_idx = self.feature_layers.index(i + 1)
                            feat_proj = self.depth_decoder['projection'][layer_idx](feat)
                            multi_scale_features.append(feat_proj)
                        else:
                            # If N is not perfect square, use interpolation
                            # This shouldn't happen for standard ViT, but handle it gracefully
                            feat_flat = x_patches.permute(0, 2, 1)  # [B, C, N]
                            # Reshape to approximate square
                            H_feat = W_feat = int(torch.sqrt(torch.tensor(N, dtype=torch.float32)).item())
                            # Pad or crop to make it square
                            target_size = H_feat * W_feat
                            if N > target_size:
                                feat_flat = feat_flat[:, :, :target_size]
                            else:
                                padding = torch.zeros(B, C, target_size - N, device=feat_flat.device, dtype=feat_flat.dtype)
                                feat_flat = torch.cat([feat_flat, padding], dim=2)
                            feat = feat_flat.view(B, C, H_feat, W_feat)
                            
                            if return_features:
                                features[f'layer_{i+1}'] = feat
                            
                            layer_idx = self.feature_layers.index(i + 1)
                            feat_proj = self.depth_decoder['projection'][layer_idx](feat)
                            multi_scale_features.append(feat_proj)
            else:
                # Fallback: process all blocks (if blocks is Sequential)
                x = self.vit.blocks(x)
                if hasattr(self.vit, 'cls_token') and self.vit.cls_token is not None:
                    x_patches = x[:, 1:, :]
                else:
                    x_patches = x
                B, N, C = x_patches.shape
                H_feat = W_feat = int(N ** 0.5)
                if H_feat * W_feat == N:
                    feat = x_patches.permute(0, 2, 1).view(B, C, H_feat, W_feat)
                    if return_features:
                        features['layer_final'] = feat
                    # Use final layer for depth prediction (use first projection)
                    feat_proj = self.depth_decoder['projection'][0](feat)
                    multi_scale_features.append(feat_proj)
                else:
                    # Handle non-square case
                    feat_flat = x_patches.permute(0, 2, 1)
                    H_feat = W_feat = int(torch.sqrt(torch.tensor(N, dtype=torch.float32)).item())
                    target_size = H_feat * W_feat
                    if N > target_size:
                        feat_flat = feat_flat[:, :, :target_size]
                    else:
                        padding = torch.zeros(B, C, target_size - N, device=feat_flat.device, dtype=feat_flat.dtype)
                        feat_flat = torch.cat([feat_flat, padding], dim=2)
                    feat = feat_flat.view(B, C, H_feat, W_feat)
                    if return_features:
                        features['layer_final'] = feat
                    feat_proj = self.depth_decoder['projection'][0](feat)
                    multi_scale_features.append(feat_proj)
            
            # ========== Spatial Depth Decoder ==========
            if len(multi_scale_features) > 0:
                # Check if we have the expected number of features
                expected_channels = (self.embed_dim // 2) * len(self.feature_layers)
                if len(multi_scale_features) < len(self.feature_layers):
                    # If some features are missing, pad with the last available feature
                    while len(multi_scale_features) < len(self.feature_layers):
                        multi_scale_features.append(multi_scale_features[-1])
                
                # Fuse multi-scale features
                fused = torch.cat(multi_scale_features, dim=1)  # [B, C*num_layers, H, W]
                
                # Verify channel count matches
                if fused.shape[1] != expected_channels:
                    # If channel count doesn't match, use adaptive approach
                    # Project to expected channels if needed
                    if fused.shape[1] < expected_channels:
                        # Pad with zeros or repeat last feature
                        pad_channels = expected_channels - fused.shape[1]
                        padding = torch.zeros(B, pad_channels, fused.shape[2], fused.shape[3], 
                                             device=fused.device, dtype=fused.dtype)
                        fused = torch.cat([fused, padding], dim=1)
                    else:
                        # Take first expected_channels
                        fused = fused[:, :expected_channels, :, :]
                
                fused = self.depth_decoder['fusion'](fused)  # [B, C//2, H, W]
                
                # Upsample to output size
                depth = self.depth_decoder['upsample'](fused)  # [B, 1, 224, 224]
                
                # Resize to exact output size if needed
                if self.final_resize and depth.shape[-1] != self.output_size:
                    depth = F.interpolate(depth, size=(self.output_size, self.output_size), 
                                        mode='bilinear', align_corners=False)
            else:
                # Fallback: use final features from transformer
                if hasattr(self.vit, 'norm'):
                    x_final = self.vit.norm(x)
                else:
                    x_final = x
                
                # Remove cls token
                if hasattr(self.vit, 'cls_token') and self.vit.cls_token is not None:
                    x_patches = x_final[:, 1:, :]
                else:
                    x_patches = x_final
                
                B, N, C = x_patches.shape
                H_feat = W_feat = int(N ** 0.5)
                if H_feat * W_feat == N:
                    feat = x_patches.permute(0, 2, 1).view(B, C, H_feat, W_feat)
                    # Project to embed_dim // 2 for consistency
                    feat_proj = nn.Conv2d(C, self.embed_dim // 2, 1).to(feat.device)(feat)
                    # Use simple upsampling decoder
                    depth = F.interpolate(feat_proj, size=(self.output_size, self.output_size), 
                                        mode='bilinear', align_corners=False)
                    depth = nn.Conv2d(self.embed_dim // 2, 1, 3, padding=1).to(depth.device)(depth)
                else:
                    # Ultimate fallback: constant depth
                    depth = torch.zeros(B, 1, self.output_size, self.output_size).to(x.device)
            
            if return_features:
                return depth, features
            else:
                return depth
    
    # Create student model (initialized from teacher)
    if teacher_model is not None:
        student_model = AudioViTLStudent(
            teacher_model=teacher_model,
            feature_layers=FEATURE_KD_LAYERS,
            output_size=IMG_SIZE
        ).to(device)
        print("‚úÖ Student model: Audio ViT-L with Spectrogram-Specialized Architecture")
        print("   - Spectrogram-aware patch embedding (multi-layer conv)")
        print("   - Spatial depth decoder (multi-scale feature fusion)")
        print("   - Initialized from teacher ViT-L weights")
    else:
        raise ValueError("Teacher model must be loaded before creating student model")
    
    # Count parameters
    student_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
    print(f"\nStudent model: Audio ViT-L (Spectrogram-Specialized)")
    print(f"Trainable parameters: {student_params:,}")
    
    # Test student
    test_audio = torch.randn(1, 2, 256, 256).to(device)
    student_output, student_features = student_model(test_audio, return_features=True)
    print(f"\nStudent test output:")
    print(f"  Output shape: {student_output.shape}")
    print(f"  Output range: [{student_output.min():.3f}, {student_output.max():.3f}]")
    print(f"  Features extracted at layers: {list(student_features.keys())}")
    for k, v in student_features.items():
        print(f"    {k}: shape {v.shape}")
    print("\n‚úÖ Student model (Audio ViT-L with Spectrogram-Specialized Architecture) created successfully!")
else:
    # Fallback to UNet if not using feature KD
    from models.unetbaseline_model import define_G
    from config_loader import load_config
    
    cfg = load_config(dataset_name='batvisionv2', mode='train')
    student_model = define_G(
        cfg=cfg,
        input_nc=2,
        output_nc=1,
        ngf=STUDENT_BASE_CHANNELS,
        netG=STUDENT_MODEL,
        norm='batch',
        use_dropout=False,
        init_type='normal',
        init_gain=0.02,
        gpu_ids=[]
    ).to(device)
    
    student_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
    print(f"Student model: {STUDENT_MODEL}")
    print(f"Trainable parameters: {student_params:,}")
    
    test_audio = torch.randn(1, 2, 256, 256).to(device)
    student_output = student_model(test_audio)
    print(f"Student output shape: {student_output.shape}")
    print("‚úÖ Student model created successfully!")



Loading Student Model (Audio ViT-L with Spectrogram-Specialized Architecture)
‚úÖ Student ViT initialized from teacher weights
‚úÖ Student model: Audio ViT-L with Spectrogram-Specialized Architecture
   - Spectrogram-aware patch embedding (multi-layer conv)
   - Spatial depth decoder (multi-scale feature fusion)
   - Initialized from teacher ViT-L weights

Student model: Audio ViT-L (Spectrogram-Specialized)
Trainable parameters: 452,146,145

Student test output:
  Output shape: torch.Size([1, 1, 256, 256])
  Output range: [-0.943, 2.067]
  Features extracted at layers: ['layer_final']
    layer_final: shape torch.Size([1, 1024, 13, 13])

‚úÖ Student model (Audio ViT-L with Spectrogram-Specialized Architecture) created successfully!


## Distillation Loss Functions


In [8]:
from utils_loss import SIlogLoss

class FeatureKDLoss(nn.Module):
    """
    Feature-level Knowledge Distillation Loss
    Matches student features with teacher features at multiple layers
    """
    def __init__(self, loss_type='mse', lambda_feature=1.0):
        super().__init__()
        self.loss_type = loss_type
        self.lambda_feature = lambda_feature
        
        if loss_type == 'mse':
            self.loss_fn = nn.MSELoss()
        elif loss_type == 'cosine':
            self.cosine_sim = nn.CosineSimilarity(dim=1)
        else:
            raise ValueError(f"Unknown feature loss type: {loss_type}")
    
    def compute_feature_loss(self, student_feat, teacher_feat):
        """
        Compute feature matching loss between student and teacher features
        
        Args:
            student_feat: [B, C_s, H_s, W_s] - Student feature map
            teacher_feat: [B, C_t, H_t, W_t] - Teacher feature map
        
        Returns:
            loss: Scalar loss value
        """
        # Get batch size and spatial dimensions (define B early to avoid UnboundLocalError)
        B = student_feat.shape[0]
        
        # Align spatial dimensions (interpolate teacher to student size)
        if student_feat.shape[-2:] != teacher_feat.shape[-2:]:
            teacher_feat = F.interpolate(
                teacher_feat, 
                size=student_feat.shape[-2:], 
                mode='bilinear', 
                align_corners=False
            )
        
        # Align channel dimensions using proper projection
        if student_feat.shape[1] != teacher_feat.shape[1]:
            C_s, H, W = student_feat.shape[1], student_feat.shape[2], student_feat.shape[3]
            C_t = teacher_feat.shape[1]
            
            # Use simple channel selection/averaging to match dimensions
            # This preserves spatial information better than adaptive pooling
            if C_s > C_t:
                # Reduce student channels: take first C_t channels (or average groups)
                if C_s % C_t == 0:
                    # Group and average
                    group_size = C_s // C_t
                    student_feat = student_feat.reshape(B, C_t, group_size, H, W).mean(dim=2)
                else:
                    # Take first C_t channels
                    student_feat = student_feat[:, :C_t, :, :]
            else:
                # Reduce teacher channels: same strategy
                if C_t % C_s == 0:
                    group_size = C_t // C_s
                    teacher_feat = teacher_feat.reshape(B, C_s, group_size, H, W).mean(dim=2)
                else:
                    teacher_feat = teacher_feat[:, :C_s, :, :]
        
        # Normalize features for stable training (important for KD!)
        # Use reshape instead of view to handle non-contiguous tensors
        student_feat_flat = student_feat.reshape(B, -1)
        teacher_feat_flat = teacher_feat.reshape(B, -1)
        student_feat_norm = F.normalize(student_feat_flat, dim=1).reshape_as(student_feat)
        teacher_feat_norm = F.normalize(teacher_feat_flat, dim=1).reshape_as(teacher_feat)
        
        # Compute loss
        if self.loss_type == 'mse':
            # Use normalized features for better gradient flow
            loss = self.loss_fn(student_feat_norm, teacher_feat_norm)
        elif self.loss_type == 'cosine':
            # Flatten spatial dimensions
            student_flat = student_feat_norm.reshape(B, -1)  # [B, C*H*W]
            teacher_flat = teacher_feat_norm.reshape(B, -1)  # [B, C*H*W]
            # Compute cosine similarity
            cosine_sim = F.cosine_similarity(student_flat, teacher_flat, dim=1).mean()
            loss = 1.0 - cosine_sim  # Convert similarity to distance
        
        return loss
    
    def forward(self, student_features, teacher_features):
        """
        Args:
            student_features: Dict of student features {layer_name: feature_map}
            teacher_features: Dict of teacher features {layer_name: feature_map}
        
        Returns:
            total_loss: Total feature KD loss
            loss_dict: Dict of per-layer losses
        """
        # Initialize total_loss as None, will be set to first layer_loss tensor
        total_loss = None
        loss_dict = {}
        
        # Match features at corresponding layers
        # For now, we'll match by layer index
        student_keys = sorted(student_features.keys())
        teacher_keys = sorted(teacher_features.keys())
        
        # Match layers (simple 1-to-1 mapping)
        num_layers = min(len(student_keys), len(teacher_keys))
        for i in range(num_layers):
            student_key = student_keys[i] if i < len(student_keys) else student_keys[-1]
            teacher_key = teacher_keys[i] if i < len(teacher_keys) else teacher_keys[-1]
            
            student_feat = student_features[student_key]
            teacher_feat = teacher_features[teacher_key]
            
            layer_loss = self.compute_feature_loss(student_feat, teacher_feat)
            
            # Initialize total_loss with first layer_loss (to get device and dtype)
            if total_loss is None:
                total_loss = layer_loss
            else:
                total_loss = total_loss + layer_loss
            
            # Store layer loss value (convert to float for logging)
            if isinstance(layer_loss, torch.Tensor):
                loss_dict[f'feature_kd_{i}'] = layer_loss.item()
            else:
                loss_dict[f'feature_kd_{i}'] = float(layer_loss)
        
        if total_loss is None:
            # No layers matched, return zero loss
            # Get device from first student feature if available
            if student_keys and len(student_keys) > 0:
                feat_device = student_features[student_keys[0]].device
            elif teacher_keys and len(teacher_keys) > 0:
                feat_device = teacher_features[teacher_keys[0]].device
            else:
                feat_device = torch.device('cpu')
            total_loss = torch.tensor(0.0, device=feat_device, dtype=torch.float32)
            loss_dict['feature_kd_total'] = 0.0
        else:
            total_loss = total_loss / max(num_layers, 1)  # Average over layers
            # Store total loss value
            if isinstance(total_loss, torch.Tensor):
                loss_dict['feature_kd_total'] = total_loss.item()
            else:
                loss_dict['feature_kd_total'] = float(total_loss)
        
        # Ensure lambda_feature is a tensor on the same device
        if isinstance(total_loss, torch.Tensor):
            lambda_feature = torch.tensor(self.lambda_feature, device=total_loss.device, dtype=total_loss.dtype)
            return total_loss * lambda_feature, loss_dict
        else:
            return total_loss * self.lambda_feature, loss_dict


class DistillationLoss(nn.Module):
    """
    Cross-Modal Distillation Loss with GT Supervision
    - Distillation: Learn from teacher's pseudo-labels (SILog only - scale invariant)
    - Feature KD: Learn from teacher's features (for ViT-L teacher)
    - GT Supervision: Learn from ground truth depth (Combined L1 + SILog)
    - Weighted combination of all
    """
    
    def __init__(self, loss_type='combined', lambda_l1=0.5, lambda_silog=0.5, silog_lambda=0.85,
                 use_gt_supervision=True, lambda_distill=0.5, lambda_gt=0.5,
                 use_feature_kd=False, lambda_feature_kd=1.0, feature_kd_loss_type='mse'):
        super().__init__()
        self.loss_type = loss_type
        self.lambda_l1 = lambda_l1
        self.lambda_silog = lambda_silog
        self.use_gt_supervision = use_gt_supervision
        self.lambda_distill = lambda_distill
        self.lambda_gt = lambda_gt
        self.use_feature_kd = use_feature_kd
        
        self.l1_loss = nn.L1Loss()
        self.silog_loss = SIlogLoss(lambda_scale=silog_lambda)
        
        if use_feature_kd:
            self.feature_kd_loss = FeatureKDLoss(
                loss_type=feature_kd_loss_type,
                lambda_feature=lambda_feature_kd
            )
    
    def compute_loss(self, pred, target, valid_mask=None, use_silog_only=False):
        """
        Compute loss between prediction and target
        
        Args:
            pred: Prediction [B, 1, H, W]
            target: Target [B, 1, H, W]
            valid_mask: Valid pixel mask [B, 1, H, W]
            use_silog_only: If True, only use SILog (for pseudo-labels with scale mismatch)
        """
        if valid_mask is None:
            valid_mask = (target > 0) & (target < 100)  # Reasonable depth range
        
        pred_valid = pred[valid_mask]
        target_valid = target[valid_mask]
        
        if len(pred_valid) == 0:
            return torch.tensor(0.0).to(pred.device), {}
        
        # For pseudo-labels (teacher): use SILog only (scale invariant)
        if use_silog_only:
            loss = self.silog_loss(pred_valid, target_valid)
            loss_dict = {'silog': loss.item()}
            return loss, loss_dict
        
        # For GT: use specified loss type
        if self.loss_type == 'l1':
            loss = self.l1_loss(pred_valid, target_valid)
            loss_dict = {'l1': loss.item()}
        elif self.loss_type == 'silog':
            loss = self.silog_loss(pred_valid, target_valid)
            loss_dict = {'silog': loss.item()}
        elif self.loss_type == 'combined':
            l1 = self.l1_loss(pred_valid, target_valid)
            silog = self.silog_loss(pred_valid, target_valid)
            loss = self.lambda_l1 * l1 + self.lambda_silog * silog
            loss_dict = {'l1': l1.item(), 'silog': silog.item(), 'subtotal': loss.item()}
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")
        
        return loss, loss_dict
    
    def forward(self, pred, teacher_target=None, gt_target=None, 
                student_features=None, teacher_features=None):
        """
        Args:
            pred: Student prediction [B, 1, H, W]
            teacher_target: Teacher pseudo-label [B, 1, H, W] (optional, for prediction-level KD)
            gt_target: Ground truth depth [B, 1, H, W] (optional)
            student_features: Dict of student features (optional, for feature-level KD)
            teacher_features: Dict of teacher features (optional, for feature-level KD)
        """
        # Initialize total_loss as None, will be set to first loss tensor
        total_loss = None
        loss_dict = {}
        
        # Feature-level KD loss (for ViT-L teacher)
        if self.use_feature_kd and student_features is not None and teacher_features is not None:
            feature_kd_loss, feature_kd_dict = self.feature_kd_loss(student_features, teacher_features)
            # Initialize total_loss with first loss tensor
            if total_loss is None:
                total_loss = feature_kd_loss
            else:
                total_loss = total_loss + feature_kd_loss
            loss_dict.update(feature_kd_dict)
        
        # Distillation loss (teacher pseudo-label) - SILog only (scale invariant)
        if teacher_target is not None:
            distill_loss, distill_dict = self.compute_loss(pred, teacher_target, use_silog_only=True)
            # Convert lambda to tensor if needed
            if isinstance(distill_loss, torch.Tensor):
                lambda_distill = torch.tensor(self.lambda_distill, device=distill_loss.device, dtype=distill_loss.dtype)
            else:
                lambda_distill = self.lambda_distill
            
            if total_loss is None:
                total_loss = lambda_distill * distill_loss
            else:
                total_loss = total_loss + lambda_distill * distill_loss
            loss_dict['distill'] = distill_loss.item() if isinstance(distill_loss, torch.Tensor) else float(distill_loss)
            for k, v in distill_dict.items():
                loss_dict[f'distill_{k}'] = v
        
        # GT supervision loss - Combined loss (L1 + SILog)
        if self.use_gt_supervision and gt_target is not None:
            gt_loss, gt_dict = self.compute_loss(pred, gt_target, use_silog_only=False)
            # Convert lambda to tensor if needed
            if isinstance(gt_loss, torch.Tensor):
                lambda_gt = torch.tensor(self.lambda_gt, device=gt_loss.device, dtype=gt_loss.dtype)
            else:
                lambda_gt = self.lambda_gt
            
            if total_loss is None:
                total_loss = lambda_gt * gt_loss
            else:
                total_loss = total_loss + lambda_gt * gt_loss
            loss_dict['gt'] = gt_loss.item() if isinstance(gt_loss, torch.Tensor) else float(gt_loss)
            for k, v in gt_dict.items():
                loss_dict[f'gt_{k}'] = v
        
        # If no losses were computed, return zero loss
        if total_loss is None:
            total_loss = torch.tensor(0.0, device=pred.device, dtype=pred.dtype)
            loss_dict['total'] = 0.0
        else:
            loss_dict['total'] = total_loss.item()
        
        return total_loss, loss_dict

# Create loss function
distill_loss_fn = DistillationLoss(
    loss_type=DISTILLATION_LOSS,
    lambda_l1=LAMBDA_L1,
    lambda_silog=LAMBDA_SILOG,
    silog_lambda=SILOG_LAMBDA,
    use_gt_supervision=USE_GT_SUPERVISION,
    lambda_distill=LAMBDA_DISTILL,
    lambda_gt=LAMBDA_GT,
    use_feature_kd=USE_FEATURE_KD and TEACHER_MODEL == 'vitl_feature',
    lambda_feature_kd=FEATURE_KD_LAMBDA,
    feature_kd_loss_type=FEATURE_KD_LOSS_TYPE
).to(device)

print("\n" + "="*80)
print("Distillation Loss Function")
print("="*80)
print(f"Loss type: {DISTILLATION_LOSS}")
if DISTILLATION_LOSS == 'combined':
    print(f"  Œª_L1: {LAMBDA_L1}")
    print(f"  Œª_SIlog: {LAMBDA_SILOG}")
    print(f"  SIlog_Œª: {SILOG_LAMBDA}")
print("")
print("Loss Strategy:")
if USE_FEATURE_KD and TEACHER_MODEL == 'vitl_feature':
    print(f"  üéØ Feature-level KD: {FEATURE_KD_LOSS_TYPE.upper()} loss (Œª={FEATURE_KD_LAMBDA})")
    print(f"    Layers: {FEATURE_KD_LAYERS}")
    print(f"  üéØ GT supervision: {DISTILLATION_LOSS.upper()} loss (Œª={LAMBDA_GT})")
else:
    print(f"  üéØ Distillation (pseudo-label): SILog only (scale invariant, Œª={LAMBDA_DISTILL})")
    print(f"  üéØ GT supervision: {DISTILLATION_LOSS.upper()} loss (Œª={LAMBDA_GT})")
if USE_GT_SUPERVISION:
    print(f"  Œª_GT (ground truth): {LAMBDA_GT}")
else:
    print(f"GT Supervision: Disabled")
print("‚úÖ Loss function created!")



Distillation Loss Function
Loss type: combined
  Œª_L1: 0.5
  Œª_SIlog: 0.5
  SIlog_Œª: 0.85

Loss Strategy:
  üéØ Feature-level KD: COSINE loss (Œª=2.0)
    Layers: [6, 12, 18, 24]
  üéØ GT supervision: COMBINED loss (Œª=1.0)
  Œª_GT (ground truth): 1.0
‚úÖ Loss function created!


# ========== Feature Extraction & Loss Calculation Check ==========
if USE_FEATURE_KD and TEACHER_MODEL == 'vitl_feature' and teacher_model is not None and student_model is not None:
    print("\n" + "="*80)
    print("üîç Feature Extraction & Loss Calculation Check")
    print("="*80)
    
    # Create test data
    test_img = torch.randn(1, 3, 256, 256).to(device)
    test_audio = torch.randn(1, 2, 256, 256).to(device)
    test_depth_gt = torch.randn(1, 1, 256, 256).to(device) * 10.0 + 5.0  # Reasonable depth range
    
    print("\n1Ô∏è‚É£ Teacher Feature Extraction Check:")
    print(f"   Expected feature layers: {FEATURE_KD_LAYERS}")
    teacher_model.eval()
    with torch.no_grad():
        teacher_features = teacher_model(test_img, return_features=True)
    print(f"   ‚úÖ Teacher features extracted: {list(teacher_features.keys())}")
    for k, v in teacher_features.items():
        print(f"      {k}: shape {v.shape}, range [{v.min():.3f}, {v.max():.3f}]")
    
    print("\n2Ô∏è‚É£ Student Feature Extraction Check:")
    print(f"   Expected feature layers: {FEATURE_KD_LAYERS}")
    student_model.eval()
    with torch.no_grad():
        student_output, student_features = student_model(test_audio, return_features=True)
    print(f"   ‚úÖ Student features extracted: {list(student_features.keys())}")
    for k, v in student_features.items():
        print(f"      {k}: shape {v.shape}, range [{v.min():.3f}, {v.max():.3f}]")
    
    print("\n3Ô∏è‚É£ Feature Matching Check:")
    student_keys = sorted(student_features.keys())
    teacher_keys = sorted(teacher_features.keys())
    print(f"   Student keys: {student_keys}")
    print(f"   Teacher keys: {teacher_keys}")
    
    matched_layers = []
    for layer_idx in FEATURE_KD_LAYERS:
        student_key = f'layer_{layer_idx}'
        teacher_key = f'layer_{layer_idx}'
        if student_key in student_features and teacher_key in teacher_features:
            matched_layers.append((student_key, teacher_key))
            print(f"   ‚úÖ Layer {layer_idx}: Matched!")
            s_feat = student_features[student_key]
            t_feat = teacher_features[teacher_key]
            print(f"      Student: {s_feat.shape}, Teacher: {t_feat.shape}")
        else:
            print(f"   ‚ùå Layer {layer_idx}: NOT MATCHED!")
            if student_key not in student_features:
                print(f"      Student key '{student_key}' not found!")
            if teacher_key not in teacher_features:
                print(f"      Teacher key '{teacher_key}' not found!")
    
    print(f"\n   Total matched layers: {len(matched_layers)}/{len(FEATURE_KD_LAYERS)}")
    
    print("\n4Ô∏è‚É£ Feature KD Loss Calculation Check:")
    if USE_FEATURE_KD:
        feature_kd_loss_fn = FeatureKDLoss(
            loss_type=FEATURE_KD_LOSS_TYPE,
            lambda_feature=FEATURE_KD_LAMBDA
        ).to(device)
        
        feature_kd_loss, feature_kd_dict = feature_kd_loss_fn(student_features, teacher_features)
        print(f"   ‚úÖ Feature KD loss calculated: {feature_kd_loss.item():.6f}")
        print(f"   Loss breakdown:")
        for k, v in feature_kd_dict.items():
            print(f"      {k}: {v:.6f}")
        
        if feature_kd_loss.item() == 0.0:
            print("   ‚ö†Ô∏è  WARNING: Feature KD loss is 0! Features may not be matched properly.")
        elif feature_kd_loss.item() > 100:
            print("   ‚ö†Ô∏è  WARNING: Feature KD loss is very large! May need normalization adjustment.")
        else:
            print("   ‚úÖ Feature KD loss is in reasonable range.")
    
    print("\n5Ô∏è‚É£ Full Distillation Loss Check:")
    # Forward pass
    student_model.train()
    student_output, student_features = student_model(test_audio, return_features=True)
    
    with torch.no_grad():
        teacher_features = teacher_model(test_img, return_features=True)
    
    total_loss, loss_dict = distill_loss_fn(
        pred=student_output,
        teacher_target=None,
        gt_target=test_depth_gt,
        student_features=student_features,
        teacher_features=teacher_features
    )
    
    print(f"   ‚úÖ Total loss: {total_loss.item():.6f}")
    print(f"   Loss breakdown:")
    for k, v in loss_dict.items():
        if isinstance(v, (int, float)):
            print(f"      {k}: {v:.6f}")
        else:
            print(f"      {k}: {v}")
    
    print("\n6Ô∏è‚É£ Loss Weight Summary:")
    print(f"   Œª_feature_KD: {FEATURE_KD_LAMBDA}")
    print(f"   Œª_GT: {LAMBDA_GT}")
    print(f"   Œª_distill: {LAMBDA_DISTILL}")
    if 'feature_kd_total' in loss_dict:
        feature_kd_contribution = loss_dict['feature_kd_total'] * FEATURE_KD_LAMBDA
        gt_contribution = loss_dict.get('gt', 0) * LAMBDA_GT
        print(f"   Feature KD contribution: {feature_kd_contribution:.6f}")
        print(f"   GT contribution: {gt_contribution:.6f}")
        if feature_kd_contribution > 0:
            ratio = gt_contribution / feature_kd_contribution if feature_kd_contribution > 0 else float('inf')
            print(f"   GT/Feature_KD ratio: {ratio:.3f}")
    
    print("\n" + "="*80)
    print("‚úÖ Feature Extraction & Loss Calculation Check Complete!")
    print("="*80)
else:
    print("‚ö†Ô∏è  Feature KD check skipped (not using feature KD or models not loaded)")

## üöÄ Depth Any Audio Training Loop



In [None]:
# ========== Feature Extraction & Loss Calculation Check ==========
if USE_FEATURE_KD and TEACHER_MODEL == 'vitl_feature' and teacher_model is not None and student_model is not None:
    print("\n" + "="*80)
    print("üîç Feature Extraction & Loss Calculation Check")
    print("="*80)
    
    # Create test data
    test_img = torch.randn(1, 3, 256, 256).to(device)
    test_audio = torch.randn(1, 2, 256, 256).to(device)
    test_depth_gt = torch.randn(1, 1, 256, 256).to(device) * 10.0 + 5.0  # Reasonable depth range
    
    print("\n1Ô∏è‚É£ Teacher Feature Extraction Check:")
    print(f"   Expected feature layers: {FEATURE_KD_LAYERS}")
    teacher_model.eval()
    with torch.no_grad():
        teacher_features = teacher_model(test_img, return_features=True)
    print(f"   ‚úÖ Teacher features extracted: {list(teacher_features.keys())}")
    for k, v in teacher_features.items():
        print(f"      {k}: shape {v.shape}, range [{v.min():.3f}, {v.max():.3f}]")
    
    print("\n2Ô∏è‚É£ Student Feature Extraction Check:")
    print(f"   Expected feature layers: {FEATURE_KD_LAYERS}")
    student_model.eval()
    with torch.no_grad():
        student_output, student_features = student_model(test_audio, return_features=True)
    print(f"   ‚úÖ Student features extracted: {list(student_features.keys())}")
    for k, v in student_features.items():
        print(f"      {k}: shape {v.shape}, range [{v.min():.3f}, {v.max():.3f}]")
    
    print("\n3Ô∏è‚É£ Feature Matching Check:")
    student_keys = sorted(student_features.keys())
    teacher_keys = sorted(teacher_features.keys())
    print(f"   Student keys: {student_keys}")
    print(f"   Teacher keys: {teacher_keys}")
    
    matched_layers = []
    for layer_idx in FEATURE_KD_LAYERS:
        student_key = f'layer_{layer_idx}'
        teacher_key = f'layer_{layer_idx}'
        if student_key in student_features and teacher_key in teacher_features:
            matched_layers.append((student_key, teacher_key))
            print(f"   ‚úÖ Layer {layer_idx}: Matched!")
            s_feat = student_features[student_key]
            t_feat = teacher_features[teacher_key]
            print(f"      Student: {s_feat.shape}, Teacher: {t_feat.shape}")
        else:
            print(f"   ‚ùå Layer {layer_idx}: NOT MATCHED!")
            if student_key not in student_features:
                print(f"      Student key '{student_key}' not found!")
            if teacher_key not in teacher_features:
                print(f"      Teacher key '{teacher_key}' not found!")
    
    print(f"\n   Total matched layers: {len(matched_layers)}/{len(FEATURE_KD_LAYERS)}")
    
    print("\n4Ô∏è‚É£ Feature KD Loss Calculation Check:")
    if USE_FEATURE_KD:
        feature_kd_loss_fn = FeatureKDLoss(
            loss_type=FEATURE_KD_LOSS_TYPE,
            lambda_feature=FEATURE_KD_LAMBDA
        ).to(device)
        
        feature_kd_loss, feature_kd_dict = feature_kd_loss_fn(student_features, teacher_features)
        print(f"   ‚úÖ Feature KD loss calculated: {feature_kd_loss.item():.6f}")
        print(f"   Loss breakdown:")
        for k, v in feature_kd_dict.items():
            print(f"      {k}: {v:.6f}")
        
        if feature_kd_loss.item() == 0.0:
            print("   ‚ö†Ô∏è  WARNING: Feature KD loss is 0! Features may not be matched properly.")
        elif feature_kd_loss.item() > 100:
            print("   ‚ö†Ô∏è  WARNING: Feature KD loss is very large! May need normalization adjustment.")
        else:
            print("   ‚úÖ Feature KD loss is in reasonable range.")
    
    print("\n5Ô∏è‚É£ Full Distillation Loss Check:")
    # Forward pass
    student_model.train()
    student_output, student_features = student_model(test_audio, return_features=True)
    
    with torch.no_grad():
        teacher_features = teacher_model(test_img, return_features=True)
    
    total_loss, loss_dict = distill_loss_fn(
        pred=student_output,
        teacher_target=None,
        gt_target=test_depth_gt,
        student_features=student_features,
        teacher_features=teacher_features
    )
    
    print(f"   ‚úÖ Total loss: {total_loss.item():.6f}")
    print(f"   Loss breakdown:")
    for k, v in loss_dict.items():
        if isinstance(v, (int, float)):
            print(f"      {k}: {v:.6f}")
        else:
            print(f"      {k}: {v}")
    
    print("\n6Ô∏è‚É£ Loss Weight Summary:")
    print(f"   Œª_feature_KD: {FEATURE_KD_LAMBDA}")
    print(f"   Œª_GT: {LAMBDA_GT}")
    print(f"   Œª_distill: {LAMBDA_DISTILL}")
    if 'feature_kd_total' in loss_dict:
        feature_kd_contribution = loss_dict['feature_kd_total'] * FEATURE_KD_LAMBDA
        gt_contribution = loss_dict.get('gt', 0) * LAMBDA_GT
        print(f"   Feature KD contribution: {feature_kd_contribution:.6f}")
        print(f"   GT contribution: {gt_contribution:.6f}")
        if feature_kd_contribution > 0:
            ratio = gt_contribution / feature_kd_contribution if feature_kd_contribution > 0 else float('inf')
            print(f"   GT/Feature_KD ratio: {ratio:.3f}")
    
    print("\n" + "="*80)
    print("‚úÖ Feature Extraction & Loss Calculation Check Complete!")
    print("="*80)
else:
    print("‚ö†Ô∏è  Feature KD check skipped (not using feature KD or models not loaded)")

## üöÄ Depth Any Audio Training Loop



In [None]:
# Create dataloaders
print("\n" + "="*80)
print("üì¶ Creating Dataloaders")
print("="*80)

# Determine locations based on USE_ALL_LOCATIONS
if USE_ALL_LOCATIONS:
    print(f"Loading ALL locations for distillation... ({len(locations)} locations)")
    train_locations = locations  # Use predefined locations list
    val_locations = locations
else:
    print(f"Loading single location: {LOCATION}")
    train_locations = LOCATION
    val_locations = LOCATION

# Create datasets (now supports multiple locations internally)
train_dataset = DepthAnyAudioDataset(
    root_dir=ROOT_DIR,
    locations=train_locations,
    split='train',
    max_depth=MAX_DEPTH,
    img_size=IMG_SIZE
)

val_dataset = DepthAnyAudioDataset(
    root_dir=ROOT_DIR,
    locations=val_locations,
    split='val',
    max_depth=MAX_DEPTH,
    img_size=IMG_SIZE
)

train_loader = DataLoader(train_dataset, batch_size=DISTILL_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=False)
val_loader = DataLoader(val_dataset, batch_size=DISTILL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=False)

print(f"\n‚úÖ Dataloaders created:")
print(f"  Train: {len(train_dataset)} samples, {len(train_loader)} batches")
print(f"  Val: {len(val_dataset)} samples, {len(val_loader)} batches")

# Optimizer and scheduler
optimizer = torch.optim.AdamW(student_model.parameters(), lr=DISTILL_LR, weight_decay=DISTILL_WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=DISTILL_EPOCHS)

print(f"\n‚úÖ Training setup complete!")
print(f"  Optimizer: AdamW (lr={DISTILL_LR}, wd={DISTILL_WEIGHT_DECAY})")
print(f"  Scheduler: CosineAnnealingLR")



üì¶ Creating Dataloaders
Loading ALL locations for distillation... (7 locations)
Loaded 431 samples from 2ndFloorLuxembourg (train)
Loaded 290 samples from 3rd_Floor_Luxembourg (train)
Loaded 37 samples from Attic (train)
Loaded 377 samples from Outdoor_Cobblestone_Path (train)
Loaded 116 samples from Salle_Chevalier (train)
Loaded 240 samples from Salle_des_Colonnes (train)
Loaded 420 samples from V119_Cake_Corridors (train)
‚úÖ Total: 1911 samples from 7 location(s)
Loaded 132 samples from 2ndFloorLuxembourg (val)
Loaded 98 samples from 3rd_Floor_Luxembourg (val)
Loaded 12 samples from Attic (val)
Loaded 125 samples from Outdoor_Cobblestone_Path (val)
Loaded 40 samples from Salle_Chevalier (val)
Loaded 80 samples from Salle_des_Colonnes (val)
Loaded 138 samples from V119_Cake_Corridors (val)
‚úÖ Total: 625 samples from 7 location(s)

‚úÖ Dataloaders created:
  Train: 1911 samples, 956 batches
  Val: 625 samples, 313 batches

‚úÖ Training setup complete!
  Optimizer: AdamW (lr=0.000

In [None]:
def normalize_depth_to_gt(teacher_depth, max_depth=30.0):
    """
    Normalize teacher depth to match GT scale [0, max_depth]
    Using min-max normalization per batch
    
    Args:
        teacher_depth: [B, 1, H, W] teacher depth predictions
        max_depth: Maximum depth value (default: 30.0m)
    
    Returns:
        Normalized depth in range [0, max_depth]
    """
    normalized = torch.zeros_like(teacher_depth)
    for b in range(teacher_depth.shape[0]):
        t_min = teacher_depth[b].min()
        t_max = teacher_depth[b].max()
        if t_max > t_min:
            # Normalize to [0, max_depth]
            normalized[b] = (teacher_depth[b] - t_min) / (t_max - t_min) * max_depth
        else:
            # If constant depth, set to 0
            normalized[b] = 0.0
    return normalized

def train_depth_any_audio(teacher, student, train_loader, val_loader, 
                          optimizer, scheduler, loss_fn, 
                          num_epochs, device, max_depth=30.0, print_every=5,
                          use_feature_kd=False, teacher_model_type='depthanything_v2_vitl',
                          grad_accum=1, use_mixed_precision=False):
    """
    Depth Any Audio Training
    - With distillation: Teacher generates proxy depth labels from RGB + GT supervision
    - With feature KD: Teacher extracts features from spectrogram + GT supervision
    - Without distillation: Only GT supervision (supervised learning)
    
    Teacher depth is normalized to GT scale using min-max normalization
    """
    
    use_teacher = teacher is not None
    use_feature_level_kd = use_feature_kd and teacher_model_type == 'vitl_feature'
    
    print("\n" + "="*80)
    if use_teacher:
        if use_feature_level_kd:
            print("üéµ Starting Depth Any Audio Training (Feature-Level KD)")
            print("  ‚úì Teacher: ViT-L (feature extraction from spectrogram)")
            print("  ‚úì Student: Audio U-Net (feature-level knowledge distillation)")
        else:
            print("üéµ Starting Depth Any Audio Training (Cross-Modal Distillation)")
            print("  ‚úì Teacher depth normalization: Enabled (min-max to GT scale)")
    else:
        print("üìö Starting Supervised Training (GT only)")
    print("="*80)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_rmse': [], 'val_delta1': []}
    
    # Mixed precision scaler (fix deprecation warning)
    if use_mixed_precision:
        try:
            scaler = torch.amp.GradScaler('cuda')
        except AttributeError:
            # Fallback for older PyTorch versions
            scaler = torch.cuda.amp.GradScaler()
    else:
        scaler = None
    
    if grad_accum > 1:
        print(f"  üìä Gradient accumulation: {grad_accum} steps (effective batch size: {DISTILL_BATCH_SIZE * grad_accum})")
    if use_mixed_precision:
        print(f"  ‚ö° Mixed precision training: Enabled")
    
    for epoch in range(1, num_epochs + 1):
        # ========== Training ==========
        student.train()
        if use_teacher:
            teacher.eval()
        
        train_losses = []
        optimizer.zero_grad()  # Initialize gradients
        
        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f'Epoch {epoch}/{num_epochs} [Train]', leave=False)):
            images = batch['image'].to(device) if use_teacher else None
            audios = batch['audio'].to(device)
            depth_gt = batch['depth_gt'].to(device)
            
            # Step 1: Teacher processing (RGB Ïù¥ÎØ∏ÏßÄ ÏûÖÎ†•)
            # Clear cache before teacher forward
            torch.cuda.empty_cache()
            
            if use_teacher:
                with torch.no_grad():
                    if use_feature_level_kd:
                        # Feature-level KD: Teacher extracts features from RGB images
                        # Use autocast for teacher too to save memory
                        with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                            teacher_features = teacher(images, return_features=True)  # Dict of features from RGB
                        teacher_depth = None
                    else:
                        # Prediction-level KD: Teacher generates depth from RGB
                        with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                            teacher_depth = teacher(images)  # [B, 1, H, W]
                        # Normalize teacher depth to GT scale [0, max_depth]
                        teacher_depth = normalize_depth_to_gt(teacher_depth, max_depth)
                        teacher_features = None
            else:
                teacher_depth = None
                teacher_features = None
            
            # Step 2: Student predicts depth from audio (Ïä§ÌéôÌä∏Î°úÍ∑∏Îû® ÏûÖÎ†•)
            # Use autocast for mixed precision
            with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                if use_feature_level_kd:
                    # Student also returns features for feature-level KD
                    student_depth, student_features = student(audios, return_features=True)
                else:
                    student_depth = student(audios)  # [B, 1, H, W]
                    student_features = None
                
                # Step 3: Compute loss
                if use_feature_level_kd:
                    # Feature-level KD: Use teacher and student features
                    loss, loss_dict = loss_fn(
                        pred=student_depth,
                        teacher_target=None,  # No prediction-level KD
                        gt_target=depth_gt,
                        student_features=student_features,
                        teacher_features=teacher_features
                    )
                else:
                    # Prediction-level KD or supervised: Use teacher depth prediction
                    loss, loss_dict = loss_fn(
                        pred=student_depth,
                        teacher_target=teacher_depth,
                        gt_target=depth_gt,
                        student_features=None,
                        teacher_features=None
                    )
            
            # Clear cache more frequently to free memory
            if batch_idx % 5 == 0:
                torch.cuda.empty_cache()
            
            # Delete intermediate variables to free memory
            del student_depth
            if student_features is not None:
                del student_features
            if teacher_features is not None:
                del teacher_features
            if teacher_depth is not None:
                del teacher_depth
            
            # Step 4: Backprop and optimize (with gradient accumulation and mixed precision)
            # Scale loss for gradient accumulation
            loss = loss / grad_accum
            
            if use_mixed_precision:
                # Mixed precision training
                scaler.scale(loss).backward()
                
                # Update weights every grad_accum steps
                if (batch_idx + 1) % grad_accum == 0 or (batch_idx + 1) == len(train_loader):
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                # Standard training
                loss.backward()
                
                # Update weights every grad_accum steps
                if (batch_idx + 1) % grad_accum == 0 or (batch_idx + 1) == len(train_loader):
                    torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
                    optimizer.step()
                    optimizer.zero_grad()
            
            train_losses.append(loss.item() * grad_accum)  # Scale back for logging
        
        train_loss = np.mean(train_losses)
        history['train_loss'].append(train_loss)
        
        # ========== Validation ==========
        student.eval()
        val_losses = []
        val_errors = []
        
        # Clear cache before validation
        torch.cuda.empty_cache()
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f'Epoch {epoch}/{num_epochs} [Val]', leave=False):
                images = batch['image'].to(device) if use_teacher else None
                audios = batch['audio'].to(device)
                depth_gt = batch['depth_gt'].to(device)
                
                # Teacher processing (RGB Ïù¥ÎØ∏ÏßÄ ÏûÖÎ†•)
                if use_teacher:
                    with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                        if use_feature_level_kd:
                            teacher_features = teacher(images, return_features=True)  # RGBÏóêÏÑú feature Ï∂îÏ∂ú
                            teacher_depth = None
                        else:
                            teacher_depth = teacher(images)
                            teacher_depth = normalize_depth_to_gt(teacher_depth, max_depth)
                            teacher_features = None
                else:
                    teacher_depth = None
                    teacher_features = None
                
                # Student prediction (Audio Ïä§ÌéôÌä∏Î°úÍ∑∏Îû® ÏûÖÎ†•)
                with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                    if use_feature_level_kd:
                        student_depth, student_features = student(audios, return_features=True)
                    else:
                        student_depth = student(audios)
                        student_features = None
                
                # Clear cache periodically during validation
                torch.cuda.empty_cache()
                
                # Compute loss
                if use_feature_level_kd:
                    loss, _ = loss_fn(
                        pred=student_depth,
                        teacher_target=None,
                        gt_target=depth_gt,
                        student_features=student_features,
                        teacher_features=teacher_features
                    )
                else:
                    loss, _ = loss_fn(
                        pred=student_depth,
                        teacher_target=teacher_depth,
                        gt_target=depth_gt,
                        student_features=None,
                        teacher_features=None
                    )
                val_losses.append(loss.item())
                
                # Metrics against GT (if available)
                if depth_gt.sum() > 0:
                    for i in range(student_depth.shape[0]):
                        pred = student_depth[i, 0].cpu().numpy()
                        gt = depth_gt[i, 0].cpu().numpy()
                        
                        if gt.max() > 0:
                            errors = compute_errors(gt, pred, min_depth_threshold=0.1)
                            val_errors.append(errors)
        
        val_loss = np.mean(val_losses)
        history['val_loss'].append(val_loss)
        
        # Compute metrics
        if len(val_errors) > 0:
            mean_errors = np.array(val_errors).mean(0)
            abs_rel, rmse, delta1 = mean_errors[0], mean_errors[1], mean_errors[2]
            history['val_rmse'].append(rmse)
            history['val_delta1'].append(delta1)
        else:
            abs_rel, rmse, delta1 = 0, 0, 0
        
        # Learning rate step
        scheduler.step()
        
        # Print progress
        if epoch % print_every == 0 or epoch == 1:
            print(f"\nEpoch [{epoch}/{num_epochs}]")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss:   {val_loss:.4f}")
            if len(val_errors) > 0:
                print(f"  RMSE:       {rmse:.3f}m")
                print(f"  ABS_REL:    {abs_rel:.4f}")
                print(f"  Delta1:     {delta1:.4f}")
            print(f"  LR:         {scheduler.get_last_lr()[0]:.6f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'student_state_dict': student.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, 'best_depth_any_audio.pth')
            print(f"  üéØ Best model saved! (val_loss={val_loss:.4f})")
    
    print("\n" + "="*80)
    print("‚úÖ Training Complete!")
    print("="*80)
    print(f"Best Val Loss: {best_val_loss:.4f}")
    
    return history



In [None]:
# Start training

history = train_depth_any_audio(
    teacher=teacher_model,
    student=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=distill_loss_fn,
    num_epochs=DISTILL_EPOCHS,
    device=device,
    max_depth=MAX_DEPTH,
    print_every=PRINT_EVERY if 'PRINT_EVERY' in globals() else 5,
    use_feature_kd=USE_FEATURE_KD,
    teacher_model_type=TEACHER_MODEL,
    grad_accum=DISTILL_GRAD_ACCUM if 'DISTILL_GRAD_ACCUM' in globals() else 1,
    use_mixed_precision=USE_MIXED_PRECISION if 'USE_MIXED_PRECISION' in globals() else False
)



üéµ Starting Depth Any Audio Training (Feature-Level KD)
  ‚úì Teacher: ViT-L (feature extraction from spectrogram)
  ‚úì Student: Audio U-Net (feature-level knowledge distillation)
  üìä Gradient accumulation: 8 steps (effective batch size: 16)
  ‚ö° Mixed precision training: Enabled


  with torch.cuda.amp.autocast(enabled=use_mixed_precision):
  with torch.cuda.amp.autocast(enabled=use_mixed_precision):
  with torch.cuda.amp.autocast(enabled=use_mixed_precision):
  with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                                                                    


Epoch [1/100]
  Train Loss: 2.5904
  Val Loss:   2.2294
  RMSE:       2.710m
  ABS_REL:    0.5560
  Delta1:     0.3668
  LR:         0.000100
  üéØ Best model saved! (val_loss=2.2294)


                                                                      

  üéØ Best model saved! (val_loss=2.1994)


                                                                      

  üéØ Best model saved! (val_loss=2.1026)


                                                                      


Epoch [5/100]
  Train Loss: 1.8569
  Val Loss:   2.0978
  RMSE:       2.524m
  ABS_REL:    0.5167
  Delta1:     0.3935
  LR:         0.000099
  üéØ Best model saved! (val_loss=2.0978)


                                                                      

  üéØ Best model saved! (val_loss=2.0359)


                                                                      

  üéØ Best model saved! (val_loss=1.9718)


                                                                       


Epoch [10/100]
  Train Loss: 1.6072
  Val Loss:   2.0033
  RMSE:       2.401m
  ABS_REL:    0.5154
  Delta1:     0.4428
  LR:         0.000098


                                                                       

  üéØ Best model saved! (val_loss=1.9629)


                                                                       


Epoch [15/100]
  Train Loss: 1.3849
  Val Loss:   1.9585
  RMSE:       2.321m
  ABS_REL:    0.4465
  Delta1:     0.4591
  LR:         0.000095
  üéØ Best model saved! (val_loss=1.9585)


                                                                       


Epoch [20/100]
  Train Loss: 1.1966
  Val Loss:   2.0264
  RMSE:       2.367m
  ABS_REL:    0.5258
  Delta1:     0.4481
  LR:         0.000090


                                                                       


Epoch [25/100]
  Train Loss: 1.0335
  Val Loss:   2.0191
  RMSE:       2.317m
  ABS_REL:    0.4600
  Delta1:     0.4736
  LR:         0.000085


Epoch 26/100 [Train]:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 857/956 [01:51<00:13,  7.31it/s]

## üìä Visualization & Evaluation


In [None]:
# Visualize predictions
print("\n" + "="*80)
print("üñºÔ∏è  Visualizing Predictions")
print("="*80)

student_model.eval()
teacher_model.eval()

with torch.no_grad():
    # Get a batch
    batch = next(iter(val_loader))
    images = batch['image'][:4].to(device)
    audios = batch['audio'][:4].to(device)
    depth_gt = batch['depth_gt'][:4]
    
    # Teacher prediction (from RGB)
    teacher_pred = teacher_model(images).cpu()
    
    # Student prediction (from Audio)
    student_pred = student_model(audios).cpu()
    
    # Plot
    fig, axes = plt.subplots(4, 4, figsize=(16, 16))
    
    for i in range(4):
        # RGB Image
        axes[i, 0].imshow(images[i].cpu().permute(1, 2, 0))
        axes[i, 0].set_title('RGB Input (Teacher)')
        axes[i, 0].axis('off')
        
        # Teacher Depth
        axes[i, 1].imshow(teacher_pred[i, 0], cmap='magma', vmin=0, vmax=MAX_DEPTH)
        axes[i, 1].set_title('Teacher Depth (VFM)')
        axes[i, 1].axis('off')
        
        # Student Depth
        axes[i, 2].imshow(student_pred[i, 0], cmap='magma', vmin=0, vmax=MAX_DEPTH)
        axes[i, 2].set_title('Student Depth (Audio)')
        axes[i, 2].axis('off')
        
        # Ground Truth
        axes[i, 3].imshow(depth_gt[i, 0], cmap='magma', vmin=0, vmax=MAX_DEPTH)
        axes[i, 3].set_title('Ground Truth')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig('depth_any_audio_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()

print("\n‚úÖ Predictions saved to: depth_any_audio_predictions.png")


In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curves
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# RMSE
if len(history['val_rmse']) > 0:
    axes[1].plot(history['val_rmse'], label='RMSE', color='orange', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('RMSE (meters)')
    axes[1].set_title('Validation RMSE')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()

# Delta1
if len(history['val_delta1']) > 0:
    axes[2].plot(history['val_delta1'], label='Delta1', color='green', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Delta1 (accuracy)')
    axes[2].set_title('Validation Delta1')
    axes[2].grid(True, alpha=0.3)
    axes[2].legend()

plt.tight_layout()
plt.savefig('depth_any_audio_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ Training curves saved to: depth_any_audio_training_curves.png")


## üìù Summary & Next Steps

### üéØ What We Built: **Depth Any Audio**

Inspired by [Depth AnyEvent (ICCV 2025)](https://github.com/bartn8/depthanyevent), we implemented cross-modal distillation for audio-based depth estimation:

#### Architecture:
- **Teacher**: Vision Foundation Model (Depth Anything V2) - processes RGB images
- **Student**: Audio U-Net - processes binaural audio spectrograms
- **Training**: Teacher generates proxy depth labels ‚Üí Student learns from audio

#### Key Innovation:
- **No depth annotations required** during training!
- Teacher VFM provides supervision from RGB images
- Student learns depth estimation from spatially-aligned binaural audio

#### Advantages:
1. ‚úÖ Leverages powerful VFMs trained on large-scale image data
2. ‚úÖ Eliminates need for expensive depth sensors during training
3. ‚úÖ Cross-modal knowledge transfer (vision ‚Üí audio)
4. ‚úÖ Works with BatvisionV2's naturally aligned RGB-Audio-Depth data

### üìä Expected Results:
- **Teacher (RGB)**: Near-perfect depth estimation (VFM)
- **Student (Audio)**: Competitive depth estimation without depth supervision
- **Gap**: Student < Teacher, but better than random initialization

### üöÄ Next Steps:
1. Train with full dataset (all locations)
2. Compare against supervised baseline (GT depth labels)
3. Add data augmentation (audio + image)
4. Try different teacher models (DINOv2, SAM, etc.)
5. Implement recurrent architecture (like Depth AnyEvent's RNN)

### üìö References:
- [Depth AnyEvent (ICCV 2025)](https://github.com/bartn8/depthanyevent)
- [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2)
- [BatVision Dataset (IROS 2023)](https://amandinebtto.github.io/Batvision-Dataset/)
