In [None]:
# ===============================================================
# Optimized Implementation
# ===============================================================

# -----------------------------
# 1. Import Libraries
# -----------------------------
import math
import json
import logging
import numpy as np
import torch
import random
import threading
import queue
import sys
import matplotlib
matplotlib.use('TkAgg')  # Ensure the correct backend is used
import matplotlib.pyplot as plt
import os
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from sklearn.model_selection import train_test_split
from torchvision.models import resnet18, ResNet18_Weights
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from torch.optim import lr_scheduler

# Set default font to reduce font scanning time
matplotlib.rcParams['font.family'] = 'DejaVu Sans'

# -----------------------------
# 2. Configure Logging and Seed
# -----------------------------

# Configure logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

# -----------------------------
# 3. Define Constants
# -----------------------------

GRID_SIZE = 30        # Fixed grid size (adjust as needed)
NUM_CLASSES = 11      # 0-10, where 10 represents dead squares

# -----------------------------
# 4. Define Data Structures and Loading Functions
# -----------------------------

# Data Class for Grid Pairs
@dataclass
class GridPair:
    task_id: str
    input_grid: list
    output_grid: list

# Load ARC Data
def load_arc_data():
    file_paths = {
        "arc-agi_training-challenges": "arc-agi_training_challenges.json",
        "arc-agi_evaluation-challenges": "arc-agi_evaluation_challenges.json",
        "arc-agi_training-solutions": "arc-agi_training_solutions.json",
        "arc-agi_evaluation-solutions": "arc-agi_evaluation_solutions.json",
    }
    arc_data = {}
    for key, path in file_paths.items():
        try:
            with open(path, 'r') as f:
                arc_data[key] = json.load(f)
                logger.info(f"Loaded {key} from {path}.")
        except (FileNotFoundError, json.JSONDecodeError) as e:
            logger.error(f"Error loading {path}: {e}")
            arc_data[key] = {}
    return arc_data

# Reshape to Fixed Square Grid
def reshape_to_square_grid(flat_list, grid_size=GRID_SIZE):
    required_length = grid_size * grid_size
    current_length = len(flat_list)

    if current_length > required_length:
        # Truncate if the grid is larger than grid_size x grid_size
        flat_list = flat_list[:required_length]
    else:
        # Pad with -1 to reach the required length
        flat_list = np.pad(flat_list, (0, required_length - current_length), 'constant', constant_values=-1)

    return flat_list.reshape(grid_size, grid_size).tolist()

# Extract and Reshape Grid
def extract_and_reshape_grid(grid, grid_size=GRID_SIZE):
    try:
        # Flatten the grid if it's a list of lists
        if isinstance(grid[0], list):
            flat_list = [item for sublist in grid for item in sublist]
        else:
            flat_list = grid
        return reshape_to_square_grid(flat_list, grid_size)
    except Exception as e:
        logger.error(f"Error processing grid: {e}")
        return None

# Flatten and Reshape Grid Data
def flatten_and_reshape(task_data, grid_size=GRID_SIZE):
    flattened_pairs = []
    for task_id, task_content in task_data.items():
        logger.info(f"Parsing task {task_id}...")
        train_pairs = task_content.get('train', [])
        for pair in train_pairs:
            input_grid = extract_and_reshape_grid(pair.get("input"), grid_size)
            output_grid = extract_and_reshape_grid(pair.get("output"), grid_size)
            if input_grid and output_grid:
                flattened_pairs.append(GridPair(task_id, input_grid, output_grid))
            else:
                logger.warning(f"Task ID: {task_id} has invalid input/output grids.")
    logger.info(f"Total valid grid pairs extracted: {len(flattened_pairs)}")
    return flattened_pairs

# -----------------------------
# 5. Data Augmentation Functions
# -----------------------------

def augment_grid(grid, noise_prob=0.2, dead_square_prob=0.1):
    """Applies augmentation to the grid by adding noise and dead squares."""
    augmented_grid = np.array(grid)

    # Generate random masks for noise and dead squares
    noise_mask = np.random.rand(*augmented_grid.shape) < noise_prob
    dead_square_mask = np.random.rand(*augmented_grid.shape) < dead_square_prob

    # Generate random noise values where noise_mask is True
    noise_values = np.random.randint(0, NUM_CLASSES - 2, size=augmented_grid.shape)
    augmented_grid = np.where(noise_mask, noise_values, augmented_grid)
    augmented_grid = np.where(dead_square_mask, -1, augmented_grid)

    return augmented_grid.tolist()

def rotate_grid(grid):
    """Randomly rotates the grid."""
    rotations = random.choice([0, 1, 2, 3])
    return np.rot90(grid, rotations).tolist()

def flip_grid(grid):
    """Randomly flips the grid."""
    flip_choice = random.choice(['none', 'vertical', 'horizontal'])
    if flip_choice == 'vertical':
        return np.flipud(grid).tolist()  # Vertical flip
    elif flip_choice == 'horizontal':
        return np.fliplr(grid).tolist()  # Horizontal flip
    else:
        return grid  # No flip

# Generate Multiple Augmented Datasets
def generate_multiple_augmented_datasets(grid_pairs, num_augmented_sets=3):
    augmented_pairs = []
    for _ in range(num_augmented_sets):
        for pair in grid_pairs:
            augmented_input = augment_grid(pair.input_grid)
            augmented_input = rotate_grid(augmented_input)
            augmented_input = flip_grid(augmented_input)
            augmented_pairs.append(GridPair(pair.task_id, augmented_input, pair.output_grid))
    return augmented_pairs

# -----------------------------
# 6. Custom Collate Function (Removed if Not Necessary)
# -----------------------------

# Since all tensors are of the same size, the default collate function suffices.
# We can remove the custom collate_fn unless specific processing is required.

# -----------------------------
# 7. PyTorch Dataset Class
# -----------------------------

class AugmentedARCDataset(Dataset):
    def __init__(self, grid_pairs, augment=False):
        self.grid_pairs = grid_pairs
        self.augment = augment

    def __len__(self):
        return len(self.grid_pairs)

    def __getitem__(self, idx):
        pair = self.grid_pairs[idx]
        input_grid = pair.input_grid
        output_grid = pair.output_grid

        if self.augment:
            input_grid = augment_grid(input_grid)
            input_grid = rotate_grid(input_grid)
            input_grid = flip_grid(input_grid)

        # Convert to tensors
        input_tensor = torch.tensor(input_grid, dtype=torch.float32).unsqueeze(0)  # Shape: (1, GRID_SIZE, GRID_SIZE)

        # Map -1 to NUM_CLASSES -1 (10)
        output_grid = np.array(output_grid)
        output_grid_mapped = np.where(output_grid == -1, NUM_CLASSES - 1, output_grid)
        output_tensor = torch.tensor(output_grid_mapped, dtype=torch.long)  # Shape: (GRID_SIZE, GRID_SIZE)

        return input_tensor, output_tensor

# -----------------------------
# 8. Define the Deep Neural Network Model
# -----------------------------

class CNNGridMapper(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, grid_size=GRID_SIZE):
        super(CNNGridMapper, self).__init__()
        self.grid_size = grid_size
        self.num_classes = num_classes

        # CNN Backbone: ResNet18 pretrained on ImageNet
        self.cnn = resnet18(weights=ResNet18_Weights.DEFAULT)
        # Modify the first convolutional layer to accept single-channel input
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Initialize the new conv1 weights
        nn.init.kaiming_normal_(self.cnn.conv1.weight, mode='fan_out', nonlinearity='relu')
        # Remove the fully connected layer and average pool
        self.cnn_layers = nn.Sequential(*list(self.cnn.children())[:-2])  # Output: (batch_size, 512, H, W)

        # Replace in-place ReLU activations with out-of-place versions
        def replace_relu(module):
            for child_name, child in module.named_children():
                if isinstance(child, nn.ReLU):
                    setattr(module, child_name, nn.ReLU(inplace=False))
                else:
                    replace_relu(child)
        replace_relu(self.cnn_layers)

        # Interpolation to match GRID_SIZE
        self.interpolate = nn.Upsample(size=(grid_size, grid_size), mode='bilinear', align_corners=False)

        # RNN Module: LSTM
        # Treat each row as a sequence of cells
        self.rnn = nn.LSTM(input_size=512,  # Number of features per cell from CNN
                           hidden_size=128,
                           num_layers=2,
                           batch_first=True,
                           bidirectional=True)

        # Fully Connected Layer
        self.fc = nn.Linear(128 * 2, num_classes)  # *2 for bidirectional

    def forward(self, x):
        batch_size = x.size(0)

        # Pass through CNN
        features = self.cnn_layers(x)  # Shape: (batch_size, 512, H, W)

        # Interpolate to (batch_size, 512, GRID_SIZE, GRID_SIZE)
        features = self.interpolate(features)  # Shape: (batch_size, 512, GRID_SIZE, GRID_SIZE)

        # Reshape for RNN
        features = features.permute(0, 2, 3, 1)  # Shape: (batch_size, GRID_SIZE, GRID_SIZE, 512)
        features = features.contiguous().view(batch_size * self.grid_size, self.grid_size, 512)

        # Pass through RNN
        rnn_out, _ = self.rnn(features)  # Shape: (batch_size * GRID_SIZE, GRID_SIZE, hidden_size * 2)

        # Pass through Fully Connected layer
        logits = self.fc(rnn_out)  # Shape: (batch_size * GRID_SIZE, GRID_SIZE, num_classes)

        # Reshape logits back to (batch_size, GRID_SIZE * GRID_SIZE, num_classes)
        logits = logits.view(batch_size, self.grid_size * self.grid_size, self.num_classes)

        return logits  # (batch_size, GRID_SIZE * GRID_SIZE, num_classes)

# -----------------------------
# 9. Define a Simple Transformer-based LLM
# -----------------------------

class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=5000, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.decoder = nn.Linear(d_model, vocab_size)

    def forward(self, src):
        src = self.embedding(src) * math.sqrt(src.size(-1))
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        output = self.decoder(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))  # (d_model/2)
        pe[:, 0::2] = torch.sin(position * div_term)  # Apply sin to even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Apply cos to odd indices
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return self.dropout(x)

# -----------------------------
# 10. GUI Class
# -----------------------------

class TrainingGUI:
    """
    A Tkinter-based GUI that displays real-time training progress, including epoch, batch, loss, and accuracy.
    Provides buttons to load, save, evaluate, retrain, select, stop, and ensemble models.
    """

    def __init__(self, root, total_epochs, total_batches, model, train_loader, val_loader, eval_loader, device):
        self.root = root
        self.root.title("Model Training and Evaluation")
        self.queue = queue.Queue()

        # Model and DataLoaders
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.eval_loader = eval_loader
        self.device = device

        # Training Parameters
        self.total_epochs = total_epochs
        self.total_batches = total_batches
        self.training_thread = None
        self.stop_event = threading.Event()

        # Initialize GUI
        self.setup_gui()

        # Start processing the queue
        self.root.after(100, self.process_queue)

    def setup_gui(self):
        """Set up the GUI components."""
        self.frame = tk.Frame(self.root)
        self.frame.pack(fill=tk.BOTH, expand=1)

        # Labels
        self.epoch_label = tk.Label(self.frame, text=f"Epoch: 0/{self.total_epochs}", font=("Helvetica", 14))
        self.epoch_label.pack(pady=5)

        self.batch_label = tk.Label(self.frame, text=f"Batch: 0/{self.total_batches}", font=("Helvetica", 12))
        self.batch_label.pack(pady=2)

        self.loss_label = tk.Label(self.frame, text="Loss: 0.0000", font=("Helvetica", 12))
        self.loss_label.pack(pady=2)

        self.accuracy_label = tk.Label(self.frame, text="Accuracy: 0.0000", font=("Helvetica", 12))
        self.accuracy_label.pack(pady=2)

        self.val_loss_label = tk.Label(self.frame, text="Validation Loss: 0.0000", font=("Helvetica", 12))
        self.val_loss_label.pack(pady=2)

        self.val_accuracy_label = tk.Label(self.frame, text="Validation Accuracy: 0.0000", font=("Helvetica", 12))
        self.val_accuracy_label.pack(pady=2)

        # Progress Bar
        self.progress_bar = ttk.Progressbar(self.frame, orient="horizontal", length=400, mode="determinate")
        self.progress_bar.pack(pady=10)

        # Button Frame
        self.button_frame = tk.Frame(self.frame)
        self.button_frame.pack(pady=10)

        # Buttons
        self.load_button = tk.Button(self.button_frame, text="Load Model", command=self.load_model)
        self.load_button.grid(row=0, column=0, padx=5)

        self.save_button = tk.Button(self.button_frame, text="Save Model", command=self.save_model)
        self.save_button.grid(row=0, column=1, padx=5)

        self.evaluate_button = tk.Button(self.button_frame, text="Evaluate Model", command=self.evaluate_model_button)
        self.evaluate_button.grid(row=0, column=2, padx=5)

        self.start_button = tk.Button(self.button_frame, text="Start Training", command=self.start_training)
        self.start_button.grid(row=0, column=3, padx=5)

        self.stop_button = tk.Button(self.button_frame, text="Stop Training", command=self.stop_training, state=tk.DISABLED)
        self.stop_button.grid(row=0, column=4, padx=5)

        self.ensemble_button = tk.Button(self.button_frame, text="Ensemble Models", command=self.ensemble_models)
        self.ensemble_button.grid(row=0, column=5, padx=5)

        # Real-time Plot Setup
        self.fig, self.ax = plt.subplots(figsize=(6, 4))
        self.line_loss, = self.ax.plot([], [], label='Training Loss', color='blue')
        self.line_val_loss, = self.ax.plot([], [], label='Validation Loss', color='orange')
        self.ax.set_xlabel('Epochs')
        self.ax.set_ylabel('Loss')
        self.ax.legend()
        self.ax.grid(True)
        self.canvas_plot = FigureCanvasTkAgg(self.fig, master=self.frame)
        self.canvas_plot.draw()
        self.canvas_plot.get_tk_widget().pack()

        # Data for plots
        self.loss_data = []
        self.val_loss_data = []

    def load_model(self):
        """Load a model from a file."""
        model_path = filedialog.askopenfilename(title="Select Model File", filetypes=[("PyTorch Models", "*.pth")])
        if model_path:
            try:
                state_dict = torch.load(model_path, map_location=self.device)
                self.model.load_state_dict(state_dict)
                self.model.to(self.device)
                logger.info(f"Model loaded from {model_path}")
                messagebox.showinfo("Load Model", f"Model loaded from {model_path}")
            except Exception as e:
                logger.error(f"Error loading model: {e}")
                messagebox.showerror("Error", f"Could not load model: {e}")

    def save_model(self):
        """Save the current model to a file."""
        model_path = filedialog.asksaveasfilename(title="Save Model As", defaultextension=".pth",
                                                  filetypes=[("PyTorch Models", "*.pth")])
        if model_path:
            try:
                torch.save(self.model.state_dict(), model_path)
                logger.info(f"Model saved to {model_path}")
                messagebox.showinfo("Save Model", f"Model saved to {model_path}")
            except Exception as e:
                logger.error(f"Error saving model: {e}")
                messagebox.showerror("Error", f"Could not save model: {e}")

    def start_training(self):
        """Start the training process in a new thread."""
        self.stop_event.clear()
        self.start_button.config(state=tk.DISABLED)
        self.stop_button.config(state=tk.NORMAL)
        self.training_thread = threading.Thread(target=self.train_model, daemon=True)
        self.training_thread.start()

    def stop_training(self):
        """Stop the training process."""
        self.stop_event.set()
        self.start_button.config(state=tk.NORMAL)
        self.stop_button.config(state=tk.DISABLED)

    def train_model(self):
        """Training logic executed in a separate thread."""
        try:
            for epoch in range(1, self.total_epochs + 1):
                if self.stop_event.is_set():
                    break

                running_loss = 0.0
                for batch_idx, (inputs, targets) in enumerate(self.train_loader, 1):
                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    self.model.train()

                    outputs = self.model(inputs)
                    loss = torch.nn.functional.cross_entropy(outputs, targets)
                    loss.backward()

                    optimizer.step()
                    optimizer.zero_grad()

                    running_loss += loss.item()
                    progress = (epoch - 1 + batch_idx / len(self.train_loader)) / self.total_epochs * 100

                    # Send progress updates to the queue
                    self.queue.put({
                        'epoch': epoch,
                        'batch': batch_idx,
                        'loss': running_loss / batch_idx,
                        'progress': progress
                    })

            self.queue.put('training_complete')

        except Exception as e:
            self.queue.put({'error': str(e)})

    def process_queue(self):
        """Process the queue for thread-safe GUI updates."""
        while not self.queue.empty():
            message = self.queue.get()
            if isinstance(message, dict):
                if 'error' in message:
                    messagebox.showerror("Error", message['error'])
                else:
                    self.update_gui(message)
            elif message == 'training_complete':
                messagebox.showinfo("Training", "Training completed successfully.")
                self.start_button.config(state=tk.NORMAL)
                self.stop_button.config(state=tk.DISABLED)

        self.root.after(100, self.process_queue)
        
    def retrain_model(self):
        """Start the model training in a separate thread."""
        try:
            # Start the training thread as a daemon to avoid blocking the GUI
            threading.Thread(target=self.train_thread, daemon=True).start()
        except Exception as e:
            logger.exception("An error occurred while starting the training thread.")
            messagebox.showerror("Training Error", f"An error occurred: {e}")

    def process_queue(self):
        """Process the queue to handle updates and messages from the training thread."""
        while not self.queue.empty():
            message = self.queue.get()
            if message == 'training_complete':
                messagebox.showinfo("Training", "Training completed successfully.")
        self.root.after(100, self.process_queue)

    def evaluate_model_button(self):
        """Evaluate the model in a separate thread."""
        threading.Thread(target=self.evaluate_model, daemon=True).start()

    def evaluate_model(self):
        """Evaluate the model and show results."""
        avg_loss, accuracy = evaluate_model(self.model, self.val_loader, device=self.device)
        messagebox.showinfo("Evaluation", f"Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

    def ensemble_models(self):
        """Ensemble multiple models."""
        # Implementation for ensembling models (omitted for brevity)

    def process_queue(self):
        """Process the queue for thread-safe GUI updates."""
        while not self.queue.empty():
            message = self.queue.get()
            if isinstance(message, dict):
                if 'error' in message:
                    messagebox.showerror("Error", message['error'])
                else:
                    self.update_gui(message)
            elif message == 'training_complete':
                messagebox.showinfo("Training", "Training completed successfully.")
                self.start_button.config(state=tk.NORMAL)
                self.stop_button.config(state=tk.DISABLED)

        self.root.after(100, self.process_queue)

    def update_gui(self, data):
        """
        Updates the GUI elements with new training information.
        """
        try:
            if 'batch' in data:
                # Update batch-level information (for progress during each epoch)
                self.batch_label.config(text=f"Batch: {data['batch']}/{self.total_batches}")
                self.loss_label.config(text=f"Loss: {data['loss']:.4f}")
                self.accuracy_label.config(text=f"Accuracy: {data.get('accuracy', 0.0):.4f}")
                
                # Ensure the GUI stays responsive during intensive updates
                self.root.update_idletasks()
    
            elif 'epoch' in data:
                # Update epoch-level information (after each epoch completion)
                self.epoch_label.config(text=f"Epoch: {data['epoch']}/{self.total_epochs}")
                self.val_loss_label.config(text=f"Validation Loss: {data.get('val_loss', 0.0):.4f}")
                self.val_accuracy_label.config(text=f"Validation Accuracy: {data.get('val_accuracy', 0.0):.4f}")
    
                # Update the progress bar based on completed epochs
                progress = (data['epoch'] / self.total_epochs) * 100
                self.progress_bar["value"] = progress
                
                # Ensure smooth GUI interaction
                self.root.update_idletasks()
    
                # Update the real-time loss and accuracy plots
                self.loss_data.append(data['epoch_loss'])
                self.val_loss_data.append(data['val_loss'])
                self.acc_data.append(data['epoch_acc'])
                self.val_acc_data.append(data['val_accuracy'])
    
                # Update the plots with the new data points
                self.line_loss.set_data(range(1, len(self.loss_data) + 1), self.loss_data)
                self.line_val_loss.set_data(range(1, len(self.val_loss_data) + 1), self.val_loss_data)
                self.line_acc.set_data(range(1, len(self.acc_data) + 1), self.acc_data)
                self.line_val_acc.set_data(range(1, len(self.val_acc_data) + 1), self.val_acc_data)
    
                # Rescale the axes to fit the new data
                self.ax.relim()
                self.ax.autoscale_view()
                self.canvas_plot.draw()
    
        except Exception as e:
            # Handle any unexpected errors gracefully
            logger.exception(f"Error updating GUI: {e}")
            messagebox.showerror("GUI Update Error", f"An error occurred while updating the GUI: {e}")


    # Training thread
    def train_thread(self):
        logger.info("Training thread started.")
        try:
            train_deep_model(self.model, self.train_loader, self.val_loader, epochs=self.total_epochs, lr=0.01,
                             device=self.device, gui=self, patience=5)
            logger.info("Training completed successfully.")
            self.queue.put('training_complete')
        except Exception as e:
            logger.exception("An error occurred in the training thread.")
            messagebox.showerror("Training Error", f"An error occurred during training: {e}")

    def train_thread(self):
        """Training logic executed in a separate thread to avoid blocking the GUI."""
        logger.info("Training thread started.")
        try:
            train_deep_model(
                self.model, self.train_loader, self.val_loader,
                epochs=self.total_epochs, lr=0.01, device=self.device,
                gui=self, patience=5
            )
            logger.info("Training completed successfully.")
            self.queue.put('training_complete')
    
        except Exception as e:
            logger.exception("An error occurred during the training process.")
            self.queue.put({'error': str(e)})
        finally:
            # Ensure the stop button is disabled and start button is enabled when the thread ends
            self.stop_button.config(state=tk.DISABLED)
            self.start_button.config(state=tk.NORMAL)


    # Visualize predictions
    def visualize_predictions(self, model, loader, device='cpu', num_samples=5):
        """
        Visualizes predictions in the GUI.
        """
        model.to(device)
        model.eval()

        samples_visualized = 0

        with torch.no_grad():
            for inputs, targets in loader:
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs)  # Shape: (batch_size, grid_size * grid_size, num_classes)
                # Reshape outputs to (batch_size, GRID_SIZE, GRID_SIZE, NUM_CLASSES)
                outputs = outputs.view(-1, GRID_SIZE, GRID_SIZE, NUM_CLASSES)
                # Get predicted classes
                _, predicted = torch.max(outputs, dim=3)  # Shape: (batch_size, GRID_SIZE, GRID_SIZE)

                for i in range(inputs.size(0)):
                    input_grid = inputs[i].cpu().numpy().squeeze()  # Shape: (GRID_SIZE, GRID_SIZE)
                    predicted_grid = predicted[i].cpu().numpy()      # Shape: (GRID_SIZE, GRID_SIZE)
                    actual_grid = targets[i].cpu().numpy()           # Shape: (GRID_SIZE, GRID_SIZE)

                    # Map the dead square label back to -1 for display purposes
                    predicted_grid_display = np.where(predicted_grid == NUM_CLASSES - 1, -1, predicted_grid)
                    actual_grid_display = np.where(actual_grid == NUM_CLASSES - 1, -1, actual_grid)

                    # Define a custom colormap to handle -1 values
                    cmap = matplotlib.cm.get_cmap('viridis', NUM_CLASSES)
                    cmap.set_bad(color='black')  # Set color for masked values (dead squares)

                    # Mask the dead squares
                    input_masked = np.ma.masked_where(input_grid == -1, input_grid)
                    predicted_masked = np.ma.masked_where(predicted_grid_display == -1, predicted_grid_display)
                    actual_masked = np.ma.masked_where(actual_grid_display == -1, actual_grid_display)

                    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

                    # Input Grid
                    im_input = axs[0].imshow(input_masked, cmap=cmap, interpolation='nearest', vmin=0, vmax=NUM_CLASSES - 2)
                    axs[0].set_title("Input Grid")
                    axs[0].axis('off')

                    # Predicted Grid
                    im_pred = axs[1].imshow(predicted_masked, cmap=cmap, interpolation='nearest', vmin=0, vmax=NUM_CLASSES - 2)
                    axs[1].set_title("Predicted Grid")
                    axs[1].axis('off')

                    # Actual Grid
                    im_actual = axs[2].imshow(actual_masked, cmap=cmap, interpolation='nearest', vmin=0, vmax=NUM_CLASSES - 2)
                    axs[2].set_title("Actual Grid")
                    axs[2].axis('off')

                    # Adjust layout to fill the space
                    plt.subplots_adjust(wspace=0.05, hspace=0)
                    plt.tight_layout()
                    # Display the plot in a new window
                    self.show_plot_in_new_window(fig)
                    plt.close(fig)  # Close the figure to free memory

                    samples_visualized += 1
                    if samples_visualized >= num_samples:
                        return

    def show_plot_in_new_window(self, fig):
        """
        Displays the Matplotlib figure in a new Tkinter window.
        """
        new_window = tk.Toplevel(self.root)
        canvas = FigureCanvasTkAgg(fig, master=new_window)
        canvas.draw()
        canvas.get_tk_widget().pack()
        toolbar = ttk.Frame(new_window)
        toolbar.pack()
        canvas._tkcanvas.pack()

# -----------------------------
# 11. Training Function with GUI Integration
# -----------------------------

def train_deep_model(model, train_loader, val_loader, epochs, lr, device, gui, patience=5):
    logger.info("Starting the training process.")
    torch.autograd.set_detect_anomaly(False)  # Disable anomaly detection for better performance
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=1e-4)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    model.to(device)

    # Use GradScaler and autocast only if CUDA is available
    use_amp = torch.cuda.is_available()
    if use_amp:
        scaler = GradScaler()
    else:
        scaler = None

    best_val_loss = float('inf')
    epochs_no_improve = 0

    total_batches = len(train_loader)

    for epoch in range(1, epochs + 1):
        logger.info(f"Starting epoch {epoch}/{epochs}.")
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader, 1):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad()

            if use_amp:
                with autocast():
                    outputs = model(inputs)
                    targets_flat = targets.view(-1)
                    outputs_flat = outputs.view(-1, NUM_CLASSES)
                    loss = criterion(outputs_flat, targets_flat)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                targets_flat = targets.view(-1)
                outputs_flat = outputs.view(-1, NUM_CLASSES)
                loss = criterion(outputs_flat, targets_flat)
                loss.backward()
                optimizer.step()

            running_loss += loss.item()

            _, predicted = torch.max(outputs_flat.detach(), 1)
            correct += (predicted == targets_flat).sum().item()
            total += targets_flat.size(0)

            # Update GUI per batch
            batch_loss = running_loss / batch_idx
            batch_acc = correct / total
            gui.queue.put({
                'batch': batch_idx,
                'total_batches': total_batches,
                'loss': batch_loss,
                'accuracy': batch_acc
            })

            if batch_idx % 10 == 0:
                logger.info(f"Epoch [{epoch}/{epochs}], Batch [{batch_idx}/{total_batches}], Loss: {loss.item():.4f}")

        epoch_loss = running_loss / total_batches
        epoch_acc = correct / total
        logger.info(f"Epoch [{epoch}/{epochs}] completed. Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.4f}")

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)

                if use_amp:
                    with autocast():
                        outputs = model(inputs)
                        targets_flat = targets.view(-1)
                        outputs_flat = outputs.view(-1, NUM_CLASSES)
                        loss = criterion(outputs_flat, targets_flat)
                else:
                    outputs = model(inputs)
                    targets_flat = targets.view(-1)
                    outputs_flat = outputs.view(-1, NUM_CLASSES)
                    loss = criterion(outputs_flat, targets_flat)

                val_loss += loss.item() * inputs.size(0)

                _, predicted = torch.max(outputs_flat, 1)
                val_correct += (predicted == targets_flat).sum().item()
                val_total += targets_flat.size(0)

        avg_val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total
        logger.info(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

        # Update GUI per epoch
        gui.queue.put({
            'epoch': epoch,
            'total_epochs': epochs,
            'epoch_loss': epoch_loss,
            'epoch_acc': epoch_acc,
            'val_loss': avg_val_loss,
            'val_accuracy': val_acc
        })

        # Check for improvement
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            # Save the best model
            torch.save(model.state_dict(), "best_deep_model.pth")
            logger.info(f"Epoch {epoch}/{epochs} - Validation loss decreased. Saving model.")
        else:
            epochs_no_improve += 1
            logger.info(f"Epoch {epoch}/{epochs} - No improvement in validation loss for {epochs_no_improve} epochs.")

        # Step the scheduler
        scheduler.step()

        # Early Stopping
        if epochs_no_improve >= patience:
            logger.info("Early stopping triggered.")
            break

    logger.info("Training completed.")

# -----------------------------
# 12. Evaluation and Prediction Functions
# -----------------------------

def evaluate_model(model, test_loader, device='cpu'):
    """
    Evaluates the model on the test dataset.

    Args:
        model (nn.Module): Trained model.
        test_loader (DataLoader): DataLoader for the test dataset.
        device (str): Device to run evaluation on.

    Returns:
        tuple: (average_loss, accuracy)
    """
    criterion = nn.CrossEntropyLoss()
    model.to(device)
    model.eval()

    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            outputs = model(inputs)  # Shape: (batch_size, grid_size * grid_size, num_classes)
            outputs_flat = outputs.view(-1, NUM_CLASSES)
            targets_flat = targets.view(-1)
            loss = criterion(outputs_flat, targets_flat)
            total_loss += loss.item() * inputs.size(0)

            _, predicted = torch.max(outputs_flat, 1)
            correct += (predicted == targets_flat).sum().item()
            total += targets_flat.size(0)

    avg_loss = total_loss / len(test_loader.dataset)
    accuracy = correct / total

    logger.info(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy

def ensemble_predictions(base_model, model_paths, test_loader, device='cpu', num_samples=5):
    """
    Performs ensemble predictions using multiple models.
    """
    models = []
    for path in model_paths:
        model = CNNGridMapper(num_classes=NUM_CLASSES, grid_size=GRID_SIZE)
        model.load_state_dict(torch.load(path, map_location=device))
        model.to(device)
        model.eval()
        models.append(model)
        logger.info(f"Loaded model from {path}")

    samples_visualized = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            # Collect predictions from all models
            outputs_list = []
            for model in models:
                outputs = model(inputs)
                outputs_list.append(outputs.unsqueeze(0))  # Add a new dimension for stacking

            # Stack outputs and average
            outputs_stack = torch.cat(outputs_list, dim=0)  # Shape: (num_models, batch_size, grid_size * grid_size, num_classes)
            outputs_avg = torch.mean(outputs_stack, dim=0)  # Shape: (batch_size, grid_size * grid_size, num_classes)

            # Reshape outputs to (batch_size, GRID_SIZE, GRID_SIZE, NUM_CLASSES)
            outputs_avg = outputs_avg.view(-1, GRID_SIZE, GRID_SIZE, NUM_CLASSES)
            # Get predicted classes
            _, predicted = torch.max(outputs_avg, dim=3)  # Shape: (batch_size, GRID_SIZE, GRID_SIZE)

            for i in range(inputs.size(0)):
                input_grid = inputs[i].cpu().numpy().squeeze()  # Shape: (GRID_SIZE, GRID_SIZE)
                predicted_grid = predicted[i].cpu().numpy()      # Shape: (GRID_SIZE, GRID_SIZE)
                actual_grid = targets[i].cpu().numpy()           # Shape: (GRID_SIZE, GRID_SIZE)

                # Map the dead square label back to -1 for display purposes
                predicted_grid_display = np.where(predicted_grid == NUM_CLASSES - 1, -1, predicted_grid)
                actual_grid_display = np.where(actual_grid == NUM_CLASSES - 1, -1, actual_grid)

                # Define a custom colormap to handle -1 values
                cmap = matplotlib.cm.get_cmap('viridis', NUM_CLASSES)
                cmap.set_bad(color='black')  # Set color for masked values (dead squares)

                # Mask the dead squares
                input_masked = np.ma.masked_where(input_grid == -1, input_grid)
                predicted_masked = np.ma.masked_where(predicted_grid_display == -1, predicted_grid_display)
                actual_masked = np.ma.masked_where(actual_grid_display == -1, actual_grid_display)

                fig, axs = plt.subplots(1, 3, figsize=(15, 5))

                # Input Grid
                im_input = axs[0].imshow(input_masked, cmap=cmap, interpolation='nearest', vmin=0, vmax=NUM_CLASSES - 2)
                axs[0].set_title("Input Grid")
                axs[0].axis('off')

                # Predicted Grid
                im_pred = axs[1].imshow(predicted_masked, cmap=cmap, interpolation='nearest', vmin=0, vmax=NUM_CLASSES - 2)
                axs[1].set_title("Ensembled Prediction")
                axs[1].axis('off')

                # Actual Grid
                im_actual = axs[2].imshow(actual_masked, cmap=cmap, interpolation='nearest', vmin=0, vmax=NUM_CLASSES - 2)
                axs[2].set_title("Actual Grid")
                axs[2].axis('off')

                # Adjust layout to fill the space
                plt.subplots_adjust(wspace=0.05, hspace=0)
                plt.tight_layout()
                # Display the plot
                plt.show()
                plt.close(fig)  # Close the figure to free memory

                samples_visualized += 1
                if samples_visualized >= num_samples:
                    return

# -----------------------------
# 13. Main Workflow with Modifications
# -----------------------------

def main():
    # Define device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")

    # Load ARC data
    arc_data = load_arc_data()

    # Extract and reshape training and evaluation grid pairs
    train_grid_pairs = flatten_and_reshape(arc_data.get("arc-agi_training-challenges", {}), grid_size=GRID_SIZE)
    eval_grid_pairs = flatten_and_reshape(arc_data.get("arc-agi_evaluation-challenges", {}), grid_size=GRID_SIZE)

    logger.info(f"Number of training grid pairs: {len(train_grid_pairs)}")
    logger.info(f"Number of evaluation grid pairs: {len(eval_grid_pairs)}")

    # Generate multiple augmented datasets
    augmented_pairs = generate_multiple_augmented_datasets(train_grid_pairs, num_augmented_sets=3)

    # Combine all datasets
    combined_train_pairs = train_grid_pairs + augmented_pairs

    # Split into training and validation sets (e.g., 80-20 split)
    train_pairs, val_pairs = train_test_split(combined_train_pairs, test_size=0.2, random_state=42)

    # Create DataLoaders
    train_dataset = AugmentedARCDataset(train_pairs, augment=False)  # Already augmented
    val_dataset = AugmentedARCDataset(val_pairs, augment=False)
    eval_dataset = AugmentedARCDataset(eval_grid_pairs, augment=False)

    batch_size = 64  # Adjust batch size as needed
    num_workers = os.cpu_count() if os.name != 'nt' else 0  # Use multiprocessing except on Windows

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    eval_loader = DataLoader(
        eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    logger.info(f"Training DataLoader size: {len(train_loader)} batches")
    logger.info(f"Validation DataLoader size: {len(val_loader)} batches")
    logger.info(f"Number of training samples: {len(train_dataset)}")
    logger.info(f"Number of validation samples: {len(val_dataset)}")
    logger.info(f"Number of evaluation samples: {len(eval_dataset)}")

    # Initialize the model
    model = CNNGridMapper(num_classes=NUM_CLASSES, grid_size=GRID_SIZE).to(device)
    logger.info("Model initialized successfully.")

    # Test model forward and backward pass
    try:
        model.train()
        sample_inputs, sample_targets = next(iter(train_loader))
        sample_inputs = sample_inputs.to(device)
        sample_targets = sample_targets.to(device)

        outputs = model(sample_inputs)
        targets_flat = sample_targets.view(-1)
        outputs_flat = outputs.view(-1, NUM_CLASSES)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs_flat, targets_flat)

        # Backward pass
        loss.backward()
        logger.info("Single batch forward and backward pass successful.")
    except Exception as e:
        logger.exception("Model forward or backward pass failed.")
        return

    # Initialize the GUI
    root = tk.Tk()
    total_epochs = 10  # Adjust as needed
    total_batches = len(train_loader)
    gui = TrainingGUI(root, total_epochs, total_batches, model, train_loader, val_loader, eval_loader, device)

    # Start the training without blocking the GUI
    root.after(100, gui.retrain_model)  # Schedule training to start without blocking

    # Start the GUI main loop
    root.mainloop()


# -----------------------------
# 14. Execute the Main Function
# -----------------------------

if __name__ == "__main__":
    main()





INFO:__main__:Using device: cpu
INFO:__main__:Loaded arc-agi_training-challenges from arc-agi_training_challenges.json.
INFO:__main__:Loaded arc-agi_evaluation-challenges from arc-agi_evaluation_challenges.json.
INFO:__main__:Loaded arc-agi_training-solutions from arc-agi_training_solutions.json.
INFO:__main__:Loaded arc-agi_evaluation-solutions from arc-agi_evaluation_solutions.json.
INFO:__main__:Parsing task 007bbfb7...
INFO:__main__:Parsing task 00d62c1b...
INFO:__main__:Parsing task 017c7c7b...
INFO:__main__:Parsing task 025d127b...
INFO:__main__:Parsing task 045e512c...
INFO:__main__:Parsing task 0520fde7...
INFO:__main__:Parsing task 05269061...
INFO:__main__:Parsing task 05f2a901...
INFO:__main__:Parsing task 06df4c85...
INFO:__main__:Parsing task 08ed6ac7...
INFO:__main__:Parsing task 09629e4f...
INFO:__main__:Parsing task 0962bcdd...
INFO:__main__:Parsing task 0a938d79...
INFO:__main__:Parsing task 0b148d64...
INFO:__main__:Parsing task 0ca9ddb6...
INFO:__main__:Parsing task 