In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
from pathlib import Path
import h5py
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

print("✓ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU (training will be slower)")

In [None]:
import os
from pathlib import Path
from PIL import Image
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split

def process_hypersim_scene(scene_path, output_dir='processed_data'):
    scene_path = Path(scene_path)
    output_dir = Path(output_dir)
    
    # Create output directories
    (output_dir / 'rgb').mkdir(parents=True, exist_ok=True)
    (output_dir / 'depth').mkdir(parents=True, exist_ok=True)
    
    print(f"Processing scene: {scene_path}")
    print(f"Output directory: {output_dir}")
    
    # Find all set directories (set1, set2, set3, ...)
    set_dirs = [d for d in scene_path.iterdir() if d.is_dir()]
    
    if not set_dirs:
        print("ERROR: No set directories found!")
        return []
    
    processed_samples = []
    skipped = 0
    
    for set_dir in tqdm(set_dirs, desc="Processing sets"):
        print(f"Processing set directory: {set_dir}")
        
        # Find all depth files (frame.XXXX.depth_meters.png)
        depth_files = sorted(set_dir.glob("frame.*.depth_meters.png"))
        print(f"Found {len(depth_files)} depth files in {set_dir}")
        
        if len(depth_files) == 0:
            print(f"ERROR: No depth files found in {set_dir}!")
            continue
        
        for depth_file in tqdm(depth_files, desc=f"Processing frames in {set_dir.name}"):
            try:
                # Extract frame number from filename (e.g., frame.0000.depth_meters.png -> 0000)
                frame_num = depth_file.stem.split('.')[1]
                
                # Find corresponding RGB image (frame.XXXX.color.jpg)
                rgb_file = set_dir / f"frame.{frame_num}.color.jpg"
                
                if not rgb_file.exists():
                    print(f"Warning: No RGB file for frame {frame_num}, skipping...")
                    skipped += 1
                    continue
                
                # Load RGB
                rgb = Image.open(rgb_file).convert('RGB')
                
                # Load depth (PNG format - already in meters)
                depth_img = Image.open(depth_file)
                depth = np.array(depth_img).astype(np.float32)
                
                # Handle different depth formats (if depth is RGB, take the first channel)
                if depth.ndim == 3:
                    depth = depth[:, :, 0]  # Take first channel if RGB
                
                # HyperSim depth_meters.png is usually in meters
                depth = np.nan_to_num(depth, nan=0.0, posinf=50.0, neginf=0.0)
                
                # Some HyperSim depth files are scaled differently (in mm)
                if depth.max() > 1000:
                    depth = depth / 1000.0  # Convert mm to meters
                
                # Clip to reasonable range (0.1m to 50m)
                depth = np.clip(depth, 0.1, 50.0)
                
                # Save processed files with set directory name and frame number
                # Use set directory name and frame number to avoid conflicts
                output_name = f"{set_dir.name}_frame_{frame_num}.png"
                
                # Save RGB image
                rgb.save(output_dir / 'rgb' / output_name)
                
                # Save depth image as 16-bit PNG (preserving precision)
                depth_normalized = ((depth - 0.1) / (50.0 - 0.1) * 65535).astype(np.uint16)
                Image.fromarray(depth_normalized).save(output_dir / 'depth' / output_name)
                
                processed_samples.append(output_name)
            
            except Exception as e:
                print(f"\nError processing {depth_file.name}: {e}")
                skipped += 1
                continue
    
    print(f"\n✓ Processing complete!")
    print(f"  Successfully processed: {len(processed_samples)} image pairs")
    print(f"  Skipped: {skipped} frames")
    
    if len(processed_samples) == 0:
        print("\n ERROR: No samples were processed successfully!")
        print("Please check:")
        print("  1. Depth files exist: frame.XXXX.depth_meters.png")
        print("  2. RGB files exist: frame.XXXX.color.jpg")
        print("  3. Frame numbers match between RGB and depth")
    
    return processed_samples


# Process the dataset
dataset_path = Path(r"dataset")

print("="*70)
print("DATASET PROCESSING")
print("="*70)

# Check if dataset exists
if not dataset_path.exists():
    raise FileNotFoundError(f"Dataset not found at: {dataset_path.absolute()}")

print(f"✓ Dataset directory found: {dataset_path.absolute()}")

# List some files to verify structure
print("\nChecking dataset structure...")
sample_files = list(dataset_path.glob("*.png"))[:5]
print(f"Sample files found:")
for f in sample_files:
    print(f"  - {f.name}")

# Process and extract RGB and depth images
samples = process_hypersim_scene(dataset_path, output_dir='processed_data')

if len(samples) == 0:
    raise ValueError("No samples were processed! Check the error messages above.")

# Create train/val split (80/20)
if len(samples) < 2:
    print(f"\n  Warning: Only {len(samples)} sample(s) found. Using all for training.")
    train_samples = samples
    val_samples = samples[:max(1, len(samples)//5)]  # At least 1 sample for validation
else:
    train_samples, val_samples = train_test_split(samples, test_size=0.2, random_state=42)

print(f"\n" + "="*70)
print("DATASET SPLIT")
print("="*70)
print(f"Total samples: {len(samples)}")
print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")

# Save split info
with open('train_split.txt', 'w') as f:
    f.write('\n'.join(train_samples))

with open('val_split.txt', 'w') as f:
    f.write('\n'.join(val_samples))

print("\n✓ Split files saved:")
print("  - train_split.txt")
print("  - val_split.txt")
print("="*70)


In [None]:
class EdgeAttentionModule(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        
        self.edge_conv = nn.Sequential(
            # Layer 1: Detect basic edges
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            
            # Layer 2: Refine edges (NEW)
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            
            # Layer 3: Output attention
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
        # Keep Sobel initialization
        self._init_as_edge_detector()
    
    def _init_as_edge_detector(self):
        with torch.no_grad():
            # Get first conv layer
            first_conv = self.edge_conv[0]
            
            # For each output channel, initialize some filters as Sobel-like
            for i in range(min(8, first_conv.out_channels)):
                # Horizontal edge detector (Sobel-X inspired)
                if i % 2 == 0:
                    sobel_x = torch.tensor([
                        [-1., 0., 1.],
                        [-2., 0., 2.],
                        [-1., 0., 1.]
                    ])
                    first_conv.weight[i, 0, :, :] = sobel_x
                # Vertical edge detector (Sobel-Y inspired)
                else:
                    sobel_y = torch.tensor([
                        [-1., -2., -1.],
                        [0., 0., 0.],
                        [1., 2., 1.]
                    ])
                    first_conv.weight[i, 0, :, :] = sobel_y
    
    def forward(self, x):
        edge_weights = self.edge_conv(x)
        return x * edge_weights


class ResidualBlock(nn.Module):
    """Standard residual block for stable training"""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)


print("✓ Edge Attention Module defined!")
print("\nHow it works:")
print("  1. Two Conv2d(3x3) layers learn edge patterns")
print("  2. Initialized similar to Sobel (but trainable)")
print("  3. Outputs attention weights (0=flat, 1=edge)")
print("  4. Multiplies features: amplifies edges, suppresses flat regions")


In [None]:
class DepthNormalNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Channel Sequence: 32, 64, 128, 256, 512 (Bottleneck)
        
        # ===== ENCODER (Downsampling) =====
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 7, padding=3),  # 3 -> 32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            ResidualBlock(32)
        )
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 32 -> 64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            ResidualBlock(64),
            EdgeAttentionModule(64)
        )
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 64 -> 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            ResidualBlock(128),
            EdgeAttentionModule(128)
        )
        
        self.enc4 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),  # 128 -> 256
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            ResidualBlock(256)
        )
        
        # ===== BOTTLENECK =====
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1),  # 256 -> 512
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            ResidualBlock(512)
        )
        
        # ===== DECODER (Upsampling) =====
        
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 512 -> 256
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        # Skip connection size: 256 (dec4 out) + 256 (enc4 skip) = 512 (dec3 in)
        
        self.dec3 = nn.Sequential(
            nn.Conv2d(512, 128, 3, padding=1),  # 512 -> 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1), # 128 -> 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        # Skip connection size: 128 (dec3 out) + 128 (enc3 skip) = 256 (dec2 in)
        
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 64, 3, padding=1),  # 256 -> 64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, 4, stride=2, padding=1), # 64 -> 64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # Skip connection size: 64 (dec2 out) + 64 (enc2 skip) = 128 (dec1 in)
        
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 32, 3, padding=1),  # 128 -> 32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1), # 32 -> 32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        # Skip connection size: 32 (dec1 out) + 32 (enc1 skip) = 64 (head in)
        
        # ===== PREDICTION HEADS =====
        self.depth_head = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),  # 64 -> 32
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 1), # 32 -> 1
            nn.Sigmoid() 
        )
        
        self.normal_head = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1), # 64 -> 32
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, 1), # 32 -> 3
            nn.Tanh()  
        )
        
    def forward(self, x):
        # Encoder with skip connections
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        
        # Bottleneck
        b = self.bottleneck(e4)
        
        # Decoder with skip connections
        d4 = self.dec4(b)
        d4 = torch.cat([d4, e4], dim=1) 
        
        d3 = self.dec3(d4)
        d3 = torch.cat([d3, e3], dim=1) 
        
        d2 = self.dec2(d3)
        d2 = torch.cat([d2, e2], dim=1) 
        
        d1 = self.dec1(d2)
        d1 = torch.cat([d1, e1], dim=1) 
        
        # Predictions
        depth = self.depth_head(d1)
        normals = self.normal_head(d1)
        
        normals = F.normalize(normals, dim=1)
        
        return depth, normals


# Create model and show stats
model = DepthNormalNet()
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("✓ Model architecture defined!")
print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB (FP32)")

In [None]:
class EdgeAwareLoss(nn.Module):
    """
    Combined loss function with edge emphasis
    
    Components:
    1. Depth L1: Basic pixel-wise accuracy
    2. Edge Loss: Emphasizes boundaries (2x weight)
    3. Normal Loss: Surface orientation accuracy
    4. Smoothness: Prevents noise
    """
    def __init__(self, edge_weight=2.0, smooth_weight=0.1):
        super().__init__()
        self.edge_weight = edge_weight
        self.smooth_weight = smooth_weight
        
        # Fixed Sobel filters for comparison/loss computation
        self.register_buffer('sobel_x', torch.tensor([[[
            [-1., 0., 1.],
            [-2., 0., 2.],
            [-1., 0., 1.]
        ]]]).float())
        
        self.register_buffer('sobel_y', torch.tensor([[[
            [-1., -2., -1.],
            [0., 0., 0.],
            [1., 2., 1.]
        ]]]).float())
    
    def compute_gradients(self, img):
        """Compute image gradients using Sobel filters"""
        grad_x = F.conv2d(img, self.sobel_x, padding=1)
        grad_y = F.conv2d(img, self.sobel_y, padding=1)
        return torch.sqrt(grad_x**2 + grad_y**2 + 1e-8)
    
    def forward(self, pred_depth, pred_normal, gt_depth, gt_normal):
        # 1. Depth L1 Loss (overall accuracy)
        depth_loss = F.l1_loss(pred_depth, gt_depth)
        
        # 2. Edge Loss (boundary sharpness)
        pred_edges = self.compute_gradients(pred_depth)
        gt_edges = self.compute_gradients(gt_depth)
        edge_loss = F.l1_loss(pred_edges, gt_edges)
        
        # 3. Normal Loss (surface orientation)
        cos_sim = F.cosine_similarity(pred_normal, gt_normal, dim=1)
        normal_loss = (1 - cos_sim).mean()
        
        # 4. Smoothness Loss (noise reduction)
        smooth_x = torch.abs(pred_depth[:, :, :, 1:] - pred_depth[:, :, :, :-1])
        smooth_y = torch.abs(pred_depth[:, :, 1:, :] - pred_depth[:, :, :-1, :])
        smoothness = (smooth_x.mean() + smooth_y.mean()) / 2
        
        # Combined weighted loss
        total_loss = (
            depth_loss + 
            self.edge_weight * edge_loss + 
            normal_loss + 
            self.smooth_weight * smoothness
        )
        
        return total_loss, {
            'total': total_loss.item(),
            'depth': depth_loss.item(),
            'edge': edge_loss.item(),
            'normal': normal_loss.item(),
            'smooth': smoothness.item()
        }


print("✓ Loss functions defined!")
print("\nLoss Components:")
print("  1. Depth L1 (weight=1.0): Overall accuracy")
print("  2. Edge Loss (weight=2.0): Boundary sharpness ← KEY")
print("  3. Normal Loss (weight=1.0): Surface orientation")
print("  4. Smoothness (weight=0.1): Noise reduction")
print("\nTotal = 1.0×depth + 2.0×edge + 1.0×normal + 0.1×smooth")

In [None]:
class HyperSimDataset(Dataset):
    """Dataset loader for processed HyperSim data"""
    def __init__(self, data_dir, split_file, img_size=256, augment=False):
        self.data_dir = Path(data_dir)
        self.img_size = img_size
        self.augment = augment
        
        # Load split
        with open(split_file, 'r') as f:
            self.samples = [line.strip() for line in f if line.strip()]
        
        print(f"✓ Loaded {len(self.samples)} samples from {split_file}")
    
    def depth_to_normals(self, depth):
        """Compute surface normals from depth map"""
        depth = depth.astype(np.float32)
        
        # Compute gradients
        grad_x = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=3)
        grad_y = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=3)
        
        # Normal vector: (-dx, -dy, 1)
        normal_x = -grad_x
        normal_y = -grad_y
        normal_z = np.ones_like(depth)
        
        # Stack and normalize
        normals = np.stack([normal_x, normal_y, normal_z], axis=-1)
        norm = np.linalg.norm(normals, axis=-1, keepdims=True)
        normals = normals / (norm + 1e-8)
        
        return normals.astype(np.float32)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample_name = self.samples[idx]
        
        # Load RGB
        rgb_path = self.data_dir / 'rgb' / sample_name
        rgb = Image.open(rgb_path).convert('RGB')
        rgb = rgb.resize((self.img_size, self.img_size), Image.BILINEAR)
        rgb = np.array(rgb).astype(np.float32) / 255.0
        
        # Load depth (16-bit PNG)
        depth_path = self.data_dir / 'depth' / sample_name
        depth = Image.open(depth_path)
        depth = depth.resize((self.img_size, self.img_size), Image.BILINEAR)
        depth = np.array(depth).astype(np.float32) / 65535.0  # Back to [0, 1]
        
        # Compute normals
        normals = self.depth_to_normals(depth)
        
        # Simple augmentation
        if self.augment and np.random.rand() > 0.5:
            # Horizontal flip
            rgb = np.fliplr(rgb).copy()
            depth = np.fliplr(depth).copy()
            normals = np.fliplr(normals).copy()
            normals[:, :, 0] *= -1  # Flip normal X component
        
        # Convert to tensors
        rgb = torch.from_numpy(rgb).permute(2, 0, 1).float()
        depth = torch.from_numpy(depth).unsqueeze(0).float()
        normals = torch.from_numpy(normals).permute(2, 0, 1).float()
        
        # Normalize RGB to [-1, 1]
        rgb = rgb * 2 - 1
        
        return rgb, depth, normals


# Create datasets
train_dataset = HyperSimDataset('processed_data', 'train_split.txt', 
                                img_size=512, augment=True)
val_dataset = HyperSimDataset('processed_data', 'val_split.txt', 
                              img_size=512, augment=False)

# Create dataloaders
BATCH_SIZE =4  # Adjust based on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                         shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, 
                       shuffle=False, num_workers=0, pin_memory=True)

print(f"\n✓ Dataloaders ready!")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

In [None]:
def train_epoch(model, loader, optimizer, criterion, device, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    metrics = {'depth': 0, 'edge': 0, 'normal': 0, 'smooth': 0}
    
    pbar = tqdm(loader, desc=f'Epoch {epoch} - Training')
    for batch_idx, (rgb, depth, normals) in enumerate(pbar):
        rgb = rgb.to(device)
        depth = depth.to(device)
        normals = normals.to(device)
        
        # Forward pass
        pred_depth, pred_normals = model(rgb)
        
        # Compute loss
        loss, loss_dict = criterion(pred_depth, pred_normals, depth, normals)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping (prevents exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        for k in ['depth', 'edge', 'normal', 'smooth']:
            metrics[k] += loss_dict[k]
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'edge': f'{loss_dict["edge"]:.4f}'
        })
    
    # Average metrics
    n = len(loader)
    return total_loss / n, {k: v / n for k, v in metrics.items()}


def validate(model, loader, criterion, device, epoch):
    """Validate model"""
    model.eval()
    total_loss = 0
    metrics = {'depth': 0, 'edge': 0, 'normal': 0, 'smooth': 0}
    
    with torch.no_grad():
        pbar = tqdm(loader, desc=f'Epoch {epoch} - Validation')
        for rgb, depth, normals in pbar:
            rgb = rgb.to(device)
            depth = depth.to(device)
            normals = normals.to(device)
            
            pred_depth, pred_normals = model(rgb)
            loss, loss_dict = criterion(pred_depth, pred_normals, depth, normals)
            
            total_loss += loss.item()
            for k in ['depth', 'edge', 'normal', 'smooth']:
                metrics[k] += loss_dict[k]
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    n = len(loader)
    return total_loss / n, {k: v / n for k, v in metrics.items()}


print("✓ Training functions defined!")


In [None]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

# Model
model = DepthNormalNet().to(device)
print(f"Model loaded to {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

# Loss and optimizer
NUM_EPOCHS = 15
criterion = EdgeAwareLoss(edge_weight=2.0, smooth_weight=0.1)
criterion = criterion.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,  # Peak learning rate
    epochs=NUM_EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,  # Warmup for 30% of training
    anneal_strategy='cos'
)

# Training settings
best_val_loss = float('inf')
history = {
    'train_loss': [], 'val_loss': [],
    'train_edge': [], 'val_edge': [],
    'train_depth': [], 'val_depth': [],
    'train_normal': [], 'val_normal': []
}

print("="*70)
print("TRAINING CONFIGURATION")
print("="*70)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: 1e-4")
print(f"Optimizer: AdamW")
print(f"Scheduler: CosineAnnealingLR")
print(f"Loss weights: depth=1.0, edge=2.0, normal=1.0, smooth=0.1")
print("="*70)
print("\nStarting training...\n")

# Training loop
for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch}/{NUM_EPOCHS}")
    print(f"{'='*70}")
    
    # Train
    train_loss, train_metrics = train_epoch(model, train_loader, optimizer, 
                                           criterion, device, epoch)
    
    # Validate
    val_loss, val_metrics = validate(model, val_loader, criterion, device, epoch)
    
    # Print metrics
    print(f"\nEpoch {epoch} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"  Train Depth: {train_metrics['depth']:.4f} | Val Depth: {val_metrics['depth']:.4f}")
    print(f"  Train Edge: {train_metrics['edge']:.4f} | Val Edge: {val_metrics['edge']:.4f} ← KEY METRIC")
    print(f"  Train Normal: {train_metrics['normal']:.4f} | Val Normal: {val_metrics['normal']:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Track history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_edge'].append(train_metrics['edge'])
    history['val_edge'].append(val_metrics['edge'])
    history['train_depth'].append(train_metrics['depth'])
    history['val_depth'].append(val_metrics['depth'])
    history['train_normal'].append(train_metrics['normal'])
    history['val_normal'].append(val_metrics['normal'])
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_loss': train_loss,
            'history': history
        }, 'best_model.pth')
        print(f"  ✓ Saved best model (val_loss: {val_loss:.4f})")
    
    # Save checkpoint every 5 epochs
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'checkpoint_epoch_{epoch}.pth')
        print(f"  ✓ Saved checkpoint")
    
    # Update learning rate
    scheduler.step()

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)
print(f"Best validation loss: {best_val_loss:.4f}")
print("="*70)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Total loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Total Loss', fontsize=12)
axes[0, 0].set_title('Training Progress - Total Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(history['train_edge'], label='Train Edge Loss', marker='o', 
                linewidth=2, color='red')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Edge Loss', fontsize=12)
axes[0, 1].set_title('Edge Loss Over Time (KEY METRIC)', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# Depth loss
axes[1, 0].plot(history['train_depth'], label='Train Depth Loss', marker='o', 
                linewidth=2, color='blue')
axes[1, 0].plot(history['val_depth'], label='Val Depth Loss', marker='s', 
                linewidth=2, color='darkblue')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Depth Loss', fontsize=12)
axes[1, 0].set_title('Depth Loss Over Time', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].axis('off')
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Training curves saved as 'training_curves.png'")

# Print statistics
print("\nTraining Statistics:")
print(f"  Initial edge loss: {history['train_edge'][0]:.4f}")
print(f"  Final edge loss: {history['train_edge'][-1]:.4f}")
print(f"  Best validation loss: {best_val_loss:.4f}")


In [None]:
def save_output_rgba(pred_depth, pred_normals, output_path='output.png'):
    """
    Save RGBA PNG: RGB channels = Normal map, Alpha channel = Depth map
    This is the production format for 3D artists
    """
    # Convert to numpy
    normals = pred_normals.squeeze(0).cpu().numpy()  # [3, H, W]
    depth = pred_depth.squeeze().cpu().numpy()  # [H, W]
    
    # Encode normals: [-1, 1] → [0, 255]
    normals = ((normals + 1) * 0.5 * 255).astype(np.uint8)
    normals = np.transpose(normals, (1, 2, 0))  # [H, W, 3]
    normals = np.clip(normals, 0, 255)
    
    # Encode depth: [0, 1] → [0, 255]
    depth = (depth * 255).astype(np.uint8)
    depth = np.expand_dims(depth, -1)  # [H, W, 1]
    
    # Combine RGBA
    rgba = np.concatenate([normals, depth], axis=-1)  # [H, W, 4]
    
    # Save as PNG
    img = Image.fromarray(rgba, mode='RGBA')
    img.save(output_path, compress_level=6)
    
    file_size = os.path.getsize(output_path) / 1024
    print(f"✓ Saved RGBA output: {output_path}")
    print(f"  Resolution: {rgba.shape[1]}×{rgba.shape[0]}")
    print(f"  File size: {file_size:.1f} KB")
    print(f"  Format: RGBA PNG (lossless)")
    print(f"  RGB channels: Normal map (tangent-space)")
    print(f"  Alpha channel: Depth map (normalized)")
    
    return rgba

In [None]:
print("\n" + "="*70)
print("TRAINING SUMMARY")
print("="*70)

print("\n Training Results:")
print(f"  Total epochs: {NUM_EPOCHS}")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Final training loss: {history['train_loss'][-1]:.4f}")



In [None]:
def test_custom_image(image_path, model, device, output_path='custom_output.png'):
    """
    Test the trained model on a custom image
    
    Usage:
    test_custom_image('path/to/your/image.jpg', model, device)
    """
    print(f"\nProcessing custom image: {image_path}")
    
    # Load and preprocess
    img = Image.open(image_path).convert('RGB')
    original_size = img.size
    img_resized = img.resize((512, 512), Image.BILINEAR)
    
    img_array = np.array(img_resized).astype(np.float32) / 255.0
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float()
    img_tensor = img_tensor * 2 - 1  # Normalize to [-1, 1]
    img_tensor = img_tensor.unsqueeze(0).to(device)
    
    # Predict
    model.eval()
    with torch.no_grad():
        pred_depth, pred_normals = model(img_tensor)
    
    # Save RGBA
    rgba = save_output_rgba(pred_depth, pred_normals, output_path)
    # Visualize
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(img_resized)
    axes[0].set_title('Input Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(pred_depth.squeeze().cpu().numpy(), cmap='plasma')
    axes[1].set_title('Predicted Depth', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    pred_normals_display = (pred_normals.squeeze().cpu().permute(1, 2, 0).numpy() + 1) / 2
    axes[2].imshow(pred_normals_display)
    axes[2].set_title('Predicted Normals', fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    axes[3].imshow(rgba)
    axes[3].set_title('RGBA Output', fontsize=14, fontweight='bold')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.savefig('custom_result.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Custom image processed!")
    print(f"  Input size: {original_size}")
    print(f"  Processed at: 512×512")
    print(f"  Output saved: {output_path}")
    print(f"  Visualization saved: custom_result.png")

test_custom_image("brick.jpg", model, device)
test_custom_image("brick2.jpg", model, device)
test_custom_image("brick4.jpg", model, device)
test_custom_image("brick5.jpg", model, device)
test_custom_image("check.png", model, device)