# Simplified Floor Plan Segmentation - Bare Minimum Approach

**Goal**: Train a simple U-Net to segment floor plans into 3 classes:
- **Background (0)**: Empty space
- **Wall (1)**: Wall pixels  
- **Room (2)**: Any room pixels (all room types combined)

Later we'll use **connected components** to separate individual rooms and **OCR** to label them.

This follows the **bare minimum** approach from our plan.


## 1. Setup & Imports


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# EXPERIMENT 3: Add Kornia for GPU augmentations
import kornia.augmentation as K

# Import our dataset
from cubicasa_dataset_v2 import CubiCasa5KDatasetV2

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


## 2. Simple Configuration


In [None]:
# Optimized configuration for better GPU utilization
CONFIG = {
    'data_root': 'dataset cubicasa/cubicasa5k/cubicasa5k',
    'batch_size': 8,   # OPTIMAL: Sweet spot for speed vs memory (16 was too slow, 4 too small)
    'image_size': 512,
    'epochs': 1,  # Reduced for faster training
    'learning_rate': 0.001,
    'num_classes': 3,  # Background, Wall, Room - THAT'S IT!
}

# Our 3 simple classes
CLASS_NAMES = {
    0: 'Background',  # Empty space
    1: 'Wall',       # Wall pixels
    2: 'Room'        # ANY room pixels (all types combined)
}

print("Simple 3-class segmentation:")
for k, v in CLASS_NAMES.items():
    print(f"  {k}: {v}")

os.makedirs('simple_checkpoints', exist_ok=True)


## 3. Simplified Dataset (3 Classes)


In [None]:
class SimplifiedDataset(Dataset):
    """
    Converts CubiCasa5K's many classes into just 3:
    - 0: Background
    - 1: Wall  
    - 2: Room (all room types combined)
    """
    def __init__(self, split_file, dataset_root, image_size=512, augment=False):
        self.original_dataset = CubiCasa5KDatasetV2(
            split_file=split_file,
            dataset_root=dataset_root, 
            image_size=(image_size, image_size),
            augment=augment
        )
        
    def __len__(self):
        return len(self.original_dataset)
    
    def __getitem__(self, idx):
        # Get original sample (returns dict)
        sample = self.original_dataset[idx]
        image = sample['image']
        mask = sample['mask']
        
        # Convert to simplified 3-class mask
        simple_mask = torch.zeros_like(mask)
        
        # Mapping: 
        # 0 = Background -> 0
        # 1 = Outdoor -> 0 (treat as background)
        # 2 = Wall -> 1  
        # 3+ = All rooms -> 2
        simple_mask[mask == 0] = 0  # Background
        simple_mask[mask == 1] = 0  # Outdoor -> Background
        simple_mask[mask == 2] = 1  # Wall
        simple_mask[mask >= 3] = 2  # All rooms -> Room
        
        return image, simple_mask

# Load datasets
dataset_root = CONFIG['data_root']
train_split = os.path.join(dataset_root, 'train.txt')
val_split = os.path.join(dataset_root, 'val.txt')

print("Loading simplified datasets...")
# EXPERIMENT: Disable augmentation to test CPU bottleneck
train_dataset = SimplifiedDataset(train_split, dataset_root, CONFIG['image_size'], augment=False)  # Changed to False
val_dataset = SimplifiedDataset(val_split, dataset_root, CONFIG['image_size'], augment=False)

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")

# Data loaders - CRITICAL WINDOWS FIX: Must use num_workers=0 in Jupyter on Windows
# Using num_workers>0 causes infinite hanging in Windows Jupyter environments
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0, pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")


## 4. Quick Data Visualization


In [None]:
# Quick visualization of simplified data
def show_samples(dataset, n=2):
    fig, axes = plt.subplots(2, n, figsize=(12, 6))
    
    for i in range(n):
        image, mask = dataset[i]
        
        # Denormalize image
        img = image.permute(1,2,0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        axes[0,i].imshow(img)
        axes[0,i].set_title(f'Image {i+1}')
        axes[0,i].axis('off')
        
        # Show mask with 3 distinct colors
        axes[1,i].imshow(mask.numpy(), cmap='viridis', vmin=0, vmax=2)
        axes[1,i].set_title(f'3-Class Mask {i+1}')
        axes[1,i].axis('off')
        
        # Print class distribution
        unique, counts = torch.unique(mask, return_counts=True)
        print(f"Sample {i+1} distribution:")
        for cls, count in zip(unique, counts):
            pct = count.item() / (512*512) * 100
            print(f"  {CLASS_NAMES[cls.item()]}: {pct:.1f}%")
    
    plt.tight_layout()
    plt.show()

show_samples(train_dataset)


## 5. Simple U-Net Model


In [None]:
# Simple U-Net setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create simple U-Net model
model = smp.Unet(
    encoder_name='resnet34',
    encoder_weights='imagenet', 
    classes=CONFIG['num_classes'],  # Just 3 classes!
    activation=None
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

# EXPERIMENT 3: GPU-based augmentations
gpu_augmentation = nn.Sequential(
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    K.RandomRotation(degrees=5.0),
    K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05)
).to(device)

# Simple loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

print("Simple model ready!")
print("GPU augmentation pipeline created!")


## 6. Simple Training Loop


In [None]:
def train_model():
    """
    Main training function with Windows optimizations:
    - num_workers=0 prevents Jupyter hanging on Windows
    - GPU augmentations for better CPU/GPU load balancing
    - Non-blocking transfers for improved performance
    """
    print(f"Starting training for {CONFIG['epochs']} epochs...")
    
    history = {'train_loss': [], 'val_loss': []}
    best_loss = float('inf')
    
    for epoch in range(CONFIG['epochs']):
        # Train
        model.train()
        train_loss = 0
        
        for batch_idx, (images, masks) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1} Train')):
            # EXPERIMENT 2: Non-blocking transfers for better CPU/GPU overlap
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, dtype=torch.long, non_blocking=True)
            
            # EXPERIMENT 3: Apply GPU augmentations
            images = gpu_augmentation(images)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # GPU monitoring every 10 batches for more granular updates
            if batch_idx % 10 == 0:
                gpu_mem = torch.cuda.memory_allocated()/1e9
                gpu_max = torch.cuda.max_memory_allocated()/1e9
                print(f"  Batch {batch_idx}: GPU Memory {gpu_mem:.1f}GB / Max {gpu_max:.1f}GB")
        
        train_loss /= len(train_loader)
        
        # Validate
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc=f'Epoch {epoch+1} Val'):
                images = images.to(device, non_blocking=True)
                masks = masks.to(device, dtype=torch.long, non_blocking=True)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'simple_checkpoints/best_simple_model.pth')
            print(f"✓ New best model saved! (Val Loss: {val_loss:.4f})")
    
    print("\\nTraining complete!")
    return history

# Start training (num_workers=0 prevents Windows hanging)
print("🚀 Starting training with Windows-compatible settings...")
history = train_model()


## 7. Quick Results & Predictions


In [None]:
# Plot training curves
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Training Loss')
plt.legend()

# Show predictions
plt.subplot(1, 2, 2)
model.eval()
with torch.no_grad():
    sample_img, sample_mask = val_dataset[0]
    pred = model(sample_img.unsqueeze(0).to(device))
    pred_mask = torch.argmax(pred, dim=1).cpu().squeeze()
    
    plt.imshow(pred_mask.numpy(), cmap='viridis', vmin=0, vmax=2)
    plt.title('Sample Prediction')
    plt.axis('off')

plt.tight_layout()
plt.show()

print(f"Best validation loss: {best_loss:.4f}")
print("Model saved to: simple_checkpoints/best_simple_model.pth")

# Show what we got
print("\nPrediction classes:")
unique_pred = torch.unique(pred_mask)
for cls in unique_pred:
    print(f"  {CLASS_NAMES[cls.item()]}")

print("\n🎯 Next step: Use this model + connected components + OCR for room area calculation!")
