In [None]:
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
import tempfile
import os
import cpuinfo
import zipfile
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons
from moviepy.editor import VideoFileClip
from matplotlib import animation
from sklearn.decomposition import PCA
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import shutil
from sklearn.utils import shuffle
from sklearn.decomposition import PCA
from matplotlib import animation
import matplotlib.pyplot as plt
import tempfile
from PIL import Image, ImageFile
import logging 

import os
def safe_output(path):
    if os.path.isdir(path):
        raise ValueError(f"Expected a file, got a directory: {path}")
    return path


logging.getLogger('PIL.TiffImagePlugin').setLevel(logging.INFO)

OUTPUT_DIR = os.path.abspath("outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

layer_map = {
    "Linear": lambda in_dim, out_dim: nn.Linear(int(in_dim), int(out_dim)),
    "Conv2d": lambda in_dim, out_dim: nn.Conv2d(int(in_dim), int(out_dim), kernel_size=3, padding=1),  
    "MaxPool2d": lambda *_: nn.MaxPool2d(kernel_size=2),  
    "AvgPool2d": lambda *_: nn.AvgPool2d(kernel_size=2),  
    "Dropout": lambda p=0.5, *_: nn.Dropout(float(p)),
    "ReLU": lambda *_: nn.ReLU(),
    "Tanh": lambda *_: nn.Tanh(),
    "Sigmoid": lambda *_: nn.Sigmoid(),
    "Flatten": lambda *_: nn.Flatten(),
    "Softmax": lambda *_: nn.Softmax(dim=1),
    "LeakyReLU": lambda slope=0.01, *_: nn.LeakyReLU(negative_slope=float(slope)),
    "GELU": lambda *_: nn.GELU(),
    "ELU": lambda alpha=1.0, *_: nn.ELU(alpha=float(alpha))
}

def parse_int_or_tuple(val):
    try:
        return tuple(map(int, str(val).split(','))) if ',' in str(val) else int(val)
    except Exception:
        raise ValueError(f"Invalid numeric input: '{val}'. Please enter an integer or comma-separated pair.")

layer_configs = []

def validate_layer_inputs(layer_type, **kwargs):
    try:
        if layer_type == "Linear":
            in_dim_int = int(kwargs.get("in_dim"))
            out_dim_int = int(kwargs.get("out_dim"))
            if in_dim_int <= 0 or out_dim_int <= 0:
                return False, f"{layer_type} dimensions must be positive integers"

        elif layer_type == "Conv2d":
            in_dim_int = int(kwargs.get("in_dim"))
            out_dim_int = int(kwargs.get("out_dim"))
            kernel_dim = parse_int_or_tuple(kwargs.get("kernel_size", 3))
            padding_dim = parse_int_or_tuple(kwargs.get("padding", 1))
            stride = parse_int_or_tuple(kwargs.get("stride", 1))

            for val in [kernel_dim, padding_dim, stride]:
                if isinstance(val, tuple):
                    if any(v < 0 for v in val):
                        return False, f"{layer_type} tuple values must be non-negative"
                else:
                    if val < 0:
                        return False, f"{layer_type} values must be non-negative"

            if in_dim_int <= 0 or out_dim_int <= 0:
                return False, f"{layer_type} in/out dims must be positive integers"

        elif layer_type == "Dropout":
            p = float(kwargs.get("in_dim"))
            if not (0 <= p <= 1):
                return False, "Dropout probability must be between 0 and 1"

        elif layer_type == "MaxPool2d":
            kernel = parse_int_or_tuple(kwargs.get("pool_kernel", 2))
            stride = parse_int_or_tuple(kwargs.get("pool_stride", 2))
            padding = parse_int_or_tuple(kwargs.get("pool_padding", 0))
            for val in [kernel, stride, padding]:
                if isinstance(val, tuple):
                    if any(v < 0 for v in val):
                        return False, f"{layer_type} tuple values must be non-negative"
                else:
                    if val < 0:
                        return False, f"{layer_type} values must be non-negative"
        
        elif layer_type == "AvgPool2d":
            kernel = parse_int_or_tuple(kwargs.get("avgpool_kernel", 2))
            stride = parse_int_or_tuple(kwargs.get("avgpool_stride", 2))
            padding = parse_int_or_tuple(kwargs.get("avgpool_padding", 0))
            for val in [kernel, stride, padding]:
                if isinstance(val, tuple):
                    if any(v < 0 for v in val):
                        return False, f"{layer_type} tuple values must be non-negative"
                else:
                    if val < 0:
                        return False, f"{layer_type} values must be non-negative"

        elif layer_type == "LeakyReLU":
            slope = float(kwargs.get("leaky_slope", 0.01))
            if slope < 0:
                return False, "LeakyReLU negative_slope must be ‚â• 0"
        
        elif layer_type == "ELU":
            alpha = float(kwargs.get("elu_alpha", 1.0))
            if alpha < 0:
                return False, "ELU alpha must be ‚â• 0"

        return True, None
    except Exception as e:
        return False, f"Validation error in {layer_type}: {str(e)}"

def add_layer(
    layer_type, in_dim, out_dim,
    kernel_size=3, padding=1, stride=1, bias=1,
    pool_kernel="2", pool_stride="2", pool_padding="0",
    avgpool_kernel=None, avgpool_stride=None, avgpool_padding=None,
    leaky_slope = "0.01", elu_alpha = "1.0"
):

    is_valid, err_msg = validate_layer_inputs(
        layer_type=layer_type,
        in_dim=in_dim,
        out_dim=out_dim,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        pool_kernel=pool_kernel,
        pool_stride=pool_stride,
        pool_padding=pool_padding,
        avgpool_kernel=avgpool_kernel, 
        avgpool_stride=avgpool_stride,
        avgpool_padding=avgpool_padding,
        leaky_slope=leaky_slope,
        elu_alpha=elu_alpha
    )

    if not is_valid:
        return err_msg

    try:
        if layer_type == "Conv2d":
            in_dim = int(in_dim)
            out_dim = int(out_dim)
            k = parse_int_or_tuple(kernel_size or 3)
            p = parse_int_or_tuple(padding or 1)
            s = parse_int_or_tuple(stride or 1)
            b = bool(bias)
            desc = f"Conv2d({in_dim}, {out_dim}, kernel={k}, padding={p}, stride={s}, bias={b})"
            config = (desc, layer_type, in_dim, out_dim, k, p, s, b)
        
        elif layer_type == "LeakyReLU":
            negative_slope = float(leaky_slope or "0.01")
            desc = f"LeakyReLU(negative_slope={negative_slope})"
            config = (desc, layer_type, negative_slope, negative_slope, None, None, None, None)
        
        elif layer_type == "ELU":
            alpha = float(elu_alpha or "1.0")
            desc = f"ELU(alpha={alpha})"
            config = (desc, layer_type, alpha, alpha, None, None, None, None)

        elif layer_type == "Softmax":
            desc = "Softmax(dim=1)"
            config = (desc, layer_type, None, None, None, None, None, None)

        elif layer_type == "GELU":
            desc = "GELU()"
            config = (desc, layer_type, None, None, None, None, None, None)
            
        elif layer_type == "Linear":
            in_dim = int(in_dim)
            out_dim = int(out_dim)
            desc = f"Linear({in_dim}, {out_dim})"
            config = (desc, layer_type, in_dim, out_dim, None, None, None, None)

        elif layer_type == "Dropout":
            p = float(in_dim)
            desc = f"Dropout({p})"
            config = (desc, layer_type, p, p, None, None, None, None)

        elif layer_type == "MaxPool2d":
            kernel_val = parse_int_or_tuple(pool_kernel or "2")
            stride_val = parse_int_or_tuple(pool_stride or "2")
            padding_val = parse_int_or_tuple(pool_padding or "0")
            desc = f"MaxPool2d(kernel={kernel_val}, stride={stride_val}, padding={padding_val})"
            config = (desc, layer_type, None, None, kernel_val, padding_val, stride_val, None)

        elif layer_type == "AvgPool2d":
            kernel_val = parse_int_or_tuple(avgpool_kernel or "2")
            stride_val = parse_int_or_tuple(avgpool_stride or "2")
            padding_val = parse_int_or_tuple(avgpool_padding or "0")
            desc = f"AvgPool2d(kernel={kernel_val}, stride={stride_val}, padding={padding_val})"
            config = (desc, layer_type, None, None, kernel_val, padding_val, stride_val, None)

        else:
            # For e.g. ReLU/Tanh/Sigmoid/Flatten
            desc = layer_type
            config = (desc, layer_type, None, None, None, None, None, None)
            
    except Exception as e:
        desc = f"[Error Adding Layer: {e}]"
        config = (desc, layer_type, None, None, None, None, None, None)

    layer_configs.append(config)
    return update_architecture_text()

def update_layer(
    index, layer_type, in_dim, out_dim,
    kernel_size=3, padding=1, stride=1, bias=True,
    pool_kernel="2", pool_stride="2", pool_padding="0",
    avgpool_kernel=None, avgpool_stride=None, avgpool_padding=None,
    leaky_slope= "0.01",elu_alpha = "1.0"
):
    
    index = int(index)
    is_valid, err_msg = validate_layer_inputs(
        layer_type=layer_type,
        in_dim=in_dim,
        out_dim=out_dim,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        pool_kernel=pool_kernel,
        pool_stride=pool_stride,
        pool_padding=pool_padding,
        avgpool_kernel=avgpool_kernel, 
        avgpool_stride=avgpool_stride,
        avgpool_padding=avgpool_padding,
        leaky_slope=leaky_slope,
        elu_alpha=elu_alpha
    )

    if not is_valid:
        return err_msg
    if index < 0 or index >= len(layer_configs):
        return update_architecture_text()

    try:
        if layer_type == "Conv2d":
            i = int(in_dim)
            o = int(out_dim)
            k = parse_int_or_tuple(kernel_size or 3)
            p = parse_int_or_tuple(padding or 1)
            s = parse_int_or_tuple(stride or 1)
            b = bool(bias)
            desc = f"Conv2d({i}, {o}, kernel={k}, padding={p}, stride={s}, bias={b})"
            layer_configs[index] = (desc, layer_type, i, o, k, p, s, b)

        elif layer_type == "Linear":
            i = int(in_dim)
            o = int(out_dim)
            desc = f"Linear({i}, {o})"
            layer_configs[index] = (desc, layer_type, i, o, None, None, None, None)
        
        elif layer_type == "ELU":
            alpha = float(elu_alpha or "1.0")
            desc = f"ELU(alpha={alpha})"
            layer_configs[index] = (desc, layer_type, alpha, alpha, None, None, None, None)

        elif layer_type == "GELU":
            desc = "GELU()"
            layer_configs[index] = (desc, layer_type, None, None, None, None, None, None)

        elif layer_type == "LeakyReLU":
            negative_slope = float(leaky_slope or "0.01")
            desc = f"LeakyReLU(negative_slope={negative_slope})"
            layer_configs[index] = (desc, layer_type, negative_slope, negative_slope, None, None, None, None)

        elif layer_type == "Softmax":
            desc = "Softmax(dim=1)"
            layer_configs[index] = (desc, layer_type, None, None, None, None, None, None)

        elif layer_type == "Dropout":
            p = float(in_dim)
            desc = f"Dropout({p})"
            layer_configs[index] = (desc, layer_type, p, p, None, None, None, None)

        elif layer_type == "AvgPool2d":
            kv = parse_int_or_tuple(avgpool_kernel or "2")
            sv = parse_int_or_tuple(avgpool_stride or "2")
            pv = parse_int_or_tuple(avgpool_padding or "0")
            desc = f"AvgPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs[index] = (desc, layer_type, None, None, kv, pv, sv, None)

        elif layer_type == "MaxPool2d":
            kv = parse_int_or_tuple(pool_kernel or "2")
            sv = parse_int_or_tuple(pool_stride or "2")
            pv = parse_int_or_tuple(pool_padding or "0")
            desc = f"MaxPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs[index] = (desc, layer_type, None, None, kv, pv, sv, None)

        else:
            desc = layer_type
            layer_configs[index] = (desc, layer_type, None, None, None, None, None, None)

    except Exception as e:
        layer_configs[index] = (f"[Error Editing Layer: {e}]", layer_type, None, None, None, None, None, None)

    return update_architecture_text()

def insert_layer(
    index, layer_type, in_dim, out_dim,
    kernel_size=3, padding=1, stride=1, bias=1,
    pool_kernel="2", pool_stride="2", pool_padding="0",
    avgpool_kernel=None, avgpool_stride=None, avgpool_padding=None, 
    leaky_slope="0.01", elu_alpha = "1.0"
):
    index = int(index)
    is_valid, err_msg = validate_layer_inputs(
        layer_type=layer_type,
        in_dim=in_dim,
        out_dim=out_dim,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        pool_kernel=pool_kernel,
        pool_stride=pool_stride,
        pool_padding=pool_padding,
        avgpool_kernel=avgpool_kernel, 
        avgpool_stride=avgpool_stride,
        avgpool_padding=avgpool_padding,
        leaky_slope=leaky_slope,
        elu_alpha=elu_alpha
    )

    if not is_valid:
        return err_msg

    try:
        if layer_type == "Conv2d":
            i = int(in_dim)
            o = int(out_dim)
            k = parse_int_or_tuple(kernel_size or "3")
            p = parse_int_or_tuple(padding or "1")
            s = parse_int_or_tuple(stride or "1")
            b = bool(bias)
            desc = f"Conv2d({i}, {o}, kernel={k}, padding={p}, stride={s}, bias={b})"
            layer_configs.insert(index, (desc, layer_type, i, o, k, p, s, b))

        elif layer_type == "Linear":
            i = int(in_dim)
            o = int(out_dim)
            desc = f"Linear({i}, {o})"
            layer_configs.insert(index, (desc, layer_type, i, o, None, None, None, None))

        elif layer_type == "ELU":
            alpha = float(elu_alpha or "1.0")
            desc = f"ELU(alpha={alpha})"
            layer_configs.insert(index, (desc, layer_type, alpha, alpha, None, None, None, None))

        elif layer_type == "LeakyReLU":
            negative_slope = float(leaky_slope or "0.01")
            desc = f"LeakyReLU(negative_slope={negative_slope})"
            layer_configs.insert(index, (desc, layer_type, negative_slope, negative_slope, None, None, None, None))

        elif layer_type == "Softmax":
            desc = "Softmax(dim=1)"
            layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

        elif layer_type == "Dropout":
            p = float(in_dim)
            desc = f"Dropout({p})"
            layer_configs.insert(index, (desc, layer_type, p, p, None, None, None, None))

        elif layer_type == "MaxPool2d":
            kv = parse_int_or_tuple(pool_kernel or "2")
            sv = parse_int_or_tuple(pool_stride or "2")
            pv = parse_int_or_tuple(pool_padding or "0")
            desc = f"MaxPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs.insert(index, (desc, layer_type, None, None, kv, pv, sv, None))

        elif layer_type == "AvgPool2d":
            kv = parse_int_or_tuple(avgpool_kernel or "2")
            sv = parse_int_or_tuple(avgpool_stride or "2")
            pv = parse_int_or_tuple(avgpool_padding or "0")
            desc = f"AvgPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs.insert(index, (desc, layer_type, None, None, kv, pv, sv, None))

        elif layer_type == "GELU":
            desc = "GELU()"
            layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

        else:
            desc = layer_type
            layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

    except Exception as e:
        desc = f"[Error Inserting Layer: {e}]"
        layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

    return update_architecture_text()

def delete_layer(index):
    index = int(index)
    if 0 <= index < len(layer_configs):
        layer_configs.pop(index)
    return update_architecture_text()

def reset_layers():
    layer_configs.clear()
    return ""

def update_architecture_text(highlight_index=None):
    lines = []
    for i, config in enumerate(layer_configs):
        prefix = f"{i}: "
        desc = config[0]
        if i == highlight_index:
            desc = f"‚ö†Ô∏è {desc}"
        lines.append(prefix + desc)
    return "\n".join(lines)

def build_model():
    layers = []
    for config in layer_configs:
        _, layer_type, in_dim, out_dim, kernel, padding, stride, bias = config

        if layer_type == "Conv2d":
            layers.append(nn.Conv2d(
                int(in_dim), int(out_dim),
                kernel_size=kernel,
                padding=padding,
                stride=stride,
                bias=bool(bias)
            ))

        elif layer_type == "Linear":
            layers.append(nn.Linear(int(in_dim), int(out_dim)))
        
        elif layer_type == "GELU":
            layers.append(nn.GELU())

        elif layer_type == "LeakyReLU":
            layers.append(nn.LeakyReLU(negative_slope=float(in_dim)))

        elif layer_type == "ELU":
            layers.append(nn.ELU(alpha=float(in_dim)))

        elif layer_type == "Dropout":
            layers.append(nn.Dropout(float(in_dim)))
        
        elif layer_type == "Softmax":
            layers.append(nn.Softmax(dim=1))

        elif layer_type == "MaxPool2d":
            layers.append(nn.MaxPool2d(kernel_size=kernel, stride=stride, padding=padding))

        elif layer_type == "AvgPool2d":
            layers.append(nn.AvgPool2d(kernel_size=kernel, stride=stride, padding=padding))

        else:
            # e.g. ReLU, Tanh, Sigmoid, Flatten
            layers.append(layer_map[layer_type]())

    return nn.Sequential(*layers)

def extract_zip_to_tempdir(file_like, custom_path=None):

    if custom_path:
        os.makedirs(custom_path, exist_ok=True)
        temp_dir = tempfile.mkdtemp(dir=custom_path)
    else:
        temp_dir = tempfile.mkdtemp()
    with zipfile.ZipFile(file_like, 'r') as zip_ref:
        zip_ref.extractall(temp_dir)
    return temp_dir

ImageFile.LOAD_TRUNCATED_IMAGES = True

def safe_pil_loader(path, num_channels=3):
    try:
        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB' if num_channels == 3 else 'L')
            img.load()
            return img
    except Exception as e:
        print(f"Skipping corrupted image: {path} ‚Äî {e}")
        return None

class SafeImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, num_channels=3):
        # Create a lambda that captures `num_channels`
        loader_with_channels = lambda path: safe_pil_loader(path, num_channels)
        super().__init__(root, transform=transform, loader=loader_with_channels)
        
        before = len(self.samples)
        self.samples = [s for s in self.samples if loader_with_channels(s[0]) is not None]
        after = len(self.samples)
        self.imgs = self.samples

    def __getitem__(self, index):
        path, target = self.samples[index]
        img = self.loader(path)
        if img is None:
            raise ValueError(f"Could not load image at {path}")
        if self.transform is not None:
            img = self.transform(img)
        return img, target

def load_data(file, custom_path=None, batch_size=32, image_size=28,  num_channels=3):
    ext = os.path.splitext(file)[1].lower()
    if ext == ".csv":
        df = pd.read_csv(file)
        if 'y' not in df.columns:
            raise ValueError("CSV must contain a 'y' column.")
        X = df.drop(columns=['y']).values
        y = df['y'].values
        X, y = shuffle(X, y)
        return {
            "type": "tabular",
            "train": (X, y),
            "path": None
        }
    elif ext == ".zip":
        with open(file, 'rb') as f:
            data_dir = extract_zip_to_tempdir(f, custom_path)
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])
        dataset = SafeImageFolder(data_dir, transform=transform, num_channels=num_channels)
        train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        return {
            "type": "image",
            "train": train_loader,
            "path": data_dir
        }
    else:
        raise ValueError("Unsupported file format. Provide a .csv or .zip path.")

def get_flat_weights(model):
    return torch.cat([p.detach().flatten() for p in model.parameters()])

def generate_loss_plot(loss_history):
    fig, ax = plt.subplots()
    ax.plot(loss_history)
    ax.set_title("Loss over Epochs")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    plt.savefig(tmpfile.name)
    tmpfile.close()
    plt.close(fig)
    return tmpfile.name

def generate_3d_animation_pca(weight_path, loss_history, output_path):
    """
    weight_path: Numpy array of flattened weights over training steps
    loss_history: List of loss values
    output_path:  Where the final .mp4 animation should be saved
    """

    # 1) Do PCA
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(weight_path)

    # 2) Make the 3D figure
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(reduced[:, 0], reduced[:, 1], loss_history, color='red')
    point, = ax.plot([reduced[0, 0]], [reduced[0, 1]], [loss_history[0]], 'ro')

    def update(i):
        point.set_data([reduced[i, 0]], [reduced[i, 1]])
        point.set_3d_properties([loss_history[i]])
        return point,

    ani = animation.FuncAnimation(
        fig, update,
        frames=len(loss_history),
        interval=100,
        blit=True
    )

    # 3) Write an intermediate .gif in the same folder as output_path
    #    for example, if output_path="outputs/animation.mp4",
    #    we can do something like "outputs/animation_temp.gif"
    base, _ = os.path.splitext(output_path)
    gif_path = f"{base}_temp.gif"

    # Save as a GIF
    ani.save(gif_path, writer='pillow')
    plt.close(fig)

    # 4) Convert .gif to .mp4
    clip = VideoFileClip(gif_path)
    clip.write_videofile(output_path, codec='libx264')
    clip.close()

    # 5) Remove the temp .gif
    os.remove(gif_path)

    # 6) Return the final MP4 path
    return output_path


def validate_model_forward_pass(model, data_type, image_size=28, num_features=None, num_channels=3):
    try:
        if data_type == "tabular":
            assert num_features is not None, "Missing number of input features for tabular data"
            dummy_input = torch.randn(1, num_features).to(device)
        else:
            dummy_input = torch.randn(1, num_channels, image_size, image_size).to(device)

        x = dummy_input
        for idx, layer in enumerate(model):
            try:
                x = layer(x)
            except Exception as e:
                return False, f"Shape mismatch at layer {idx}: {layer.__class__.__name__} ‚Äî {str(e)}", idx
        return True, None, None
    except Exception as e:
        return False, f"Unexpected validation error: {str(e)}", None

def train_model(loss_name, opt_name, lr, batch_size='32', image_size='28', file=None, custom_path=None, epochs='100', num_channels=3, generate_animation=False):
    """
    Must always return exactly 5 items to match the 5 outputs in trainer_interface!
    We'll now use yield so we can stream partial updates:
      (loss_plot, animation_path, model_path, architecture_text, log_text)
    """

    
    animation_path = None  # ‚úÖ safe default
    loss_plot_path = None
    model_path = None
    final_logs = ""
    try:
        channels = int(num_channels)
        if channels not in [1, 3]:
            yield None, None, None, update_architecture_text(), "‚ùå Channels must be 1 or 3."
            return
    except:
        yield None, None, None, update_architecture_text(), "‚ùå Channels must be numeric."
        return

    # 1) Convert epochs to int with early validation
    try:
        max_epochs = int(epochs)
        if max_epochs <= 0:
            # Immediately yield error and stop
            yield None, None, None, update_architecture_text(), "‚ùå Epochs must be a positive integer."
            return
    except:
        yield None, None, None, update_architecture_text(), "‚ùå Epochs must be a valid number."
        return

    try:
        # 2) If no model is configured
        if not layer_configs:
            yield None, None, None, update_architecture_text(), "‚ùå No model configured! Please add at least one trainable layer."
            return

        # 3) If file path is invalid
        if not os.path.exists(file):
            msg = f"‚ùå File not found: {file}\nPlease check the path and try again."
            yield None, None, None, update_architecture_text(), msg
            return

        # 4) If file type is not .csv or .zip
        if not (file.endswith('.csv') or file.endswith('.zip')):
            yield None, None, None, update_architecture_text(), "‚ùå Invalid file type. Please provide a .csv or .zip file."
            return

        # 5) If custom directory is invalid
        if custom_path and not os.path.isdir(custom_path):
            try:
                os.makedirs(custom_path, exist_ok=True)
            except Exception as e:
                msg = f"‚ùå Could not create directory '{custom_path}': {e}"
                yield None, None, None, update_architecture_text(), msg
                return

        # 6) Build model
        lr = float(lr)
        model = build_model().to(device)

        # 7) Validate forward pass
        if file.endswith(".csv"):
            df = pd.read_csv(file)
            if 'y' not in df.columns:
                yield None, None, None, update_architecture_text(), "‚ùå CSV missing 'y' column."
                return
            num_features = df.shape[1] - 1
            is_valid, error_msg, bad_layer_idx = validate_model_forward_pass(model, "tabular", num_features=num_features)
        else:
            is_valid, error_msg, bad_layer_idx = validate_model_forward_pass(model, "image", image_size=int(image_size), num_channels=num_channels)

        if not is_valid:
            # highlight offending layer + yield
            updated_view = update_architecture_text(highlight_index=bad_layer_idx)
            yield None, None, None, updated_view, ""
            return

        # 8) Check for trainable params
        if not any(p.requires_grad for p in model.parameters()):
            yield None, None, None, update_architecture_text(), "‚ö†Ô∏è Model has no trainable parameters. Add a Linear or Conv2d layer."
            return

        # 9) Validate batch_size
        try:
            batch_size = int(batch_size)
            if batch_size <= 0:
                yield None, None, None, update_architecture_text(), "‚ùå Batch size must be a positive integer"
                return
        except:
            yield None, None, None, update_architecture_text(), "‚ùå Batch size must be a valid number"
            return

        # 10) Validate image_size
        try:
            image_size = int(image_size)
            if image_size <= 0:
                yield None, None, None, update_architecture_text(), "‚ùå Image size must be a positive integer"
                return
        except:
            yield None, None, None, update_architecture_text(), "‚ùå Image size must be a valid number"
            return

        # 11) Load data
        data = load_data(file, custom_path, batch_size=batch_size, image_size=image_size, num_channels=channels)
        loss_fn = nn.MSELoss() if loss_name == 'MSELoss' else nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=lr) if opt_name == 'SGD' else optim.Adam(model.parameters(), lr=lr)

        loss_history = []
        weight_path = []
        status_lines = []


        print(f"[DEBUG] Start of train_model: len(loss_history) = {len(loss_history)}")


        # 12) Train

        from torch.utils.data import DataLoader, TensorDataset

        if data["type"] == "tabular":
            X_train, y_train = data["train"]
            X_train = torch.tensor(X_train, dtype=torch.float32)
            y_train = torch.tensor(y_train, dtype=torch.long)
        
            train_dataset = TensorDataset(X_train, y_train)
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
            for epoch in range(1, max_epochs + 1):
                epoch_loss = 0
                num_batches = 0
        
                for X_batch, y_batch in train_loader:
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device)
        
                    optimizer.zero_grad()
        
                    if isinstance(loss_fn, nn.MSELoss):
                        y_input = torch.nn.functional.one_hot(y_batch, num_classes=2).float()
                    else:
                        y_input = y_batch
        
                    out = model(X_batch)
        
                    if isinstance(loss_fn, nn.MSELoss):
                        out = torch.softmax(out, dim=1)
        
                    loss = loss_fn(out, y_input)
                    loss.backward()
                    optimizer.step()
        
                    epoch_loss += loss.item()
                    num_batches += 1
                    weight_path.append(get_flat_weights(model).cpu().numpy())
        
                avg_epoch_loss = epoch_loss / num_batches
                loss_history.append(avg_epoch_loss)
        
                status_lines.append(f"Epoch {epoch}/{max_epochs} ‚Äî Loss: {avg_epoch_loss:.4f}")
                yield None, None, None, update_architecture_text(), "\n\n".join(status_lines)

        else:
            # image data
            train_loader = data["train"]
            for epoch in range(1, max_epochs + 1):
                epoch_loss = 0
                num_batches = 0
                for X_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device)

                    if isinstance(loss_fn, nn.MSELoss):
                        # One-hot for MSE
                        y_input = torch.nn.functional.one_hot(
                            y_batch,
                            num_classes=len(torch.unique(y_batch))
                        ).float().to(device)
                    else:
                        y_input = y_batch

                    out = model(X_batch)
                    if isinstance(loss_fn, nn.MSELoss):
                        out = torch.softmax(out, dim=1)

                    loss = loss_fn(out, y_input)
                    loss.backward()
                    optimizer.step()

                    epoch_loss += loss.item()
                    num_batches += 1

                    weight_path.append(get_flat_weights(model).cpu().numpy())


    # Average loss for the epoch
                avg_epoch_loss = epoch_loss / num_batches
                loss_history.append(avg_epoch_loss)

                status_lines.append(f"Epoch {epoch}/{max_epochs} ‚Äî Loss: {avg_epoch_loss:.4f}")
                yield None, None, None, update_architecture_text(), "\n\n".join(status_lines)

            # Cleanup extracted images
            if data["path"]:
                try:
                    shutil.rmtree(data["path"])
                except Exception as e:
                    print(f"Warning: Could not delete extracted folder: {e}")

        # 13) Produce final outputs (loss plot, 3D animation, model file, architecture, logs)
        print(f"[DEBUG] After training loop: len(loss_history) = {len(loss_history)}")

        loss_plot_path = os.path.join(OUTPUT_DIR, "loss_plot.png")
        fig, ax = plt.subplots()
        ax.plot(loss_history)
        ax.set_title("Loss over Epochs")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        plt.savefig(loss_plot_path)
        plt.close(fig)
        animation_path = os.path.join(OUTPUT_DIR, "animation.mp4")
        
        if generate_animation:

            generate_3d_animation_pca(
            np.array(weight_path),
            loss_history,
            animation_path
            )
        else:
            create_dummy_video(animation_path)
# Save the trained model in OUTPUT_DIR
        model_path = os.path.join(OUTPUT_DIR, "trained_model.pt")
        torch.save(model.state_dict(), model_path)

        final_logs = "\n".join(status_lines)

        # The final yield is still 5 items

        print(f"[DEBUG] Final yield:")
        print(f" - Loss plot: {loss_plot_path}")
        print(f" - Animation path: {animation_path}")
        print(f" - Model path: {model_path}")
        print(f" - Arch text: {update_architecture_text()}")
        print(f" - Logs: {final_logs}")

        for path in [loss_plot_path, animation_path, model_path]:
            if os.path.isdir(path):
                raise ValueError(f"‚ùå Output path is a directory, expected a file: {path}")
        print(f"[DEBUG] Before final yield: len(loss_history) = {len(loss_history)}")

        yield loss_plot_path,  animation_path, model_path, update_architecture_text(), final_logs

    except gr.Error as e:
        raise e
    except Exception as e:
        raise gr.Error(f"‚ùå Unexpected error: {str(e)}")

def get_device_status():
    if torch.cuda.is_available():
        return f"üü¢ GPU: {torch.cuda.get_device_name(0)}"
    else:
        return f"üî¥ CPU: {cpuinfo.get_cpu_info()}"


def get_default_writable_folder():
    """
    Returns a subfolder in the user's home directory, e.g.:
      Windows: C:\\Users\\<USERNAME>\\my_gradio_data
      Linux/Mac: /home/<USERNAME>/my_gradio_data
    This avoids writing to the root of C: or system-protected paths.
    """
    home_dir = os.path.expanduser("~")  # e.g. C:\\Users\\Alice on Windows
    default_path = os.path.join(home_dir, "my_gradio_data")
    os.makedirs(default_path, exist_ok=True)
    return default_path


def train_model_with_default_path(
    loss_name, opt_name, lr, batch_size, image_size,
    file, custom_path, epochs, num_channels, generate_animation
):
    if not custom_path or custom_path.strip() == "":
        custom_path = get_default_writable_folder()

    # Instead of returning the generator, re-yield its contents:
    for item in train_model(
        loss_name, opt_name, lr, batch_size, image_size,
        file, custom_path, epochs, num_channels, generate_animation
    ):
        yield item

import subprocess


def create_dummy_video(output_path):
    command = [
        "ffmpeg",
        "-f", "lavfi",
        "-i", "color=c=black:s=1280x720:d=5",  # ‚¨ÖÔ∏è bump resolution and duration
        "-c:v", "libx264",
        "-t", "5",
        "-pix_fmt", "yuv420p",
        "-y", output_path
    ]
    try:
        result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print(f"‚úÖ Dummy video created at: {output_path}")
        print(result.stdout.decode())
        print(result.stderr.decode())
    except subprocess.CalledProcessError as e:
        print(f"‚ùå Error creating dummy video: {e.stderr.decode()}")

# --------------------------------
# Gradio wiring
# --------------------------------

with gr.Blocks() as dashboard:

    with gr.Tab("Build Model"):
        gr.Markdown("## Build a Layer")
        builder_arch = gr.Textbox(label="Architecture So Far", lines=6)

        layer_type_dropdown = gr.Dropdown(choices=list(layer_map.keys()), label="Layer Type", value="Linear")
        in_dim = gr.Textbox(label="Input Dim (Linear/Conv2d)")
        out_dim = gr.Textbox(label="Output Dim (Linear/Conv2d)")

        conv_kernel = gr.Textbox(label="Kernel Size", value="3", visible=False)
        conv_padding = gr.Textbox(label="Padding", value="1", visible=False)
        conv_stride = gr.Textbox(label="Stride", value="1", visible=False)
        conv_bias = gr.Checkbox(label="Include Bias", value=True, visible=False)

        pool_kernel = gr.Textbox(label="Pool Kernel Size", value="2", visible=False)
        pool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        pool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        avgpool_kernel = gr.Textbox(label="AvgPool Kernel Size", value="2", visible=False)
        avgpool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        avgpool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        leaky_relu_slope = gr.Textbox(label="Negative Slope", value="0.01", visible=False)
        elu_alpha = gr.Textbox(label="ELU Alpha", value="1.0", visible=False)

        add_btn = gr.Button("Add Layer")
        add_btn.click(
            fn=add_layer,
            inputs=[
                layer_type_dropdown, in_dim, out_dim,
                conv_kernel, conv_padding, conv_stride, conv_bias,
                pool_kernel, pool_stride, pool_padding,
                avgpool_kernel, avgpool_stride, avgpool_padding,
                leaky_relu_slope, elu_alpha
            ],
            outputs=builder_arch
        )

        def toggle_fields(layer_type):
            is_conv = (layer_type == "Conv2d")
            is_pool = (layer_type == "MaxPool2d")
            is_avgpool = (layer_type == "AvgPool2d")
            is_leaky = (layer_type == "LeakyReLU")
            is_elu = (layer_type == "ELU")
            return [
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_leaky),
                gr.update(visible=is_elu),
            ]

        layer_type_dropdown.change(
            toggle_fields,
            inputs=[layer_type_dropdown],
            outputs=[
                conv_kernel, conv_padding, conv_stride, conv_bias,
                pool_kernel, pool_stride, pool_padding,
                avgpool_kernel, avgpool_stride, avgpool_padding,
                leaky_relu_slope, elu_alpha
            ]
        )

        reset_btn = gr.Button("Reset Layers")
        reset_btn.click(fn=reset_layers, inputs=[], outputs=builder_arch)

        gr.Markdown("### Edit or Delete a Layer")
        layer_index = gr.Number(label="Layer Index (0-based)", precision=0)
        new_layer_type = gr.Dropdown(list(layer_map.keys()), label="New Layer Type")
        new_in_dim = gr.Textbox(label="New Input Dim (if applicable)")
        new_out_dim = gr.Textbox(label="New Output Dim (if applicable)")

        edit_kernel = gr.Textbox(label="Kernel Size", value="3", visible=False)
        edit_padding = gr.Textbox(label="Padding", value="1", visible=False)
        edit_stride = gr.Textbox(label="Stride", value="1", visible=False)
        edit_bias = gr.Checkbox(label="Include Bias", value=True, visible=False)

        edit_pool_kernel = gr.Textbox(label="Pool Kernel Size", value="2", visible=False)
        edit_pool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        edit_pool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        edit_avgpool_kernel = gr.Textbox(label="AvgPool Kernel Size", value="2", visible=False)
        edit_avgpool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        edit_avgpool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        edit_leaky_relu_slope = gr.Textbox(label="Negative Slope", value="0.01", visible=False)
        edit_elu_alpha = gr.Textbox(label="ELU Alpha", value="1.0", visible=False)

        edit_btn = gr.Button("Edit Layer")
        delete_btn = gr.Button("Delete Layer")
        insert_btn = gr.Button("Insert New Layer")

        edit_btn.click(
            fn=update_layer,
            inputs=[
                layer_index, new_layer_type, new_in_dim, new_out_dim,
                edit_kernel, edit_padding, edit_stride, edit_bias,
                edit_pool_kernel, edit_pool_stride, edit_pool_padding,
                edit_avgpool_kernel, edit_avgpool_stride, edit_avgpool_padding,
                edit_leaky_relu_slope, edit_elu_alpha
            ],
            outputs=builder_arch
        )

        delete_btn.click(fn=delete_layer, inputs=[layer_index], outputs=builder_arch)
        insert_btn.click(
            fn=insert_layer,
            inputs=[
                layer_index, new_layer_type, new_in_dim, new_out_dim,
                edit_kernel, edit_padding, edit_stride, edit_bias,
                edit_pool_kernel, edit_pool_stride, edit_pool_padding,
                edit_avgpool_kernel, edit_avgpool_stride, edit_avgpool_padding,
                edit_leaky_relu_slope, edit_elu_alpha
            ],
            outputs=builder_arch
        )

        def toggle_edit_fields(layer_type):
            is_conv = (layer_type == "Conv2d")
            is_pool = (layer_type == "MaxPool2d")
            is_avgpool = (layer_type == "AvgPool2d")
            is_leaky = (layer_type == "LeakyReLU")
            is_elu = (layer_type == "ELU")
            return [
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_leaky),
                gr.update(visible=is_elu),
            ]

        new_layer_type.change(
            toggle_edit_fields,
            inputs=[new_layer_type],
            outputs=[
                edit_kernel, edit_padding, edit_stride, edit_bias,
                edit_pool_kernel, edit_pool_stride, edit_pool_padding,
                edit_avgpool_kernel, edit_avgpool_stride, edit_avgpool_padding,
                edit_leaky_relu_slope, edit_elu_alpha
            ]
        )

        # Device info row
        with gr.Row():
            device_output = gr.Markdown(get_device_status())
            refresh_button = gr.Button("üîÑ Refresh Device Status")
            refresh_button.click(fn=get_device_status, inputs=[], outputs=device_output)

    with gr.Tab("Train"):
        gr.Markdown("## Train the Model")

        # Same inputs as before
        loss_dropdown = gr.Dropdown(['MSELoss', 'CrossEntropyLoss'], label='Loss Function')
        opt_dropdown = gr.Dropdown(['SGD', 'Adam'], label='Optimizer')
        lr_box = gr.Textbox(value="0.01", label="Learning Rate")
        batch_box = gr.Textbox(value="32", label="Batch Size")
        size_box = gr.Textbox(value="28", label="Image Resize (e.g. 28x28)")

        # This is still the old "Path to CSV or ZIP" input
        file_box = gr.Textbox(label="Path to CSV or ZIP")

        # The user can optionally override the default path:
        custom_box = gr.Textbox(label="Custom Extraction Path (optional)")

        epochs_box = gr.Textbox(value="100", label="Epochs")
        generate_3d_checkbox = gr.Checkbox(label="Generate 3D Descent Animation (‚ö†Ô∏è Slower, CPU/RAM-intensive)", value=False)
        channel_dropdown = gr.Dropdown([1, 3], label="Input Channels (1 = Grayscale, 3 = RGB)", value=3)

        # Outputs
        loss_curve = gr.Image(label="Loss Curve")
        animation_video = gr.Video(label="3D Descent Animation")
        model_file = gr.File(label="Download Trained Model")  # <-- user can download from ANY drive they want
        log_box = gr.Markdown(label="Log")

        # Use our wrapper so we can insert a default path if custom_box is empty:
        train_button = gr.Button("Start Training")
        train_button.click(
            fn=train_model_with_default_path,
            inputs=[
                loss_dropdown,
                opt_dropdown,
                lr_box,
                batch_box,
                size_box,
                file_box,
                custom_box,
                epochs_box,
                channel_dropdown,
                generate_3d_checkbox
            ],
            outputs=[
                loss_curve,
                animation_video,
                model_file,
                builder_arch,
                log_box
            ]
        )

# Queue + Launch
dashboard.queue()
dashboard.launch(debug=True, share=True)