# ScanNet Instance Segmentation Data Exploration

This notebook explores the ScanNet dataset and analyzes various aspects of the point cloud data for instance segmentation.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import torch
import h5py
import yaml
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import open3d as o3d
from tqdm.notebook import tqdm

from src.data import ScanNetDataset
from src.utils.data_utils import DataProcessor

# Set plotting style
plt.style.use('seaborn')
sns.set_palette('husl')
%matplotlib inline

## 1. Load Configuration

In [None]:
# Load config
with open('../configs/data_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Initialize dataset
dataset = ScanNetDataset(
    root_dir='../data/scannet',
    split='train',
    config_path='../configs/data_config.yaml'
)

## 2. Dataset Statistics

In [None]:
def analyze_dataset_statistics(dataset):
    """Analyze basic statistics of the dataset."""
    stats = {
        'num_scenes': len(dataset),
        'points_per_scene': [],
        'instances_per_scene': [],
        'semantic_classes': set(),
        'points_per_instance': []
    }
    
    for i in tqdm(range(len(dataset)), desc='Analyzing dataset'):
        data = dataset[i]
        
        # Count points
        stats['points_per_scene'].append(len(data['points']))
        
        # Count instances
        unique_instances = torch.unique(data['instance_labels'])
        stats['instances_per_scene'].append(len(unique_instances))
        
        # Count semantic classes
        stats['semantic_classes'].update(torch.unique(data['semantic_labels']).numpy())
        
        # Count points per instance
        for inst_id in unique_instances:
            if inst_id == 0:  # Skip background
                continue
            stats['points_per_instance'].append(
                torch.sum(data['instance_labels'] == inst_id).item()
            )
    
    return stats

stats = analyze_dataset_statistics(dataset)

print(f"Dataset Statistics:")
print(f"Number of scenes: {stats['num_scenes']}")
print(f"Number of semantic classes: {len(stats['semantic_classes'])}")
print(f"\nPoints per scene:")
print(f"  Mean: {np.mean(stats['points_per_scene']):.2f}")
print(f"  Std: {np.std(stats['points_per_scene']):.2f}")
print(f"  Min: {np.min(stats['points_per_scene'])}")
print(f"  Max: {np.max(stats['points_per_scene'])}")

## 3. Visualize Point Cloud Distribution

In [None]:
# Plot distribution of points per scene
plt.figure(figsize=(10, 6))
plt.hist(stats['points_per_scene'], bins=50)
plt.title('Distribution of Points per Scene')
plt.xlabel('Number of Points')
plt.ylabel('Count')
plt.show()

# Plot distribution of instances per scene
plt.figure(figsize=(10, 6))
plt.hist(stats['instances_per_scene'], bins=30)
plt.title('Distribution of Instances per Scene')
plt.xlabel('Number of Instances')
plt.ylabel('Count')
plt.show()

## 4. Analyze Feature Distributions

In [None]:
def analyze_features(data):
    """Analyze feature distributions for a single scene."""
    feature_names = ['geometric_features', 'contextual_features']
    
    fig = plt.figure(figsize=(15, 5 * len(feature_names)))
    
    for i, feat_name in enumerate(feature_names):
        if feat_name not in data or data[feat_name] is None:
            continue
            
        features = data[feat_name].numpy()
        n_features = features.shape[1]
        
        for j in range(n_features):
            plt.subplot(len(feature_names), n_features, i*n_features + j + 1)
            plt.hist(features[:, j], bins=50)
            plt.title(f'{feat_name}_{j}')
            plt.xlabel('Value')
            plt.ylabel('Count')
    
    plt.tight_layout()
    plt.show()

# Analyze features for a sample scene
sample_data = dataset[0]
analyze_features(sample_data)

## 5. Visualize 3D Point Clouds

In [None]:
def visualize_scene(points, instance_labels, semantic_labels=None):
    """Visualize point cloud with instance and semantic labels."""
    # Create color map for instances
    unique_instances = torch.unique(instance_labels)
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_instances)))[:, :3]
    
    # Create point cloud
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points.numpy())
    
    # Color by instance
    point_colors = np.zeros((len(points), 3))
    for i, inst_id in enumerate(unique_instances):
        mask = (instance_labels == inst_id)
        point_colors[mask] = colors[i]
    
    pcd.colors = o3d.utility.Vector3dVector(point_colors)
    
    # Visualize
    o3d.visualization.draw_geometries([pcd])

# Visualize sample scene
sample_data = dataset[0]
visualize_scene(
    sample_data['points'],
    sample_data['instance_labels'],
    sample_data['semantic_labels']
)

## 6. Analyze Instance Size Distribution

In [None]:
def analyze_instance_sizes():
    instance_sizes = []
    semantic_distribution = {}
    
    for i in tqdm(range(len(dataset)), desc='Analyzing instances'):
        data = dataset[i]
        instance_labels = data['instance_labels']
        semantic_labels = data['semantic_labels']
        
        for inst_id in torch.unique(instance_labels):
            if inst_id == 0:  # Skip background
                continue
                
            mask = (instance_labels == inst_id)
            size = torch.sum(mask).item()
            sem_label = semantic_labels[mask][0].item()
            
            instance_sizes.append(size)
            semantic_distribution[sem_label] = semantic_distribution.get(sem_label, 0) + 1
    
    return instance_sizes, semantic_distribution

instance_sizes, semantic_dist = analyze_instance_sizes()

# Plot instance size distribution
plt.figure(figsize=(10, 6))
plt.hist(np.log10(instance_sizes), bins=50)
plt.title('Distribution of Instance Sizes (log scale)')
plt.xlabel('Log10(Number of Points)')
plt.ylabel('Count')
plt.show()

# Plot semantic class distribution
plt.figure(figsize=(12, 6))
plt.bar(semantic_dist.keys(), semantic_dist.values())
plt.title('Distribution of Semantic Classes')
plt.xlabel('Semantic Class ID')
plt.ylabel('Number of Instances')
plt.xticks(rotation=45)
plt.show()

## 7. Analyze Spatial Distribution

In [None]:
def analyze_spatial_distribution(points):
    """Analyze spatial distribution of points."""
    # Create 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=points[:, 2],
            colorscale='Viridis',
            opacity=0.8
        )
    )])
    
    fig.update_layout(
        title='Spatial Distribution of Points',
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        )
    )
    
    fig.show()

# Analyze spatial distribution for a sample scene
sample_data = dataset[0]
analyze_spatial_distribution(sample_data['points'])

## 8. Feature Correlation Analysis

In [None]:
def analyze_feature_correlations(data):
    """Analyze correlations between different features."""
    # Combine all features
    feature_dict = {}
    
    if 'geometric_features' in data and data['geometric_features'] is not None:
        for i in range(data['geometric_features'].shape[1]):
            feature_dict[f'geometric_{i}'] = data['geometric_features'][:, i]
            
    if 'contextual_features' in data and data['contextual_features'] is not None:
        for i in range(data['contextual_features'].shape[1]):
            feature_dict[f'contextual_{i}'] = data['contextual_features'][:, i]
            
    # Create correlation matrix
    feature_matrix = np.array(list(feature_dict.values())).T
    correlation_matrix = np.corrcoef(feature_matrix.T)
    
    # Plot correlation matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        correlation_matrix,
        xticklabels=list(feature_dict.keys()),
        yticklabels=list(feature_dict.keys()),
        cmap='coolwarm',
        center=0,
        annot=True,
        fmt='.2f'
    )
    plt.title('Feature Correlation Matrix')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
    
    return correlation_matrix

# Analyze feature correlations for a sample scene
sample_data = dataset[0]
correlation_matrix = analyze_feature_correlations(sample_data)

## 9. Instance Boundary Analysis

In [None]:
def analyze_instance_boundaries(data):
    """Analyze characteristics of instance boundaries."""
    points = data['points']
    instance_labels = data['instance_labels']
    
    # Create KD-tree for nearest neighbor search
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points.numpy())
    kdtree = o3d.geometry.KDTreeFlann(pcd)
    
    # Find boundary points
    boundary_points = []
    for i in tqdm(range(len(points)), desc='Analyzing boundaries'):
        [_, idx, _] = kdtree.search_knn_vector_3d(pcd.points[i], 30)
        neighbor_labels = instance_labels[idx]
        if len(torch.unique(neighbor_labels)) > 1:
            boundary_points.append(i)
    
    boundary_points = np.array(boundary_points)
    
    # Visualize boundary points
    colors = np.zeros((len(points), 3))
    colors[boundary_points] = [1, 0, 0]  # Red for boundary points
    
    pcd.colors = o3d.utility.Vector3dVector(colors)
    o3d.visualization.draw_geometries([pcd])
    
    return boundary_points

# Analyze boundaries for a sample scene
sample_data = dataset[0]
boundary_points = analyze_instance_boundaries(sample_data)

## 10. Instance Size vs. Feature Analysis

In [None]:
def analyze_size_feature_relationship(data):
    """Analyze relationship between instance size and features."""
    instance_labels = data['instance_labels']
    unique_instances = torch.unique(instance_labels)
    
    # Calculate instance sizes and mean features
    sizes = []
    mean_features = []
    
    for inst_id in unique_instances:
        if inst_id == 0:  # Skip background
            continue
            
        mask = (instance_labels == inst_id)
        sizes.append(torch.sum(mask).item())
        
        if 'geometric_features' in data and data['geometric_features'] is not None:
            feat_mean = data['geometric_features'][mask].mean(0).numpy()
            mean_features.append(feat_mean)
    
    sizes = np.array(sizes)
    mean_features = np.array(mean_features)
    
    # Plot relationships
    if len(mean_features) > 0:
        fig, axes = plt.subplots(1, mean_features.shape[1], figsize=(15, 5))
        for i in range(mean_features.shape[1]):
            axes[i].scatter(np.log10(sizes), mean_features[:, i], alpha=0.6)
            axes[i].set_xlabel('Log10(Instance Size)')
            axes[i].set_ylabel(f'Feature {i}')
            axes[i].set_title(f'Size vs Feature {i}')
        
        plt.tight_layout()
        plt.show()
    
    return sizes, mean_features

# Analyze size-feature relationships
sample_data = dataset[0]
sizes, mean_features = analyze_size_feature_relationship(sample_data)

## 11. Data Augmentation Visualization

In [None]:
def visualize_augmentations(data):
    """Visualize effects of different data augmentations."""
    from src.data.augmentation import AugmentationPipeline
    
    # Initialize augmentation pipeline
    augmenter = AugmentationPipeline('../configs/data_config.yaml')
    
    # Apply different augmentations
    augmentations = ['random_rotation', 'random_scale', 'random_flip']
    
    fig = plt.figure(figsize=(15, 5))
    
    # Original
    ax = fig.add_subplot(141, projection='3d')
    ax.scatter(data['points'][:, 0], data['points'][:, 1], data['points'][:, 2], 
              c=data['instance_labels'], cmap='tab20', s=1)
    ax.set_title('Original')
    
    # Augmented versions
    for i, aug_name in enumerate(augmentations, 2):
        # Apply single augmentation
        augmented_data = augmenter._augment_single(data.copy(), aug_name)
        
        ax = fig.add_subplot(1, 4, i, projection='3d')
        ax.scatter(augmented_data['points'][:, 0],
                  augmented_data['points'][:, 1],
                  augmented_data['points'][:, 2],
                  c=augmented_data['instance_labels'],
                  cmap='tab20', s=1)
        ax.set_title(aug_name)
    
    plt.tight_layout()
    plt.show()

# Visualize augmentations for a sample scene
sample_data = dataset[0]
visualize_augmentations(sample_data)

## 12. Scene Complexity Analysis

In [None]:
def analyze_scene_complexity():
    """Analyze scene complexity metrics across the dataset."""
    complexity_metrics = {
        'num_instances': [],
        'instance_density': [],  # instances per cubic meter
        'point_density': [],     # points per cubic meter
        'volume': [],            # scene volume
        'instance_overlap': []   # average number of instances in local neighborhoods
    }
    
    for i in tqdm(range(len(dataset)), desc='Analyzing scene complexity'):
        data = dataset[i]
        points = data['points']
        instance_labels = data['instance_labels']
        
        # Calculate scene volume
        bounds = torch.max(points, dim=0)[0] - torch.min(points, dim=0)[0]
        volume = bounds.prod().item()
        complexity_metrics['volume'].append(volume)
        
        # Count instances
        num_instances = len(torch.unique(instance_labels)) - 1  # exclude background
        complexity_metrics['num_instances'].append(num_instances)
        
        # Calculate densities
        complexity_metrics['instance_density'].append(num_instances / volume)
        complexity_metrics['point_density'].append(len(points) / volume)
        
        # Calculate instance overlap
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points.numpy())
        kdtree = o3d.geometry.KDTreeFlann(pcd)
        
        overlap_counts = []
        sample_points = np.random.choice(len(points), size=min(100, len(points)))
        
        for idx in sample_points:
            [_, neighbors, _] = kdtree.search_radius_vector_3d(pcd.points[idx], 0.5)
            overlap_counts.append(len(torch.unique(instance_labels[neighbors])))
            
        complexity_metrics['instance_overlap'].append(np.mean(overlap_counts))
    
    # Plot complexity metrics
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    
    axes[0, 0].hist(complexity_metrics['num_instances'], bins=30)
    axes[0, 0].set_title('Number of Instances')
    
    axes[0, 1].hist(complexity_metrics['instance_density'], bins=30)
    axes[0, 1].set_title('Instance Density')
    
    axes[1, 0].hist(complexity_metrics['point_density'], bins=30)
    axes[1, 0].set_title('Point Density')
    
    axes[1, 1].hist(complexity_metrics['instance_overlap'], bins=30)
    axes[1, 1].set_title('Average Instance Overlap')
    
    plt.tight_layout()
    plt.show()
    
    return complexity_metrics

complexity_metrics = analyze_scene_complexity()

## 13. Summary and Conclusions

In [None]:
def print_dataset_summary(stats, complexity_metrics):
    """Print summary of dataset analysis."""
    print("Dataset Summary:")
    print("===============")
    print(f"Number of scenes: {stats['num_scenes']}")
    print(f"Number of semantic classes: {len(stats['semantic_classes'])}")
    print("\nScene Statistics:")
    print("----------------")
    print(f"Average points per scene: {np.mean(stats['points_per_scene']):.2f} ± {np.std(stats['points_per_scene']):.2f}")
    print(f"Average instances per scene: {np.mean(stats['instances_per_scene']):.2f} ± {np.std(stats['instances_per_scene']):.2f}")
    
    print("\nComplexity Metrics:")
    print("-----------------")
    print(f"Average instance density: {np.mean(complexity_metrics['instance_density']):.2f} instances/m³")
    print(f"Average point density: {np.mean(complexity_metrics['point_density']):.2f} points/m³")
    print(f"Average instance overlap: {np.mean(complexity_metrics['instance_overlap']):.2f} instances")
    
    print("\nKey Observations:")
    print("----------------")
    print("1. Instance Distribution:")
    print(f"   - Most scenes have between {np.percentile(stats['instances_per_scene'], 25):.0f} and "
          f"{np.percentile(stats['instances_per_scene'], 75):.0f} instances")
    
    print("\n2. Point Cloud Density:")
    print(f"   - Point density varies significantly across scenes")
    print(f"   - {np.percentile(complexity_metrics['point_density'], 90):.0f} points/m³ for the densest 10% of scenes")
    
    print("\n3. Instance Complexity:")
    print(f"   - Average instance overlap suggests complex spatial relationships")
    print(f"   - Instance density indicates challenging segmentation scenarios")

# Print summary
print_dataset_summary(stats, complexity_metrics)

## 14. Feature Importance Analysis

In [None]:
def analyze_feature_importance(dataset, num_samples=1000):
    """Analyze feature importance using simple statistical measures."""
    # Collect features and labels
    all_features = []
    all_labels = []
    
    for i in tqdm(range(min(num_samples, len(dataset))), desc='Collecting features'):
        data = dataset[i]
        if 'geometric_features' in data and data['geometric_features'] is not None:
            features = torch.cat([
                data['geometric_features'],
                data['contextual_features']
            ], dim=1) if 'contextual_features' in data else data['geometric_features']
            
            all_features.append(features)
            all_labels.append(data['instance_labels'])
    
    features = torch.cat(all_features, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    
    # Calculate feature statistics
    feature_stats = {
        'variance': np.var(features, axis=0),
        'instance_separation': []
    }
    
    # Calculate instance separation power
    for i in range(features.shape[1]):
        unique_instances = np.unique(labels)
        instance_means = np.array([np.mean(features[labels == inst, i]) 
                                  for inst in unique_instances])
        separation = np.var(instance_means) / np.var(features[:, i])
        feature_stats['instance_separation'].append(separation)
    
    # Plot feature importance
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    
    # Variance plot
    ax1.bar(range(len(feature_stats['variance'])), feature_stats['variance'])
    ax1.set_title('Feature Variance')
    ax1.set_xlabel('Feature Index')
    ax1.set_ylabel('Variance')
    
    # Instance separation plot
    ax2.bar(range(len(feature_stats['instance_separation'])), 
            feature_stats['instance_separation'])
    ax2.set_title('Instance Separation Power')
    ax2.set_xlabel('Feature Index')
    ax2.set_ylabel('Separation Score')
    
    plt.tight_layout()
    plt.show()
    
    return feature_stats

# Analyze feature importance
feature_stats = analyze_feature_importance(dataset)

## 15. Instance Relationship Analysis

In [None]:
def analyze_instance_relationships(data):
    """Analyze spatial relationships between instances."""
    points = data['points']
    instance_labels = data['instance_labels']
    unique_instances = torch.unique(instance_labels)
    
    # Calculate instance centroids and bounding boxes
    centroids = {}
    bboxes = {}
    
    for inst_id in unique_instances:
        if inst_id == 0:  # Skip background
            continue
            
        mask = (instance_labels == inst_id)
        inst_points = points[mask]
        
        centroids[inst_id.item()] = inst_points.mean(0)
        bboxes[inst_id.item()] = {
            'min': inst_points.min(0)[0],
            'max': inst_points.max(0)[0]
        }
    
    # Calculate instance relationships
    relationships = []
    
    for id1 in centroids.keys():
        for id2 in centroids.keys():
            if id1 >= id2:
                continue
                
            # Calculate centroid distance
            dist = torch.norm(centroids[id1] - centroids[id2])
            
            # Check for overlap
            overlap = True
            for dim in range(3):
                if (bboxes[id1]['min'][dim] > bboxes[id2]['max'][dim] or
                    bboxes[id1]['max'][dim] < bboxes[id2]['min'][dim]):
                    overlap = False
                    break
                    
            relationships.append({
                'instance1': id1,
                'instance2': id2,
                'distance': dist.item(),
                'overlapping': overlap
            })
    
    # Visualize relationships
    plt.figure(figsize=(10, 5))
    
    # Distance distribution
    distances = [r['distance'] for r in relationships]
    plt.subplot(121)
    plt.hist(distances, bins=30)
    plt.title('Instance Distance Distribution')
    plt.xlabel('Distance')
    plt.ylabel('Count')
    
    # Overlap statistics
    overlap_count = sum(r['overlapping'] for r in relationships)
    plt.subplot(122)
    plt.pie([overlap_count, len(relationships) - overlap_count],
            labels=['Overlapping', 'Non-overlapping'],
            autopct='%1.1f%%')
    plt.title('Instance Overlap Statistics')
    
    plt.tight_layout()
    plt.show()
    
    return relationships

# Analyze instance relationships for a sample scene
sample_data = dataset[0]
relationships = analyze_instance_relationships(sample_data)

## 16. Recommendations for Model Design

Based on the analysis above, here are key considerations for model design:

1. **Feature Engineering**:
   - Geometric features show strong instance separation power
   - Consider using multi-scale features due to varying instance sizes
   - Include local density information due to varying point density

2. **Architecture Choices**:
   - Need to handle varying number of instances per scene
   - Consider attention mechanisms for complex spatial relationships
   - Include multi-resolution processing for different instance sizes

3. **Training Strategy**:
   - Use instance overlap awareness in loss function
   - Consider curriculum learning based on scene complexity
   - Implement robust data augmentation due to spatial variations

4. **Evaluation Metrics**:
   - Include metrics for different instance sizes
   - Consider boundary quality metrics
   - Evaluate performance on overlapping instances separately