In [9]:
# ===============================================================
# 1. Import Libraries
# ===============================================================

# Standard Libraries
import json
import logging
import math
import os
import queue
import random
import sys
import threading
from collections import defaultdict
from dataclasses import dataclass

# Data Processing and Scientific Computing
import numpy as np
from sklearn.model_selection import train_test_split

from typing import Any, Dict, Optional, Set, Tuple


# PyTorch and Deep Learning
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch import nn, optim
from torch.cuda.amp import GradScaler, autocast
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torchvision.models import ResNet18_Weights, resnet18

# Visualization and GUI
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from PIL import Image, ImageTk

# Tree and Graph Processing
from anytree import RenderTree
from anytree.exporter import DotExporter

# Custom Modules
from Augment import GridAugmentor
from grid_pair import GridPair
from DNN import DeepModelTrainer
from cnn_grid_mapper import CNNGridMapper
from data_structures import GridPair
from grid_pair import GridPair
from data_tree import DataTree
from datasets import AugmentedARCDataset
from embedding import TransformerEmbeddings, GridTransformerEmbeddings, EmbeddingConfig
from ensemble_training import (
    EnsembleModel, DualModelTraining,
    EnsembleTrainer, EnsembleConfig
)

from evaluation import ModelEvaluator
from grid_utils import (
    GridPair, create_data_dict, flatten_and_reshape,
    get_device, grid_to_image, handle_batch_size,
    handle_channels, interpolate_tensor
)
from gui import TrainingGUI
from logger_setup import setup_logger
from node import Node
from reward_based_model import RewardBasedModel
from simple_transformer import SimpleTransformer
from train_model_with_gui import train_model_with_gui

# Global Configuration
logger = setup_logger(__name__)
matplotlib.rcParams['font.family'] = 'DejaVu Sans'

ImportError: cannot import name 'create_data_dict' from 'grid_utils' (C:\Users\Owner\grid_utils.py)

In [None]:
class SimpleTransformer(nn.Module):
    """Enhanced transformer model with proper embeddings support."""
    
    def __init__(
        self,
        vocab_size: int = 5000,
        d_model: int = 512,
        nhead: int = 8,
        num_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        is_grid: bool = True
    ):
        """Initialize transformer.
        
        Args:
            vocab_size: Size of vocabulary
            d_model: Model dimension
            nhead: Number of attention heads
            num_layers: Number of transformer layers
            dim_feedforward: Feedforward dimension
            dropout: Dropout rate
            is_grid: Whether input is grid-structured
        """
        super().__init__()
        
        # Create embedding configuration
        embedding_config = EmbeddingConfig(
            vocab_size=vocab_size,
            d_model=d_model,
            dropout=dropout,
            use_position_embedding=True,
            learn_position=False
        )
        
        # Initialize embeddings
        self.embedding = (
            GridTransformerEmbeddings(embedding_config)
            if is_grid else
            TransformerEmbeddings(embedding_config)
        )
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        
        # Output decoder
        self.decoder = nn.Linear(d_model, vocab_size)
        
    def forward(
        self,
        src: torch.Tensor,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass.
        
        Args:
            src: Input tensor
            src_mask: Optional attention mask
            src_key_padding_mask: Optional key padding mask
            
        Returns:
            Output tensor
        """
        # Embed input
        embedded = self.embedding(src)
        
        # Apply transformer encoder
        encoded = self.transformer_encoder(
            embedded,
            mask=src_mask,
            src_key_padding_mask=src_key_padding_mask
        )
        
        # Decode to vocabulary
        output = self.decoder(encoded)
        
        return F.log_softmax(output, dim=-1)

In [None]:
from typing import Dict, Optional, List, Any
import json
from pathlib import Path
import hashlib
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
import gzip
import shutil
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class ARCDataConfig:
    """Configuration for ARC data loading.
    
    Attributes:
        data_dir: Directory containing data files
        cache_dir: Directory for caching data
        use_cache: Whether to use data caching
        validate_data: Whether to validate loaded data
        parallel_loading: Whether to load files in parallel
        compression: Whether to use compression
        checksum_verification: Whether to verify file checksums
    """
    data_dir: str = "data"
    cache_dir: str = "cache"
    use_cache: bool = True
    validate_data: bool = True
    parallel_loading: bool = True
    compression: bool = True
    checksum_verification: bool = True

class ARCDataLoader:
    """Loader for ARC dataset with advanced features."""
    
    FILE_PATHS = {
        "training_challenges": "arc-agi_training_challenges.json",
        "evaluation_challenges": "arc-agi_evaluation_challenges.json",
        "training_solutions": "arc-agi_training_solutions.json",
        "evaluation_solutions": "arc-agi_evaluation_solutions.json"
    }
    
    CHECKSUMS = {
        # Add expected SHA-256 checksums for each file
        "arc-agi_training_challenges.json": "...",
        "arc-agi_evaluation_challenges.json": "...",
        "arc-agi_training_solutions.json": "...",
        "arc-agi_evaluation_solutions.json": "..."
    }
    
    def __init__(
        self,
        config: Optional[ARCDataConfig] = None
    ):
        """Initialize data loader.
        
        Args:
            config: Optional configuration
        """
        self.config = config or ARCDataConfig()
        self._setup_directories()
        self.data_cache: Dict[str, Any] = {}
        
    def _setup_directories(self) -> None:
        """Create necessary directories."""
        Path(self.config.data_dir).mkdir(parents=True, exist_ok=True)
        if self.config.use_cache:
            Path(self.config.cache_dir).mkdir(parents=True, exist_ok=True)
            
    def load_data(self) -> Dict[str, Dict]:
        """Load ARC dataset.
        
        Returns:
            Dictionary containing dataset components
        """
        try:
            if self.config.use_cache and self._cache_exists():
                return self._load_from_cache()
                
            if self.config.parallel_loading:
                data = self._load_parallel()
            else:
                data = self._load_sequential()
                
            if self.config.use_cache:
                self._save_to_cache(data)
                
            return data
            
        except Exception as e:
            logger.error(f"Error loading ARC data: {e}")
            raise
            
    def _load_sequential(self) -> Dict[str, Dict]:
        """Load files sequentially.
        
        Returns:
            Loaded data dictionary
        """
        data = {}
        for key, filename in self.FILE_PATHS.items():
            try:
                file_path = Path(self.config.data_dir) / filename
                data[key] = self._load_single_file(file_path)
                logger.info(
                    f"Loaded {key} with {len(data[key])} items "
                    f"from {filename}"
                )
            except Exception as e:
                logger.error(f"Error loading {filename}: {e}")
                data[key] = {}
                
        return data
        
    def _load_parallel(self) -> Dict[str, Dict]:
        """Load files in parallel.
        
        Returns:
            Loaded data dictionary
        """
        data = {}
        with ThreadPoolExecutor() as executor:
            future_to_key = {
                executor.submit(
                    self._load_single_file,
                    Path(self.config.data_dir) / filename
                ): key
                for key, filename in self.FILE_PATHS.items()
            }
            
            for future in future_to_key:
                key = future_to_key[future]
                try:
                    data[key] = future.result()
                    logger.info(
                        f"Loaded {key} with {len(data[key])} items"
                    )
                except Exception as e:
                    logger.error(f"Error loading {key}: {e}")
                    data[key] = {}
                    
        return data
        
    def _load_single_file(self, file_path: Path) -> Dict:
        """Load and validate single file.
        
        Args:
            file_path: Path to file
            
        Returns:
            Loaded data dictionary
        """
        if not file_path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")
            
        if self.config.checksum_verification:
            self._verify_checksum(file_path)
            
        if file_path.suffix == '.gz':
            with gzip.open(file_path, 'rt') as f:
                data = json.load(f)
        else:
            with open(file_path, 'r') as f:
                data = json.load(f)
                
        if self.config.validate_data:
            self._validate_data(data, file_path.name)
            
        return data
        
    def _verify_checksum(self, file_path: Path) -> None:
        """Verify file checksum.
        
        Args:
            file_path: Path to file
            
        Raises:
            ValueError: If checksum verification fails
        """
        if file_path.name not in self.CHECKSUMS:
            logger.warning(f"No checksum found for {file_path.name}")
            return
            
        with open(file_path, 'rb') as f:
            file_hash = hashlib.sha256(f.read()).hexdigest()
            
        if file_hash != self.CHECKSUMS[file_path.name]:
            raise ValueError(
                f"Checksum verification failed for {file_path.name}"
            )
            
    def _validate_data(
        self,
        data: Dict,
        filename: str
    ) -> None:
        """Validate loaded data.
        
        Args:
            data: Loaded data
            filename: Source filename
            
        Raises:
            ValueError: If validation fails
        """
        if not isinstance(data, dict):
            raise ValueError(f"Invalid data format in {filename}")
            
        required_keys = {'train', 'test'} if 'challenges' in filename else {'input', 'output'}
        
        for item in data.values():
            if not all(key in item for key in required_keys):
                raise ValueError(
                    f"Missing required keys in {filename}"
                )
                
    def _cache_exists(self) -> bool:
        """Check if cache exists.
        
        Returns:
            Whether cache exists
        """
        cache_file = Path(self.config.cache_dir) / "arc_data_cache.json.gz"
        return cache_file.exists()
        
    def _load_from_cache(self) -> Dict[str, Dict]:
        """Load data from cache.
        
        Returns:
            Cached data dictionary
        """
        cache_file = Path(self.config.cache_dir) / "arc_data_cache.json.gz"
        with gzip.open(cache_file, 'rt') as f:
            data = json.load(f)
            
        logger.info("Loaded data from cache")
        return data
        
    def _save_to_cache(self, data: Dict[str, Dict]) -> None:
        """Save data to cache.
        
        Args:
            data: Data to cache
        """
        cache_file = Path(self.config.cache_dir) / "arc_data_cache.json.gz"
        with gzip.open(cache_file, 'wt') as f:
            json.dump(data, f)
            
        logger.info("Saved data to cache")
        
    def get_dataset_stats(
        self,
        data: Dict[str, Dict]
    ) -> Dict[str, Any]:
        """Get dataset statistics.
        
        Args:
            data: Dataset dictionary
            
        Returns:
            Statistics dictionary
        """
        stats = {
            'total_examples': 0,
            'challenges': {
                'training': len(data.get('training_challenges', {})),
                'evaluation': len(data.get('evaluation_challenges', {}))
            },
            'solutions': {
                'training': len(data.get('training_solutions', {})),
                'evaluation': len(data.get('evaluation_solutions', {}))
            }
        }
        
        stats['total_examples'] = sum(stats['challenges'].values())
        
        return stats


# Create configuration
config = ARCDataConfig(
    data_dir="data",
    cache_dir="cache",
    use_cache=True,
    validate_data=True,
    parallel_loading=True
)

# Initialize loader
loader = ARCDataLoader(config)

# Load data
try:
    arc_data = loader.load_data()
    
    # Get statistics
    stats = loader.get_dataset_stats(arc_data)
    print(f"Dataset statistics: {stats}")
    
except Exception as e:
    print(f"Error loading data: {e}")

In [None]:
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass
from pathlib import Path
import json
import numpy as np
from collections import defaultdict
import pandas as pd
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class GridDimensions:
    """Container for grid dimensions and statistics.
    
    Attributes:
        height: Grid height
        width: Grid width
        min_value: Minimum value in grid
        max_value: Maximum value in grid
        unique_values: Number of unique values
        sparsity: Percentage of zero values
    """
    height: int
    width: int
    min_value: int
    max_value: int
    unique_values: int
    sparsity: float

@dataclass
class TaskAnalysis:
    """Container for task analysis results.
    
    Attributes:
        input_dims: Input grid dimensions
        output_dims: Output grid dimensions
        transformation_type: Type of grid transformation
        complexity_score: Estimated task complexity
        patterns: Detected patterns
    """
    input_dims: GridDimensions
    output_dims: GridDimensions
    transformation_type: str
    complexity_score: float
    patterns: Dict[str, Any]

class GridAnalyzer:
    """Analyzer for ARC grid data with advanced analytics."""
    
    def __init__(self):
        """Initialize grid analyzer."""
        self.results = defaultdict(lambda: defaultdict(list))
        self.statistics = {}
        
    def analyze_file(
        self,
        file_path: Path
    ) -> Dict[str, Dict[str, Any]]:
        """Analyze grids in file.
        
        Args:
            file_path: Path to data file
            
        Returns:
            Analysis results dictionary
        """
        try:
            # Load and validate data
            data = self._load_data(file_path)
            
            # Analyze each task
            for task_id, task_content in data.items():
                self._analyze_task(task_id, task_content)
                
            # Compute global statistics
            self._compute_statistics()
            
            logger.info(
                f"Analyzed {len(data)} tasks from {file_path}"
            )
            
            return dict(self.results)
            
        except Exception as e:
            logger.error(f"Error analyzing file {file_path}: {e}")
            raise
            
    def _load_data(
        self,
        file_path: Path
    ) -> Dict[str, Any]:
        """Load and validate data file.
        
        Args:
            file_path: Path to data file
            
        Returns:
            Loaded data dictionary
        """
        if not file_path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")
            
        with open(file_path, 'r') as f:
            data = json.load(f)
            
        if not isinstance(data, dict):
            raise ValueError("Invalid data format")
            
        return data
        
    def _analyze_task(
        self,
        task_id: str,
        task_content: Dict[str, Any]
    ) -> None:
        """Analyze single task.
        
        Args:
            task_id: Task identifier
            task_content: Task data
        """
        try:
            # Analyze train and test sections
            for section_name, sections in task_content.items():
                self._analyze_sections(
                    task_id,
                    section_name,
                    sections
                )
                
            # Analyze relationships between input and output
            self._analyze_relationships(task_id)
            
        except Exception as e:
            logger.error(f"Error analyzing task {task_id}: {e}")
            
    def _analyze_sections(
        self,
        task_id: str,
        section_name: str,
        sections: List[Dict[str, Any]]
    ) -> None:
        """Analyze sections of task.
        
        Args:
            task_id: Task identifier
            section_name: Section name
            sections: Section data
        """
        for section in sections:
            input_data = section.get("input", [])
            output_data = section.get("output", [])
            
            # Analyze dimensions
            self._analyze_dimensions(
                task_id,
                section_name,
                input_data,
                output_data
            )
            
            # Analyze patterns
            if "train" in section_name:
                self._analyze_patterns(
                    task_id,
                    input_data,
                    output_data
                )
                
    def _analyze_dimensions(
        self,
        task_id: str,
        section_name: str,
        input_data: List[Any],
        output_data: List[Any]
    ) -> None:
        """Analyze grid dimensions.
        
        Args:
            task_id: Task identifier
            section_name: Section name
            input_data: Input data
            output_data: Output data
        """
        # Process input data
        if isinstance(input_data, list):
            input_dims = self._get_grid_dimensions(input_data)
            self.results[task_id][f"{section_name}_input_dims"].append(
                input_dims
            )
            
        # Process output data
        if isinstance(output_data, list):
            output_dims = self._get_grid_dimensions(output_data)
            self.results[task_id][f"{section_name}_output_dims"].append(
                output_dims
            )
            
    def _get_grid_dimensions(
        self,
        grid: List[Any]
    ) -> Optional[GridDimensions]:
        """Calculate grid dimensions and statistics.
        
        Args:
            grid: Grid data
            
        Returns:
            Grid dimensions object
        """
        try:
            if not grid or not isinstance(grid[0], list):
                return None
                
            np_grid = np.array(grid)
            
            return GridDimensions(
                height=np_grid.shape[0],
                width=np_grid.shape[1],
                min_value=np.min(np_grid),
                max_value=np.max(np_grid),
                unique_values=len(np.unique(np_grid)),
                sparsity=np.mean(np_grid == 0)
            )
            
        except Exception:
            return None
            
    def _analyze_patterns(
        self,
        task_id: str,
        input_data: List[Any],
        output_data: List[Any]
    ) -> None:
        """Analyze patterns in grids.
        
        Args:
            task_id: Task identifier
            input_data: Input grid
            output_data: Output grid
        """
        patterns = {
            'symmetry': self._check_symmetry(input_data, output_data),
            'repetition': self._check_repetition(input_data, output_data),
            'transformation': self._identify_transformation(
                input_data,
                output_data
            )
        }
        
        self.results[task_id]['patterns'].append(patterns)
        
    def _check_symmetry(
        self,
        input_grid: List[Any],
        output_grid: List[Any]
    ) -> Dict[str, bool]:
        """Check for symmetry in grids.
        
        Args:
            input_grid: Input grid
            output_grid: Output grid
            
        Returns:
            Dictionary of symmetry properties
        """
        try:
            in_grid = np.array(input_grid)
            out_grid = np.array(output_grid)
            
            return {
                'horizontal': np.array_equal(
                    in_grid,
                    np.fliplr(in_grid)
                ),
                'vertical': np.array_equal(
                    in_grid,
                    np.flipud(in_grid)
                ),
                'diagonal': np.array_equal(
                    in_grid,
                    in_grid.T
                )
            }
            
        except Exception:
            return {}
            
    def _check_repetition(
        self,
        input_grid: List[Any],
        output_grid: List[Any]
    ) -> Dict[str, float]:
        """Check for repetitive patterns.
        
        Args:
            input_grid: Input grid
            output_grid: Output grid
            
        Returns:
            Dictionary of repetition metrics
        """
        try:
            in_grid = np.array(input_grid)
            
            unique, counts = np.unique(
                in_grid,
                return_counts=True
            )
            
            return {
                'unique_ratio': len(unique) / in_grid.size,
                'max_repetition': np.max(counts) / in_grid.size
            }
            
        except Exception:
            return {}
            
    def _identify_transformation(
        self,
        input_grid: List[Any],
        output_grid: List[Any]
    ) -> str:
        """Identify transformation type.
        
        Args:
            input_grid: Input grid
            output_grid: Output grid
            
        Returns:
            Transformation type string
        """
        try:
            in_grid = np.array(input_grid)
            out_grid = np.array(output_grid)
            
            if in_grid.shape == out_grid.shape:
                if np.array_equal(out_grid, np.rot90(in_grid)):
                    return "rotation_90"
                elif np.array_equal(out_grid, np.fliplr(in_grid)):
                    return "horizontal_flip"
                elif np.array_equal(out_grid, np.flipud(in_grid)):
                    return "vertical_flip"
                    
            return "unknown"
            
        except Exception:
            return "error"
            
    def _compute_statistics(self) -> None:
        """Compute global statistics."""
        stats = defaultdict(list)
        
        for task_results in self.results.values():
            for key, values in task_results.items():
                if 'dims' in key:
                    dims = [d for d in values if d is not None]
                    if dims:
                        stats[f"{key}_height"].extend(
                            [d.height for d in dims]
                        )
                        stats[f"{key}_width"].extend(
                            [d.width for d in dims]
                        )
                        
        self.statistics = {
            key: {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values)
            }
            for key, values in stats.items()
        }
        
    def get_summary(self) -> pd.DataFrame:
        """Generate summary DataFrame.
        
        Returns:
            Summary DataFrame
        """
        summary_data = []
        
        for task_id, results in self.results.items():
            summary = {
                'task_id': task_id,
                'input_dims': len(results['train_input_dims']),
                'output_dims': len(results['train_output_dims']),
                'patterns_found': len(results['patterns']),
                'complexity': self._calculate_complexity(results)
            }
            summary_data.append(summary)
            
        return pd.DataFrame(summary_data)
        
    def _calculate_complexity(
        self,
        task_results: Dict[str, Any]
    ) -> float:
        """Calculate task complexity score.
        
        Args:
            task_results: Task results dictionary
            
        Returns:
            Complexity score
        """
        complexity = 0.0
        
        # Consider grid sizes
        for dims in task_results.get('train_input_dims', []):
            if dims:
                complexity += dims.height * dims.width
                
        # Consider pattern complexity
        for pattern in task_results.get('patterns', []):
            if pattern.get('symmetry', {}).get('diagonal', False):
                complexity *= 1.5
            if pattern.get('repetition', {}).get('unique_ratio', 1.0) < 0.5:
                complexity *= 1.2
                
        return complexity

In [None]:
from typing import Set, Dict, Any, Optional, List, Union
from collections import Counter
from dataclasses import dataclass
import numpy as np
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class ClassStatistics:
    """Statistics about dataset classes.
    
    Attributes:
        num_classes: Number of unique classes
        class_distribution: Distribution of classes
        min_class: Minimum class value
        max_class: Maximum class value
        class_frequencies: Frequency of each class
        rare_classes: Classes with low frequency
        common_classes: Most common classes
    """
    num_classes: int
    class_distribution: Dict[int, float]
    min_class: int
    max_class: int
    class_frequencies: Counter
    rare_classes: Set[int]
    common_classes: Set[int]

class ClassAnalyzer:
    """Analyzer for determining and analyzing classes in ARC data."""
    
    def __init__(
        self,
        rare_threshold: float = 0.01,
        common_threshold: float = 0.1
    ):
        """Initialize class analyzer.
        
        Args:
            rare_threshold: Threshold for rare class determination
            common_threshold: Threshold for common class determination
        """
        self.rare_threshold = rare_threshold
        self.common_threshold = common_threshold
        self.class_cache = {}
        
    def analyze_classes(
        self,
        arc_data: Dict[str, Any]
    ) -> ClassStatistics:
        """Analyze classes in ARC data.
        
        Args:
            arc_data: ARC dataset dictionary
            
        Returns:
            Class statistics object
        """
        try:
            # Extract all classes
            classes = self._extract_classes(arc_data)
            
            # Calculate statistics
            stats = self._calculate_statistics(classes)
            
            logger.info(
                f"Found {stats.num_classes} unique classes "
                f"(min: {stats.min_class}, max: {stats.max_class})"
            )
            
            return stats
            
        except Exception as e:
            logger.error(f"Error analyzing classes: {e}")
            raise
            
    def _extract_classes(
        self,
        data: Dict[str, Any]
    ) -> Set[int]:
        """Extract unique classes from data.
        
        Args:
            data: Input data dictionary
            
        Returns:
            Set of unique classes
        """
        classes = set()
        
        for key, content in data.items():
            try:
                # Extract from dictionary structure
                if isinstance(content, dict):
                    classes.update(
                        self._extract_from_dict(content)
                    )
                # Extract from list structure
                elif isinstance(content, list):
                    classes.update(
                        self._extract_from_list(content)
                    )
                    
            except Exception as e:
                logger.warning(
                    f"Error extracting classes from {key}: {e}"
                )
                continue
                
        return classes
        
    def _extract_from_dict(
        self,
        content: Dict[str, Any]
    ) -> Set[int]:
        """Extract classes from dictionary structure.
        
        Args:
            content: Dictionary content
            
        Returns:
            Set of classes
        """
        classes = set()
        
        for mode in ["train", "test"]:
            entries = content.get(mode, [])
            if isinstance(entries, list):
                classes.update(
                    self._extract_from_entries(entries)
                )
                
        return classes
        
    def _extract_from_list(
        self,
        content: List[Any]
    ) -> Set[int]:
        """Extract classes from list structure.
        
        Args:
            content: List content
            
        Returns:
            Set of classes
        """
        return self._extract_from_entries(content)
        
    def _extract_from_entries(
        self,
        entries: List[Dict[str, Any]]
    ) -> Set[int]:
        """Extract classes from entries.
        
        Args:
            entries: List of data entries
            
        Returns:
            Set of classes
        """
        classes = set()
        
        for entry in entries:
            output = entry.get("output")
            if output is not None:
                classes.update(
                    self._process_output(output)
                )
                
        return classes
        
    def _process_output(
        self,
        output: Union[int, List[Any], np.ndarray]
    ) -> Set[int]:
        """Process output to extract classes.
        
        Args:
            output: Output data
            
        Returns:
            Set of classes
        """
        if isinstance(output, (int, float)):
            return {int(output)}
        elif isinstance(output, (list, np.ndarray)):
            if isinstance(output, np.ndarray):
                output = output.flatten().tolist()
            return set(map(int, self._flatten_list(output)))
        return set()
        
    def _flatten_list(
        self,
        lst: List[Any]
    ) -> List[Any]:
        """Flatten nested list.
        
        Args:
            lst: Input list
            
        Returns:
            Flattened list
        """
        flattened = []
        for item in lst:
            if isinstance(item, list):
                flattened.extend(self._flatten_list(item))
            else:
                flattened.append(item)
        return flattened
        
    def _calculate_statistics(
        self,
        classes: Set[int]
    ) -> ClassStatistics:
        """Calculate class statistics.
        
        Args:
            classes: Set of classes
            
        Returns:
            Class statistics object
        """
        # Calculate frequencies
        frequencies = Counter(classes)
        total = sum(frequencies.values())
        
        # Calculate distribution
        distribution = {
            cls: count/total
            for cls, count in frequencies.items()
        }
        
        # Identify rare and common classes
        rare_classes = {
            cls for cls, freq in distribution.items()
            if freq < self.rare_threshold
        }
        
        common_classes = {
            cls for cls, freq in distribution.items()
            if freq > self.common_threshold
        }
        
        return ClassStatistics(
            num_classes=len(classes),
            class_distribution=distribution,
            min_class=min(classes),
            max_class=max(classes),
            class_frequencies=frequencies,
            rare_classes=rare_classes,
            common_classes=common_classes
        )
        
    def analyze_class_transitions(
        self,
        arc_data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Analyze class transitions in sequences.
        
        Args:
            arc_data: ARC dataset dictionary
            
        Returns:
            Transition analysis dictionary
        """
        transitions = defaultdict(Counter)
        
        for content in arc_data.values():
            if isinstance(content, dict):
                for mode in ["train", "test"]:
                    entries = content.get(mode, [])
                    self._analyze_transitions(entries, transitions)
                    
        return dict(transitions)
        
    def _analyze_transitions(
        self,
        entries: List[Dict[str, Any]],
        transitions: Dict[int, Counter]
    ) -> None:
        """Analyze transitions in entries.
        
        Args:
            entries: Data entries
            transitions: Transition counter dictionary
        """
        for entry in entries:
            output = entry.get("output")
            if isinstance(output, list):
                for i in range(len(output) - 1):
                    current = output[i]
                    next_class = output[i + 1]
                    transitions[current][next_class] += 1


In [None]:
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
from dataclasses import dataclass
from enum import Enum
import random
from collections import defaultdict
from logger_setup import setup_logger

logger = setup_logger(__name__)

class GridPattern(Enum):
    """Types of grid patterns."""
    RANDOM = "random"
    DIAGONAL = "diagonal"
    CHECKERBOARD = "checkerboard"
    BORDER = "border"
    GRADIENT = "gradient"
    SYMMETRIC = "symmetric"

@dataclass
class GridConfig:
    """Configuration for grid creation.
    
    Attributes:
        num_classes: Number of possible values
        min_size: Minimum grid size
        max_size: Maximum grid size
        patterns: Enabled pattern types
        sparsity: Grid sparsity (0-1)
        symmetry_probability: Probability of symmetry
        noise_level: Level of noise to add
    """
    num_classes: int
    min_size: int = 2
    max_size: int = 10
    patterns: List[GridPattern] = None
    sparsity: float = 0.2
    symmetry_probability: float = 0.3
    noise_level: float = 0.1

class GridCreator:
    """Creator for grid patterns with advanced features."""
    
    def __init__(
        self,
        config: Optional[GridConfig] = None
    ):
        """Initialize grid creator.
        
        Args:
            config: Optional grid configuration
        """
        self.config = config or GridConfig(num_classes=10)
        
        if not self.config.patterns:
            self.config.patterns = list(GridPattern)
            
        # Initialize statistics
        self.stats = defaultdict(int)
        
    def create_grids(
        self,
        task_metadata: Dict[str, Dict[str, Any]]
    ) -> Dict[str, Dict[str, List[List[List[int]]]]]:
        """Create grids based on metadata.
        
        Args:
            task_metadata: Task metadata dictionary
            
        Returns:
            Dictionary of created grids
        """
        try:
            grids_by_task = {}
            
            for task_id, sections in task_metadata.items():
                grids_by_task[task_id] = self._create_task_grids(
                    task_id,
                    sections
                )
                
            logger.info(
                f"Created grids for {len(grids_by_task)} tasks"
            )
            
            return grids_by_task
            
        except Exception as e:
            logger.error(f"Error creating grids: {e}")
            raise
            
    def _create_task_grids(
        self,
        task_id: str,
        sections: Dict[str, Any]
    ) -> Dict[str, List[List[List[int]]]]:
        """Create grids for single task.
        
        Args:
            task_id: Task identifier
            sections: Section data
            
        Returns:
            Dictionary of created grids
        """
        task_grids = {}
        
        for section, data in sections.items():
            task_grids[section] = []
            
            for idx, lengths in data.items():
                for length, count in lengths.items():
                    task_grids[section].extend(
                        self._create_section_grids(length, count)
                    )
                    
        return task_grids
        
    def _create_section_grids(
        self,
        size: int,
        count: int
    ) -> List[List[List[int]]]:
        """Create grids for section.
        
        Args:
            size: Grid size
            count: Number of grids
            
        Returns:
            List of created grids
        """
        grids = []
        
        for _ in range(count):
            # Select random pattern
            pattern = random.choice(self.config.patterns)
            
            # Create grid with pattern
            grid = self._create_pattern_grid(
                size,
                pattern
            )
            
            # Apply transformations
            grid = self._apply_transformations(grid)
            
            grids.append(grid.tolist())
            self.stats[pattern.value] += 1
            
        return grids
        
    def _create_pattern_grid(
        self,
        size: int,
        pattern: GridPattern
    ) -> np.ndarray:
        """Create grid with specific pattern.
        
        Args:
            size: Grid size
            pattern: Pattern type
            
        Returns:
            Created grid
        """
        if pattern == GridPattern.RANDOM:
            return self._create_random_grid(size)
        elif pattern == GridPattern.DIAGONAL:
            return self._create_diagonal_grid(size)
        elif pattern == GridPattern.CHECKERBOARD:
            return self._create_checkerboard_grid(size)
        elif pattern == GridPattern.BORDER:
            return self._create_border_grid(size)
        elif pattern == GridPattern.GRADIENT:
            return self._create_gradient_grid(size)
        elif pattern == GridPattern.SYMMETRIC:
            return self._create_symmetric_grid(size)
        else:
            return self._create_random_grid(size)
            
    def _create_random_grid(
        self,
        size: int
    ) -> np.ndarray:
        """Create random grid.
        
        Args:
            size: Grid size
            
        Returns:
            Random grid
        """
        grid = np.random.randint(
            0,
            self.config.num_classes,
            size=(size, size)
        )
        
        # Apply sparsity
        if self.config.sparsity > 0:
            mask = np.random.random(size=(size, size)) < self.config.sparsity
            grid[mask] = 0
            
        return grid
        
    def _create_diagonal_grid(
        self,
        size: int
    ) -> np.ndarray:
        """Create diagonal pattern grid.
        
        Args:
            size: Grid size
            
        Returns:
            Diagonal grid
        """
        grid = np.zeros((size, size), dtype=int)
        value = random.randint(1, self.config.num_classes - 1)
        
        for i in range(size):
            grid[i, i] = value
            if random.random() < 0.5:
                grid[i, size-1-i] = value
                
        return grid
        
    def _create_checkerboard_grid(
        self,
        size: int
    ) -> np.ndarray:
        """Create checkerboard pattern grid.
        
        Args:
            size: Grid size
            
        Returns:
            Checkerboard grid
        """
        grid = np.zeros((size, size), dtype=int)
        values = random.sample(
            range(1, self.config.num_classes),
            k=2
        )
        
        for i in range(size):
            for j in range(size):
                grid[i, j] = values[(i + j) % 2]
                
        return grid
        
    def _create_border_grid(
        self,
        size: int
    ) -> np.ndarray:
        """Create border pattern grid.
        
        Args:
            size: Grid size
            
        Returns:
            Border grid
        """
        grid = np.zeros((size, size), dtype=int)
        value = random.randint(1, self.config.num_classes - 1)
        
        grid[0, :] = value
        grid[-1, :] = value
        grid[:, 0] = value
        grid[:, -1] = value
        
        return grid
        
    def _create_gradient_grid(
        self,
        size: int
    ) -> np.ndarray:
        """Create gradient pattern grid.
        
        Args:
            size: Grid size
            
        Returns:
            Gradient grid
        """
        x = np.linspace(0, 1, size)
        y = np.linspace(0, 1, size)
        X, Y = np.meshgrid(x, y)
        
        gradient = (X + Y) / 2
        return (gradient * (self.config.num_classes - 1)).astype(int)
        
    def _create_symmetric_grid(
        self,
        size: int
    ) -> np.ndarray:
        """Create symmetric pattern grid.
        
        Args:
            size: Grid size
            
        Returns:
            Symmetric grid
        """
        half_size = (size + 1) // 2
        half_grid = self._create_random_grid(half_size)
        
        # Mirror horizontally
        grid = np.concatenate(
            [half_grid, np.fliplr(half_grid[:, :size//2])],
            axis=1
        )
        
        return grid[:size, :size]
        
    def _apply_transformations(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Apply random transformations.
        
        Args:
            grid: Input grid
            
        Returns:
            Transformed grid
        """
        # Add noise
        if self.config.noise_level > 0:
            noise_mask = np.random.random(grid.shape) < self.config.noise_level
            noise = np.random.randint(
                0,
                self.config.num_classes,
                size=grid.shape
            )
            grid[noise_mask] = noise[noise_mask]
            
        # Apply symmetry
        if random.random() < self.config.symmetry_probability:
            if random.random() < 0.5:
                # Horizontal symmetry
                grid = np.concatenate(
                    [grid, np.flipud(grid)],
                    axis=0
                )
            else:
                # Vertical symmetry
                grid = np.concatenate(
                    [grid, np.fliplr(grid)],
                    axis=1
                )
                
        return grid
        
    def get_statistics(self) -> Dict[str, Any]:
        """Get creation statistics.
        
        Returns:
            Statistics dictionary
        """
        return {
            'pattern_counts': dict(self.stats),
            'total_grids': sum(self.stats.values()),
            'pattern_distribution': {
                pattern: count/sum(self.stats.values())
                for pattern, count in self.stats.items()
            }
        }

In [None]:
from typing import List, Union, Optional, Tuple
import numpy as np
from dataclasses import dataclass
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class GridShape:
    """Container for grid shape information.
    
    Attributes:
        original_shape: Original grid dimensions
        flattened_size: Size after flattening
        padding: Padding added (if any)
        dtype: Data type of grid elements
    """
    original_shape: Tuple[int, ...]
    flattened_size: int
    padding: Optional[int] = None
    dtype: np.dtype = np.int32

class GridFlattener:
    """Utility for flattening and processing grid data."""
    
    def __init__(
        self,
        pad_value: int = 0,
        normalize: bool = False,
        preserve_shape: bool = True
    ):
        """Initialize grid flattener.
        
        Args:
            pad_value: Value to use for padding
            normalize: Whether to normalize values
            preserve_shape: Whether to store shape information
        """
        self.pad_value = pad_value
        self.normalize = normalize
        self.preserve_shape = preserve_shape
        self.shape_cache = {}
        
    def flatten(
        self,
        grid: Union[List[List[int]], np.ndarray],
        grid_id: Optional[str] = None
    ) -> Union[List[int], np.ndarray]:
        """Flatten 2D grid to 1D.
        
        Args:
            grid: Input grid
            grid_id: Optional identifier for shape caching
            
        Returns:
            Flattened grid
            
        Raises:
            ValueError: If grid is invalid
        """
        try:
            # Convert to numpy array if needed
            grid_array = self._to_array(grid)
            
            # Validate grid
            self._validate_grid(grid_array)
            
            # Store shape information if needed
            if self.preserve_shape and grid_id:
                self._store_shape(grid_id, grid_array)
                
            # Flatten grid
            flattened = grid_array.flatten()
            
            # Normalize if requested
            if self.normalize:
                flattened = self._normalize(flattened)
                
            return flattened.tolist() if isinstance(grid, list) else flattened
            
        except Exception as e:
            logger.error(f"Error flattening grid: {e}")
            raise
            
    def unflatten(
        self,
        flat_data: Union[List[int], np.ndarray],
        shape: Optional[Tuple[int, ...]] = None,
        grid_id: Optional[str] = None
    ) -> Union[List[List[int]], np.ndarray]:
        """Reconstruct 2D grid from flattened data.
        
        Args:
            flat_data: Flattened data
            shape: Target shape (optional)
            grid_id: Grid identifier for shape lookup
            
        Returns:
            Reconstructed grid
            
        Raises:
            ValueError: If shape information is missing
        """
        try:
            # Get shape information
            if shape is None and grid_id:
                shape = self._get_shape(grid_id)
            if shape is None:
                raise ValueError("Shape information required for unflattening")
                
            # Convert to array
            flat_array = np.array(flat_data)
            
            # Pad if necessary
            required_size = np.prod(shape)
            if len(flat_array) < required_size:
                flat_array = self._pad_array(
                    flat_array,
                    required_size
                )
                
            # Reshape array
            grid = flat_array.reshape(shape)
            
            return grid.tolist() if isinstance(flat_data, list) else grid
            
        except Exception as e:
            logger.error(f"Error unflattening data: {e}")
            raise
            
    def _to_array(
        self,
        grid: Union[List[List[int]], np.ndarray]
    ) -> np.ndarray:
        """Convert grid to numpy array.
        
        Args:
            grid: Input grid
            
        Returns:
            Numpy array
        """
        if isinstance(grid, list):
            return np.array(grid, dtype=np.float32 if self.normalize else np.int32)
        return grid
        
    def _validate_grid(
        self,
        grid: np.ndarray
    ) -> None:
        """Validate grid data.
        
        Args:
            grid: Input grid
            
        Raises:
            ValueError: If grid is invalid
        """
        if grid.ndim < 2:
            raise ValueError("Grid must be at least 2-dimensional")
            
        if not np.issubdtype(grid.dtype, np.number):
            raise ValueError("Grid must contain numeric values")
            
        if np.any(np.isnan(grid)):
            raise ValueError("Grid contains NaN values")
            
    def _store_shape(
        self,
        grid_id: str,
        grid: np.ndarray
    ) -> None:
        """Store grid shape information.
        
        Args:
            grid_id: Grid identifier
            grid: Input grid
        """
        self.shape_cache[grid_id] = GridShape(
            original_shape=grid.shape,
            flattened_size=grid.size,
            dtype=grid.dtype
        )
        
    def _get_shape(
        self,
        grid_id: str
    ) -> Optional[Tuple[int, ...]]:
        """Get stored shape information.
        
        Args:
            grid_id: Grid identifier
            
        Returns:
            Grid shape tuple
        """
        if grid_id in self.shape_cache:
            return self.shape_cache[grid_id].original_shape
        return None
        
    def _normalize(
        self,
        array: np.ndarray
    ) -> np.ndarray:
        """Normalize array values.
        
        Args:
            array: Input array
            
        Returns:
            Normalized array
        """
        if len(array) == 0:
            return array
            
        min_val = np.min(array)
        max_val = np.max(array)
        
        if min_val == max_val:
            return np.zeros_like(array, dtype=np.float32)
            
        return (array - min_val) / (max_val - min_val)
        
    def _pad_array(
        self,
        array: np.ndarray,
        target_size: int
    ) -> np.ndarray:
        """Pad array to target size.
        
        Args:
            array: Input array
            target_size: Desired size
            
        Returns:
            Padded array
        """
        if len(array) >= target_size:
            return array
            
        padding_size = target_size - len(array)
        return np.pad(
            array,
            (0, padding_size),
            mode='constant',
            constant_values=self.pad_value
        )
        
    def get_statistics(
        self,
        grid: Union[List[List[int]], np.ndarray]
    ) -> Dict[str, float]:
        """Calculate grid statistics.
        
        Args:
            grid: Input grid
            
        Returns:
            Dictionary of statistics
        """
        try:
            grid_array = self._to_array(grid)
            
            return {
                'min': float(np.min(grid_array)),
                'max': float(np.max(grid_array)),
                'mean': float(np.mean(grid_array)),
                'std': float(np.std(grid_array)),
                'sparsity': float(np.mean(grid_array == 0)),
                'unique_values': int(len(np.unique(grid_array)))
            }
            
        except Exception as e:
            logger.error(f"Error calculating grid statistics: {e}")
            return {}
            
    def batch_flatten(
        self,
        grids: List[Union[List[List[int]], np.ndarray]]
    ) -> List[Union[List[int], np.ndarray]]:
        """Flatten multiple grids.
        
        Args:
            grids: List of grids
            
        Returns:
            List of flattened grids
        """
        flattened = []
        for i, grid in enumerate(grids):
            try:
                flat = self.flatten(grid, grid_id=f"batch_{i}")
                flattened.append(flat)
            except Exception as e:
                logger.error(f"Error flattening grid {i}: {e}")
                continue
        return flattened

In [None]:
def extract_grid_metadata(arc_data):
    """
    Extract task IDs, list lengths, and organize grids by their sizes.
    """
    task_metadata = {}

    for key, tasks in arc_data.items():
        for task_id, content in tasks.items():
            task_metadata[task_id] = {"sizes": {}, "grids": {}}

            for mode in ["train", "test"]:
                if mode in content:
                    for entry in content[mode]:
                        input_grid = entry.get("input", [])
                        output_grid = entry.get("output", [])

                        input_size = (len(input_grid), len(input_grid[0])) if input_grid else (0, 0)
                        output_size = (len(output_grid), len(output_grid[0])) if output_grid else (0, 0)

                        if input_size not in task_metadata[task_id]["sizes"]:
                            task_metadata[task_id]["sizes"][input_size] = 0
                        if output_size not in task_metadata[task_id]["sizes"]:
                            task_metadata[task_id]["sizes"][output_size] = 0

                        task_metadata[task_id]["sizes"][input_size] += 1
                        task_metadata[task_id]["sizes"][output_size] += 1

                        # Store grids
                        if input_size not in task_metadata[task_id]["grids"]:
                            task_metadata[task_id]["grids"][input_size] = {'grids': []}
                        task_metadata[task_id]["grids"][input_size]['grids'].append(input_grid)

                        if output_size not in task_metadata[task_id]["grids"]:
                            task_metadata[task_id]["grids"][output_size] = {'grids': []}
                        task_metadata[task_id]["grids"][output_size]['grids'].append(output_grid)

    return task_metadata


In [2]:
from typing import List, Optional, Dict, Any, Iterator
from dataclasses import dataclass, field
import numpy as np
from collections import deque
import json
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class GridMetadata:
    """Metadata for grid data.
    
    Attributes:
        shape: Grid dimensions
        dtype: Data type
        timestamp: Creation timestamp
        source: Data source
        transformations: Applied transformations
    """
    shape: tuple
    dtype: np.dtype
    timestamp: str
    source: str
    transformations: List[str] = field(default_factory=list)

@dataclass
class NodeStats:
    """Statistics for node data.
    
    Attributes:
        total_grids: Total number of grids
        total_children: Total number of child nodes
        depth: Node depth in tree
        leaf_count: Number of leaf nodes
        grid_sizes: List of grid sizes
    """
    total_grids: int
    total_children: int
    depth: int
    leaf_count: int
    grid_sizes: List[tuple]

class DataNode:
    """Enhanced node class for hierarchical data storage.
    
    This class represents a node in a tree structure, supporting
    grid data storage, metadata tracking, and tree operations.
    """
    
    def __init__(
        self,
        name: str,
        task_id: Optional[str] = None,
        parent: Optional['DataNode'] = None,
        metadata: Optional[Dict[str, Any]] = None
    ):
        """Initialize data node.
        
        Args:
            name: Node name
            task_id: Optional task identifier
            parent: Optional parent node
            metadata: Optional metadata dictionary
        """
        self.name = name
        self.task_id = task_id
        self.parent = parent
        self.children: List['DataNode'] = []
        self.grids: List[np.ndarray] = []
        self.grid_metadata: List[GridMetadata] = []
        self.metadata = metadata or {}
        self.processed = False
        
        logger.info(f"Created node: {self.name} (ID: {self.task_id})")
        
    def add_child(
        self,
        child: 'DataNode'
    ) -> None:
        """Add child node.
        
        Args:
            child: Child node to add
        """
        child.parent = self
        self.children.append(child)
        logger.debug(f"Added child {child.name} to {self.name}")
        
    def add_grid(
        self,
        grid: np.ndarray,
        metadata: Optional[Dict[str, Any]] = None
    ) -> None:
        """Add grid with metadata.
        
        Args:
            grid: Grid data to add
            metadata: Optional grid metadata
        """
        try:
            # Convert to numpy array if needed
            if not isinstance(grid, np.ndarray):
                grid = np.array(grid)
                
            # Create grid metadata
            grid_meta = GridMetadata(
                shape=grid.shape,
                dtype=grid.dtype,
                timestamp=self._get_timestamp(),
                source=metadata.get('source', 'unknown') if metadata else 'unknown',
                transformations=[]
            )
            
            self.grids.append(grid)
            self.grid_metadata.append(grid_meta)
            
            logger.debug(
                f"Added grid to {self.name} "
                f"with shape {grid.shape}"
            )
            
        except Exception as e:
            logger.error(f"Error adding grid to {self.name}: {e}")
            raise
            
    def remove_child(
        self,
        child: 'DataNode'
    ) -> None:
        """Remove child node.
        
        Args:
            child: Child node to remove
        """
        if child in self.children:
            child.parent = None
            self.children.remove(child)
            logger.debug(f"Removed child {child.name} from {self.name}")
            
    def remove_grid(
        self,
        index: int
    ) -> None:
        """Remove grid by index.
        
        Args:
            index: Index of grid to remove
        """
        if 0 <= index < len(self.grids):
            del self.grids[index]
            del self.grid_metadata[index]
            logger.debug(f"Removed grid {index} from {self.name}")
            
    def get_path(self) -> List[str]:
        """Get path from root to this node.
        
        Returns:
            List of node names in path
        """
        path = []
        current = self
        while current:
            path.append(current.name)
            current = current.parent
        return list(reversed(path))
        
    def find_node(
        self,
        name: str
    ) -> Optional['DataNode']:
        """Find node by name.
        
        Args:
            name: Name to search for
            
        Returns:
            Matching node or None
        """
        if self.name == name:
            return self
            
        for child in self.children:
            result = child.find_node(name)
            if result:
                return result
                
        return None
        
    def get_ancestors(self) -> List['DataNode']:
        """Get list of ancestor nodes.
        
        Returns:
            List of ancestor nodes
        """
        ancestors = []
        current = self.parent
        while current:
            ancestors.append(current)
            current = current.parent
        return ancestors
        
    def get_descendants(self) -> List['DataNode']:
        """Get list of descendant nodes.
        
        Returns:
            List of descendant nodes
        """
        descendants = []
        for child in self.children:
            descendants.append(child)
            descendants.extend(child.get_descendants())
        return descendants
        
    def get_siblings(self) -> List['DataNode']:
        """Get list of sibling nodes.
        
        Returns:
            List of sibling nodes
        """
        if not self.parent:
            return []
        return [
            child for child in self.parent.children
            if child is not self
        ]
        
    def get_stats(self) -> NodeStats:
        """Get node statistics.
        
        Returns:
            Node statistics object
        """
        stats = NodeStats(
            total_grids=len(self.grids),
            total_children=len(self.children),
            depth=self._get_depth(),
            leaf_count=self._count_leaves(),
            grid_sizes=[grid.shape for grid in self.grids]
        )
        return stats
        
    def _get_depth(self) -> int:
        """Calculate node depth.
        
        Returns:
            Node depth
        """
        depth = 0
        current = self
        while current.parent:
            depth += 1
            current = current.parent
        return depth
        
    def _count_leaves(self) -> int:
        """Count leaf nodes in subtree.
        
        Returns:
            Number of leaf nodes
        """
        if not self.children:
            return 1
        return sum(child._count_leaves() for child in self.children)
        
    def traverse_breadth_first(self) -> Iterator['DataNode']:
        """Traverse tree breadth-first.
        
        Yields:
            Nodes in breadth-first order
        """
        queue = deque([self])
        while queue:
            node = queue.popleft()
            yield node
            queue.extend(node.children)
            
    def traverse_depth_first(self) -> Iterator['DataNode']:
        """Traverse tree depth-first.
        
        Yields:
            Nodes in depth-first order
        """
        yield self
        for child in self.children:
            yield from child.traverse_depth_first()
            
    def to_dict(self) -> Dict[str, Any]:
        """Convert node to dictionary.
        
        Returns:
            Dictionary representation
        """
        return {
            'name': self.name,
            'task_id': self.task_id,
            'metadata': self.metadata,
            'grid_count': len(self.grids),
            'children': [
                child.to_dict() for child in self.children
            ]
        }
        
    def save_to_file(
        self,
        filepath: str
    ) -> None:
        """Save node to file.
        
        Args:
            filepath: Output file path
        """
        try:
            data = self.to_dict()
            with open(filepath, 'w') as f:
                json.dump(data, f, indent=2)
            logger.info(f"Saved node data to {filepath}")
        except Exception as e:
            logger.error(f"Error saving node data: {e}")
            
    @classmethod
    def load_from_file(
        cls,
        filepath: str
    ) -> 'DataNode':
        """Load node from file.
        
        Args:
            filepath: Input file path
            
        Returns:
            Loaded node
        """
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            return cls._from_dict(data)
        except Exception as e:
            logger.error(f"Error loading node data: {e}")
            raise
            
    @classmethod
    def _from_dict(
        cls,
        data: Dict[str, Any]
    ) -> 'DataNode':
        """Create node from dictionary.
        
        Args:
            data: Input dictionary
            
        Returns:
            Created node
        """
        node = cls(
            name=data['name'],
            task_id=data['task_id'],
            metadata=data.get('metadata')
        )
        
        for child_data in data.get('children', []):
            child = cls._from_dict(child_data)
            node.add_child(child)
            
        return node
        
    def __repr__(self) -> str:
        """String representation of node."""
        return (
            f"{self.name} "
            f"(ID: {self.task_id}, "
            f"Grids: {len(self.grids)}, "
            f"Children: {len(self.children)})"
        )

In [3]:
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
import numpy as np
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class TreeConfig:
    """Configuration for data tree building.
    
    Attributes:
        max_depth: Maximum tree depth
        parallel_processing: Whether to use parallel processing
        validate_data: Whether to validate data
        store_metadata: Whether to store metadata
        cache_enabled: Whether to enable caching
    """
    max_depth: int = 5
    parallel_processing: bool = True
    validate_data: bool = True
    store_metadata: bool = True
    cache_enabled: bool = True

class DataTreeBuilder:
    """Builder for hierarchical data trees."""
    
    GRID_CATEGORIES = [
        "train_input_grids",
        "train_output_grids",
        "test_input_grids",
        "test_output_grids"
    ]
    
    LIST_CATEGORIES = [
        "train_input_lists",
        "train_output_lists",
        "test_input_lists",
        "test_output_lists"
    ]
    
    def __init__(
        self,
        config: Optional[TreeConfig] = None
    ):
        """Initialize tree builder.
        
        Args:
            config: Optional builder configuration
        """
        self.config = config or TreeConfig()
        self.cache = {} if self.config.cache_enabled else None
        self.stats = defaultdict(int)
        
    def build_tree(
        self,
        task_metadata: Dict[str, Dict[str, Any]]
    ) -> DataNode:
        """Build data tree from task metadata.
        
        Args:
            task_metadata: Task metadata dictionary
            
        Returns:
            Root node of built tree
        """
        try:
            logger.info("Starting tree construction")
            
            # Create root node
            root = DataNode(
                "ARC Data",
                metadata={'total_tasks': len(task_metadata)}
            )
            
            # Build tree
            if self.config.parallel_processing:
                self._build_parallel(root, task_metadata)
            else:
                self._build_sequential(root, task_metadata)
                
            # Log statistics
            self._log_tree_stats(root)
            
            return root
            
        except Exception as e:
            logger.error(f"Error building tree: {e}")
            raise
            
    def _build_sequential(
        self,
        root: DataNode,
        task_metadata: Dict[str, Dict[str, Any]]
    ) -> None:
        """Build tree sequentially.
        
        Args:
            root: Root node
            task_metadata: Task metadata
        """
        for task_id, metadata in task_metadata.items():
            try:
                self._process_task(root, task_id, metadata)
            except Exception as e:
                logger.error(f"Error processing task {task_id}: {e}")
                continue
                
    def _build_parallel(
        self,
        root: DataNode,
        task_metadata: Dict[str, Dict[str, Any]]
    ) -> None:
        """Build tree using parallel processing.
        
        Args:
            root: Root node
            task_metadata: Task metadata
        """
        with ThreadPoolExecutor() as executor:
            futures = []
            for task_id, metadata in task_metadata.items():
                future = executor.submit(
                    self._process_task,
                    root,
                    task_id,
                    metadata
                )
                futures.append((task_id, future))
                
            # Process results
            for task_id, future in futures:
                try:
                    future.result()
                except Exception as e:
                    logger.error(
                        f"Error processing task {task_id}: {e}"
                    )
                    
    def _process_task(
        self,
        root: DataNode,
        task_id: str,
        metadata: Dict[str, Any]
    ) -> None:
        """Process single task.
        
        Args:
            root: Root node
            task_id: Task identifier
            metadata: Task metadata
        """
        # Create task node
        task_node = DataNode(
            "Task",
            task_id,
            metadata=self._extract_task_metadata(metadata)
        )
        root.add_child(task_node)
        
        # Process grid categories
        self._process_categories(
            task_node,
            metadata,
            self.GRID_CATEGORIES,
            is_grid=True
        )
        
        # Process list categories
        self._process_categories(
            task_node,
            metadata,
            self.LIST_CATEGORIES,
            is_grid=False
        )
        
    def _process_categories(
        self,
        task_node: DataNode,
        metadata: Dict[str, Any],
        categories: List[str],
        is_grid: bool
    ) -> None:
        """Process data categories.
        
        Args:
            task_node: Task node
            metadata: Task metadata
            categories: Category list
            is_grid: Whether processing grids
        """
        for category in categories:
            try:
                data = metadata.get(category, [])
                if data:
                    category_node = DataNode(
                        category,
                        metadata={'type': 'grid' if is_grid else 'list'}
                    )
                    task_node.add_child(category_node)
                    
                    self._add_data_to_node(
                        category_node,
                        data,
                        is_grid
                    )
                    
            except Exception as e:
                logger.error(
                    f"Error processing category {category}: {e}"
                )
                
    def _add_data_to_node(
        self,
        node: DataNode,
        data: List[Any],
        is_grid: bool
    ) -> None:
        """Add data to node.
        
        Args:
            node: Target node
            data: Data to add
            is_grid: Whether data is grid
        """
        for item in data:
            try:
                processed_item = (
                    self._process_grid(item)
                    if is_grid else
                    self._process_list(item)
                )
                
                if self.config.validate_data:
                    self._validate_data(processed_item)
                    
                node.add_grid(
                    processed_item,
                    metadata={'is_grid': is_grid}
                )
                
            except Exception as e:
                logger.error(f"Error adding data: {e}")
                
    def _process_grid(
        self,
        grid: List[List[Any]]
    ) -> np.ndarray:
        """Process grid data.
        
        Args:
            grid: Input grid
            
        Returns:
            Processed grid array
        """
        if self.config.cache_enabled:
            grid_hash = hash(str(grid))
            if grid_hash in self.cache:
                return self.cache[grid_hash]
                
        processed = np.array(grid)
        
        if self.config.cache_enabled:
            self.cache[grid_hash] = processed
            
        return processed
        
    def _process_list(
        self,
        lst: List[Any]
    ) -> np.ndarray:
        """Process list data.
        
        Args:
            lst: Input list
            
        Returns:
            Processed array
        """
        return np.array(lst)
        
    def _validate_data(
        self,
        data: np.ndarray
    ) -> None:
        """Validate data array.
        
        Args:
            data: Data array to validate
            
        Raises:
            ValueError: If validation fails
        """
        if not isinstance(data, np.ndarray):
            raise ValueError("Data must be numpy array")
            
        if data.size == 0:
            raise ValueError("Empty data array")
            
        if np.any(np.isnan(data)):
            raise ValueError("Data contains NaN values")
            
    def _extract_task_metadata(
        self,
        metadata: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Extract task metadata.
        
        Args:
            metadata: Raw metadata
            
        Returns:
            Processed metadata dictionary
        """
        if not self.config.store_metadata:
            return {}
            
        return {
            'grid_counts': {
                category: len(metadata.get(category, []))
                for category in self.GRID_CATEGORIES
            },
            'list_counts': {
                category: len(metadata.get(category, []))
                for category in self.LIST_CATEGORIES
            }
        }
        
    def _log_tree_stats(
        self,
        root: DataNode
    ) -> None:
        """Log tree statistics.
        
        Args:
            root: Root node
        """
        stats = root.get_stats()
        logger.info(
            f"Tree built with {stats.total_grids} grids, "
            f"{stats.total_children} nodes, "
            f"depth {stats.depth}"
        )
        
    def get_tree_summary(
        self,
        root: DataNode
    ) -> Dict[str, Any]:
        """Get tree summary.
        
        Args:
            root: Root node
            
        Returns:
            Summary dictionary
        """
        return {
            'total_tasks': len(root.children),
            'total_grids': sum(
                len(node.grids)
                for node in root.traverse_depth_first()
            ),
            'max_depth': max(
                node._get_depth()
                for node in root.traverse_depth_first()
            ),
            'leaf_nodes': sum(
                1 for node in root.traverse_depth_first()
                if not node.children
            ),
            'grid_distribution': self._get_grid_distribution(root)
        }
        
    def _get_grid_distribution(
        self,
        root: DataNode
    ) -> Dict[str, int]:
        """Get grid distribution by category.
        
        Args:
            root: Root node
            
        Returns:
            Distribution dictionary
        """
        distribution = defaultdict(int)
        
        for node in root.traverse_depth_first():
            if node.name in self.GRID_CATEGORIES:
                distribution[node.name] = len(node.grids)
                
        return dict(distribution)

In [4]:
from typing import List, Tuple, Dict, Any, Optional
import torch
from torch.utils.data import Dataset
import numpy as np
from dataclasses import dataclass
from collections import defaultdict
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class DatasetConfig:
    """Configuration for dynamic grid dataset.
    
    Attributes:
        augmentation_enabled: Whether to use augmentation
        cache_enabled: Whether to enable caching
        normalize_grids: Whether to normalize grid values
        min_grid_size: Minimum grid size to include
        max_grid_size: Maximum grid size to include
        include_metadata: Whether to include metadata
    """
    augmentation_enabled: bool = False
    cache_enabled: bool = True
    normalize_grids: bool = True
    min_grid_size: Optional[Tuple[int, int]] = None
    max_grid_size: Optional[Tuple[int, int]] = None
    include_metadata: bool = True

class DynamicGridDataset(Dataset):
    """Dynamic dataset for grid data with enhanced features."""
    
    def __init__(
        self,
        data_tree: Any,
        config: Optional[DatasetConfig] = None,
        transform: Optional[callable] = None
    ):
        """Initialize dataset.
        
        Args:
            data_tree: Input data tree
            config: Optional dataset configuration
            transform: Optional transform function
        """
        super().__init__()
        self.config = config or DatasetConfig()
        self.transform = transform
        
        # Initialize containers
        self.data: List[Tuple[np.ndarray, str]] = []
        self.metadata: List[Dict[str, Any]] = []
        self.cache: Dict[int, torch.Tensor] = {}
        self.class_mapping: Dict[str, int] = {}
        
        # Process data tree
        self._process_tree(data_tree)
        
        logger.info(
            f"Initialized dataset with {len(self.data)} samples "
            f"and {self.num_classes} classes"
        )
        
    def _process_tree(
        self,
        data_tree: Any
    ) -> None:
        """Process data tree to build dataset.
        
        Args:
            data_tree: Input data tree
        """
        try:
            # Determine classes
            self.num_classes = self._determine_classes(data_tree)
            if self.num_classes <= 0:
                raise ValueError("No classes found in dataset")
                
            # Build dataset
            for task_node in data_tree.children.values():
                self._process_task_node(task_node)
                
            # Calculate statistics
            self._calculate_statistics()
            
        except Exception as e:
            logger.error(f"Error processing data tree: {e}")
            raise
            
    def _determine_classes(
        self,
        data_tree: Any
    ) -> int:
        """Determine number of classes.
        
        Args:
            data_tree: Input data tree
            
        Returns:
            Number of classes
        """
        classes = set()
        
        for task_node in data_tree.children.values():
            for mode_node in task_node.children.values():
                for grid_node in mode_node.children.values():
                    for grid in grid_node.children.values():
                        if hasattr(grid, 'grid') and grid.grid is not None:
                            classes.update(np.unique(grid.grid))
                            
        # Create class mapping
        self.class_mapping = {
            str(cls): idx
            for idx, cls in enumerate(sorted(classes))
        }
        
        return len(classes)
        
    def _process_task_node(
        self,
        task_node: Any
    ) -> None:
        """Process task node.
        
        Args:
            task_node: Task node to process
        """
        for mode_node in task_node.children.values():
            for grid_node in mode_node.children.values():
                self._process_grid_node(
                    grid_node,
                    task_node.name
                )
                
    def _process_grid_node(
        self,
        grid_node: Any,
        task_id: str
    ) -> None:
        """Process grid node.
        
        Args:
            grid_node: Grid node to process
            task_id: Task identifier
        """
        for grid in grid_node.children.values():
            if hasattr(grid, 'grid') and grid.grid is not None:
                if self._validate_grid(grid.grid):
                    self._add_grid(grid.grid, task_id, grid)
                    
    def _validate_grid(
        self,
        grid: np.ndarray
    ) -> bool:
        """Validate grid dimensions.
        
        Args:
            grid: Grid to validate
            
        Returns:
            Whether grid is valid
        """
        if not isinstance(grid, np.ndarray):
            return False
            
        if grid.size == 0:
            return False
            
        if self.config.min_grid_size:
            if any(s < m for s, m in zip(grid.shape, self.config.min_grid_size)):
                return False
                
        if self.config.max_grid_size:
            if any(s > m for s, m in zip(grid.shape, self.config.max_grid_size)):
                return False
                
        return True
        
    def _add_grid(
        self,
        grid: np.ndarray,
        task_id: str,
        node: Any
    ) -> None:
        """Add grid to dataset.
        
        Args:
            grid: Grid data
            task_id: Task identifier
            node: Grid node
        """
        # Process grid
        if self.config.normalize_grids:
            grid = self._normalize_grid(grid)
            
        # Store data
        self.data.append((grid, task_id))
        
        # Store metadata if enabled
        if self.config.include_metadata:
            self.metadata.append({
                'task_id': task_id,
                'shape': grid.shape,
                'unique_values': len(np.unique(grid))
            })
            
    def _normalize_grid(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Normalize grid values.
        
        Args:
            grid: Input grid
            
        Returns:
            Normalized grid
        """
        if grid.size == 0:
            return grid
            
        grid_min = np.min(grid)
        grid_max = np.max(grid)
        
        if grid_min == grid_max:
            return np.zeros_like(grid, dtype=np.float32)
            
        return ((grid - grid_min) / (grid_max - grid_min)).astype(np.float32)
        
    def _calculate_statistics(self) -> None:
        """Calculate dataset statistics."""
        self.statistics = {
            'total_samples': len(self.data),
            'grid_shapes': defaultdict(int),
            'class_distribution': defaultdict(int)
        }
        
        for grid, task_id in self.data:
            self.statistics['grid_shapes'][grid.shape] += 1
            self.statistics['class_distribution'][task_id] += 1
            
    def __len__(self) -> int:
        """Get dataset length."""
        return len(self.data)
        
    def __getitem__(
        self,
        index: int
    ) -> Tuple[torch.Tensor, str]:
        """Get dataset item.
        
        Args:
            index: Item index
            
        Returns:
            Tuple of (grid tensor, task ID)
        """
        if self.config.cache_enabled and index in self.cache:
            return self.cache[index]
            
        grid, task_id = self.data[index]
        
        # Convert to tensor
        grid_tensor = torch.tensor(
            grid,
            dtype=torch.float32
        ).unsqueeze(0)  # Add channel dimension
        
        # Apply transform if provided
        if self.transform is not None:
            grid_tensor = self.transform(grid_tensor)
            
        # Cache if enabled
        if self.config.cache_enabled:
            self.cache[index] = (grid_tensor, task_id)
            
        return grid_tensor, task_id
        
    def get_grid_shapes(self) -> List[Tuple[int, int]]:
        """Get list of unique grid shapes.
        
        Returns:
            List of grid shapes
        """
        return list(self.statistics['grid_shapes'].keys())
        
    def get_class_distribution(self) -> Dict[str, int]:
        """Get class distribution.
        
        Returns:
            Class distribution dictionary
        """
        return dict(self.statistics['class_distribution'])
        
    def get_metadata(
        self,
        index: int
    ) -> Optional[Dict[str, Any]]:
        """Get metadata for index.
        
        Args:
            index: Item index
            
        Returns:
            Metadata dictionary or None
        """
        if 0 <= index < len(self.metadata):
            return self.metadata[index]
        return None
        
    def get_statistics(self) -> Dict[str, Any]:
        """Get dataset statistics.
        
        Returns:
            Statistics dictionary
        """
        return self.statistics

In [6]:
from typing import List, Tuple, Dict, Any, Optional, Union
import torch
import numpy as np
from dataclasses import dataclass
import random
from enum import Enum
from logger_setup import setup_logger

logger = setup_logger(__name__)

class AugmentationType(Enum):
    """Types of grid augmentation."""
    NOISE = "noise"
    ROTATION = "rotation"
    FLIP = "flip"
    SHIFT = "shift"
    MASK = "mask"
    PERMUTE = "permute"

@dataclass
class AugmentationConfig:
    """Configuration for data augmentation.
    
    Attributes:
        augmentation_factor: Number of augmentations per sample
        enabled_augmentations: List of enabled augmentation types
        noise_range: Range for noise values
        rotation_angles: List of rotation angles
        shift_range: Range for shifting
        mask_probability: Probability of masking
        preserve_class_balance: Whether to preserve class balance
    """
    augmentation_factor: int = 2
    enabled_augmentations: List[AugmentationType] = None
    noise_range: Tuple[int, int] = (-1, 1)
    rotation_angles: List[int] = None
    shift_range: Tuple[int, int] = (-1, 1)
    mask_probability: float = 0.1
    preserve_class_balance: bool = True

class AugmentedDynamicGridDataset(DynamicGridDataset):
    """Enhanced dataset with advanced augmentation capabilities."""
    
    def __init__(
        self,
        data_tree: Any,
        num_classes: int,
        config: Optional[AugmentationConfig] = None,
        transform: Optional[callable] = None
    ):
        """Initialize augmented dataset.
        
        Args:
            data_tree: Input data tree
            num_classes: Number of classes
            config: Optional augmentation configuration
            transform: Optional transform function
        """
        super().__init__(data_tree, transform=transform)
        
        self.num_classes = num_classes
        self.config = config or AugmentationConfig()
        
        # Set default augmentations if none specified
        if not self.config.enabled_augmentations:
            self.config.enabled_augmentations = list(AugmentationType)
            
        if not self.config.rotation_angles:
            self.config.rotation_angles = [0, 90, 180, 270]
            
        # Augment data
        self.data = self._augment_data()
        
        logger.info(
            f"Created augmented dataset with "
            f"{len(self.data)} samples"
        )
        
    def _augment_data(self) -> List[Tuple[np.ndarray, str]]:
        """Augment dataset.
        
        Returns:
            List of augmented samples
        """
        augmented_data = []
        class_counts = defaultdict(int)
        
        for grid, task_id in self.data:
            # Original sample
            augmented_data.append((grid, task_id))
            class_counts[task_id] += 1
            
            # Generate augmentations
            for _ in range(self.config.augmentation_factor):
                try:
                    augmented_grid = self._apply_augmentations(grid)
                    
                    if self.config.preserve_class_balance:
                        if class_counts[task_id] >= max(class_counts.values()):
                            continue
                            
                    augmented_data.append((augmented_grid, task_id))
                    class_counts[task_id] += 1
                    
                except Exception as e:
                    logger.warning(
                        f"Failed to augment grid for task {task_id}: {e}"
                    )
                    continue
                    
        return augmented_data
        
    def _apply_augmentations(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Apply multiple augmentations.
        
        Args:
            grid: Input grid
            
        Returns:
            Augmented grid
        """
        augmented = grid.copy()
        
        # Randomly select and apply augmentations
        for aug_type in random.sample(
            self.config.enabled_augmentations,
            k=random.randint(1, len(self.config.enabled_augmentations))
        ):
            augmented = self._apply_single_augmentation(
                augmented,
                aug_type
            )
            
        return augmented
        
    def _apply_single_augmentation(
        self,
        grid: np.ndarray,
        aug_type: AugmentationType
    ) -> np.ndarray:
        """Apply single augmentation.
        
        Args:
            grid: Input grid
            aug_type: Augmentation type
            
        Returns:
            Augmented grid
        """
        if aug_type == AugmentationType.NOISE:
            return self._add_noise(grid)
        elif aug_type == AugmentationType.ROTATION:
            return self._rotate_grid(grid)
        elif aug_type == AugmentationType.FLIP:
            return self._flip_grid(grid)
        elif aug_type == AugmentationType.SHIFT:
            return self._shift_grid(grid)
        elif aug_type == AugmentationType.MASK:
            return self._mask_grid(grid)
        elif aug_type == AugmentationType.PERMUTE:
            return self._permute_values(grid)
        else:
            return grid
            
    def _add_noise(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Add random noise to grid.
        
        Args:
            grid: Input grid
            
        Returns:
            Noisy grid
        """
        noise = np.random.randint(
            self.config.noise_range[0],
            self.config.noise_range[1] + 1,
            size=grid.shape
        )
        noisy = grid + noise
        return np.mod(noisy, self.num_classes)
        
    def _rotate_grid(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Rotate grid.
        
        Args:
            grid: Input grid
            
        Returns:
            Rotated grid
        """
        k = np.random.choice(self.config.rotation_angles) // 90
        return np.rot90(grid, k=k)
        
    def _flip_grid(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Flip grid.
        
        Args:
            grid: Input grid
            
        Returns:
            Flipped grid
        """
        if random.random() < 0.5:
            return np.fliplr(grid)
        return np.flipud(grid)
        
    def _shift_grid(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Shift grid values.
        
        Args:
            grid: Input grid
            
        Returns:
            Shifted grid
        """
        shift = np.random.randint(
            self.config.shift_range[0],
            self.config.shift_range[1] + 1
        )
        shifted = np.roll(grid, shift, axis=random.randint(0, 1))
        return shifted
        
    def _mask_grid(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Randomly mask grid values.
        
        Args:
            grid: Input grid
            
        Returns:
            Masked grid
        """
        mask = np.random.random(grid.shape) < self.config.mask_probability
        masked = grid.copy()
        masked[mask] = random.randint(0, self.num_classes - 1)
        return masked
        
    def _permute_values(
        self,
        grid: np.ndarray
    ) -> np.ndarray:
        """Permute grid values.
        
        Args:
            grid: Input grid
            
        Returns:
            Permuted grid
        """
        unique_values = np.unique(grid)
        permutation = np.random.permutation(unique_values)
        value_map = dict(zip(unique_values, permutation))
        return np.vectorize(value_map.get)(grid)
        
    def get_augmentation_stats(self) -> Dict[str, Any]:
        """Get augmentation statistics.
        
        Returns:
            Statistics dictionary
        """
        stats = {
            'total_samples': len(self.data),
            'augmentation_factor': self.config.augmentation_factor,
            'enabled_augmentations': [
                aug.value for aug in self.config.enabled_augmentations
            ],
            'class_distribution': defaultdict(int)
        }
        
        for _, task_id in self.data:
            stats['class_distribution'][task_id] += 1
            
        return stats

In [8]:
import os
import sys
import tkinter as tk
from tkinter import messagebox
import torch
import torch.multiprocessing as mp
from pathlib import Path
from dataclasses import dataclass
import json
import argparse
from datetime import datetime
from logger_setup import setup_logger

logger = setup_logger(__name__)

@dataclass
class AppConfig:
    """Application configuration.
    
    Attributes:
        data_dir: Directory containing data files
        model_dir: Directory for model checkpoints
        log_dir: Directory for logs
        batch_size: Training batch size
        num_workers: Number of data loading workers
        use_gpu: Whether to use GPU
        debug_mode: Whether to enable debug mode
    """
    data_dir: str = "data"
    model_dir: str = "models"
    log_dir: str = "logs"
    batch_size: int = 32
    num_workers: int = 4
    use_gpu: bool = True
    debug_mode: bool = False

class ARCTrainingApp:
    """Main application for ARC training and visualization."""
    
    def __init__(
        self,
        config: AppConfig
    ):
        """Initialize application.
        
        Args:
            config: Application configuration
        """
        self.config = config
        self.setup_directories()
        
        # Initialize components
        self.device = None
        self.data_tree = None
        self.gui = None
        
        logger.info("Initialized ARC Training Application")
        
    def setup_directories(self) -> None:
        """Create necessary directories."""
        for directory in [
            self.config.data_dir,
            self.config.model_dir,
            self.config.log_dir
        ]:
            Path(directory).mkdir(parents=True, exist_ok=True)
            
    def setup_multiprocessing(self) -> None:
        """Configure multiprocessing."""
        if os.name == 'nt':  # Windows
            mp.set_start_method('spawn', force=True)
        else:
            mp.set_start_method('fork', force=True)
            
    def setup_device(self) -> None:
        """Configure computation device."""
        if self.config.use_gpu and torch.cuda.is_available():
            self.device = torch.device('cuda')
            # Enable cuDNN autotuner
            torch.backends.cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')
            
        logger.info(f"Using device: {self.device}")
        
    def load_data(self) -> None:
        """Load and process ARC data."""
        try:
            # Load raw data
            self.arc_data = load_arc_data()
            
            # Analyze dimensions
            self.task_metadata = analyze_grid_dimensions(
                "arc-agi_training_challenges.json"
            )
            
            # Build data tree
            self.data_tree = build_data_tree(self.task_metadata)
            
            logger.info("Data loading completed successfully")
            
        except Exception as e:
            logger.exception("Data loading failed")
            raise RuntimeError(f"Failed to load data: {e}")
            
    def setup_gui(self) -> None:
        """Initialize GUI components."""
        try:
            self.root = tk.Tk()
            self.root.title("ARC Training and Visualization GUI")
            
            # Configure window
            self.root.geometry("1200x800")
            self.root.protocol(
                "WM_DELETE_WINDOW",
                self.on_closing
            )
            
            # Create GUI instance
            self.gui = TrainingGUI(
                root=self.root,
                model=self.model,
                train_loader=self.train_loader,
                val_loader=self.val_loader,
                device=self.device,
                data_tree=self.data_tree
            )
            
            logger.info("GUI setup completed")
            
        except Exception as e:
            logger.exception("GUI setup failed")
            raise RuntimeError(f"Failed to setup GUI: {e}")
            
    def setup_model(self) -> None:
        """Initialize model and training components."""
        try:
            # Create dataset
            self.dataset = AugmentedDynamicGridDataset(
                self.data_tree,
                config=DatasetConfig(
                    augmentation_enabled=True,
                    cache_enabled=True
                )
            )
            
            # Create data loaders
            self.train_loader, self.val_loader = self.create_data_loaders()
            
            # Initialize model
            self.model = SimpleTransformer(
                config=TransformerConfig(
                    vocab_size=self.dataset.num_classes,
                    d_model=512
                )
            ).to(self.device)
            
            logger.info("Model setup completed")
            
        except Exception as e:
            logger.exception("Model setup failed")
            raise RuntimeError(f"Failed to setup model: {e}")
            
    def create_data_loaders(self) -> Tuple[DataLoader, DataLoader]:
        """Create train and validation data loaders.
        
        Returns:
            Tuple of (train_loader, val_loader)
        """
        # Split dataset
        train_size = int(0.8 * len(self.dataset))
        val_size = len(self.dataset) - train_size
        
        train_dataset, val_dataset = torch.utils.data.random_split(
            self.dataset,
            [train_size, val_size]
        )
        
        # Create loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=True
        )
        
        return train_loader, val_loader
        
    def save_session(self) -> None:
        """Save current session state."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        session_dir = Path(self.config.model_dir) / f"session_{timestamp}"
        session_dir.mkdir(parents=True, exist_ok=True)
        
        # Save model
        torch.save(
            self.model.state_dict(),
            session_dir / "model.pth"
        )
        
        # Save configuration
        with open(session_dir / "config.json", 'w') as f:
            json.dump(vars(self.config), f, indent=2)
            
        logger.info(f"Session saved to {session_dir}")
        
    def on_closing(self) -> None:
        """Handle application closing."""
        if messagebox.askokcancel("Quit", "Do you want to save before quitting?"):
            self.save_session()
        self.root.destroy()
        
    def run(self) -> None:
        """Run the application."""
        try:
            # Setup components
            self.setup_multiprocessing()
            self.setup_device()
            self.load_data()
            self.setup_model()
            self.setup_gui()
            
            # Start GUI
            self.root.mainloop()
            
        except Exception as e:
            logger.exception("Application failed")
            messagebox.showerror(
                "Error",
                f"Application failed to start: {e}"
            )
            sys.exit(1)
            
def parse_args() -> argparse.Namespace:
    """Parse command line arguments.
    
    Returns:
        Parsed arguments
    """
    parser = argparse.ArgumentParser(
        description="ARC Training Application"
    )
    
    parser.add_argument(
        "--data-dir",
        default="data",
        help="Data directory"
    )
    
    parser.add_argument(
        "--model-dir",
        default="models",
        help="Model directory"
    )
    
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Training batch size"
    )
    
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Enable debug mode"
    )
    
    return parser.parse_args()

def main() -> None:
    """Main entry point."""
    # Parse arguments
    args = parse_args()
    
    # Create configuration
    config = AppConfig(
        data_dir=args.data_dir,
        model_dir=args.model_dir,
        batch_size=args.batch_size,
        debug_mode=args.debug
    )
    
    # Create and run application
    app = ARCTrainingApp(config)
    app.run()

if __name__ == "__main__":
    main()

usage: ipykernel_launcher.py [-h] [--data-dir DATA_DIR] [--model-dir MODEL_DIR]
                             [--batch-size BATCH_SIZE] [--debug]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\Owner\AppData\Roaming\jupyter\runtime\kernel-e4b0464c-e810-4653-99eb-be10565d3cd6.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
