# Floor Plan AI - U-Net Model Training

This notebook implements the training pipeline for our custom U-Net model using the CubiCasa5K dataset.

## Project Overview
- **Goal**: Train a semantic segmentation model to identify rooms, walls, and other elements in floor plans
- **Dataset**: CubiCasa5K with 8 semantic classes
- **Architecture**: U-Net with ResNet34 encoder
- **Classes**: Background(0), Outdoor(1), Wall(2), Kitchen(3), Living/Dining(4), Bedroom(5), Bath(6), Entry/Hall(7)


## 1. Setup & Imports


In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

# Import our custom dataset
from cubicasa_dataset_v2 import CubiCasa5KDatasetV2

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


## 2. Configuration & Hyperparameters


In [None]:
# Training Configuration
CONFIG = {
    # Data
    'data_root': 'dataset cubicasa/cubicasa5k/cubicasa5k',
    'batch_size': 4,
    'num_workers': 4,
    'image_size': 512,
    
    # Model - SIMPLIFIED TO 3 CLASSES
    'encoder_name': 'resnet34',
    'encoder_weights': 'imagenet',
    'num_classes': 3,  # Background, Wall, Room (simplified!)
    'activation': None,
    
    # Training
    'epochs': 30,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'scheduler_patience': 5,
    'scheduler_factor': 0.5,
    
    # Loss weights
    'dice_weight': 0.7,
    'ce_weight': 0.3,
    
    # Checkpointing
    'save_dir': 'model_checkpoints',
    'save_best_only': True,
    'early_stopping_patience': 10
}

# Simplified class names
CLASS_NAMES = {
    0: 'Background',
    1: 'Wall',
    2: 'Room'  # All room types combined!
}

# Create save directory
os.makedirs(CONFIG['save_dir'], exist_ok=True)

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 3. Dataset Loading & Visualization


In [None]:
# Create datasets
print("Loading datasets...")

# Define split file paths
dataset_root = CONFIG['data_root']
train_split = os.path.join(dataset_root, 'train.txt')
val_split = os.path.join(dataset_root, 'val.txt')

train_dataset = CubiCasa5KDatasetV2(
    split_file=train_split,
    dataset_root=dataset_root,
    image_size=(CONFIG['image_size'], CONFIG['image_size']),
    augment=True
)

val_dataset = CubiCasa5KDatasetV2(
    split_file=val_split,
    dataset_root=dataset_root,
    image_size=(CONFIG['image_size'], CONFIG['image_size']),
    augment=False
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Create data loaders - Note: dataset returns a dict, so we need to handle that
def collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    masks = torch.stack([item['mask'] for item in batch])
    return images, masks

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    collate_fn=collate_fn
)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

In [None]:
# Visualize sample data
def visualize_batch(dataset, num_samples=4):
    """Visualize samples from the dataset"""
    fig, axes = plt.subplots(2, num_samples, figsize=(16, 8))
    
    for i in range(num_samples):
        # Get sample (dataset returns a dict)
        sample = dataset[i]
        image = sample['image']
        mask = sample['mask']
        
        # Convert tensor to numpy and denormalize image
        img_np = image.permute(1, 2, 0).numpy()
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)
        
        mask_np = mask.numpy()
        
        # Plot image
        axes[0, i].imshow(img_np)
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        # Plot mask with color mapping
        axes[1, i].imshow(mask_np, cmap='tab10', vmin=0, vmax=11)  # CubiCasa has more classes
        axes[1, i].set_title(f'Mask {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Show class distribution in masks
    print("\nClass distribution in sample masks:")
    for i in range(num_samples):
        sample = dataset[i]
        mask = sample['mask']
        unique, counts = np.unique(mask.numpy(), return_counts=True)
        print(f"Sample {i+1}:")
        for class_id, count in zip(unique, counts):
            percentage = count / (CONFIG['image_size'] ** 2) * 100
            print(f"  Class {class_id}: {percentage:.1f}%")

print("Training dataset samples:")
visualize_batch(train_dataset, num_samples=4)

In [None]:
# Create U-Net model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = smp.Unet(
    encoder_name=CONFIG['encoder_name'],
    encoder_weights=CONFIG['encoder_weights'],
    classes=CONFIG['num_classes'],
    activation=CONFIG['activation']
)

model = model.to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel: U-Net with {CONFIG['encoder_name']} encoder")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1e6:.1f} MB")


In [None]:
# Define loss functions
class CombinedLoss(nn.Module):
    def __init__(self, dice_weight=0.7, ce_weight=0.3):
        super().__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.dice_loss = smp.losses.DiceLoss(mode='multiclass')
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        ce = self.ce_loss(pred, target)
        return self.dice_weight * dice + self.ce_weight * ce

# Initialize loss and optimizer
criterion = CombinedLoss(
    dice_weight=CONFIG['dice_weight'],
    ce_weight=CONFIG['ce_weight']
)

optimizer = optim.Adam(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    patience=CONFIG['scheduler_patience'],
    factor=CONFIG['scheduler_factor'],
    verbose=True
)

# Metrics
train_metrics = smp.utils.metrics.IoU(threshold=0.5)
val_metrics = smp.utils.metrics.IoU(threshold=0.5)

print("Loss function: Combined Dice + CrossEntropy")
print(f"Dice weight: {CONFIG['dice_weight']}, CE weight: {CONFIG['ce_weight']}")
print(f"Optimizer: Adam (lr={CONFIG['learning_rate']}, wd={CONFIG['weight_decay']})")
print(f"Scheduler: ReduceLROnPlateau (patience={CONFIG['scheduler_patience']})")


In [None]:
# Training history tracking
history = {
    'train_loss': [],
    'val_loss': [],
    'train_iou': [],
    'val_iou': [],
    'learning_rate': []
}

best_val_loss = float('inf')
epochs_without_improvement = 0

def train_epoch(model, train_loader, criterion, optimizer, metrics, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    metrics.reset()
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (images, masks) in enumerate(pbar):
        images = images.to(device)
        masks = masks.to(device, dtype=torch.long)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Update metrics
        pred_masks = torch.argmax(outputs, dim=1)
        metrics.update(pred_masks, masks)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'avg_loss': f'{running_loss/(batch_idx+1):.4f}'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_iou = metrics.compute()
    return epoch_loss, epoch_iou

def validate_epoch(model, val_loader, criterion, metrics, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    metrics.reset()
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(device)
            masks = masks.to(device, dtype=torch.long)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item()
            
            # Update metrics
            pred_masks = torch.argmax(outputs, dim=1)
            metrics.update(pred_masks, masks)
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{running_loss/(batch_idx+1):.4f}'
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_iou = metrics.compute()
    return epoch_loss, epoch_iou

print("Training setup complete. Ready to start training!")


In [None]:
# Main training loop
print(f"Starting training for {CONFIG['epochs']} epochs...")
print(f"Model will be saved to: {CONFIG['save_dir']}")
print("="*60)

start_time = datetime.now()

for epoch in range(CONFIG['epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")
    print("-" * 40)
    
    # Train
    train_loss, train_iou = train_epoch(
        model, train_loader, criterion, optimizer, train_metrics, device
    )
    
    # Validate
    val_loss, val_iou = validate_epoch(
        model, val_loader, criterion, val_metrics, device
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou.item())
    history['val_iou'].append(val_iou.item())
    history['learning_rate'].append(current_lr)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f} | Train IoU: {train_iou:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val IoU: {val_iou:.4f}")
    print(f"Learning Rate: {current_lr:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        
        # Save model checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'val_iou': val_iou.item(),
            'config': CONFIG,
            'history': history
        }
        
        torch.save(checkpoint, os.path.join(CONFIG['save_dir'], 'best_model.pth'))
        print(f"✓ New best model saved! (Val Loss: {val_loss:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"No improvement for {epochs_without_improvement} epochs")
    
    # Early stopping
    if epochs_without_improvement >= CONFIG['early_stopping_patience']:
        print(f"\nEarly stopping triggered after {epochs_without_improvement} epochs without improvement")
        break
    
    # Save latest checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'val_iou': val_iou.item(),
            'config': CONFIG,
            'history': history
        }
        torch.save(checkpoint, os.path.join(CONFIG['save_dir'], f'checkpoint_epoch_{epoch+1}.pth'))

end_time = datetime.now()
training_time = end_time - start_time

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print(f"Total training time: {training_time}")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Final validation IoU: {history['val_iou'][-1]:.4f}")
print("="*60)


In [None]:
# Plot training history
def plot_training_history(history):
    """Plot training curves"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss')
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # IoU curves
    axes[0, 1].plot(epochs, history['train_iou'], 'b-', label='Train IoU')
    axes[0, 1].plot(epochs, history['val_iou'], 'r-', label='Val IoU')
    axes[0, 1].set_title('Training and Validation IoU')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('IoU')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Learning rate
    axes[1, 0].plot(epochs, history['learning_rate'], 'g-')
    axes[1, 0].set_title('Learning Rate Schedule')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(True)
    
    # Training progress
    axes[1, 1].plot(epochs, np.array(history['val_loss']) - np.array(history['train_loss']), 'purple')
    axes[1, 1].set_title('Overfitting Monitor (Val - Train Loss)')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss Difference')
    axes[1, 1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['save_dir'], 'training_curves.png'), dpi=300, bbox_inches='tight')
    plt.show()

# Plot training history
plot_training_history(history)

# Print final statistics
print("\nFinal Training Statistics:")
print(f"Final train loss: {history['train_loss'][-1]:.4f}")
print(f"Final val loss: {history['val_loss'][-1]:.4f}")
print(f"Final train IoU: {history['train_iou'][-1]:.4f}")
print(f"Final val IoU: {history['val_iou'][-1]:.4f}")
print(f"Best val loss: {min(history['val_loss']):.4f} (epoch {np.argmin(history['val_loss'])+1})")
print(f"Best val IoU: {max(history['val_iou']):.4f} (epoch {np.argmax(history['val_iou'])+1})")


In [None]:
# Load best model and visualize predictions
best_checkpoint = torch.load(os.path.join(CONFIG['save_dir'], 'best_model.pth'))
model.load_state_dict(best_checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {best_checkpoint['epoch']}")
print(f"Best validation loss: {best_checkpoint['val_loss']:.4f}")
print(f"Best validation IoU: {best_checkpoint['val_iou']:.4f}")

# Visualize predictions
def visualize_predictions(model, dataset, device, num_samples=4):
    """Visualize model predictions vs ground truth"""
    model.eval()
    
    fig, axes = plt.subplots(3, num_samples, figsize=(16, 12))
    
    with torch.no_grad():
        for i in range(num_samples):
            # Get sample
            image, mask_true = dataset[i]
            
            # Predict
            image_batch = image.unsqueeze(0).to(device)
            pred_logits = model(image_batch)
            pred_mask = torch.argmax(pred_logits, dim=1).cpu().squeeze().numpy()
            
            # Denormalize image
            img_np = image.permute(1, 2, 0).numpy()
            img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img_np = np.clip(img_np, 0, 1)
            
            mask_true_np = mask_true.numpy()
            
            # Plot original image
            axes[0, i].imshow(img_np)
            axes[0, i].set_title(f'Original Image {i+1}')
            axes[0, i].axis('off')
            
            # Plot ground truth mask
            axes[1, i].imshow(mask_true_np, cmap='tab10', vmin=0, vmax=7)
            axes[1, i].set_title(f'Ground Truth {i+1}')
            axes[1, i].axis('off')
            
            # Plot prediction
            axes[2, i].imshow(pred_mask, cmap='tab10', vmin=0, vmax=7)
            axes[2, i].set_title(f'Prediction {i+1}')
            axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['save_dir'], 'predictions_visualization.png'), dpi=300, bbox_inches='tight')
    plt.show()

# Visualize predictions on validation set
print("Visualizing model predictions:")
visualize_predictions(model, val_dataset, device, num_samples=4)


In [None]:
# Calculate detailed class-wise metrics
def calculate_class_metrics(model, val_loader, device, num_classes=8):
    """Calculate IoU for each class"""
    model.eval()
    
    confusion_matrix = np.zeros((num_classes, num_classes))
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc='Calculating metrics'):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            pred_masks = torch.argmax(outputs, dim=1)
            
            # Convert to numpy
            pred_np = pred_masks.cpu().numpy().flatten()
            true_np = masks.cpu().numpy().flatten()
            
            # Update confusion matrix
            for t, p in zip(true_np, pred_np):
                confusion_matrix[t, p] += 1
    
    # Calculate IoU for each class
    class_ious = []
    for i in range(num_classes):
        tp = confusion_matrix[i, i]
        fp = confusion_matrix[:, i].sum() - tp
        fn = confusion_matrix[i, :].sum() - tp
        
        if tp + fp + fn > 0:
            iou = tp / (tp + fp + fn)
        else:
            iou = 0.0
        class_ious.append(iou)
    
    return class_ious, confusion_matrix

# Calculate metrics
print("Calculating detailed class metrics...")
class_ious, conf_matrix = calculate_class_metrics(model, val_loader, device)

# Print results
print("\nClass-wise IoU Scores:")
print("-" * 40)
for i, (class_name, iou) in enumerate(zip(CLASS_NAMES.values(), class_ious)):
    print(f"{class_name:15s}: {iou:.4f}")

mean_iou = np.mean(class_ious)
print(f"\nMean IoU: {mean_iou:.4f}")

# Save results
results = {
    'class_names': list(CLASS_NAMES.values()),
    'class_ious': class_ious,
    'mean_iou': mean_iou,
    'training_history': history,
    'config': CONFIG
}

with open(os.path.join(CONFIG['save_dir'], 'training_results.json'), 'w') as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to {CONFIG['save_dir']}/training_results.json")
