# Dataset Testing Notebook

This notebook tests each dataset individually:
1. RID Dataset (Segmentation)
2. Roofline Dataset (Line Detection)
3. AIRS Dataset (Building Outlines)

Each dataset has its own implementation in src/datasets/, sharing common functionality through a base class.

In [None]:
# Add project root to path
import sys
import os
from pathlib import Path

# Get absolute path to project root
notebook_dir = Path(os.getcwd())
project_root = notebook_dir.parent
print(f"Project root: {project_root}")
sys.path.append(str(project_root))

# Import dependencies
import torch
import matplotlib.pyplot as plt
import numpy as np
import logging

# Import our modules
from src.datasets import create_dataloaders
from src.models import create_model

# Set up logging and plotting
logging.basicConfig(level=logging.INFO)
%matplotlib inline

## Common Configuration

In [None]:
# Training settings
config = {
    'batch_size': 2,
    'num_workers': 2,
    'max_samples': 10,
    'image_size': 512,
    'num_classes': 12
}

# Dataset paths
data_paths = {
    'rid': project_root / 'Reference Materials' / 'data' / 'RID' / 'm1655470' / 'RID_dataset',
    'roofline': project_root / 'Reference Materials' / 'data' / 'Roofline-Extraction',
    'airs': project_root / 'Reference Materials' / 'data' / 'AIRS'
}

# Verify paths exist
print("\nVerifying dataset paths:")
for name, path in data_paths.items():
    print(f"\n{name} dataset:")
    print(f"Path: {path}")
    print(f"Exists: {path.exists()}")
    if path.exists():
        print("Contents:")
        for item in path.iterdir():
            if item.is_dir():
                print(f"  - {item.name}/")
            else:
                print(f"  - {item.name}")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nUsing device: {device}')

## Visualization Functions

In [None]:
def show_batch(images, targets, dataset_type):
    """Display a batch of images and their targets."""
    batch_size = images.shape[0]
    fig, axes = plt.subplots(batch_size, 4, figsize=(20, 5*batch_size))
    
    for i in range(batch_size):
        # Original image
        img = images[i].permute(1, 2, 0).numpy()
        axes[i,0].imshow(img)
        axes[i,0].set_title('Original Image')
        
        # Segmentation mask
        mask = targets['segments'][i].numpy()
        axes[i,1].imshow(mask, cmap='tab20')
        axes[i,1].set_title('Segmentation Mask')
        
        # Line detection
        lines = targets['lines'][i].numpy()
        line_vis = np.zeros((*lines.shape[1:], 3))
        line_vis[...,0] = lines[0]  # Ridge lines (red)
        line_vis[...,1] = lines[1]  # Valley lines (green)
        line_vis[...,2] = lines[3]  # Building outline (blue)
        axes[i,2].imshow(line_vis)
        axes[i,2].set_title('Line Detection')
        
        # Depth map
        depth = targets['depth'][i].numpy()
        axes[i,3].imshow(depth, cmap='viridis')
        axes[i,3].set_title('Depth Map')
    
    plt.tight_layout()
    return fig

def show_predictions(images, targets, predictions, dataset_type):
    """Display predictions alongside ground truth."""
    batch_size = images.shape[0]
    fig, axes = plt.subplots(batch_size, 4, figsize=(20, 5*batch_size))
    
    for i in range(batch_size):
        # Original image
        img = images[i].permute(1, 2, 0).numpy()
        axes[i,0].imshow(img)
        axes[i,0].set_title('Original Image')
        
        # Segmentation prediction
        pred_mask = torch.argmax(predictions['segments'][i], dim=0).numpy()
        true_mask = targets['segments'][i].numpy()
        axes[i,1].imshow(pred_mask, cmap='tab20', alpha=0.7)
        axes[i,1].imshow(true_mask, cmap='tab20', alpha=0.3)
        axes[i,1].set_title('Segmentation (Pred/True)')
        
        # Line detection
        pred_lines = predictions['lines'][i].numpy()
        true_lines = targets['lines'][i].numpy()
        line_vis = np.zeros((*pred_lines.shape[1:], 3))
        line_vis[...,0] = pred_lines[0]  # Predicted (red)
        line_vis[...,1] = true_lines[0]  # True (green)
        axes[i,2].imshow(line_vis)
        axes[i,2].set_title('Lines (Pred/True)')
        
        # Depth prediction
        pred_depth = predictions['depth'][i].squeeze().numpy()
        true_depth = targets['depth'][i].numpy()
        depth_diff = np.abs(pred_depth - true_depth)
        axes[i,3].imshow(depth_diff, cmap='viridis')
        axes[i,3].set_title('Depth Error')
    
    plt.tight_layout()
    return fig

## Test RID Dataset

In [None]:
# Create RID dataset and dataloader
print("Creating RID dataloaders...")
rid_train_loader, rid_val_loader = create_dataloaders(
    data_paths['rid'],
    'rid',
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    max_samples=config['max_samples']
)

print("\nGetting a batch of data...")
images, targets = next(iter(rid_train_loader))

print("\nBatch information:")
print(f"Images shape: {images.shape}")
print(f"Number of images: {len(images)}")
print(f"Target keys: {list(targets.keys())}")

# Display the batch
print("\nDisplaying batch...")
show_batch(images, targets, 'rid')

## Test Model with RID Data

In [None]:
# Create model
model, criterion = create_model(num_classes=config['num_classes'])
model = model.to(device)
model.train()

# Move batch to device
images = images.to(device)
device_targets = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                 for k, v in targets.items()}

# Get predictions
with torch.no_grad():
    predictions = model(images)

# Calculate loss
loss, losses_dict = criterion(predictions, device_targets)
print(f'Total loss: {loss.item():.4f}')
for k, v in losses_dict.items():
    print(f'{k} loss: {v.item():.4f}')

# Move everything back to CPU for visualization
images = images.cpu()
predictions = {k: v.cpu() for k, v in predictions.items()}
targets = {k: v.cpu() if isinstance(v, torch.Tensor) else v
          for k, v in targets.items()}

# Show predictions
show_predictions(images, targets, predictions, 'rid')

## Next Steps

After verifying RID dataset works:
1. Implement RooflineDataset class
2. Test Roofline dataset loading and inference
3. Implement AIRSDataset class
4. Test AIRS dataset loading and inference
5. Prepare for AWS SageMaker training