In [1]:
# -----------------------------
# 1. Import Libraries
# -----------------------------
import os
import sys
import math
import json
import queue
import logging
import random
import threading
from io import BytesIO

# Numerical & Data Handling
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split

# Visualization & GUI
import matplotlib
matplotlib.use('TkAgg')  # Use Tkinter-compatible backend for Matplotlib
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from mpl_toolkits.mplot3d import Axes3D  # For 3D plots

import tkinter as tk
from tkinter import ttk, filedialog, messagebox
from PIL import Image, ImageTk

# Model Architectures
from torchvision.models import resnet18, ResNet18_Weights

# Tree Structure & Visualization
from anytree import NodeMixin, RenderTree
from anytree.exporter import DotExporter

# Learning Rate Schedulers
from torch.optim import lr_scheduler
from dataclasses import dataclass
from graphviz import Source
from concurrent.futures import ThreadPoolExecutor, as_completed

from tqdm import tqdm  # For progress bars
import time
from PIL import Image, ImageTk, ImageDraw


In [2]:
##"""
# Explanation of the Main Workflow with Modifications, Successes, and Pitfalls:

#This `main()` function orchestrates the entire workflow, from loading the data to 
#model initialization and launching the GUI for real-time visualization. 
#Throughout the process, we’ve added components to ensure reliability, scalability, and 
#smooth user interaction, as well as handling common pitfalls encountered in model training.

#---

# 1. Device Detection:
 #   device = get_device()
#- **Reason**: To ensure the model uses GPU if available, for faster training.
#- **Success**: Seamless fallback to CPU in case of unavailable GPU.
#- **Pitfall**: Not accounting for limited GPU memory could crash the program during large batch processing. 
#- **Modification**: Add logic to switch to CPU automatically if GPU memory is insufficient.

#---

# 2. Progress Bar Initialization:
 #   progress_bar = tqdm(total=100, desc="Loading Data", unit="%", leave=True)
#- **Reason**: Provides visual feedback to the user during the loading process.
#- **Success**: Improves transparency by showing the progress of various steps.
#- **Pitfall**: Users might think the program is frozen if no feedback is given for smaller tasks.
- **Modification**: Consider tracking finer-grained steps within data loading for more feedback.

---

# 3. Data Loading:
    arc_data = load_arc_data()
    progress_bar.update(20)
- **Reason**: To load the ARC dataset, which is essential for training and evaluation.
- **Success**: Loading large datasets in batches prevents memory overflow.
- **Pitfall**: Missing or corrupted data could cause crashes.
- **Modification**: Add file existence and format validation to ensure robustness.

---

# 4. Data Extraction and Reshaping:
    train_grid_pairs = flatten_and_reshape(arc_data.get("arc-agi_training-challenges", {}))
    eval_grid_pairs = flatten_and_reshape(arc_data.get("arc-agi_evaluation-challenges", {}))
    progress_bar.update(30)
#- **Reason**: Converts the nested grid data into a format usable for training.
- **Success**: Ensured compatibility with the DataLoader structure.
- **Pitfall**: Errors could arise if the data structure is inconsistent or missing fields.
- **Modification**: Implement logging to track extraction and reshaping issues.

#---

#5. Building Data Tree and Task Dictionary:
    root_node, task_dict = build_data_tree(train_grid_pairs)
    traverse_and_debug(root_node)
    progress_bar.update(20)
- **Reason**: Organizes tasks hierarchically for better management during training.
- **Success**: Easier debugging and visualization of task relationships.
- **Pitfall**: Misaligned grid data could cause tree construction to fail silently.
- **Modification**: Added detailed logging during tree traversal to catch errors early.

---

## 6. Logging Task Information:
    logger.info(f"Task dictionary initialized with {len(task_dict)} tasks:")
- **Reason**: Logs each task's ID, node, and grid shape to track initialization.
- **Success**: Helps monitor if all tasks have been properly loaded.
- **Pitfall**: Missing tasks could go unnoticed if not properly logged.
- **Modification**: Added logging for each task's details to detect anomalies.

---

## 7. Dataset and DataLoader Initialization:
    train_dataset = AugmentedARCDataset(train_grid_pairs, augment=False)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=collate_fn)
- **Reason**: Prepares data for training, using DataLoader to handle batching and shuffling.
- **Success**: Custom collate functions enable flexible input sizes.
- **Pitfall**: Low `num_workers` slows down loading; too high causes deadlocks.
- **Modification**: Set `num_workers` dynamically based on CPU availability.

---

## 8. Model Initialization:
    model = CNNGridMapper(num_classes=NUM_CLASSES).to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    logger.info("Model initialized successfully.")
- **Reason**: Moves the model to the appropriate device (CPU or GPU) for faster processing.
- **Success**: Handles multi-GPU setups using `DataParallel`.
- **Pitfall**: Failing to move the model to the correct device causes runtime errors.
- **Modification**: Added a fallback to `nn.DataParallel` for multi-GPU setups.

---

## 9. Exception Handling:
    except Exception as e:
        logger.exception(f"Data loading or model initialization failed: {e}")
        progress_bar.close()
        return
- **Reason**: Catches unexpected errors during data loading or model setup.
- **Success**: Ensures the application exits gracefully with meaningful error messages.
- **Pitfall**: Users may be confused without actionable advice for resolving errors.
- **Modification**: Provide tips in error logs (e.g., “Check GPU memory usage”).

---

## 10. Close Progress Bar:
    progress_bar.close()
- **Reason**: Ensures the progress bar is closed once tasks are completed.
- **Success**: Prevents console clutter with open progress bars.
- **Pitfall**: Forgetting to close the bar may confuse users about task completion.
- **Modification**: Display a summary report after closing the progress bar.

---

## 11. GUI Initialization:
    gui = TrainingGUI(
        root_window, total_epochs=10, total_batches=len(train_loader),
        model=model, train_loader=train_loader, val_loader=val_loader,
        eval_loader=None, device=device, data_tree=root_node, task_dict=task_dict
    )
- **Reason**: Provides a GUI for real-time monitoring of the training process.
- **Success**: Keeps users informed about training progress.
- **Pitfall**: The GUI can freeze without proper threading.
- **Modification**: Offload training logic to a separate thread to keep the GUI responsive.

---

## 12. Start Training Thread:
    training_thread = threading.Thread(
        target=train_model_with_gui, args=(model, train_loader, val_loader, device, gui)
    )
    training_thread.daemon = True
    training_thread.start()
- **Reason**: Runs training in a non-blocking thread to maintain GUI responsiveness.
- **Success**: GUI remains interactive while training continues in the background.
- **Pitfall**: Threads can fail silently, leading to halted training.
- **Modification**: Added `daemon=True` to ensure the thread terminates with the main process.

---

## 13. Start GUI Main Loop:
    root_window.mainloop()
- **Reason**: Keeps the GUI running for user interaction.
- **Success**: Allows users to interact with the application throughout the training.
- **Pitfall**: Users might quit accidentally, halting training without warning.
- **Modification**: Added a confirmation dialog to prompt users before quitting.

---

## 14. Entry Point Check:
    if __name__ == "__main__":
        main()
- **Reason**: Prevents accidental execution when the script is imported as a module.
- **Success**: Ensures the workflow only starts when run directly.
- **Pitfall**: If not implemented, unintended execution could interfere with other scripts.
- **Modification**: Standardized the entry-point check to avoid such issues.

---

### Final Thoughts:
This workflow ensures that the application is efficient, scalable, and user-friendly. 
By handling potential pitfalls and providing useful feedback throughout, we minimize downtime 
and confusion during execution. The use of multi-threading, dynamic resource management, 
and informative logging significantly improves the robustness of the application.
"""##




In [3]:
# -----------------------------
# 3. Define Constants
# -----------------------------

# Define the number of classes
NUM_CLASSES = 11  # 0-10, where 10 represents dead squares


In [4]:
# -----------------------------
# 4. Define Data Structures and Loading Functions
# -----------------------------

# Data Class for Grid Pairs
@dataclass
class GridPair:
    task_id: str
    input_grid: np.ndarray
    output_grid: np.ndarray

def load_arc_data():
    file_paths = {
        "arc-agi_training-challenges": "arc-agi_training_challenges.json",
        "arc-agi_evaluation-challenges": "arc-agi_evaluation_challenges.json",
        "arc-agi_training-solutions": "arc-agi_training_solutions.json",
        "arc-agi_evaluation-solutions": "arc-agi_evaluation_solutions.json",
    }
    arc_data = {key: load_json_file(path) for key, path in file_paths.items()}
    return arc_data

def load_json_file(path):
    try:
        with open(path, 'r') as f:
            data = json.load(f)
            logger.info(f"Loaded data from {path}.")
            return data
    except (FileNotFoundError, json.JSONDecodeError) as e:
        logger.error(f"Error loading {path}: {e}")
        return {}

def get_device():
    """Detect the best available device: CUDA or CPU."""
    if torch.cuda.is_available():
        device = torch.device('cuda')  # CUDA GPU
        logger.info("Using NVIDIA GPU via CUDA.")
    else:
        device = torch.device('cpu')  # Fallback to CPU
        logger.info("Using CPU as fallback.")
    return device

def extract_and_reshape_grid(grid):
    try:
        # Convert to NumPy array if not already
        grid = np.array(grid)
        # Handle empty grids or grids with zero dimensions
        if grid.size == 0 or 0 in grid.shape:
            logger.error(f"Empty grid or grid with zero dimension encountered: {grid.shape}")
            return None
        # Ensure grid is 2D
        if grid.ndim == 1:
            # If the grid is 1D, reshape to (1, N)
            grid = grid.reshape(1, -1)
            logger.warning(f"Grid reshaped to 2D: {grid.shape}")
        elif grid.ndim > 2:
            grid = grid.squeeze()
            if grid.ndim > 2:
                logger.error(f"Grid has more than 2 dimensions after squeeze: {grid.shape}")
                return None
        return grid  # Return as is, without resizing
    except Exception as e:
        logger.error(f"Error processing grid: {e}")
        return None

# Flatten and Reshape Grid Data
def flatten_and_reshape(task_data):
    flattened_pairs = []
    for task_id, task_content in task_data.items():
        logger.info(f"Parsing task {task_id}...")
        train_pairs = task_content.get('train', [])
        for pair in train_pairs:
            input_grid = extract_and_reshape_grid(pair.get("input"))
            output_grid = extract_and_reshape_grid(pair.get("output"))
            if input_grid is not None and output_grid is not None:
                # Check for zero dimensions in input or output grid
                if 0 in input_grid.shape or 0 in output_grid.shape:
                    logger.warning(f"Task ID: {task_id} has grid with zero dimension. Skipping.")
                    continue
                # Store the grids even if shapes differ
                flattened_pairs.append(GridPair(task_id, input_grid, output_grid))
            else:
                logger.warning(f"Task ID: {task_id} has invalid input/output grids.")
    logger.info(f"Total valid grid pairs extracted: {len(flattened_pairs)}")
    return flattened_pairs

def grid_to_image(grid, color_map):
    img_array = np.zeros((grid.shape[0], grid.shape[1], 3), dtype=np.uint8)
    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            img_array[i, j] = color_map.get(grid[i, j], [0, 0, 0])  # Default to black
    return Image.fromarray(img_array)

color_map = {
    0: [0, 0, 0],       # Black
    1: [255, 0, 0],     # Red
    2: [0, 255, 0],     # Green
    3: [0, 0, 255],     # Blue
    # Add more colors as needed
}

class TreeNode(NodeMixin):
    def __init__(self, name, input_grid=None, parent=None, children=None):
        self.name = name
        self.input_grid = input_grid
        self.embedding = None
        self.parent = parent
        if children:
            self.children = children

        logger.info(f"Node '{self.name}' initialized.")

    def set_embedding(self, embedding):
        """
        Set the embedding for the node.
        """
        self.embedding = embedding
        logger.info(f"Embedding set for node '{self.name}'.")

    def __repr__(self):
        """String representation for easier debugging."""
        return f"TreeNode(name={self.name}, children={len(self.children) if self.children else 0})"

def build_data_tree(grid_pairs):
    """
    Build a hierarchical tree from the ARC data using the Node class,
    and create a task dictionary for quick access.

    Args:
        grid_pairs (list): List of GridPair objects.

    Returns:
        tuple: (Node, dict) - Root node and task dictionary.
    """
    # Create the root node
    root = TreeNode(name="ARC Dataset")

    # Initialize the task dictionary
    task_dict = {}

    # Loop through the grid pairs to build task nodes
    for idx, pair in enumerate(grid_pairs):
        try:
            if not isinstance(pair, GridPair):
                raise TypeError(f"Expected GridPair, got {type(pair)}: {pair}")

            # Ensure grids are NumPy arrays
            input_grid = np.array(pair.input_grid) if not isinstance(pair.input_grid, np.ndarray) else pair.input_grid
            output_grid = np.array(pair.output_grid) if not isinstance(pair.output_grid, np.ndarray) else pair.output_grid

            # Create task and output nodes
            task_node = TreeNode(name=f"Task {pair.task_id}", parent=root)
            output_node = TreeNode(name=f"Output {pair.task_id}", parent=task_node)

            # Set embeddings for the nodes
            task_node.set_embedding(input_grid)
            output_node.set_embedding(output_grid)

            # Store the nodes in the task dictionary
            task_dict[pair.task_id] = {
                'task_node': task_node,
                'output_node': output_node,
                'grids': (input_grid, output_grid)
            }

            # Log success
            logger.info(f"Created task node for {pair.task_id} with embedding shape: {task_node.embedding.shape}")

        except Exception as e:
            logger.exception(f"Failed to create nodes for grid pair {idx}: {e}")
            continue  # Skip this pair if there's an issue

    # Return the root node and the task dictionary
    return root, task_dict


In [5]:
# -----------------------------
# Explanation of Code Block:
# -----------------------------
# 4. Define Data Structures and Loading Functions
# -----------------------------

# 1. GridPair Data Class:
#    - Utilizes the @dataclass decorator to automatically generate special methods like __init__().
#    - Represents a pair of input and output grids associated with a specific task.
#    - Attributes:
#        - task_id (str): Identifier for the task.
#        - input_grid (np.ndarray): The input grid for the task.
#        - output_grid (np.ndarray): The expected output grid for the task.
#    - Purpose: Organizes and manages data associated with each task for cleaner code and easier access.

#    *Success*: Simplifies the codebase by automatically generating methods and improves data management.
#    *Pitfall*: If grids aren't properly validated, it could lead to downstream errors.

# 2. load_arc_data():
#    - Loads the ARC dataset from specified JSON files.
#    - Takes in a dictionary mapping descriptive keys to the filenames of the JSON files.
#    - Uses a dictionary comprehension to load each file with load_json_file().
#    - Returns a dictionary containing the loaded data for training and evaluation challenges.

#    *Reason for Addition*: Centralizes data loading, making maintenance easier.
#    *Success*: Clean separation of data access logic.
#    *Pitfall*: Missing files or incorrect paths can cause silent failures without validation.

# 3. load_json_file(path):
#    - Helper function to load a single JSON file given its path.
#    - Uses a try-except block to handle potential errors such as FileNotFoundError and JSONDecodeError.
#    - Logs a success or error message based on the outcome.
#    - Returns the loaded data or an empty dictionary in case of an error.

#    *Reason for Addition*: Provides robust error handling for individual file loading.
#    *Success*: Ensures graceful handling of file-related errors.
#    *Pitfall*: Silent failures if logging isn't monitored carefully.

# 4. get_device():
#    - Determines the computing device (GPU if available, otherwise CPU).
#    - Uses torch.cuda.is_available() to check for GPU.
#    - Logs the selected device for transparency.
#    - Returns a torch.device object representing the device.

#    *Reason for Addition*: Optimizes computation by utilizing GPU when available.
#    *Success*: Seamless transition between CPU and GPU usage.
#    *Pitfall*: Lack of fallback mechanism for memory issues on GPU.

# 5. extract_and_reshape_grid(grid):
#    - Processes an individual grid to ensure it is formatted correctly.
#    - Converts input to a NumPy array if necessary.
#    - Logs and handles cases with empty or malformed grids.
#    - Ensures grids are two-dimensional:
#        - Reshapes one-dimensional grids into 2D grids with one row.
#        - Attempts to reduce grids with more than two dimensions by squeezing.
#        - Logs an error and returns None if the grid remains too complex.

#    *Reason for Addition*: Standardizes input data to prevent shape-related errors.
#    *Success*: Catches inconsistencies early in the workflow.
#    *Pitfall*: Over-aggressive reshaping could lead to data loss.

# 6. flatten_and_reshape(task_data):
#    - Transforms nested task data into a flat list of GridPair instances.
#    - Iterates over each task’s training pairs.
#    - Processes input and output grids with extract_and_reshape_grid().
#    - Skips invalid pairs (None or zero-dimension grids).
#    - Logs the total number of valid grid pairs extracted.

#    *Reason for Addition*: Prepares task data for training.
#    *Success*: Ensures the model receives data in the expected format.
#    *Pitfall*: Incorrect grids might get silently skipped without user notice.

# 7. grid_to_image(grid, color_map):
#    - Converts a numerical grid into a visual image using a color map.
#    - Creates an RGB array where each cell’s color corresponds to the value in the grid.
#    - Uses Image.fromarray() to create a PIL Image from the array.

#    *Reason for Addition*: Enables visual inspection and debugging of grids.
#    *Success*: Helps with quick troubleshooting by viewing grid data as images.
#    *Pitfall*: Large grids could slow down rendering or debugging sessions.

# 8. color_map:
#    - A dictionary mapping grid values to RGB colors (e.g., 0 -> black, 1 -> red).
#    - Allows for easy customization by adding more mappings.

#    *Reason for Addition*: Defines visual representation for grid values.
#    *Success*: Simplifies visualization logic by abstracting color mapping.
#    *Pitfall*: Limited color choices could make certain grids harder to interpret.

# 9. TreeNode Class:
#    - Inherits from NodeMixin to create tree structures.
#    - Attributes:
#        - name (str): Node’s name.
#        - input_grid (optional, np.ndarray): Grid data associated with the node.
#        - embedding (any): Stores embeddings or metadata.
#        - parent (optional, TreeNode): Parent node reference.
#        - children (optional, list): List of child nodes.
#    - Methods:
#        - set_embedding(embedding): Assigns an embedding and logs the action.
#        - __repr__(): Returns a string representation for debugging.

#    *Reason for Addition*: Structures dataset tasks hierarchically for better management.
#    *Success*: Helps organize large datasets with clear parent-child relationships.
#    *Pitfall*: Complex hierarchies can become hard to navigate without proper logging.

# 10. build_data_tree(grid_pairs):
#     - Constructs a tree structure from GridPair instances.
#     - Creates a root node called "ARC Dataset."
#     - Initializes a task dictionary for quick access.
#     - For each GridPair:
#        - Validates the instance type and data.
#        - Ensures the grids are NumPy arrays.
#        - Creates task nodes and their children.
#        - Sets embeddings using input and output grids.
#        - Logs the creation of each node.
#     - Returns the root node and task dictionary.

#     *Reason for Addition*: Provides a structured representation of the dataset.
#     *Success*: Enables hierarchical processing and easy visualization.
#     *Pitfall*: If tasks aren’t validated, it could lead to broken tree structures.

# Summary:
# This section defines the data structures and helper functions required to manage and
# load the ARC dataset efficiently. Key considerations include handling malformed data, 
# logging errors, and ensuring data consistency throughout the workflow. 
# These additions improve maintainability, error handling, and visualization but require 
# careful monitoring to avoid pitfalls like silent skips or memory issues.


In [6]:
# -----------------------------
# 5. Data Augmentation Functions
# -----------------------------

def augment_grid(grid, noise_prob=0.2, dead_square_prob=0.1):
    augmented_grid = np.array(grid)

    # Ensure the grid is 2D
    if augmented_grid.ndim != 2:
        logger.error(f"Augmenting grid failed due to invalid shape: {augmented_grid.shape}")
        return augmented_grid  # Return the original grid without augmentation

    # Random noise and dead square masks
    noise_mask = np.random.rand(*augmented_grid.shape) < noise_prob
    dead_mask = np.random.rand(*augmented_grid.shape) < dead_square_prob

    # Apply noise
    noise_values = np.random.randint(0, NUM_CLASSES - 1, size=augmented_grid.shape)
    augmented_grid = np.where(noise_mask, noise_values, augmented_grid)
    augmented_grid = np.where(dead_mask, -1, augmented_grid)  # Mark as dead squares

    return augmented_grid

def rotate_grid(grid):
    """Randomly rotates the grid."""
    rotations = random.choice([0, 1, 2, 3])
    return np.rot90(grid, rotations)

def flip_grid(grid):
    """Randomly flips the grid."""
    flip_choice = random.choice(['none', 'vertical', 'horizontal'])
    if flip_choice == 'vertical':
        return np.flipud(grid)  # Vertical flip
    elif flip_choice == 'horizontal':
        return np.fliplr(grid)  # Horizontal flip
    else:
        return grid  # No flip

def generate_multiple_augmented_datasets(grid_pairs, num_augmented_sets=3):
    """
    Generates multiple augmented datasets from the input grid pairs.

    Args:
        grid_pairs (list): List of GridPair objects.
        num_augmented_sets (int): Number of augmented sets to generate.

    Returns:
        list: Augmented grid pairs.
    """
    augmented_pairs = []
    for _ in range(num_augmented_sets):
        for pair in grid_pairs:
            # Apply augmentations to input grid
            augmented_input = augment_grid(pair.input_grid)

            # Optionally rotate and flip
            augmented_input = rotate_grid(augmented_input)
            augmented_input = flip_grid(augmented_input)

            # Append the augmented input with the original target grid
            augmented_pairs.append(GridPair(pair.task_id, augmented_input, pair.output_grid))

    return augmented_pairs

In [7]:
# -----------------------------
# Explanation of the Data Augmentation Functions:
# -----------------------------
# 5. Data Augmentation Functions
# -----------------------------

# 1. augment_grid(grid, noise_prob=0.2, dead_square_prob=0.1):
#    - **Purpose**: Applies random noise and dead squares to the grid to introduce variation.
#    - **Parameters**:
#        - **grid**: A 2D NumPy array representing the input grid.
#        - **noise_prob**: Probability that a cell will be replaced with a random noise value.
#        - **dead_square_prob**: Probability that a cell will be marked as a dead square (e.g., -1).
#    - **Steps**:
#        a. Converts the input to a NumPy array if not already one.
#        b. Verifies that the grid is 2D; if not, logs an error and returns the original grid.
#        c. Generates random masks for noise and dead squares.
#        d. Creates random noise values within the valid class range (0 to NUM_CLASSES - 1).
#        e. Replaces cells according to the generated noise and dead square masks.
#        f. Returns the modified grid.
#    - **Notes**:
#        - Adds variability to the training data, helping the model generalize better.
#        - Dead squares represent unusable data, indicated by a special value (e.g., -1).

#    *Reason for Addition*: Augmenting data ensures that the model learns from a variety of inputs.
#    *Success*: Helps prevent overfitting, leading to improved generalization.
#    *Pitfall*: Overuse of noise or dead squares could reduce the representativeness of the data.

# 2. rotate_grid(grid):
#    - **Purpose**: Randomly rotates the input grid by 0, 90, 180, or 270 degrees.
#    - **Parameters**:
#        - **grid**: A 2D NumPy array representing the grid to be rotated.
#    - **Steps**:
#        a. Randomly selects a rotation angle (0, 90, 180, or 270 degrees).
#        b. Uses NumPy’s `rot90` function to rotate the grid.
#        c. Returns the rotated grid.
#    - **Notes**:
#        - Rotation makes the model invariant to the orientation of inputs.
#        - Especially useful when grid patterns can appear in multiple orientations.

#    *Reason for Addition*: Prepares the model to handle orientation differences in data.
#    *Success*: Increases robustness by making the model invariant to rotation.
#    *Pitfall*: Too many rotations can lead to redundant patterns that don’t enhance learning.

# 3. flip_grid(grid):
#    - **Purpose**: Randomly flips the grid either vertically, horizontally, or not at all.
#    - **Parameters**:
#        - **grid**: A 2D NumPy array representing the grid to be flipped.
#    - **Steps**:
#        a. Randomly selects a flip type: 'none', 'vertical', or 'horizontal'.
#        b. Uses NumPy functions to apply the selected flip:
#            - `np.flipud(grid)` for a vertical flip (up-down).
#            - `np.fliplr(grid)` for a horizontal flip (left-right).
#        c. Returns the flipped grid.
#    - **Notes**:
#        - Flipping introduces more variation, enhancing the dataset.
#        - Helps the model learn patterns that remain consistent across different flips.

#    *Reason for Addition*: Increases data variety, helping the model learn feature consistency.
#    *Success*: Improves generalization by exposing the model to different grid orientations.
#    *Pitfall*: Some flipped grids may become unrealistic, leading to unhelpful data points.

# 4. generate_multiple_augmented_datasets(grid_pairs, num_augmented_sets=3):
#    - **Purpose**: Creates multiple sets of augmented data from the original grid pairs.
#    - **Parameters**:
#        - **grid_pairs**: A list of GridPair objects containing input and output grids.
#        - **num_augmented_sets**: Number of times to augment the dataset.
#    - **Steps**:
#        a. Initializes an empty list to store the augmented grid pairs.
#        b. Repeats the augmentation process for the specified number of sets.
#        c. For each GridPair:
#            i. Applies augmentations to the input grid:
#                - Adds noise and dead squares using `augment_grid()`.
#                - Randomly rotates the grid with `rotate_grid()`.
#                - Randomly flips the grid with `flip_grid()`.
#            ii. Creates a new GridPair with the augmented input grid and the original output grid.
#            iii. Appends the new GridPair to the augmented pairs list.
#        d. Returns the list of augmented grid pairs.
#    - **Notes**:
#        - Increases the size of the dataset by generating multiple augmented versions.
#        - Retains the original output grid to maintain the correct training target.
#        - Essential for improving model robustness, especially with small datasets.

#    *Reason for Addition*: Augments the dataset to ensure the model sees more varied inputs.
#    *Success*: Helps prevent overfitting by exposing the model to diverse inputs.
#    *Pitfall*: Excessive augmentation might dilute meaningful patterns in the original data.

# General Comments:
# - **Importance of Data Augmentation**: 
#   - Helps prevent overfitting and improves the model’s ability to generalize.
#   - Introduces randomness, simulating different scenarios the model might encounter.
#   - Particularly crucial for datasets that are small or lack sufficient diversity.

# - **Why These Techniques Were Chosen**:
#   - **Noise Addition & Dead Squares**: Simulate missing or noisy data.
#   - **Rotation and Flipping**: Increase robustness to orientation changes.
#   - **Multiple Augmentation Sets**: Expand the dataset to avoid overfitting.

# - **Successes**:
#   - Effective in enhancing the model’s robustness and reducing overfitting.
#   - Increases data variety without the need for additional labeled samples.

# - **Pitfalls**:
#   - Excessive augmentation can produce unrealistic patterns that confuse the model.
#   - Some grids might lose essential information if augmented too aggressively.

# - **Conclusion**:
#   - Data augmentation is essential for small or limited datasets like ARC tasks.
#   - Proper balancing of augmentation techniques ensures a diverse yet meaningful dataset.


In [8]:
# -----------------------------
# 6. PyTorch Dataset Class
# -----------------------------

class AugmentedARCDataset(torch.utils.data.Dataset):
    def __init__(self, grid_pairs, augment=False):
        # Filter out pairs where input or output grid has zero dimensions
        self.grid_pairs = [
            pair for pair in grid_pairs
            if pair.input_grid.size != 0 and pair.output_grid.size != 0
            and 0 not in pair.input_grid.shape and 0 not in pair.output_grid.shape
        ]
        self.augment = augment
        logger.info(f"Dataset initialized with {len(self.grid_pairs)} valid grid pairs.")

    def __len__(self):
        return len(self.grid_pairs)

    def __getitem__(self, idx):
        # Get the GridPair object
        pair = self.grid_pairs[idx]

        # Access the input and target grids
        input_grid = pair.input_grid
        target_grid = pair.output_grid

        # Apply augmentation if enabled
        if self.augment:
            input_grid = augment_grid(input_grid)

        # Convert grids to tensors
        input_tensor = torch.tensor(input_grid, dtype=torch.float32).unsqueeze(0)  # Shape: [1, H, W]
        target_tensor = torch.tensor(target_grid, dtype=torch.long)

        # Ensure target_tensor is 2D
        if target_tensor.dim() > 2:
            target_tensor = target_tensor.squeeze()

        # Debugging statements
        logger.debug(f"Index {idx}:")
        logger.debug(f"  Input tensor shape: {input_tensor.shape}")
        logger.debug(f"  Target tensor shape: {target_tensor.shape}")

        return input_tensor, target_tensor


In [9]:
# -----------------------------
# Explanation of the PyTorch Dataset Class:
# -----------------------------
# 6. PyTorch Dataset Class
# -----------------------------

# The `AugmentedARCDataset` class inherits from `torch.utils.data.Dataset`, 
# which provides a standard interface to create custom datasets in PyTorch. 
# This dataset class plays a critical role in the training pipeline, handling 
# data loading, preprocessing, and optional data augmentation. It ensures that 
# raw data is converted into a format suitable for model training.

# 1. __init__(self, grid_pairs, augment=False):
#    - **Constructor**: Initializes the dataset with grid pairs and an optional augmentation flag.
#    - **Parameters**:
#        - **grid_pairs**: A list of `GridPair` objects containing input and output grids.
#        - **augment**: A boolean flag indicating whether data augmentation should be applied.
#    - **Actions**:
#        a. Filters out invalid grid pairs where the input or output grid has zero size or zero dimensions.
#            - **Reason**: Prevents runtime errors during training caused by malformed grids.
#            - **Implementation**: Uses a list comprehension to iterate through `grid_pairs` and apply filtering.
#        b. Sets the `augment` flag to control whether data augmentation is applied during data loading.
#        c. Logs the number of valid grid pairs that are initialized in the dataset.
#    - **Purpose**:
#        - Ensures only valid data is included in the dataset, minimizing the risk of training errors.
#        - Provides flexibility to apply data augmentation during training for better generalization.

#    *Reason for Addition*: Ensures robust handling of invalid data and offers optional data augmentation.
#    *Success*: Prevents invalid grids from propagating into the training loop.
#    *Pitfall*: If filtering is too strict, useful data may be excluded inadvertently.

# 2. __len__(self):
#    - **Purpose**: Returns the total number of items in the dataset.
#    - **Details**:
#        - This function allows `len(dataset)` to retrieve the dataset size.
#        - It’s essential for PyTorch’s DataLoader to determine the number of batches.
#    - **Usage**: Ensures compatibility with PyTorch's DataLoader for efficient batching.

#    *Reason for Addition*: Supports DataLoader functionality and enables batching.
#    *Success*: Allows smooth interaction with PyTorch utilities.
#    *Pitfall*: An incorrect length implementation could lead to index out-of-bounds errors.

# 3. __getitem__(self, idx):
#    - **Purpose**: Retrieves a single data item from the dataset at a specified index.
#    - **Parameters**:
#        - **idx**: The index of the data item to retrieve.
#    - **Actions**:
#        a. Retrieves the `GridPair` object at the given index.
#        b. Extracts the `input_grid` and `target_grid` from the `GridPair`.
#        c. Applies data augmentation to the `input_grid` if `augment` is True.
#            - **Implementation**: Uses the `augment_grid()` function to apply noise, flips, or rotations.
#        d. Converts the grids into PyTorch tensors:
#            - **input_tensor**: 
#                - Converted to a float32 tensor.
#                - Adds an extra channel dimension using `unsqueeze(0)` to match CNN input format (shape [1, H, W]).
#            - **target_tensor**: 
#                - Converted to a long tensor (used for classification tasks).
#                - Ensures it is 2D by squeezing unnecessary dimensions.
#        e. Logs debugging information about the index and tensor shapes for traceability.
#    - **Returns**: 
#        - A tuple `(input_tensor, target_tensor)` that can be used directly in the training loop.

#    *Reason for Addition*: Provides modular data retrieval with preprocessing included.
#    *Success*: Centralizes tensor conversion and augmentation logic.
#    *Pitfall*: Augmentation might introduce inconsistencies if not managed carefully.

# Additional Notes:
# - **Integration with DataLoader**: 
#   - By inheriting from `torch.utils.data.Dataset`, the class can be used with PyTorch’s DataLoader, 
#     which handles batch creation, shuffling, and parallel data loading.
# - **Augmentation Control**: 
#   - The `augment` parameter provides flexibility to use augmentation only during training, 
#     avoiding unnecessary transformations during evaluation.
# - **Debugging Support**:
#   - Includes logging to provide insights into the data processing workflow, 
#     helping trace potential issues with grid shapes or tensor conversions.
# - **Tensor Conversion Details**:
#   - Ensuring the input tensor has a channel dimension is crucial for feeding data into convolutional neural networks (CNNs).
#   - Guaranteeing that the target tensor is 2D avoids issues during loss computation and predictions.

# **Why This Class is Important**:
# - **Seamless Data Handling**: 
#   - Connects raw data with the model training process, ensuring inputs are correctly preprocessed.
# - **Efficiency**:
#   - Works with PyTorch’s DataLoader for batch processing, which is essential for large datasets.
# - **Flexibility**:
#   - Allows for easy switching between training with or without augmentation by toggling a single flag.

# **Successes**:
# - Handles invalid data gracefully by filtering out problematic grid pairs.
# - Provides a modular way to apply data augmentation, simplifying the main training loop.

# **Pitfalls**:
# - Strict filtering criteria may exclude useful data if not carefully tuned.
# - Incorrect tensor shapes could cause runtime errors during training if not properly managed.

# **Conclusion**:
# This dataset class is a crucial component in the PyTorch training pipeline. 
# It bridges the gap between raw data and model training, handling preprocessing, 
# tensor conversion, and optional augmentation within a single, modular structure. 
# The design ensures that only valid data reaches the training loop, while providing 
# the flexibility to enhance the dataset with augmentation techniques as needed.


In [10]:
# -----------------------------
# 7. Define the Deep Neural Network Model
# -----------------------------

class CNNGridMapper(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(CNNGridMapper, self).__init__()
        self.num_classes = num_classes

        # Use a CNN backbone (e.g., ResNet18)
        self.cnn = resnet18(weights=ResNet18_Weights.DEFAULT)

        # Modify the first convolutional layer for single-channel input
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        nn.init.kaiming_normal_(self.cnn.conv1.weight, mode='fan_out', nonlinearity='relu')

        # Remove the fully connected layer
        self.cnn_layers = nn.Sequential(*list(self.cnn.children())[:-2])

        # Upsampling layers to recover spatial dimensions
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, num_classes, kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        x = self.upsample(x)
        return x  # Output shape: (batch_size, num_classes, H', W')


In [11]:
# -----------------------------
# Explanation of the Deep Neural Network Model:
# -----------------------------
# 7. Define the Deep Neural Network Model
# -----------------------------

# The `CNNGridMapper` class defines a convolutional neural network (CNN) that maps input grids to output grids.
# It adapts the well-known ResNet-18 architecture to handle single-channel inputs and produces spatial outputs 
# suitable for grid-based tasks, such as semantic segmentation.

# 1. **Class Definition**:
#    - Inherits from `nn.Module`, the base class for all neural network models in PyTorch.
#    - The name `CNNGridMapper` reflects its purpose of mapping grids using a convolutional neural network.

# 2. **__init__(self, num_classes=NUM_CLASSES)**:
#    - **Constructor**: Initializes the network’s layers.
#    - **Parameters**:
#        - `num_classes`: Specifies the number of output classes, corresponding to the possible values in the grid 
#          (e.g., 11 classes for values ranging from 0 to 10).
#    - **Purpose**: Sets up the model architecture and prepares it for training and inference.

#    *Reason for Addition*: Provides a customizable number of output classes to match the target task.

# 3. **Using ResNet-18 Backbone**:
#    - Loads the ResNet-18 model using:
#        - `self.cnn = resnet18(weights=ResNet18_Weights.DEFAULT)`.
#    - **Purpose**: ResNet-18 is a well-established architecture with residual connections, making it efficient and 
#      effective for feature extraction.

#    *Success*: Benefiting from transfer learning by using pretrained weights.
#    *Pitfall*: Pretrained models expect specific input formats, requiring customization for single-channel inputs.

# 4. **Modifying the First Convolutional Layer**:
#    - ResNet-18 expects 3-channel RGB images, but our task uses single-channel grid data.
#    - We replace the first convolutional layer with:
#        - `self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)`.
#    - **Weights Initialization**: Uses Kaiming Normal initialization to match the ReLU activation function.

#    *Reason for Addition*: Adapts the architecture to process single-channel grids.
#    *Pitfall*: If weights are not properly initialized, the model may perform poorly.

# 5. **Removing the Fully Connected Layer**:
#    - Removes the final layers (average pooling and fully connected layers) to convert the model into a fully 
#      convolutional network (FCN):
#        - `self.cnn_layers = nn.Sequential(*list(self.cnn.children())[:-2])`.
#    - **Purpose**: Produces spatial feature maps instead of scalar outputs, which is necessary for per-grid classification.

#    *Reason for Addition*: Adapts the network for spatial tasks rather than classification.
#    *Pitfall*: Removing layers improperly could break the model.

# 6. **Upsampling Layers**:
#    - The CNN backbone reduces spatial dimensions through pooling and striding.
#    - We use transposed convolutional layers (deconvolutions) to increase spatial dimensions:
#        - `self.upsample = nn.Sequential(...)`.
#    - Each transposed convolution layer doubles the spatial dimensions and applies ReLU activation for non-linearity.

#    *Reason for Addition*: Restores spatial dimensions to match the output grid size.
#    *Pitfall*: If upsampling is not properly aligned, it may introduce artifacts in the output.

# 7. **Upsampling Layer Details**:
#    - **First Transposed Convolution**:
#        - Input channels: 512 (from the last ResNet-18 layer).
#        - Output channels: 256.
#    - **Second Transposed Convolution**:
#        - Input channels: 256.
#        - Output channels: 128.
#    - **Third Transposed Convolution**:
#        - Input channels: 128.
#        - Output channels: `num_classes` (final number of output classes).
#    - **Purpose**: These layers progressively increase spatial dimensions to match the target grid size.

#    *Success*: Effectively reverses downsampling, restoring spatial dimensions.
#    *Pitfall*: Misaligned upsampling can reduce accuracy by distorting feature maps.

# 8. **forward(self, x)**:
#    - **Defines the forward pass** of the model.
#    - **Input**: 
#        - `x`: A tensor of shape (batch_size, 1, H, W), where 1 is the channel dimension for single-channel grids.
#    - **Steps**:
#        - Passes the input through the CNN backbone:
#            - `x = self.cnn_layers(x)`.
#            - Results in feature maps with reduced spatial dimensions.
#        - Passes the feature maps through the upsampling layers:
#            - `x = self.upsample(x)`.
#            - Restores the spatial dimensions to the desired size.
#    - **Output**: 
#        - Returns a tensor of shape (batch_size, num_classes, H', W'), where H' and W' depend on the input size.

#    *Reason for Addition*: Implements the core logic for data flow through the network.
#    *Success*: Ensures the input is processed correctly and spatial dimensions are restored.
#    *Pitfall*: Incorrect shapes could cause runtime errors during training.

# 9. **Output Interpretation**:
#    - **Purpose**: The output tensor provides class scores (logits) for each grid position.
#    - Suitable for tasks like **semantic segmentation**, where each cell in the grid is classified independently.
#    - **Training**: 
#        - Typically, `CrossEntropyLoss` is used as the loss function, which expects raw scores (logits) as input.

#    *Reason for Addition*: Provides the correct output format for per-cell classification tasks.

# 10. **Notes on Model Design**:
#     - **Transfer Learning**: 
#       - Leveraging a pretrained ResNet-18 improves performance by utilizing features learned from large datasets.
#     - **Fully Convolutional Network (FCN)**:
#       - Removing the fully connected layers allows the model to produce spatial outputs.
#     - **Upsampling via Transposed Convolutions**:
#       - Enables the model to learn effective ways to restore spatial dimensions.

# Summary:
# - The `CNNGridMapper` class adapts the ResNet-18 architecture to process single-channel input grids and generate 
#   spatial outputs with multiple classes. 
# - This design is tailored for grid-to-grid mapping tasks, where both the input and output are grids (matrices) of values.
# - The combination of **deep feature extraction** (via the CNN backbone) and **upsampling** ensures that the model can 
#   capture high-level patterns and produce detailed spatial outputs.
# - **Successes**:
#   - Efficient feature extraction using ResNet-18.
#   - Flexible handling of single-channel inputs and spatial outputs.
# - **Pitfalls**:
#   - Incorrect configuration of upsampling layers could degrade performance.
#   - Customizing the ResNet-18 layers requires careful weight initialization and shape alignment.



In [12]:
# -----------------------------
# 8. Custom Collate Function
# -----------------------------

def collate_fn(batch):
    inputs = [item[0] for item in batch]  # Shape: [C, H, W]
    targets = [item[1] for item in batch]  # Shape: [H, W]

    # Find max dimensions in the batch for inputs and targets separately
    max_input_height = max(t.size(-2) for t in inputs)
    max_input_width = max(t.size(-1) for t in inputs)
    max_target_height = max(t.size(-2) for t in targets)
    max_target_width = max(t.size(-1) for t in targets)

    batch_size = len(inputs)
    num_channels = inputs[0].size(0)

    # Initialize tensors with zeros
    batch_inputs = torch.zeros((batch_size, num_channels, max_input_height, max_input_width), dtype=inputs[0].dtype)
    batch_targets = torch.zeros((batch_size, max_target_height, max_target_width), dtype=targets[0].dtype)

    for i in range(batch_size):
        input_tensor = inputs[i]
        target_tensor = targets[i]

        # Get shapes
        c, h_inp, w_inp = input_tensor.size()
        h_tar, w_tar = target_tensor.size()

        # Copy input_tensor into batch_inputs
        batch_inputs[i, :, :h_inp, :w_inp] = input_tensor

        # Copy target_tensor into batch_targets
        batch_targets[i, :h_tar, :w_tar] = target_tensor

        # Debugging statements
        logger.debug(f"Batch index {i}:")
        logger.debug(f"  Input tensor shape: {input_tensor.shape}")
        logger.debug(f"  Target tensor shape: {target_tensor.shape}")
        logger.debug(f"  Batch input shape: {batch_inputs[i].shape}")
        logger.debug(f"  Batch target shape: {batch_targets[i].shape}")

    return batch_inputs, batch_targets


In [13]:
# -----------------------------
# 8. Custom Collate Function
# -----------------------------

# def collate_fn(batch):
#     """
#     This custom `collate_fn` is designed to handle batches where the input and target grids may have varying sizes.
#     In tasks like the Abstraction and Reasoning Corpus (ARC), grids can have different dimensions, so a standard
#     collate function (which assumes all inputs are the same size) would not work.
#
#     Key Steps and Purpose:
#
#     1. Extract Inputs and Targets:
#         - The function receives a batch, which is a list of tuples where each tuple contains an input tensor and a target tensor.
#         - It separates the inputs and targets into two lists for processing.
#
#     2. Determine Maximum Dimensions:
#         - It calculates the maximum height and width among all input tensors (`inputs`) and target tensors (`targets`) in the batch.
#         - This step is crucial for creating tensors that can hold all the samples, considering the largest dimensions.
#
#     3. Initialize Batched Tensors:
#         - Creates two zero-filled tensors (`batch_inputs` and `batch_targets`) with shapes:
#             - `batch_inputs`: `[batch_size, num_channels, max_input_height, max_input_width]`
#             - `batch_targets`: `[batch_size, max_target_height, max_target_width]`
#         - These tensors will hold all input and target tensors, padded where necessary.
#
#     4. Populate Batched Tensors:
#         - Iterates over each sample in the batch.
#         - For each sample:
#             - Retrieves the input and target tensors.
#             - Gets their actual dimensions.
#             - Copies the input tensor into the corresponding slice of `batch_inputs`.
#             - Copies the target tensor into the corresponding slice of `batch_targets`.
#         - Since the batched tensors may be larger than the individual tensors, the extra regions remain zero (effectively padding).
#
#     5. Debugging Statements:
#         - Logs detailed shape information for each sample, which is helpful for debugging issues related to tensor dimensions.
#
#     6. Return Batched Tensors:
#         - Returns the `batch_inputs` and `batch_targets` tensors, which can now be used in the training loop.
#
#     Why This Function is Necessary:
#
#     - **Variable-Sized Inputs**:
#         - In many datasets, especially with images or grids, not all samples are of the same size.
#         - The standard DataLoader expects all samples in a batch to have the same dimensions, which isn't the case here.
#
#     - **Padding to Maximum Size**:
#         - By padding all tensors to the maximum size in the batch, we can batch them together.
#         - This approach avoids the need to resize or distort the data, preserving the original information.
#
#     - **Efficiency**:
#         - Handling variable-sized data efficiently without writing custom batch handling logic in the training loop.
#
#     Considerations:
#
#     - **Memory Usage**:
#         - Padding to the maximum size can lead to increased memory usage, especially if there's a large discrepancy between the smallest and largest samples.
#         - This can be mitigated by grouping similar-sized samples together (bucketing) or setting a maximum allowable size.
#
#     - **Model Adaptation**:
#         - The model must be able to handle inputs of varying sizes.
#         - In this code, the model uses convolutional layers and upsampling, which can work with variable input sizes.
#
#     - **Loss Function Compatibility**:
#         - The loss function and any metric calculations need to account for the padded regions, if necessary.
#         - In this implementation, the padding is with zeros, which may correspond to a valid class (e.g., background), so care must be taken.
#
#     Conclusion:
#
#     The custom `collate_fn` is an essential component when working with datasets containing variable-sized samples.
#     It ensures that data can be batched and fed into the model without losing the integrity of the original samples.
#     This function enhances the flexibility and robustness of the data loading pipeline.
#     """


In [14]:
# -----------------------------
# 9. Training GUI Class
# -----------------------------

class TrainingGUI:
    """
    A Tkinter-based GUI for real-time training progress visualization with 3D metrics plotting and data tree integration.
    """

    def __init__(self, root, total_epochs, total_batches, model, train_loader, 
                 val_loader, eval_loader, device, data_tree, task_dict):
        """Initialize the Training GUI."""
        self.root = root
        self.total_epochs = total_epochs
        self.total_batches = total_batches
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.eval_loader = eval_loader
        self.device = device
        self.data_tree = data_tree  # Data tree integration
        self.task_dict = task_dict  # Store task dictionary for training logic

        # Initialize other required attributes
        self.queue = queue.Queue()
        self.stop_event = threading.Event()

        # Initialize data storage for plots
        self.loss_data = []
        self.val_loss_data = []
        self.acc_data = []
        self.prediction_distances = []

        # Set up the GUI
        self.setup_gui()
        self.root.after(100, self.process_queue)

    def setup_gui(self):
        """Set up the GUI components."""
        self.frame = tk.Frame(self.root)
        self.frame.pack(fill=tk.BOTH, expand=True)

        # Top Section for Labels
        self.label_frame = tk.Frame(self.frame)
        self.label_frame.pack(pady=10)

        self.epoch_label = tk.Label(self.label_frame, text=f"Epoch: 0/{self.total_epochs}", font=("Helvetica", 14))
        self.epoch_label.grid(row=0, column=0, padx=10)

        self.batch_label = tk.Label(self.label_frame, text=f"Batch: 0/{self.total_batches}", font=("Helvetica", 12))
        self.batch_label.grid(row=0, column=1, padx=10)

        self.loss_label = tk.Label(self.label_frame, text="Loss: 0.0000", font=("Helvetica", 12))
        self.loss_label.grid(row=0, column=2, padx=10)

        self.accuracy_label = tk.Label(self.label_frame, text="Accuracy: 0.0000", font=("Helvetica", 12))
        self.accuracy_label.grid(row=0, column=3, padx=10)

        # Data Tree Visualization Section
        self.tree_frame = tk.Frame(self.frame, width=300, height=400)
        self.tree_frame.pack(side=tk.LEFT, padx=10, pady=10, fill=tk.Y)

        self.tree_label = tk.Label(self.tree_frame, text="Data Tree", font=("Helvetica", 14))
        self.tree_label.pack()

        self.tree_canvas = tk.Canvas(self.tree_frame, width=300, height=400, bg='white')
        self.tree_canvas.pack()

        # Display the data tree
        self.display_data_tree()

        # Plot Section (2D + 3D)
        self.fig = plt.figure(figsize=(12, 6))

        # 3D Plot on the Left
        self.ax_3d = self.fig.add_subplot(121, projection='3d')
        self.ax_3d.set_xlabel('Epoch')
        self.ax_3d.set_ylabel('Accuracy')
        self.ax_3d.set_zlabel('Distance from Actual')

        # 2D Plot on the Right
        self.ax_2d = self.fig.add_subplot(122)
        self.line_loss, = self.ax_2d.plot([], [], label='Training Loss')
        self.line_val_loss, = self.ax_2d.plot([], [], label='Validation Loss')
        self.ax_2d.legend()

        self.canvas_plot = FigureCanvasTkAgg(self.fig, master=self.frame)
        self.canvas_plot.draw()
        self.canvas_plot.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Bottom Section for Control Buttons
        self.button_frame = tk.Frame(self.frame)
        self.button_frame.pack(pady=10)

        self.start_button = tk.Button(self.button_frame, text="Start Training", command=self.start_training)
        self.start_button.grid(row=0, column=0, padx=10)

        self.stop_button = tk.Button(self.button_frame, text="Stop Training", command=self.stop_training)
        self.stop_button.grid(row=0, column=1, padx=10)

        self.evaluate_button = tk.Button(self.button_frame, text="Evaluate Model", command=self.evaluate_model_button)
        self.evaluate_button.grid(row=0, column=2, padx=10)

    def display_data_tree(self):
        """Generate and display the data tree as an image."""
        try:
            # Render the tree to a PNG using anytree and Graphviz
            dot_file = "tree.dot"
            png_file = "tree.png"

            # Export to .dot file
            DotExporter(self.data_tree).to_dotfile(dot_file)
            logger.info(f"Tree exported to {dot_file}")

            # Convert .dot to .png using Graphviz
            result = os.system(f'dot -Tpng {dot_file} -o {png_file}')
            if result != 0:
                raise RuntimeError("Failed to generate PNG. Ensure Graphviz is installed and in PATH.")

            # Load and display the PNG image
            img = Image.open(png_file)
            img = img.resize((300, 400), Image.LANCZOS)
            img_tk = ImageTk.PhotoImage(img)

            # Display the image in the canvas
            self.tree_canvas.create_image(0, 0, anchor=tk.NW, image=img_tk)
            self.tree_canvas.image = img_tk  # Keep reference to avoid garbage collection
            logger.info("Tree visualization displayed successfully.")

        except Exception as e:
            logger.exception("Failed to display the data tree.")
            tk.messagebox.showerror("Tree Display Error", f"Error: {e}")

    def process_queue(self):
        """Process the queue for thread-safe GUI updates."""
        while not self.queue.empty():
            message = self.queue.get()
            if isinstance(message, dict):
                self.update_gui(message)
        self.root.after(100, self.process_queue)

    def update_gui(self, data):
        """Update the GUI with real-time training and validation metrics."""
        try:
            if 'batch' in data:
                # Update batch-level metrics in the GUI
                self.batch_label.config(text=f"Batch: {data['batch']}/{self.total_batches}")
                self.loss_label.config(text=f"Loss: {data['loss']:.4f}")
                self.accuracy_label.config(text=f"Accuracy: {data.get('accuracy', 0.0):.4f}")

                # Append new batch data to the 2D plot lists
                self.loss_data.append(data['loss'])
                self.acc_data.append(data.get('accuracy', 0.0))

                # Update the 2D plot with each batch completion
                batches = list(range(1, len(self.loss_data) + 1))
                self.line_loss.set_data(batches, self.loss_data)
                self.line_val_loss.set_data(batches, self.acc_data)

                # Adjust the axes to fit the new data
                self.ax_2d.relim()
                self.ax_2d.autoscale_view()

                # Redraw the 2D plot with new data
                self.canvas_plot.draw()

            elif 'epoch' in data:
                # Update epoch-level metrics
                self.epoch_label.config(text=f"Epoch: {data['epoch']}/{self.total_epochs}")
                # Handle validation loss and accuracy if provided
                val_loss = data.get('val_loss', 0.0)
                val_accuracy = data.get('val_accuracy', 0.0)
                # Update labels if you have labels for validation metrics

                # Calculate prediction error distance
                predicted = np.array(data.get('predicted', []))
                actual = np.array(data.get('actual', []))

                if predicted.size == 0 or actual.size == 0:
                    distance = float('nan')  # Handle empty arrays gracefully
                elif predicted.shape != actual.shape:
                    distance = float('nan')  # Handle shape mismatch
                else:
                    distance = np.abs(predicted - actual).mean()

                # Replace NaN with 0.0 for plotting purposes
                distance = 0.0 if np.isnan(distance) else distance

                # Store valid distances for 3D plot
                self.prediction_distances.append(distance)

                # Update the 3D plot
                epochs = list(range(1, len(self.prediction_distances) + 1))
                self.ax_3d.clear()
                self.ax_3d.set_xlabel('Epoch')
                self.ax_3d.set_ylabel('Accuracy')
                self.ax_3d.set_zlabel('Distance from Actual')
                self.ax_3d.set_title('3D Prediction Error vs Accuracy')

                # Scatter plot with prediction distances
                self.ax_3d.scatter(epochs, self.acc_data, self.prediction_distances, label='Error vs Accuracy', color='green')
                self.ax_3d.legend()

                # Redraw the 3D plot
                self.canvas_plot.draw()

        except Exception as e:
            logger.exception("An error occurred while updating the GUI.")
            messagebox.showerror("Error", f"An error occurred: {e}")

    def start_training(self):
        """Start training in a new thread."""
        self.stop_event.clear()
        threading.Thread(target=self.train_thread, daemon=True).start()

    def stop_training(self):
        """Stop the training process."""
        self.stop_event.set()

    def train_thread(self):
        """Training logic executed in a separate thread to avoid blocking the GUI."""
        logger.info("Training thread started.")

        # Optimizer, scheduler, and criterion setup
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.01, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        criterion = nn.CrossEntropyLoss()

        # Mixed precision scaler (if using CUDA)
        scaler = torch.cuda.amp.GradScaler() if self.device.type == 'cuda' else None

        # Loop over epochs
        for epoch in range(self.total_epochs):
            if self.stop_event.is_set():
                logger.info("Training stopped by user.")
                break
            logger.info(f"Starting epoch {epoch + 1}/{self.total_epochs}.")
            self.model.train()  # Set the model in training mode
            running_loss = 0.0
            correct = 0
            total = 0

            # Loop over batches
            for batch_idx, (inputs, targets) in enumerate(self.train_loader, 1):
                if self.stop_event.is_set():
                    logger.info("Training stopped by user.")
                    break
                try:
                    # Train the batch and gather metrics
                    batch_loss, batch_accuracy, batch_size = self.train_batch(
                        batch_idx, inputs, targets, optimizer, scaler, criterion
                    )

                    running_loss += batch_loss * batch_size
                    correct += int(batch_accuracy * batch_size / 100)
                    total += targets.numel()

                    # Update the GUI every 10 batches or at the end of epoch
                    if batch_idx % 10 == 0 or batch_idx == len(self.train_loader):
                        gui_batch_loss = running_loss / total
                        gui_batch_accuracy = 100.0 * correct / total
                        self.queue.put({
                            'batch': batch_idx,
                            'loss': gui_batch_loss,
                            'accuracy': gui_batch_accuracy
                        })

                except Exception as e:
                    logger.exception(f"Error in batch {batch_idx}: {e}")
                    continue  # Continue with the next batch if an error occurs

            # Epoch metrics
            epoch_loss = running_loss / total
            epoch_accuracy = 100.0 * correct / total

            # Send epoch updates to the GUI
            self.queue.put({
                'epoch': epoch + 1,
                'loss': epoch_loss,
                'accuracy': epoch_accuracy
            })

            # Scheduler step
            scheduler.step()

        logger.info("Training completed.")
        self.queue.put({'status': 'Training Completed'})

    def train_batch(self, batch_idx, inputs, targets, optimizer, scaler, criterion):
        """Train a single batch."""
        # Move data to the appropriate device
        inputs, targets = inputs.to(self.device), targets.to(self.device)

        # Reset gradients
        optimizer.zero_grad()

        if self.device.type == 'cuda' and scaler is not None:
            # Mixed precision training with autocast
            with torch.autocast(device_type=self.device.type, enabled=True):
                outputs = self.model(inputs)
                # Resize outputs to match targets
                outputs = F.interpolate(outputs, size=targets.shape[1:], mode='bilinear', align_corners=False)
                loss = criterion(outputs, targets)

            # Backward pass and optimizer step with scaler
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # Standard training without autocast
            outputs = self.model(inputs)
            # Resize outputs to match targets
            outputs = F.interpolate(outputs, size=targets.shape[1:], mode='bilinear', align_corners=False)
            loss = criterion(outputs, targets)

            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()

        # Compute metrics
        batch_loss = loss.item()
        _, predicted = outputs.max(1)
        correct_predictions = predicted.eq(targets).sum().item()
        batch_accuracy = 100.0 * correct_predictions / targets.numel()

        return batch_loss, batch_accuracy, targets.numel()

    def evaluate_model_button(self):
        """Evaluate the model in a new thread."""
        threading.Thread(target=self.evaluate_model, daemon=True).start()

    def evaluate_model(self):
        """Evaluate the model."""
        avg_loss, accuracy = evaluate_model(self.model, self.val_loader, self.device)
        messagebox.showinfo("Evaluation", f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")


In [15]:
# -----------------------------
# 9. Explanation of the TrainingGUI Class
# -----------------------------

# The `TrainingGUI` class is responsible for creating a graphical user interface (GUI) to visualize the 
# training process of a neural network model in real-time. It uses Tkinter for GUI elements and Matplotlib 
# for plotting training metrics. This class enhances user interaction and provides real-time insights into the 
# model’s performance.

# 1. **Class Initialization (__init__ method)**:
#    - **Purpose**: Initializes the GUI and prepares all necessary components for tracking and controlling training.
#    - **Parameters**:
#        - `root`: The root window of the Tkinter GUI.
#        - `total_epochs`, `total_batches`: Total number of epochs and batches for progress tracking.
#        - `model`: The neural network model being trained.
#        - `train_loader`, `val_loader`, `eval_loader`: Data loaders for training, validation, and evaluation.
#        - `device`: The device (CPU/GPU) used for training.
#        - `data_tree`: A hierarchical tree structure representing the dataset, used for visualization.
#        - `task_dict`: Dictionary containing task-related data to be referenced during training.
#    - **Actions**:
#        - Initializes attributes for managing threads and data for plotting.
#        - Calls `setup_gui()` to build the interface layout.
#        - Uses `root.after()` to schedule periodic updates from the queue.
#    
#    **Success**: Separates the GUI setup from the main training logic, ensuring modularity and easier maintenance.
#    **Pitfall**: If the periodic update interval is not well-tuned, the GUI may lag or feel unresponsive during large datasets.

# 2. **setup_gui method**:
#    - **Purpose**: Constructs the layout and components of the GUI using Tkinter.
#    - **Sections**:
#        - **Top Section**: Displays metrics like the current epoch, batch number, loss, and accuracy.
#        - **Data Tree Visualization**:
#            - Creates a canvas to display the dataset’s hierarchical tree.
#            - Uses `display_data_tree()` to render the tree as an image.
#        - **Plot Section**:
#            - Initializes a **3D plot** for prediction error versus accuracy and a **2D plot** for loss metrics.
#            - Embeds these plots into the GUI using `FigureCanvasTkAgg`.
#        - **Bottom Section**: Adds buttons for starting, stopping, and evaluating the model.
#    
#    **Success**: Provides a clear, structured interface with real-time feedback on metrics.
#    **Pitfall**: Overpopulating the GUI with too many metrics can overwhelm users, reducing usability.

# 3. **display_data_tree method**:
#    - **Purpose**: Visualizes the dataset structure as a tree diagram.
#    - **Implementation**:
#        - Uses `anytree` and `Graphviz` to export the tree structure to a DOT file and convert it to a PNG image.
#        - Displays the PNG within the GUI canvas.
#        - Catches and handles exceptions gracefully if the tree cannot be rendered.
#    
#    **Success**: Offers users a visual understanding of the dataset structure.
#    **Pitfall**: External dependencies (e.g., Graphviz) can introduce errors if not properly installed.

# 4. **process_queue method**:
#    - **Purpose**: Periodically processes messages from a thread-safe queue to update the GUI.
#    - **Implementation**:
#        - Uses `root.after()` to ensure regular, non-blocking updates.
#        - Invokes `update_gui()` with data from the queue.
#    
#    **Success**: Keeps the GUI responsive during long training sessions by offloading updates to a separate thread.
#    **Pitfall**: If too many messages accumulate in the queue, GUI updates might lag behind.

# 5. **update_gui method**:
#    - **Purpose**: Updates the GUI components based on received data.
#    - **Batch-Level Updates**:
#        - Updates metrics like batch number, loss, and accuracy in real-time.
#        - Appends data to lists for dynamic plotting.
#        - Refreshes the 2D loss plot.
#    - **Epoch-Level Updates**:
#        - Updates the current epoch label.
#        - Computes prediction error for the 3D plot.
#        - Adds new points to the 3D plot.
#    
#    **Success**: Provides detailed insights into the model’s performance during both batch and epoch levels.
#    **Pitfall**: Frequent GUI updates may slow down training, especially with complex plots.

# 6. **start_training method**:
#    - **Purpose**: Starts the training process in a new thread to avoid blocking the GUI.
#    - **Implementation**: Clears existing stop events and initiates a new training session.
#    
#    **Success**: Keeps the GUI interactive by running training in a separate thread.
#    **Pitfall**: Improper thread handling can cause crashes or deadlocks if not managed carefully.

# 7. **stop_training method**:
#    - **Purpose**: Signals the training thread to halt gracefully using a stop event.
#    
#    **Success**: Allows users to interrupt training safely without crashing the application.
#    **Pitfall**: If not handled properly, stopping the thread could leave the model in an inconsistent state.

# 8. **train_thread method**:
#    - **Purpose**: Manages the main training loop in a separate thread.
#    - **Implementation**:
#        - Checks for stop events to allow user interruption.
#        - Uses `train_batch()` for training on individual batches.
#        - Aggregates metrics and updates the GUI via the queue.
#        - Steps the learning rate scheduler after each epoch.
#        - Sends a completion message to the GUI after finishing training.
#    
#    **Success**: Ensures smooth training without blocking the GUI thread.
#    **Pitfall**: If the stop event is not properly monitored, training could run indefinitely.

# 9. **train_batch method**:
#    - **Purpose**: Handles training on a single batch of data.
#    - **Steps**:
#        - Moves data to the appropriate device (CPU/GPU).
#        - Resets gradients, performs forward and backward passes, and updates the optimizer.
#        - Supports mixed precision training if CUDA is available.
#        - Returns batch metrics (loss, accuracy) for aggregation.
#    
#    **Success**: Efficiently trains on batches while utilizing GPU resources when available.
#    **Pitfall**: Incorrect data handling (e.g., forgetting to reset gradients) could affect model performance.

# 10. **evaluate_model_button and evaluate_model methods**:
#     - **Purpose**: Evaluates the model on the validation dataset.
#     - **Implementation**:
#         - Runs the evaluation in a separate thread to maintain GUI responsiveness.
#         - Displays results in a message box once evaluation is complete.
#    
#    **Success**: Provides a seamless way to assess the model’s performance without interrupting the GUI.
#    **Pitfall**: Running evaluations too frequently could degrade performance or cause GUI lags.

# **General Notes**:
# - **Threading**: The training process runs in a separate thread, keeping the GUI responsive.
# - **Queue Communication**: A thread-safe queue ensures safe communication between the training thread and the GUI.
# - **Dynamic Plotting**: Matplotlib plots offer real-time feedback on training metrics.
# - **Error Handling**: Proper error handling prevents crashes and keeps the application stable.
# - **Interactive Controls**: Buttons allow users to start, stop, and evaluate the model interactively.
# - **Data Tree Visualization**: The hierarchical dataset structure provides useful insights into the training data.

# **Conclusion**:
# The `TrainingGUI` class serves as a powerful tool for monitoring and controlling the training process. It combines 
# real-time metrics visualization, user interaction, and multi-threaded execution to create a smooth and informative 
# experience for model development and analysis.


In [16]:
# -----------------------------
# 10. Training Function with GUI Integration
# -----------------------------

def train_model_with_gui(model, train_loader, val_loader, device, gui):
    """Train the model and update the GUI in real-time."""
    try:
        # Start the GUI training display
        gui.start_training()

    except Exception as e:
        logger.exception(f"Training failed: {e}")
        gui.queue.put({'error': str(e)})  # Inform the GUI about the error


In [17]:
# -----------------------------
# 10. Explanation of the Training Function with GUI Integration
# -----------------------------

# The `train_model_with_gui` function acts as a bridge between the training logic and the GUI.
# It starts the training process while ensuring the GUI remains responsive with real-time updates.

# **Function Definition**:
# def train_model_with_gui(model, train_loader, val_loader, device, gui):

# **Parameters**:
# - `model`: The neural network model to be trained.
# - `train_loader`: DataLoader for the training dataset.
# - `val_loader`: DataLoader for the validation dataset.
# - `device`: The computing device (CPU or GPU) to perform training.
# - `gui`: An instance of the `TrainingGUI` class that manages the GUI.

# **Purpose**:
# - Initiates the training process while keeping the GUI responsive.
# - Encapsulates the logic for starting the training and handling any exceptions.

# **Key Steps**:

# 1. **Start Training**:
#    - Calls `gui.start_training()` to begin the training loop in a new thread.
#    - **Reason**: 
#        - Running training on a separate thread prevents the GUI from freezing.
#        - This allows for **real-time updates** and **user interaction** during training.

#    **Success**: Keeps the GUI responsive, enhancing the user experience.
#    **Pitfall**: Incorrect thread management could cause race conditions or crashes.

# 2. **Exception Handling**:
#    - Wraps the training initiation in a `try-except` block to catch and handle any exceptions.
#    - **Actions**:
#        - Uses `logger.exception()` to log the error with a full stack trace.
#        - Sends the error message to the GUI’s **message queue** using:
#            - `gui.queue.put({'error': str(e)})`
#    - **Reason**:
#        - This approach ensures the GUI displays the error message without directly calling GUI methods from the training thread, which could violate thread safety.

#    **Success**: Ensures smooth error reporting without crashing the application.
#    **Pitfall**: If exceptions are not properly propagated, the user may be unaware of underlying issues.

# **Notes**:
# - The function itself **does not handle the core training logic**; it simply initiates training through the `TrainingGUI` instance.
# - **Separation of Concerns**: By delegating the training logic to the GUI, this function keeps the main workflow clean and organized.
# - **Best Practices**: Adheres to best practices for GUI applications by ensuring that long-running tasks (like training) do not block the main thread.

# **Conclusion**:
# The `train_model_with_gui` function is crucial for integrating the model training process with the GUI. 
# It ensures:
# - The **training process begins smoothly** in a non-blocking way.
# - **Exceptions are properly handled** and reported through the GUI’s message queue.
# - The GUI remains **responsive and interactive** throughout the training process, providing a better user experience.


In [18]:
# -----------------------------
# 11. Evaluation Function
# -----------------------------

def evaluate_model(model, test_loader, device='cpu'):
    """
    Evaluates the model on the test dataset.

    Args:
        model (nn.Module): Trained model.
        test_loader (DataLoader): DataLoader for the test dataset.
        device (str): Device to run evaluation on.

    Returns:
        tuple: (average_loss, accuracy)
    """
    criterion = nn.CrossEntropyLoss()  # Loss function
    model.to(device)  # Move model to the appropriate device
    model.eval()  # Set the model to evaluation mode

    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            # Forward pass
            outputs = model(inputs)
            # Resize outputs to match targets
            outputs = F.interpolate(outputs, size=targets.shape[1:], mode='bilinear', align_corners=False)

            # Calculate loss
            loss = criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)  # Accumulate weighted loss

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += targets.numel()

    # Calculate average loss and accuracy
    avg_loss = total_loss / len(test_loader.dataset)
    accuracy = correct / total

    logger.info(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy


In [19]:
# -----------------------------
# 11. Explanation of the Evaluation Function
# -----------------------------

# The `evaluate_model` function assesses the performance of a trained neural network on a test dataset.
# It calculates the **average loss** and **overall accuracy** across the test set, providing essential metrics 
# for evaluating model generalization on unseen data.

# **Function Definition**:
# def evaluate_model(model, test_loader, device):

# **Parameters**:
# - `model` (nn.Module): The trained PyTorch model to be evaluated.
# - `test_loader` (DataLoader): DataLoader providing an iterable over the test dataset.
# - `device` (str): The device ('cpu' or 'cuda') on which to perform the evaluation.

# **Steps**:

# 1. **Initialize the Loss Function**:
#    - `criterion = nn.CrossEntropyLoss()` initializes the loss function for multi-class classification.
#    
#    **Success**: CrossEntropyLoss is appropriate for multi-class tasks.
#    **Pitfall**: Using the wrong loss function could result in incorrect evaluations.

# 2. **Prepare the Model**:
#    - Moves the model to the specified device: `model.to(device)`.
#    - Sets the model to evaluation mode: `model.eval()`.
#        - This ensures layers like **dropout** and **batch normalization** behave correctly during evaluation.
#    
#    **Success**: Ensures proper inference behavior by disabling training-specific behaviors.
#    **Pitfall**: Forgetting to switch to `eval()` mode could yield misleading results.

# 3. **Initialize Metrics**:
#    - `total_loss`: Accumulates the total loss across all batches.
#    - `correct`: Counts the number of correctly predicted samples.
#    - `total`: Tracks the total number of elements evaluated.

#    **Success**: Provides reliable tracking of performance across batches.
#    **Pitfall**: Incorrect metric initialization could skew results.

# 4. **Disable Gradient Computation**:
#    - Wrapping the evaluation in `with torch.no_grad():` saves memory and improves performance by disabling gradients.
#    
#    **Success**: Reduces unnecessary computation, speeding up the evaluation.
#    **Pitfall**: Forgetting to disable gradients can cause memory leaks.

# 5. **Iterate Over Test Data**:
#    - Loops through each batch provided by `test_loader`.

# 6. **Move Data to Device**:
#    - Moves input and target tensors to the specified device:
#        - `inputs.to(device, non_blocking=True)` enables asynchronous GPU transfers when possible.
#    
#    **Success**: Ensures data is processed on the correct device for efficient computation.
#    **Pitfall**: Mismatched devices can result in runtime errors.

# 7. **Forward Pass**:
#    - `outputs = model(inputs)` computes the raw output (logits) of the model.

# 8. **Resize Outputs (if necessary)**:
#    - `F.interpolate(outputs, size=targets.shape[1:], mode='bilinear', align_corners=False)` adjusts the output size to match the target tensor.
#        - This is particularly useful for tasks like **semantic segmentation**, where output dimensions may differ.
#    
#    **Success**: Ensures output matches target dimensions for correct loss calculation.
#    **Pitfall**: Incorrect resizing can lead to dimension mismatches or performance degradation.

# 9. **Compute Loss**:
#    - `loss = criterion(outputs, targets)` calculates the loss between model outputs and targets.
#    - Accumulates the weighted loss: `total_loss += loss.item() * inputs.size(0)`.
#    
#    **Success**: Accurately tracks loss across batches, accounting for batch size variability.
#    **Pitfall**: Failing to weight the loss by batch size could result in incorrect average loss.

# 10. **Compute Predictions and Accuracy**:
#     - `_, predicted = torch.max(outputs, 1)` extracts the predicted class with the highest score.
#     - `correct += (predicted == targets).sum().item()` increments the correct prediction count.
#     - `total += targets.numel()` increments the total elements evaluated.
#    
#    **Success**: Tracks accuracy effectively across all batches.
#    **Pitfall**: Misaligned predictions and targets can result in incorrect accuracy metrics.

# 11. **Calculate Average Metrics**:
#     - `avg_loss = total_loss / len(test_loader.dataset)` computes the average loss per sample.
#     - `accuracy = correct / total` computes the overall accuracy.
#    
#    **Success**: Provides meaningful insights into model performance.
#    **Pitfall**: Incorrect total counts could skew the average loss and accuracy.

# 12. **Logging and Return**:
#     - `logger.info(...)` logs the evaluation metrics for monitoring and debugging.
#     - `return avg_loss, accuracy` returns the computed metrics.

# **Notes**:
# - **Evaluation Mode**: Setting the model to `eval()` ensures correct behavior of layers like dropout and batch normalization.
# - **No Gradient Computation**: Wrapping the code in `torch.no_grad()` saves memory and speeds up evaluation.
# - **Resizing Outputs**: The use of `F.interpolate` handles cases where model outputs and target sizes differ.
# - **Loss Weighting**: Weighting the loss by batch size ensures correct metric calculation across variable-sized batches.
# - **Accuracy Calculation**: Accurate predictions and correct counts are critical for reliable accuracy metrics.

# **Conclusion**:
# The `evaluate_model` function provides a systematic approach to assessing model performance on unseen data. 
# It ensures:
# - Proper **loss and accuracy calculation** across batches.
# - Correct **model behavior** during evaluation by switching to `eval()` mode.
# - Efficient use of memory and computation by disabling gradient tracking.

# This function offers valuable insights into the model’s generalization capabilities, guiding further improvements 
# and adjustments to the model as needed.


In [20]:
# -----------------------------
# 12. Traverse and Debug Function
# -----------------------------

def traverse_and_debug(node):
    """Traverse the tree and log node details with reduced verbosity."""
    # Limit logging to fewer nodes
    if hasattr(node, "name"):
        grid_shape = getattr(node, 'embedding', None)
        grid_shape = grid_shape.shape if grid_shape is not None else 'Missing'
        logger.debug(f"Node: {node.name}, Grid Shape: {grid_shape}")
    if len(node.children) > 10:  # Avoid excessive logging if many children exist
        logger.warning(f"Node '{node.name}' has too many children, skipping further logs...")
        return
    for child in node.children:
        traverse_and_debug(child)


In [21]:
# -----------------------------
# 12. Explanation of the Traverse and Debug Function
# -----------------------------

# The `traverse_and_debug` function is designed to traverse a **tree data structure** starting from a given node.
# It logs details about each node but includes mechanisms to reduce verbosity, preventing excessive logging,
# especially when dealing with nodes containing many children.

# **Function Definition**:
# def traverse_and_debug(node):
#     """Traverse the tree and log node details with reduced verbosity."""

# **Key Steps**:

# 1. **Check for 'name' Attribute**:
#    - Uses `hasattr(node, "name")` to check if the node has a 'name' attribute.
#    - **Reason**: Not all nodes may have the 'name' attribute, and accessing it without checking could raise an `AttributeError`.
#
#    **Success**: Prevents errors by safely accessing attributes.
#    **Pitfall**: If essential attributes are missing, it may reduce the usefulness of the logs.

# 2. **Retrieve the Node's Embedding**:
#    - Uses `getattr(node, 'embedding', None)` to safely retrieve the embedding attribute.
#    - **Fallback**: Defaults to `None` if the 'embedding' attribute is not present.
#
#    **Success**: Handles missing attributes gracefully.
#    **Pitfall**: Important debugging information may be missed if many embeddings are missing.

# 3. **Determine Grid Shape**:
#    - If the 'embedding' is present, the function accesses its `shape` attribute.
#    - If the 'embedding' is `None`, sets `grid_shape` to `'Missing'`.
#
#    **Success**: Ensures logs reflect whether the grid shape is available or missing.
#    **Pitfall**: Complex embeddings with unexpected shapes could still cause issues if not validated.

# 4. **Log Node Information**:
#    - Logs the node’s **name** and **grid shape** using `logger.debug`.
#
#    **Success**: Provides valuable information for debugging tree structures.
#    **Pitfall**: Excessive logging could overwhelm the log files if not carefully managed.

# 5. **Limit Logging for Nodes with Many Children**:
#    - If the node has more than **10 children**, logs a warning and skips further traversal.
#    - Uses:
#        - `logger.warning(f"Node '{node.name}' has too many children, skipping further logs...")`
#        - Returns early to avoid logging all children.
#
#    **Success**: Prevents overwhelming logs with excessive child node data.
#    **Pitfall**: Important child nodes might be missed in the logs if this threshold is too low.

# 6. **Recursively Traverse Child Nodes**:
#    - If the node has **10 or fewer children**, the function proceeds to traverse each child recursively:
#        - `for child in node.children:`
#            - `traverse_and_debug(child)`
#
#    **Success**: Navigates through the entire tree efficiently when the number of children is manageable.
#    **Pitfall**: Deep recursion may exceed Python’s recursion limit, causing a `RecursionError`.

# **Purpose of the Function**:
# - **Traversal**: Walks through a tree structure, starting from the given node.
# - **Logging**: Collects and logs information about each node for debugging purposes.
# - **Reduced Verbosity**: Implements checks to prevent excessive logging when nodes have too many children, keeping logs concise and readable.

# **Use Cases**:
# - **Debugging Tree Structures**: Helpful for inspecting complex datasets organized in hierarchical structures.
# - **Performance Monitoring**: Can identify nodes that may cause slowdowns due to a large number of children.
# - **Data Validation**: Ensures nodes contain the expected attributes, helping identify issues in the data structure.

# **Considerations**:
# 1. **Recursion Limit**:
#    - Since the function uses recursion, deep trees may exceed Python’s recursion limit.
#    - **Solution**: Consider converting the logic to an iterative approach or carefully increasing the recursion limit if necessary.

# 2. **Logging Levels**:
#    - Uses `logger.debug` for regular logs and `logger.warning` when skipping nodes with too many children.
#    - **Ensure** that the logging configuration captures the desired level of detail without overwhelming the logs.

# 3. **Attribute Checks**:
#    - Safely checks for attributes to prevent runtime errors.
#    - **However**: Missing attributes may reduce the utility of the logs if critical information is omitted.

# **Example Scenario**:
# - You have a tree representing tasks and subtasks in a hierarchical dataset.
# - The function logs the **name** and **grid shape** of each node, but if a node has more than 10 children, it logs a warning and skips further traversal to avoid excessive output.
# - This ensures that only manageable information is logged, improving readability and debugging efficiency.

# **Summary**:
# - The `traverse_and_debug` function is a useful tool for navigating and debugging tree-like data structures.
# - It **balances detailed logging** with **verbosity control** by limiting output for nodes with many children.
# - This design makes it effective for **performance monitoring, data validation**, and **troubleshooting complex hierarchies**.


In [22]:
# -----------------------------
# 13. Main Workflow with Modifications
# -----------------------------

def main():
    # Detect device
    device = get_device()

    # Initialize the progress bar
    progress_bar = tqdm(total=100, desc="Loading Data", unit="%", leave=True)

    try:
        # Load ARC data
        arc_data = load_arc_data()
        progress_bar.update(20)

        # Extract and reshape grid pairs
        train_grid_pairs = flatten_and_reshape(
            arc_data.get("arc-agi_training-challenges", {})
        )
        eval_grid_pairs = flatten_and_reshape(
            arc_data.get("arc-agi_evaluation-challenges", {})
        )
        progress_bar.update(30)

        # Build the data tree and retrieve the task dictionary
        root_node, task_dict = build_data_tree(train_grid_pairs)
        traverse_and_debug(root_node)
        progress_bar.update(20)

        # Log the task dictionary
        logger.info(f"Task dictionary initialized with {len(task_dict)} tasks:")
        for task_id, task_data in task_dict.items():
            logger.info(
                f"Task ID: {task_id}, Node: {task_data['task_node'].name}, "
                f"Grid Shape: {task_data['grids'][0].shape}"
            )

        # Initialize DataLoaders and model
        train_dataset = AugmentedARCDataset(train_grid_pairs, augment=False)
        val_dataset = AugmentedARCDataset(eval_grid_pairs, augment=False)

        train_loader = DataLoader(
            train_dataset,
            batch_size=4,  # Reduce batch size if you encounter memory issues
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            collate_fn=collate_fn  # Use the custom collate function
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=4,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            collate_fn=collate_fn
        )

        # Initialize the model
        model = CNNGridMapper(num_classes=NUM_CLASSES).to(device)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        logger.info("Model initialized successfully.")

    except Exception as e:
        logger.exception(f"Data loading or model initialization failed: {e}")
        progress_bar.close()
        return

    progress_bar.close()

    # Initialize and start the GUI
    root_window = tk.Tk()
    gui = TrainingGUI(
        root_window, total_epochs=10, total_batches=len(train_loader),
        model=model, train_loader=train_loader, val_loader=val_loader,
        eval_loader=None, device=device, data_tree=root_node, task_dict=task_dict
    )

    # Start the training thread
    training_thread = threading.Thread(
        target=train_model_with_gui, args=(model, train_loader, val_loader, device, gui)
    )
    training_thread.daemon = True
    training_thread.start()

    # Start the GUI main loop
    root_window.mainloop()


if __name__ == "__main__":
    main()


NameError: name 'logger' is not defined

In [None]:
# -----------------------------
# 13. Explanation of the Main Workflow with Modifications
# -----------------------------

# The `main` function orchestrates the entire workflow of the application, from **data loading** and **preprocessing**
# to **model initialization** and launching the **graphical user interface (GUI)** for real-time visualization of the training process.

# **Function Definition**:
# def main():

# **Key Steps**:

# 1. **Device Detection**:
#    - Calls `get_device()` to determine whether a **GPU** (if available) or **CPU** should be used for computations.
#    - Stores the selected device in the `device` variable:
#        - `device = get_device()`

#    **Success**: Automatically optimizes the computation by selecting the fastest available device.
#    **Pitfall**: If the selected device has insufficient memory (e.g., GPU), the process may crash or slow down.

# 2. **Progress Bar Initialization**:
#    - Initializes a progress bar with `tqdm` to provide visual feedback on data loading progress:
#        - `progress_bar = tqdm(total=100, desc="Loading Data", unit="%", leave=True)`

#    **Success**: Keeps users informed about the progress, especially for time-consuming operations.
#    **Pitfall**: If not updated correctly, users may assume the program is frozen.

# 3. **Data Loading**:
#    - Loads the ARC dataset using `load_arc_data()` and updates the progress bar:
#        - `arc_data = load_arc_data()`
#        - `progress_bar.update(20)`

#    **Success**: Centralizes data access logic, improving maintainability.
#    **Pitfall**: Missing or corrupted files could cause crashes during loading.

# 4. **Data Extraction and Reshaping**:
#    - Uses `flatten_and_reshape()` to extract and reshape grid pairs from the ARC dataset.
#    - Retrieves data for **training** and **evaluation** challenges:
#        - `train_grid_pairs = flatten_and_reshape(...)`
#        - `eval_grid_pairs = flatten_and_reshape(...)`
#    - Updates the progress bar:
#        - `progress_bar.update(30)`

#    **Success**: Ensures the data is properly structured for model training.
#    **Pitfall**: Inconsistent data shapes could cause runtime errors.

# 5. **Building Data Tree and Task Dictionary**:
#    - Constructs a hierarchical **data tree** using `build_data_tree()`.
#    - Calls `traverse_and_debug()` to traverse the tree and log node details:
#        - `root_node, task_dict = build_data_tree(train_grid_pairs)`
#        - `traverse_and_debug(root_node)`
#    - Updates the progress bar:
#        - `progress_bar.update(20)`

#    **Success**: Organizes the dataset for better visualization and management.
#    **Pitfall**: Deep trees could exceed Python’s recursion limit if not handled carefully.

# 6. **Logging Task Information**:
#    - Logs the number of tasks initialized in the `task_dict`:
#        - `logger.info(f"Task dictionary initialized with {len(task_dict)} tasks.")`
#    - Iterates over the tasks and logs task-specific details (e.g., ID, grid shape):
#        - `for task_id, task_data in task_dict.items(): logger.info(...)`

#    **Success**: Provides transparency and ensures data integrity.
#    **Pitfall**: Missing or incorrect task information could make debugging difficult.

# 7. **Dataset and DataLoader Initialization**:
#    - Initializes **training** and **validation** datasets using `AugmentedARCDataset`:
#        - `train_dataset = AugmentedARCDataset(train_grid_pairs, augment=False)`
#        - `val_dataset = AugmentedARCDataset(eval_grid_pairs, augment=False)`
#    - Creates DataLoaders with custom `collate_fn` to handle variable-sized inputs:
#        - `train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=collate_fn)`
#        - `val_loader = DataLoader(val_dataset, batch_size=4, collate_fn=collate_fn)`

#    **Success**: Efficiently batches and shuffles data for faster training.
#    **Pitfall**: Incorrect DataLoader configurations (e.g., too many workers) can cause performance issues.

# 8. **Model Initialization**:
#    - Creates an instance of `CNNGridMapper` and moves the model to the appropriate device:
#        - `model = CNNGridMapper(num_classes=NUM_CLASSES).to(device)`
#    - If multiple GPUs are available, wraps the model with `nn.DataParallel` for parallel training:
#        - `if torch.cuda.device_count() > 1: model = nn.DataParallel(model)`
#    - Logs a success message:
#        - `logger.info("Model initialized successfully.")`

#    **Success**: Prepares the model for efficient training on available hardware.
#    **Pitfall**: Incorrect device handling can result in runtime errors.

# 9. **Exception Handling**:
#    - Wraps the main logic in a `try-except` block to catch and log exceptions:
#        - `logger.exception(f"Data loading or model initialization failed: {e}")`
#    - Closes the progress bar upon error:
#        - `progress_bar.close()`
#    - Returns to prevent further execution.

#    **Success**: Prevents crashes by handling exceptions gracefully.
#    **Pitfall**: Not providing helpful error messages could make debugging difficult.

# 10. **Close Progress Bar**:
#     - Closes the progress bar after data loading completes:
#        - `progress_bar.close()`

#     **Success**: Keeps the console output clean and avoids confusion about the progress status.
#     **Pitfall**: Forgetting to close the progress bar can clutter the output.

# 11. **GUI Initialization**:
#     - Creates a Tkinter GUI window and initializes `TrainingGUI` with all necessary parameters:
#        - `gui = TrainingGUI(root_window, total_epochs=10, model=model, ...)`

#     **Success**: Provides a user-friendly interface to monitor and control the training process.
#     **Pitfall**: Complex GUIs can become unresponsive without proper threading.

# 12. **Start Training Thread**:
#     - Creates a new thread to run the training process without blocking the GUI:
#        - `training_thread = threading.Thread(target=train_model_with_gui, args=(...))`
#     - Sets `daemon=True` to ensure the thread terminates with the main program:
#        - `training_thread.daemon = True`
#     - Starts the thread:
#        - `training_thread.start()`

#     **Success**: Ensures the GUI remains responsive during long-running training processes.
#     **Pitfall**: Incorrect thread handling could cause deadlocks or crashes.

# 13. **Start GUI Main Loop**:
#     - Calls `root_window.mainloop()` to run the Tkinter event loop:
#        - `root_window.mainloop()`

#     **Success**: Keeps the GUI interactive, allowing users to control the training.
#     **Pitfall**: Forgetting to call the main loop will result in a non-responsive GUI.

# **Entry Point Check**:
# - Ensures that `main()` is called only when the script is run directly:
#     - `if __name__ == "__main__": main()`

#     **Success**: Prevents unintended execution when the script is imported as a module.
#     **Pitfall**: Forgetting this check can lead to unexpected behavior when importing the script.

# **Conclusion**:
# The `main` function integrates multiple components—data loading, model initialization, and GUI setup—into a cohesive workflow.
# It ensures smooth interaction between the backend logic and the frontend interface, providing a robust framework for model development.
