# StructGAN Data Exploration

This notebook explores the StructGAN dataset structure and visualizes sample data.

## Contents
1. Setup and Imports
2. Explore Dataset Structure
3. Visualize Sample Images
4. Analyze Data Distribution
5. Preprocessing Pipeline Test

## 1. Setup and Imports

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from collections import Counter

# Project imports
from src.data_preprocessing.dataset import StructGANDataset, visualize_sample
from src.utils.visualization import tensor_to_image

print("Imports successful!")

## 2. Explore Dataset Structure

In [None]:
# Check StructGAN repository
structgan_path = project_root / "StructGAN_v1"
datasets_path = structgan_path / "0_datasets"

print(f"StructGAN repository exists: {structgan_path.exists()}")
print(f"Datasets folder exists: {datasets_path.exists()}")

if datasets_path.exists():
    print("\nAvailable dataset groups:")
    for item in sorted(datasets_path.iterdir()):
        if item.is_dir() and "Group" in item.name:
            n_images = len(list(item.glob("*.png"))) + len(list(item.glob("*.jpg")))
            print(f"  {item.name}: {n_images} images")
else:
    print("\nPlease run setup.sh first to clone the StructGAN repository.")

In [None]:
# List all files in a dataset group
dataset_name = "Group7-H2"  # Change this to explore different groups
dataset_dir = datasets_path / dataset_name

if dataset_dir.exists():
    image_files = sorted(list(dataset_dir.glob("*.png")) + list(dataset_dir.glob("*.jpg")))
    print(f"Dataset: {dataset_name}")
    print(f"Total images: {len(image_files)}")
    print(f"\nFirst 5 files:")
    for f in image_files[:5]:
        print(f"  {f.name}")
else:
    print(f"Dataset {dataset_name} not found.")

## 3. Visualize Sample Images

In [None]:
def show_paired_image(image_path):
    """Display a paired image (input|target side by side)."""
    img = Image.open(image_path)
    img_np = np.array(img)
    
    w = img_np.shape[1]
    h = img_np.shape[0]
    
    # Split into input and target
    input_img = img_np[:, :w//2, :]
    target_img = img_np[:, w//2:, :]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(input_img)
    axes[0].set_title('Input: Architectural Plan')
    axes[0].axis('off')
    
    axes[1].imshow(target_img)
    axes[1].set_title('Target: Structural Layout')
    axes[1].axis('off')
    
    # Overlay
    overlay = cv2.addWeighted(input_img, 0.5, target_img, 0.5, 0)
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.suptitle(Path(image_path).name)
    plt.tight_layout()
    plt.show()
    
    return input_img, target_img

# Show a sample image
if 'image_files' in dir() and len(image_files) > 0:
    input_img, target_img = show_paired_image(image_files[0])

In [None]:
# Display multiple samples
if 'image_files' in dir() and len(image_files) > 0:
    n_samples = min(6, len(image_files))
    fig, axes = plt.subplots(n_samples, 2, figsize=(10, 4 * n_samples))
    
    for i, img_path in enumerate(image_files[:n_samples]):
        img = Image.open(img_path)
        img_np = np.array(img)
        w = img_np.shape[1]
        
        input_img = img_np[:, :w//2, :]
        target_img = img_np[:, w//2:, :]
        
        axes[i, 0].imshow(input_img)
        axes[i, 0].set_title(f'Input {i+1}')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(target_img)
        axes[i, 1].set_title(f'Target {i+1}')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

## 4. Analyze Data Distribution

In [None]:
def analyze_image(image_path):
    """Analyze a single paired image."""
    img = Image.open(image_path)
    img_np = np.array(img)
    
    w = img_np.shape[1]
    target_img = img_np[:, w//2:, :]
    
    # Count structural element pixels
    # Assuming red = shear walls, blue = columns
    red_mask = (target_img[:,:,0] > 200) & (target_img[:,:,1] < 100) & (target_img[:,:,2] < 100)
    blue_mask = (target_img[:,:,2] > 200) & (target_img[:,:,0] < 100) & (target_img[:,:,1] < 100)
    
    total_pixels = target_img.shape[0] * target_img.shape[1]
    
    return {
        'width': w // 2,
        'height': img_np.shape[0],
        'wall_ratio': np.sum(red_mask) / total_pixels,
        'column_ratio': np.sum(blue_mask) / total_pixels
    }

# Analyze all images
if 'image_files' in dir() and len(image_files) > 0:
    stats = [analyze_image(f) for f in image_files]
    
    print(f"Dataset Statistics ({len(stats)} images):")
    print(f"Image size: {stats[0]['width']} x {stats[0]['height']}")
    print(f"\nShear Wall Ratio:")
    print(f"  Mean: {np.mean([s['wall_ratio'] for s in stats]):.4f}")
    print(f"  Std:  {np.std([s['wall_ratio'] for s in stats]):.4f}")
    print(f"  Min:  {np.min([s['wall_ratio'] for s in stats]):.4f}")
    print(f"  Max:  {np.max([s['wall_ratio'] for s in stats]):.4f}")
    print(f"\nColumn Ratio:")
    print(f"  Mean: {np.mean([s['column_ratio'] for s in stats]):.4f}")
    print(f"  Std:  {np.std([s['column_ratio'] for s in stats]):.4f}")

In [None]:
# Plot distributions
if 'stats' in dir():
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].hist([s['wall_ratio'] for s in stats], bins=20, edgecolor='black')
    axes[0].set_title('Shear Wall Ratio Distribution')
    axes[0].set_xlabel('Wall Ratio')
    axes[0].set_ylabel('Frequency')
    
    axes[1].hist([s['column_ratio'] for s in stats], bins=20, edgecolor='black')
    axes[1].set_title('Column Ratio Distribution')
    axes[1].set_xlabel('Column Ratio')
    axes[1].set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.show()

## 5. Test DataLoader

In [None]:
# Test our custom dataset class
if 'dataset_dir' in dir() and dataset_dir.exists():
    dataset = StructGANDataset(
        root_dir=str(dataset_dir),
        split="train",
        image_size=256,
        paired_format="side_by_side"
    )
    
    print(f"Dataset size: {len(dataset)}")
    
    # Get a sample
    input_tensor, target_tensor = dataset[0]
    print(f"Input shape: {input_tensor.shape}")
    print(f"Target shape: {target_tensor.shape}")
    print(f"Input range: [{input_tensor.min():.2f}, {input_tensor.max():.2f}]")
    print(f"Target range: [{target_tensor.min():.2f}, {target_tensor.max():.2f}]")

In [None]:
# Visualize tensor samples
if 'dataset' in dir():
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    for i in range(4):
        input_t, target_t = dataset[i]
        
        # Convert from [-1, 1] to [0, 1]
        input_img = ((input_t + 1) / 2).permute(1, 2, 0).numpy()
        target_img = ((target_t + 1) / 2).permute(1, 2, 0).numpy()
        
        axes[0, i].imshow(input_img)
        axes[0, i].set_title(f'Input {i+1}')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(target_img)
        axes[1, i].set_title(f'Target {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

## 6. Color Analysis of Structural Elements

In [None]:
def analyze_colors(image_path):
    """Analyze the color distribution in target image."""
    img = Image.open(image_path)
    img_np = np.array(img)
    
    w = img_np.shape[1]
    target = img_np[:, w//2:, :]
    
    # Get unique colors
    pixels = target.reshape(-1, 3)
    unique_colors = np.unique(pixels, axis=0)
    
    # Count each color
    color_counts = {}
    for color in unique_colors:
        mask = np.all(pixels == color, axis=1)
        count = np.sum(mask)
        color_counts[tuple(color)] = count
    
    return color_counts

# Analyze colors in first image
if 'image_files' in dir() and len(image_files) > 0:
    colors = analyze_colors(image_files[0])
    
    print("Top 10 colors in structural layout:")
    sorted_colors = sorted(colors.items(), key=lambda x: x[1], reverse=True)[:10]
    
    fig, axes = plt.subplots(1, len(sorted_colors), figsize=(2*len(sorted_colors), 2))
    
    for i, (color, count) in enumerate(sorted_colors):
        patch = np.ones((50, 50, 3), dtype=np.uint8) * np.array(color, dtype=np.uint8)
        axes[i].imshow(patch)
        axes[i].set_title(f'{count} px\n{color}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

## Summary

Key observations about the dataset:

1. **Format**: Images are paired side-by-side (input|target)
2. **Size**: Typically 512x256 combined (256x256 each)
3. **Color Coding**:
   - Input: Room colors, black walls, door/window markers
   - Target: Red = shear walls, Blue = columns, White = background
4. **Distribution**: Wall coverage varies, columns are sparse

Next steps:
- Run baseline training: `python src/training/train_baseline.py`
- Monitor with TensorBoard: `tensorboard --logdir models/checkpoints`