In [None]:
# Import all necessary libraries

import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
import tempfile
import os
import gc
import cpuinfo
import zipfile
from sklearn.model_selection import train_test_split
from moviepy.editor import VideoFileClip
from matplotlib import animation
from sklearn.decomposition import PCA
from torchvision import datasets, transforms
import shutil
from sklearn.utils import shuffle
from matplotlib import animation
import matplotlib.pyplot as plt
import tempfile
from PIL import Image, ImageFile
import logging 
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import subprocess
from torch.utils.data import DataLoader, TensorDataset

# This function finds a safe path to copy to. 
# Input: path
# Output: safe path to write files to
def safe_output(path):
    if os.path.isdir(path):
        gr.Warning(f"Expected a file, got a directory: {path}")
    return path


OUTPUT_DIR = os.path.abspath("outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# This line sets the GPU as the device to train on otherwise sets the device to CPU if no GPU found
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# --------------------------------
# Layer Functions
# --------------------------------


# The dictionary containing a map of all layers. Used in wiring the frontend
layer_map = {
    "Linear": lambda in_dim, out_dim: nn.Linear(int(in_dim), int(out_dim)),
    "Conv2d": lambda in_dim, out_dim: nn.Conv2d(int(in_dim), int(out_dim), kernel_size=3, padding=1),  
    "MaxPool2d": lambda *_: nn.MaxPool2d(kernel_size=2),  
    "AvgPool2d": lambda *_: nn.AvgPool2d(kernel_size=2),  
    "Dropout": lambda p=0.5, *_: nn.Dropout(float(p)),
    "ReLU": lambda *_: nn.ReLU(),
    "Tanh": lambda *_: nn.Tanh(),
    "Sigmoid": lambda *_: nn.Sigmoid(),
    "Flatten": lambda *_: nn.Flatten(),
    "Softmax": lambda *_: nn.Softmax(dim=1),
    "LeakyReLU": lambda slope=0.01, *_: nn.LeakyReLU(negative_slope=float(slope)),
    "GELU": lambda *_: nn.GELU(),
    "ELU": lambda alpha=1.0, *_: nn.ELU(alpha=float(alpha))
}

# This helper function turns integers entered as csv into tuples
# Input: An integer, such as 5, or 3,4, or invalid pair such as 3,car
# Output: A tuple, such as 5 as an int or (3,4) or a gr.Warning respectively
def parse_int_or_tuple(val):
    try:
        return tuple(map(int, str(val).split(','))) if ',' in str(val) else int(val)
    except Exception:
        gr.Warning(f"Invalid numeric input: '{val}'. Please enter an integer or comma-separated pair.")

layer_configs = []

# This function checks all dimensions and makes sures correct dimensions are being passed in. 
# Input: A layer type and dimensions passed in.
# Output: Nothing if correct and false otherwise along with a message
def validate_layer_inputs(layer_type, **kwargs):
    try:
        if layer_type == "Linear":
            in_dim_int = int(kwargs.get("in_dim"))
            out_dim_int = int(kwargs.get("out_dim"))
            if in_dim_int <= 0 or out_dim_int <= 0:
                return False, f"{layer_type} dimensions must be positive integers"

        elif layer_type == "Conv2d":
            in_dim_int = int(kwargs.get("in_dim"))
            out_dim_int = int(kwargs.get("out_dim"))
            kernel_dim = parse_int_or_tuple(kwargs.get("kernel_size", 3))
            padding_dim = parse_int_or_tuple(kwargs.get("padding", 1))
            stride = parse_int_or_tuple(kwargs.get("stride", 1))

            for val in [kernel_dim, padding_dim, stride]:
                if isinstance(val, tuple):
                    if any(v < 0 for v in val):
                        return False, f"{layer_type} tuple values must be non-negative"
                else:
                    if val < 0:
                        return False, f"{layer_type} values must be non-negative"

            if in_dim_int <= 0 or out_dim_int <= 0:
                return False, f"{layer_type} in/out dims must be positive integers"

        elif layer_type == "Dropout":
            p = float(kwargs.get("in_dim"))
            if not (0 <= p <= 1):
                return False, "Dropout probability must be between 0 and 1"

        elif layer_type == "MaxPool2d":
            kernel = parse_int_or_tuple(kwargs.get("pool_kernel", 2))
            stride = parse_int_or_tuple(kwargs.get("pool_stride", 2))
            padding = parse_int_or_tuple(kwargs.get("pool_padding", 0))
            for val in [kernel, stride, padding]:
                if isinstance(val, tuple):
                    if any(v < 0 for v in val):
                        return False, f"{layer_type} tuple values must be non-negative"
                else:
                    if val < 0:
                        return False, f"{layer_type} values must be non-negative"
        
        elif layer_type == "AvgPool2d":
            kernel = parse_int_or_tuple(kwargs.get("avgpool_kernel", 2))
            stride = parse_int_or_tuple(kwargs.get("avgpool_stride", 2))
            padding = parse_int_or_tuple(kwargs.get("avgpool_padding", 0))
            for val in [kernel, stride, padding]:
                if isinstance(val, tuple):
                    if any(v < 0 for v in val):
                        return False, f"{layer_type} tuple values must be non-negative"
                else:
                    if val < 0:
                        return False, f"{layer_type} values must be non-negative"

        elif layer_type == "LeakyReLU":
            slope = float(kwargs.get("leaky_slope", 0.01))
            if slope < 0:
                return False, "LeakyReLU negative_slope must be ≥ 0"
        
        elif layer_type == "ELU":
            alpha = float(kwargs.get("elu_alpha", 1.0))
            if alpha < 0:
                return False, "ELU alpha must be ≥ 0"

        return True, None
    except Exception as e:
        return False, f"Validation error in {layer_type}: {str(e)}"

# This function adds layers as the user requests on the frontend
# Input: Parameters and type of layer passed in from the frontend
# Output: A layer of the type requested is added or an error message is passed to the frontend
def add_layer(
    layer_type, in_dim, out_dim,
    kernel_size=3, padding=1, stride=1, bias=1,
    pool_kernel="2", pool_stride="2", pool_padding="0",
    avgpool_kernel=None, avgpool_stride=None, avgpool_padding=None,
    leaky_slope = "0.01", elu_alpha = "1.0"
):

    is_valid, err_msg = validate_layer_inputs(
        layer_type=layer_type,
        in_dim=in_dim,
        out_dim=out_dim,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        pool_kernel=pool_kernel,
        pool_stride=pool_stride,
        pool_padding=pool_padding,
        avgpool_kernel=avgpool_kernel, 
        avgpool_stride=avgpool_stride,
        avgpool_padding=avgpool_padding,
        leaky_slope=leaky_slope,
        elu_alpha=elu_alpha
    )

    if not is_valid:
        return err_msg

    try:
        if layer_type == "Conv2d":
            in_dim = int(in_dim)
            out_dim = int(out_dim)
            k = parse_int_or_tuple(kernel_size or 3)
            p = parse_int_or_tuple(padding or 1)
            s = parse_int_or_tuple(stride or 1)
            b = bool(bias)
            desc = f"Conv2d({in_dim}, {out_dim}, kernel={k}, padding={p}, stride={s}, bias={b})"
            config = (desc, layer_type, in_dim, out_dim, k, p, s, b)
        
        elif layer_type == "LeakyReLU":
            negative_slope = float(leaky_slope or "0.01")
            desc = f"LeakyReLU(negative_slope={negative_slope})"
            config = (desc, layer_type, negative_slope, negative_slope, None, None, None, None)
        
        elif layer_type == "ELU":
            alpha = float(elu_alpha or "1.0")
            desc = f"ELU(alpha={alpha})"
            config = (desc, layer_type, alpha, alpha, None, None, None, None)

        elif layer_type == "Softmax":
            desc = "Softmax(dim=1)"
            config = (desc, layer_type, None, None, None, None, None, None)

        elif layer_type == "GELU":
            desc = "GELU()"
            config = (desc, layer_type, None, None, None, None, None, None)
            
        elif layer_type == "Linear":
            in_dim = int(in_dim)
            out_dim = int(out_dim)
            desc = f"Linear({in_dim}, {out_dim})"
            config = (desc, layer_type, in_dim, out_dim, None, None, None, None)

        elif layer_type == "Dropout":
            p = float(in_dim)
            desc = f"Dropout({p})"
            config = (desc, layer_type, p, p, None, None, None, None)

        elif layer_type == "MaxPool2d":
            kernel_val = parse_int_or_tuple(pool_kernel or "2")
            stride_val = parse_int_or_tuple(pool_stride or "2")
            padding_val = parse_int_or_tuple(pool_padding or "0")
            desc = f"MaxPool2d(kernel={kernel_val}, stride={stride_val}, padding={padding_val})"
            config = (desc, layer_type, None, None, kernel_val, padding_val, stride_val, None)

        elif layer_type == "AvgPool2d":
            kernel_val = parse_int_or_tuple(avgpool_kernel or "2")
            stride_val = parse_int_or_tuple(avgpool_stride or "2")
            padding_val = parse_int_or_tuple(avgpool_padding or "0")
            desc = f"AvgPool2d(kernel={kernel_val}, stride={stride_val}, padding={padding_val})"
            config = (desc, layer_type, None, None, kernel_val, padding_val, stride_val, None)

        else:
            # For e.g. ReLU/Tanh/Sigmoid/Flatten
            desc = layer_type
            config = (desc, layer_type, None, None, None, None, None, None)
            
    except Exception as e:
        desc = f"[Error Adding Layer: {e}]"
        config = (desc, layer_type, None, None, None, None, None, None)

    layer_configs.append(config)
    return update_architecture_text()

# This function updates layers as the user requests on the frontend
# Input: Parameters and type of layer passed in from the frontend
# Output: A layer of the type requested is added or an error message is passed to the frontend
def update_layer(
    index, layer_type, in_dim, out_dim,
    kernel_size=3, padding=1, stride=1, bias=True,
    pool_kernel="2", pool_stride="2", pool_padding="0",
    avgpool_kernel=None, avgpool_stride=None, avgpool_padding=None,
    leaky_slope= "0.01",elu_alpha = "1.0"
):
    
    index = int(index)
    is_valid, err_msg = validate_layer_inputs(
        layer_type=layer_type,
        in_dim=in_dim,
        out_dim=out_dim,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        pool_kernel=pool_kernel,
        pool_stride=pool_stride,
        pool_padding=pool_padding,
        avgpool_kernel=avgpool_kernel, 
        avgpool_stride=avgpool_stride,
        avgpool_padding=avgpool_padding,
        leaky_slope=leaky_slope,
        elu_alpha=elu_alpha
    )

    if not is_valid:
        return err_msg
    if index < 0 or index >= len(layer_configs):
        return update_architecture_text()

    try:
        if layer_type == "Conv2d":
            i = int(in_dim)
            o = int(out_dim)
            k = parse_int_or_tuple(kernel_size or 3)
            p = parse_int_or_tuple(padding or 1)
            s = parse_int_or_tuple(stride or 1)
            b = bool(bias)
            desc = f"Conv2d({i}, {o}, kernel={k}, padding={p}, stride={s}, bias={b})"
            layer_configs[index] = (desc, layer_type, i, o, k, p, s, b)

        elif layer_type == "Linear":
            i = int(in_dim)
            o = int(out_dim)
            desc = f"Linear({i}, {o})"
            layer_configs[index] = (desc, layer_type, i, o, None, None, None, None)
        
        elif layer_type == "ELU":
            alpha = float(elu_alpha or "1.0")
            desc = f"ELU(alpha={alpha})"
            layer_configs[index] = (desc, layer_type, alpha, alpha, None, None, None, None)

        elif layer_type == "GELU":
            desc = "GELU()"
            layer_configs[index] = (desc, layer_type, None, None, None, None, None, None)

        elif layer_type == "LeakyReLU":
            negative_slope = float(leaky_slope or "0.01")
            desc = f"LeakyReLU(negative_slope={negative_slope})"
            layer_configs[index] = (desc, layer_type, negative_slope, negative_slope, None, None, None, None)

        elif layer_type == "Softmax":
            desc = "Softmax(dim=1)"
            layer_configs[index] = (desc, layer_type, None, None, None, None, None, None)

        elif layer_type == "Dropout":
            p = float(in_dim)
            desc = f"Dropout({p})"
            layer_configs[index] = (desc, layer_type, p, p, None, None, None, None)

        elif layer_type == "AvgPool2d":
            kv = parse_int_or_tuple(avgpool_kernel or "2")
            sv = parse_int_or_tuple(avgpool_stride or "2")
            pv = parse_int_or_tuple(avgpool_padding or "0")
            desc = f"AvgPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs[index] = (desc, layer_type, None, None, kv, pv, sv, None)

        elif layer_type == "MaxPool2d":
            kv = parse_int_or_tuple(pool_kernel or "2")
            sv = parse_int_or_tuple(pool_stride or "2")
            pv = parse_int_or_tuple(pool_padding or "0")
            desc = f"MaxPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs[index] = (desc, layer_type, None, None, kv, pv, sv, None)

        else:
            desc = layer_type
            layer_configs[index] = (desc, layer_type, None, None, None, None, None, None)

    except Exception as e:
        layer_configs[index] = (f"[Error Editing Layer: {e}]", layer_type, None, None, None, None, None, None)

    return update_architecture_text()

# This function inserts layers as the user requests on the frontend
# Input: Parameters and type of layer passed in from the frontend
# Output: A layer of the type requested is added or an error message is passed to the frontend
def insert_layer(
    index, layer_type, in_dim, out_dim,
    kernel_size=3, padding=1, stride=1, bias=1,
    pool_kernel="2", pool_stride="2", pool_padding="0",
    avgpool_kernel=None, avgpool_stride=None, avgpool_padding=None, 
    leaky_slope="0.01", elu_alpha = "1.0"
):
    index = int(index)
    is_valid, err_msg = validate_layer_inputs(
        layer_type=layer_type,
        in_dim=in_dim,
        out_dim=out_dim,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        pool_kernel=pool_kernel,
        pool_stride=pool_stride,
        pool_padding=pool_padding,
        avgpool_kernel=avgpool_kernel, 
        avgpool_stride=avgpool_stride,
        avgpool_padding=avgpool_padding,
        leaky_slope=leaky_slope,
        elu_alpha=elu_alpha
    )

    if not is_valid:
        return err_msg

    try:
        if layer_type == "Conv2d":
            i = int(in_dim)
            o = int(out_dim)
            k = parse_int_or_tuple(kernel_size or "3")
            p = parse_int_or_tuple(padding or "1")
            s = parse_int_or_tuple(stride or "1")
            b = bool(bias)
            desc = f"Conv2d({i}, {o}, kernel={k}, padding={p}, stride={s}, bias={b})"
            layer_configs.insert(index, (desc, layer_type, i, o, k, p, s, b))

        elif layer_type == "Linear":
            i = int(in_dim)
            o = int(out_dim)
            desc = f"Linear({i}, {o})"
            layer_configs.insert(index, (desc, layer_type, i, o, None, None, None, None))

        elif layer_type == "ELU":
            alpha = float(elu_alpha or "1.0")
            desc = f"ELU(alpha={alpha})"
            layer_configs.insert(index, (desc, layer_type, alpha, alpha, None, None, None, None))

        elif layer_type == "LeakyReLU":
            negative_slope = float(leaky_slope or "0.01")
            desc = f"LeakyReLU(negative_slope={negative_slope})"
            layer_configs.insert(index, (desc, layer_type, negative_slope, negative_slope, None, None, None, None))

        elif layer_type == "Softmax":
            desc = "Softmax(dim=1)"
            layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

        elif layer_type == "Dropout":
            p = float(in_dim)
            desc = f"Dropout({p})"
            layer_configs.insert(index, (desc, layer_type, p, p, None, None, None, None))

        elif layer_type == "MaxPool2d":
            kv = parse_int_or_tuple(pool_kernel or "2")
            sv = parse_int_or_tuple(pool_stride or "2")
            pv = parse_int_or_tuple(pool_padding or "0")
            desc = f"MaxPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs.insert(index, (desc, layer_type, None, None, kv, pv, sv, None))

        elif layer_type == "AvgPool2d":
            kv = parse_int_or_tuple(avgpool_kernel or "2")
            sv = parse_int_or_tuple(avgpool_stride or "2")
            pv = parse_int_or_tuple(avgpool_padding or "0")
            desc = f"AvgPool2d(kernel={kv}, stride={sv}, padding={pv})"
            layer_configs.insert(index, (desc, layer_type, None, None, kv, pv, sv, None))

        elif layer_type == "GELU":
            desc = "GELU()"
            layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

        else:
            desc = layer_type
            layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

    except Exception as e:
        desc = f"[Error Inserting Layer: {e}]"
        layer_configs.insert(index, (desc, layer_type, None, None, None, None, None, None))

    return update_architecture_text()

# This function deletes layers as the user requests on the frontend in the edit tab
def delete_layer(index):
    index = int(index)
    if 0 <= index < len(layer_configs):
        layer_configs.pop(index)
    return update_architecture_text()
    
# This function clears all layers currently selected if the user hits reset on the frontend
def reset_layers():
    layer_configs.clear()
    return ""

# This helper function shows the error messages if dimensions do not match. It does not calculate them it is simply called to show the visual error message. 
# Input: index to highlight at
# Output: A warning printed on the architecture at the mismatch layer
def update_architecture_text(highlight_index=None):
    lines = []
    for i, config in enumerate(layer_configs):
        prefix = f"{i}: "
        desc = config[0]
        if i == highlight_index:
            desc = f"⚠️ {desc}"
        lines.append(prefix + desc)
    return "\n".join(lines)

# This function builds the model the user requests
# Input: Nothing
# Output: The model built and passed into training
def build_model():
    layers = []
    for config in layer_configs:
        _, layer_type, in_dim, out_dim, kernel, padding, stride, bias = config

        if layer_type == "Conv2d":
            layers.append(nn.Conv2d(
                int(in_dim), int(out_dim),
                kernel_size=kernel,
                padding=padding,
                stride=stride,
                bias=bool(bias)
            ))

        elif layer_type == "Linear":
            layers.append(nn.Linear(int(in_dim), int(out_dim)))
        
        elif layer_type == "GELU":
            layers.append(nn.GELU())

        elif layer_type == "LeakyReLU":
            slope = in_dim if in_dim is not None else 0.01  
            layers.append(nn.LeakyReLU(negative_slope=float(slope)))

        elif layer_type == "ELU":
            alpha = in_dim if in_dim is not None else 1.0  
            layers.append(nn.ELU(alpha=float(alpha)))

        elif layer_type == "Dropout":
            p = in_dim if in_dim is not None else 0.5  
            layers.append(nn.Dropout(float(p)))
        
        elif layer_type == "Softmax":
            layers.append(nn.Softmax(dim=1))

        elif layer_type == "MaxPool2d":
            layers.append(nn.MaxPool2d(kernel_size=kernel, stride=stride, padding=padding))

        elif layer_type == "AvgPool2d":
            layers.append(nn.AvgPool2d(kernel_size=kernel, stride=stride, padding=padding))

        else:
            # ReLU, Tanh, Sigmoid, Flatten
            layers.append(layer_map[layer_type]())

    return nn.Sequential(*layers)


# --------------------------------
# Data Loaders
# --------------------------------

# This helper function extracts the zip file provided for images to a custom path as the user requests on the frontend or into a C: drive temp file
# Input: File and custom path
# Output: The unzipped file st the custom path or at the C: drive
def extract_zip_to_tempdir(file_like, custom_path=None):

    if custom_path:
        os.makedirs(custom_path, exist_ok=True)
        temp_dir = tempfile.mkdtemp(dir=custom_path)
    else:
        temp_dir = tempfile.mkdtemp()
    with zipfile.ZipFile(file_like, 'r') as zip_ref:
        zip_ref.extractall(temp_dir)
    return temp_dir

ImageFile.LOAD_TRUNCATED_IMAGES = True

# This helper function checks if any images are corrupted. To prevent a crash it skips them and makes a info bubble. 
# Input: Path and number of channels of the images
# Output: Info bubble if corrupt image
def safe_pil_loader(path, num_channels=3):
    try:
        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB' if num_channels == 3 else 'L')
            img.load()
            return img
    except Exception as e:
        gr.Info(f"Skipping corrupted image: {path} — {e}")
        return None

# This class safely loads images from a folder, skipping invalid or unreadable files.
# Input: root (dataset folder path), transform (optional image transforms), num_channels (number of image channels)
# Output: Filters out invalid images during dataset initialization, ensuring only valid samples are used.
class SafeImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, num_channels=3):
        loader_with_channels = lambda path: safe_pil_loader(path, num_channels)
        super().__init__(root, transform=transform, loader=loader_with_channels)
        
        before = len(self.samples)
        self.samples = [s for s in self.samples if loader_with_channels(s[0]) is not None]
        after = len(self.samples)
        self.imgs = self.samples

    def __getitem__(self, index):
        path, target = self.samples[index]
        img = self.loader(path)
        if img is None:
            gr.Warning(f"Could not load image at {path}")
        if self.transform is not None:
            img = self.transform(img)
        return img, target

# This function loads data. It calls the helper functions and loads clean, uncorrupted data. 
# Input: file, custom_path=None, batch_size=32, image_size=28,  num_channels=3, loss_fn=None
# Output: Clean data ready for training
def load_data(file, custom_path=None, batch_size=32, image_size=28,  num_channels=3, loss_fn=None):
    ext = os.path.splitext(file)[1].lower()
    if ext == ".csv":

        df = pd.read_csv(file)
        if 'y' not in df.columns:
            gr.Warning("CSV must contain a 'y' column.")
        
        X = df.drop(columns=['y']).values.astype(np.float32)
        if isinstance(loss_fn, nn.CrossEntropyLoss):
            y = df['y'].values.astype(np.int64)
        elif isinstance(loss_fn, (nn.MSELoss, nn.BCEWithLogitsLoss)):
            y = df['y'].values.astype(np.float32)
        else:
            gr.Warning(f"Unhandled loss function type: {type(loss_fn)}")
    
        X, y = shuffle(X, y)

        return {
            "type": "tabular",
            "train": (X, y),
            "path": None
        }

    elif ext == ".zip":
        with open(file, 'rb') as f:
            data_dir = extract_zip_to_tempdir(f, custom_path)
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])
        dataset = SafeImageFolder(data_dir, transform=transform, num_channels=num_channels)
        train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        return {
            "type": "image",
            "train": train_loader,
            "path": data_dir
        }
    else:
        gr.Warning("Unsupported file format. Provide a .csv or .zip path.")


# --------------------------------
# Graph Functions
# --------------------------------


# This function flattens the model parameters into a vector which then appears as a ball on the 3d plot. 
# Input: model
# Output: flattened parameters
def get_flat_weights(model):
    return torch.cat([p.detach().flatten() for p in model.parameters()])

# This function plots the 2d graph showing the loss history of the model.
# Input: Loss_history
# Output: A 2d graph showing the loss history per epoch
def generate_loss_plot(loss_history):
    fig, ax = plt.subplots()
    ax.plot(loss_history)
    ax.set_title("Loss over Epochs")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    plt.savefig(tmpfile.name)
    tmpfile.close()
    plt.close(fig)
    return tmpfile.name

# This function generates the 3d video of teh loss
# Input: weight_path, loss_history, output_path, target_frames=300, frame_rate=10
# Output: A video showing the function descending to a minima or a warning
def generate_3d_animation_pca(weight_path, loss_history, output_path, target_frames=300, frame_rate=10):

    if len(loss_history) == 0 or np.any(np.isnan(loss_history)):
        gr.Warning("❌ loss_history is empty or contains NaNs — cannot animate.")
        return

    if len(weight_path) < target_frames:
        target_frames = len(weight_path)


    indices = np.linspace(0, len(weight_path) - 1, target_frames).astype(int)
    weight_path_sampled = weight_path[indices]

    pca = PCA(n_components=2)
    reduced = pca.fit_transform(weight_path_sampled)

    interpolated_loss = np.interp(
        np.linspace(0, len(loss_history) - 1, target_frames),
        np.arange(len(loss_history)),
        loss_history
    )

    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_title("Training Progression")
    ax.set_xlabel("PCA 1")
    ax.set_ylabel("PCA 2")
    ax.set_zlabel("Loss")

    ax.plot(reduced[:, 0], reduced[:, 1], interpolated_loss, color='red', alpha=0.6)
    point, = ax.plot([reduced[0, 0]], [reduced[0, 1]], [interpolated_loss[0]], 'ro')

    def update(i):
        point.set_data([reduced[i, 0]], [reduced[i, 1]])
        point.set_3d_properties([interpolated_loss[i]])
        return point,

    interval = 1000 / frame_rate  
    ani = animation.FuncAnimation(
        fig, update,
        frames=target_frames,
        interval=interval,
        blit=True
    )

    ani.save(output_path, writer='ffmpeg', fps=frame_rate, codec='libx264')
    plt.close(fig)

    return output_path

# --------------------------------
# Training Functions
# --------------------------------

# This function makes sure that the model being built is a valid model before training
# Input: model, data_type, image_size=28, num_features=None, num_channels=3
# Output: errors if the model is not valid and a three tuple if the model is valid
def validate_model_forward_pass(model, data_type, image_size=28, num_features=None, num_channels=3):
    try:
        if data_type == "tabular":
            assert num_features is not None, "Missing number of input features for tabular data"
            dummy_input = torch.randn(1, num_features).to(device)
        else:
            dummy_input = torch.randn(1, num_channels, image_size, image_size).to(device)

        x = dummy_input
        for idx, layer in enumerate(model):
            try:
                x = layer(x)
            except Exception as e:
                return False, f"Shape mismatch at layer {idx}: {layer.__class__.__name__} — {str(e)}", idx
        return True, None, None
    except Exception as e:
        return False, f"Unexpected validation error: {str(e)}", None

# This function validates the entire pipeline through a series of tests
# Input: X_tensor, y_tensor, model, loss_fn, batch_size=None, auto_fix=True
# Output: A gr.Warning if any of the tests is failed and X_tensor, y_tensor which are potentially modified tensors after validation and auto-fixes
def full_pipeline_validator(X_tensor, y_tensor, model, loss_fn, batch_size=None, auto_fix=True):

    if X_tensor.shape[0] == 0 or y_tensor.shape[0] == 0:
        gr.Warning("❌ Dataset is empty. Check your data loading pipeline.")


    if torch.isnan(X_tensor).any() or torch.isnan(y_tensor).any():
        gr.Warning("❌ NaN detected in input or target tensor.")
    if torch.isinf(X_tensor).any() or torch.isinf(y_tensor).any():
        gr.Warning("❌ Inf detected in input or target tensor.")
    
    gr.Info("✅ No NaNs or Infs detected.")


    if X_tensor.dtype != torch.float32:
        if auto_fix:
            gr.Info(f"⚠️ Auto-fixing X_tensor dtype from {X_tensor.dtype} to torch.float32")
            X_tensor = X_tensor.float()
        else:
            gr.Warning(f"❌ X_tensor dtype must be float32, got {X_tensor.dtype}")

    if isinstance(loss_fn, nn.CrossEntropyLoss):
        if y_tensor.dtype != torch.long:
            if auto_fix:
                gr.Info(f"⚠️ Auto-fixing y_tensor dtype from {y_tensor.dtype} to torch.long")
                y_tensor = y_tensor.long()
            else:
                gr.Warning(f"❌ For CrossEntropyLoss, y_tensor dtype must be torch.long, got {y_tensor.dtype}")
    else:
        if y_tensor.dtype != torch.float32:
            if auto_fix:
                gr.Info(f"⚠️ Auto-fixing y_tensor dtype from {y_tensor.dtype} to torch.float32")
                y_tensor = y_tensor.float()
            else:
                gr.Warning(f"❌ y_tensor dtype must be float32, got {y_tensor.dtype}")

    if len(X_tensor.shape) != 2:
        gr.Warning(f"❌ X_tensor must be 2D [batch_size, features], got {X_tensor.shape}")

    if isinstance(loss_fn, (nn.MSELoss, nn.BCEWithLogitsLoss)):
        expected_y_shape = (X_tensor.shape[0], 1)
        if y_tensor.shape != expected_y_shape:
            if auto_fix:
                gr.Info(f"⚠️ Auto-reshaping y_tensor from {y_tensor.shape} to {expected_y_shape}")
                y_tensor = y_tensor.view(-1, 1)
            else:
                gr.Warning(f"❌ y_tensor should have shape {expected_y_shape}, but got {y_tensor.shape}")

    elif isinstance(loss_fn, nn.CrossEntropyLoss):
        expected_y_shape = (X_tensor.shape[0],)
        if y_tensor.shape != expected_y_shape:
            if auto_fix:
                gr.Info(f"⚠️ Auto-reshaping y_tensor from {y_tensor.shape} to {expected_y_shape}")
                y_tensor = y_tensor.view(-1)
            else:
                gr.Warning(f"❌ y_tensor should have shape {expected_y_shape}, but got {y_tensor.shape}")

    if batch_size is not None and X_tensor.shape[0] < batch_size:
        gr.Info(f"⚠️ Batch size {batch_size} is larger than dataset size {X_tensor.shape[0]}, adjusting.")
        batch_size = X_tensor.shape[0]

    try:
        model.eval()  
        with torch.no_grad():
            dummy_out = model(X_tensor[:1])
            if isinstance(loss_fn, nn.CrossEntropyLoss):
                if dummy_out.ndim != 2:
                    gr.Warning(f"❌ Model output for CrossEntropyLoss should be 2D [batch_size, num_classes], got {dummy_out.shape}")
            else:
                if dummy_out.shape != y_tensor[:1].shape:
                    gr.Warning(f"❌ Model output shape {dummy_out.shape} does not match target shape {y_tensor[:1].shape}")
    except Exception as e:
        gr.Warning(f"❌ Model forward pass failed: {e}")

 
    return X_tensor, y_tensor

# This function gets the device status shown on the frontend. 
# Input: Nothing
# Output: The device used and any potential cuda errors
def get_device_status():
    if not torch.cuda.is_available():
        return f"🔴 CPU: {cpuinfo.get_cpu_info()}"
    try:
        device_name = torch.cuda.get_device_name(0)
        mem_allocated = torch.cuda.memory_allocated(0) / (1024 ** 2)  
        mem_reserved = torch.cuda.memory_reserved(0) / (1024 ** 2)  
        total_mem = torch.cuda.get_device_properties(0).total_memory / (1024 ** 2)  
        mem_free = total_mem - mem_reserved

        try:
            torch.cuda.synchronize()
            sync_status = "✅ GPU context healthy."
        except Exception as sync_error:
            sync_status = f"❌ GPU context error: {sync_error}"

        return (
            f"✅ CUDA Device: {device_name}<br>"
            f"Memory allocated: {mem_allocated:.2f} MB<br>"
            f"Memory reserved: {mem_reserved:.2f} MB<br>"
            f"Free memory (estimated): {mem_free:.2f} MB<br>"
            f"{sync_status}"
        )

    except Exception as e:
        return f"❌ CUDA error: {str(e)}"

# This is the function that trains the model. It calls many helper functions also and checks many parameters the user passes in from the frontend
# Input: loss_name, opt_name, lr, batch_size='32', image_size='28', file=None, custom_path=None, epochs='100', num_channels=3, generate_animation=False,  target_frames=None, frame_rate=None
# Output: Warnings, Info and inputs to graph functions and model for user to download. It will also delete the folder that was created to unzip files into at the end of training 
def train_model(loss_name, opt_name, lr, batch_size='32', image_size='28', file=None, custom_path=None, epochs='100', num_channels=3, generate_animation=False,  target_frames=None, frame_rate=None):

    animation_path = None
    loss_plot_path = None
    model_path = None
    final_logs = ""
    try: 
        target_frames = int(target_frames)
        if target_frames<=0:
            yield None, None, None, update_architecture_text(), "❌ Target frames must be a positive integer."
            return
    except:
        yield None, None, None, update_architecture_text(), "❌ Target frames must be numeric."
        return

    try:
        frame_rate = int(frame_rate)
        if frame_rate<=0:
            yield None, None, None, update_architecture_text(), "❌ Frame Rate must be a positive integer."
            return
    except:
        yield None, None, None, update_architecture_text(), "❌ Frame Rate must be numeric."
        return
            
            
    try:
        channels = int(num_channels)
        if channels not in [1, 3]:
            yield None, None, None, update_architecture_text(), "❌ Channels must be 1 or 3."
            return
    except:
        yield None, None, None, update_architecture_text(), "❌ Channels must be numeric."
        return

    # 1) Convert epochs to int with early validation
    try:
        max_epochs = int(epochs)
        if max_epochs <= 0:
            # Immediately yield error and stop
            yield None, None, None, update_architecture_text(), "❌ Epochs must be a positive integer."
            return
    except:
        yield None, None, None, update_architecture_text(), "❌ Epochs must be a valid number."
        return

    try:
        # 2) If no model is configured
        if not layer_configs:
            yield None, None, None, update_architecture_text(), "❌ No model configured! Please add at least one trainable layer."
            return

        # 3) If file path is invalid
        if not os.path.exists(file):
            msg = f"❌ File not found: {file}\nPlease check the path and try again."
            yield None, None, None, update_architecture_text(), msg
            return

        # 4) If file type is not .csv or .zip
        if not (file.endswith('.csv') or file.endswith('.zip')):
            yield None, None, None, update_architecture_text(), "❌ Invalid file type. Please provide a .csv or .zip file."
            return

        # 5) If custom directory is invalid
        if custom_path and not os.path.isdir(custom_path):
            try:
                os.makedirs(custom_path, exist_ok=True)
            except Exception as e:
                msg = f"❌ Could not create directory '{custom_path}': {e}"
                yield None, None, None, update_architecture_text(), msg
                return

        # 6) Build model
        lr = float(lr)
        if torch.cuda.is_available():
            try:
                torch.cuda.synchronize()  # Will raise if CUDA state is bad
            except RuntimeError as e:
                raise gr.Error("❌ CUDA failure detected. Please restart the dashboard or kernel.")

        model = build_model().to(device)

        # 7) Validate forward pass
        if file.endswith(".csv"):
            df = pd.read_csv(file)
            if 'y' not in df.columns:
                yield None, None, None, update_architecture_text(), "❌ CSV missing 'y' column."
                return
            num_features = df.shape[1] - 1
            is_valid, error_msg, bad_layer_idx = validate_model_forward_pass(model, "tabular", num_features=num_features)
        else:
            is_valid, error_msg, bad_layer_idx = validate_model_forward_pass(model, "image", image_size=int(image_size), num_channels=num_channels)

        if not is_valid:
            # highlight offending layer + yield
            updated_view = update_architecture_text(highlight_index=bad_layer_idx)
            yield None, None, None, updated_view, ""
            return

        # 8) Check for trainable params
        if not any(p.requires_grad for p in model.parameters()):
            yield None, None, None, update_architecture_text(), "⚠️ Model has no trainable parameters. Add a Linear or Conv2d layer."
            return

        # 9) Validate batch_size
        try:
            batch_size = int(batch_size)
            if batch_size <= 0:
                yield None, None, None, update_architecture_text(), "❌ Batch size must be a positive integer"
                return
        except:
            yield None, None, None, update_architecture_text(), "❌ Batch size must be a valid number"
            return

        # 10) Validate image_size
        try:
            image_size = int(image_size)
            if image_size <= 0:
                yield None, None, None, update_architecture_text(), "❌ Image size must be a positive integer"
                return
        except:
            yield None, None, None, update_architecture_text(), "❌ Image size must be a valid number"
            return

        # 11) Load data
        loss_fn = nn.MSELoss() if loss_name == 'MSELoss' else nn.CrossEntropyLoss()
        data = load_data(file, custom_path, batch_size=batch_size, image_size=image_size, num_channels=channels, loss_fn = loss_fn)
        optimizer = optim.SGD(model.parameters(), lr=lr) if opt_name == 'SGD' else optim.Adam(model.parameters(), lr=lr)

        loss_history = []
        weight_path = []
        status_lines = []
        
        # 12) Train

        if data["type"] == "tabular":
            X_train, y_train = data["train"]
        
            # Convert full dataset tensors (before DataLoader)
            X_train = torch.tensor(X_train, dtype=torch.float32)
        
            if isinstance(loss_fn, nn.MSELoss):
                y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)  
            elif isinstance(loss_fn, nn.CrossEntropyLoss):
                y_train = torch.tensor(y_train, dtype=torch.long).view(-1)  
            elif isinstance(loss_fn, nn.BCEWithLogitsLoss):
                y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)  
            else:
                yield None, None, None, update_architecture_text(), "❌ Error loading training data"
                return
        
            # DataLoader
            train_dataset = TensorDataset(X_train, y_train)
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
            for epoch in range(1, max_epochs + 1):
                epoch_loss = 0
                num_batches = 0
        
                for X_batch, y_batch in train_loader:
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device)
        
                    optimizer.zero_grad()
        
                    if isinstance(loss_fn, nn.MSELoss):
                        y_input = y_batch.float().view(-1, 1)
                    elif isinstance(loss_fn, nn.CrossEntropyLoss):
                        y_input = y_batch.long().view(-1)  # CrossEntropy expects class indices, not 2D
                    elif isinstance(loss_fn, nn.BCEWithLogitsLoss):
                        y_input = y_batch.float().view(-1, 1)
                    else:
                        yield None, None, None, update_architecture_text(), "❌ Unhandled Loss Function"
                        return
        
                    
                    out = model(X_batch)
                    loss = loss_fn(out, y_input)
                    loss.backward()
                    optimizer.step()
        
                    epoch_loss += loss.item()
                    num_batches += 1
        
                    weight_path.append(get_flat_weights(model).cpu().numpy())
        
                avg_epoch_loss = epoch_loss / num_batches
                loss_history.append(avg_epoch_loss)
        
                status_lines.append(f"Epoch {epoch}/{max_epochs} — Loss: {avg_epoch_loss:.4f}")
                yield None, None, None, update_architecture_text(), "\n\n".join(status_lines)

        else:
            # image data
            train_loader = data["train"]
            for epoch in range(1, max_epochs + 1):
                epoch_loss = 0
                num_batches = 0
                for X_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device)

                    if isinstance(loss_fn, nn.MSELoss):
                        # One-hot for MSE
                        y_input = torch.nn.functional.one_hot(
                            y_batch,
                            num_classes=len(torch.unique(y_batch))
                        ).float().to(device)
                    else:
                        y_input = y_batch

                    out = model(X_batch)
                    if isinstance(loss_fn, nn.MSELoss):
                        out = torch.softmax(out, dim=1)

                    loss = loss_fn(out, y_input)
                    loss.backward()
                    optimizer.step()

                    epoch_loss += loss.item()
                    num_batches += 1

                    weight_path.append(get_flat_weights(model).cpu().numpy())


                # Average loss for the epoch
                avg_epoch_loss = epoch_loss / num_batches
                loss_history.append(avg_epoch_loss)

                status_lines.append(f"Epoch {epoch}/{max_epochs} — Loss: {avg_epoch_loss:.4f}")
                yield None, None, None, update_architecture_text(), "\n\n".join(status_lines)

            # Cleanup extracted images
            if data["path"]:
                try:
                    shutil.rmtree(data["path"])
                except Exception as e:
                    gr.Warning(f"Warning: Could not delete extracted folder: {e}")

        loss_plot_path = os.path.join(OUTPUT_DIR, "loss_plot.png")
        fig, ax = plt.subplots()
        ax.plot(loss_history)
        ax.set_title("Loss over Epochs")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        plt.savefig(loss_plot_path)
        plt.close(fig)
        animation_path = os.path.join(OUTPUT_DIR, "animation.mp4")


        if generate_animation:

            generate_3d_animation_pca(
            np.array(weight_path),
            loss_history,
            animation_path, 
            target_frames=target_frames,
            frame_rate=frame_rate
            )
        else:
            create_dummy_video(animation_path)
        # Save the trained model in OUTPUT_DIR
        model_path = os.path.join(OUTPUT_DIR, "trained_model.pt")
        torch.save(model, model_path)

        final_logs = "\n".join(status_lines)

        for path in [loss_plot_path, animation_path, model_path]:
            if os.path.isdir(path):
                gr.Warning(f"❌ Output path is a directory, expected a file: {path}")
        
        del model
        gc.collect()
        torch.cuda.empty_cache()

        yield loss_plot_path,  animation_path, model_path, update_architecture_text(), final_logs

    except gr.Error as e:
        raise e
    except Exception as e:
        raise gr.Error(f"❌ Unexpected error: {str(e)}")

# This function gets a subfolder in the home directory which is not protected and is writeable
# Input: Nothing
# Output: A writeable folder
def get_default_writable_folder():

    home_dir = os.path.expanduser("~")  # e.g. C:\\Users\\Alice on Windows
    default_path = os.path.join(home_dir, "my_gradio_data")
    os.makedirs(default_path, exist_ok=True)
    return default_path

# This function wraps the training operation, ensures a valid save path, and yields training outputs
# Input: loss_name, opt_name, lr, batch_size, image_size, 
                                  #file, custom_path, epochs, num_channels, 
                                 # generate_animation, target_frames, frame_rate
# Output: A generator which returns training progress outputs from train_model
def train_model_with_default_path(loss_name, opt_name, lr, batch_size, image_size, 
                                  file, custom_path, epochs, num_channels, 
                                  generate_animation, target_frames, frame_rate
):
    if not custom_path or custom_path.strip() == "":
        custom_path = get_default_writable_folder()

    for item in train_model(loss_name, opt_name, lr, 
                            batch_size, image_size, file, 
                            custom_path, epochs, num_channels, 
                            generate_animation, target_frames, frame_rate
    ):
        yield item


# This function creates a dummy video to prevent the gradio frontend from crashing if the user selects no video
# Input: Output_path to put the video in
# Output: A black video for the gradio screen
def create_dummy_video(output_path):
    command = [
        "ffmpeg",
        "-f", "lavfi",
        "-i", "color=c=black:s=1280x720:d=5", 
        "-c:v", "libx264",
        "-t", "5",
        "-pix_fmt", "yuv420p",
        "-y", output_path
    ]
    try:
        result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        gr.Warning(f"❌ Error creating dummy video: {e.stderr.decode()}")

# --------------------------------
# Gradio wiring
# --------------------------------

with gr.Blocks() as dashboard:

    with gr.Tab("Build Model"):
        gr.Markdown("## Build a Layer")
        builder_arch = gr.Textbox(label="Architecture So Far", lines=6)

        layer_type_dropdown = gr.Dropdown(choices=list(layer_map.keys()), label="Layer Type", value="Linear")
        in_dim = gr.Textbox(label="Input Dim (Linear/Conv2d)")
        out_dim = gr.Textbox(label="Output Dim (Linear/Conv2d)")

        conv_kernel = gr.Textbox(label="Kernel Size", value="3", visible=False)
        conv_padding = gr.Textbox(label="Padding", value="1", visible=False)
        conv_stride = gr.Textbox(label="Stride", value="1", visible=False)
        conv_bias = gr.Checkbox(label="Include Bias", value=True, visible=False)

        pool_kernel = gr.Textbox(label="Pool Kernel Size", value="2", visible=False)
        pool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        pool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        avgpool_kernel = gr.Textbox(label="AvgPool Kernel Size", value="2", visible=False)
        avgpool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        avgpool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        leaky_relu_slope = gr.Textbox(label="Negative Slope", value="0.01", visible=False)
        elu_alpha = gr.Textbox(label="ELU Alpha", value="1.0", visible=False)

        add_btn = gr.Button("Add Layer")
        add_btn.click(
            fn=add_layer,
            inputs=[
                layer_type_dropdown, in_dim, out_dim,
                conv_kernel, conv_padding, conv_stride, conv_bias,
                pool_kernel, pool_stride, pool_padding,
                avgpool_kernel, avgpool_stride, avgpool_padding,
                leaky_relu_slope, elu_alpha
            ],
            outputs=builder_arch
        )

        def toggle_fields(layer_type):
            is_conv = (layer_type == "Conv2d")
            is_pool = (layer_type == "MaxPool2d")
            is_avgpool = (layer_type == "AvgPool2d")
            is_leaky = (layer_type == "LeakyReLU")
            is_elu = (layer_type == "ELU")
            return [
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_leaky),
                gr.update(visible=is_elu),
            ]

        layer_type_dropdown.change(
            toggle_fields,
            inputs=[layer_type_dropdown],
            outputs=[
                conv_kernel, conv_padding, conv_stride, conv_bias,
                pool_kernel, pool_stride, pool_padding,
                avgpool_kernel, avgpool_stride, avgpool_padding,
                leaky_relu_slope, elu_alpha
            ]
        )

        reset_btn = gr.Button("Reset Layers")
        reset_btn.click(fn=reset_layers, inputs=[], outputs=builder_arch)

        gr.Markdown("### Edit or Delete a Layer")
        layer_index = gr.Number(label="Layer Index (0-based)", precision=0)
        new_layer_type = gr.Dropdown(list(layer_map.keys()), label="New Layer Type")
        new_in_dim = gr.Textbox(label="New Input Dim (if applicable)")
        new_out_dim = gr.Textbox(label="New Output Dim (if applicable)")

        edit_kernel = gr.Textbox(label="Kernel Size", value="3", visible=False)
        edit_padding = gr.Textbox(label="Padding", value="1", visible=False)
        edit_stride = gr.Textbox(label="Stride", value="1", visible=False)
        edit_bias = gr.Checkbox(label="Include Bias", value=True, visible=False)

        edit_pool_kernel = gr.Textbox(label="Pool Kernel Size", value="2", visible=False)
        edit_pool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        edit_pool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        edit_avgpool_kernel = gr.Textbox(label="AvgPool Kernel Size", value="2", visible=False)
        edit_avgpool_stride = gr.Textbox(label="Stride", value="2", visible=False)
        edit_avgpool_padding = gr.Textbox(label="Padding", value="0", visible=False)

        edit_leaky_relu_slope = gr.Textbox(label="Negative Slope", value="0.01", visible=False)
        edit_elu_alpha = gr.Textbox(label="ELU Alpha", value="1.0", visible=False)

        edit_btn = gr.Button("Edit Layer")
        delete_btn = gr.Button("Delete Layer")
        insert_btn = gr.Button("Insert New Layer")

        edit_btn.click(
            fn=update_layer,
            inputs=[
                layer_index, new_layer_type, new_in_dim, new_out_dim,
                edit_kernel, edit_padding, edit_stride, edit_bias,
                edit_pool_kernel, edit_pool_stride, edit_pool_padding,
                edit_avgpool_kernel, edit_avgpool_stride, edit_avgpool_padding,
                edit_leaky_relu_slope, edit_elu_alpha
            ],
            outputs=builder_arch
        )

        delete_btn.click(fn=delete_layer, inputs=[layer_index], outputs=builder_arch)
        insert_btn.click(
            fn=insert_layer,
            inputs=[
                layer_index, new_layer_type, new_in_dim, new_out_dim,
                edit_kernel, edit_padding, edit_stride, edit_bias,
                edit_pool_kernel, edit_pool_stride, edit_pool_padding,
                edit_avgpool_kernel, edit_avgpool_stride, edit_avgpool_padding,
                edit_leaky_relu_slope, edit_elu_alpha
            ],
            outputs=builder_arch
        )

        def toggle_edit_fields(layer_type):
            is_conv = (layer_type == "Conv2d")
            is_pool = (layer_type == "MaxPool2d")
            is_avgpool = (layer_type == "AvgPool2d")
            is_leaky = (layer_type == "LeakyReLU")
            is_elu = (layer_type == "ELU")
            return [
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_conv),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_pool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_avgpool),
                gr.update(visible=is_leaky),
                gr.update(visible=is_elu),
            ]

        new_layer_type.change(
            toggle_edit_fields,
            inputs=[new_layer_type],
            outputs=[
                edit_kernel, edit_padding, edit_stride, edit_bias,
                edit_pool_kernel, edit_pool_stride, edit_pool_padding,
                edit_avgpool_kernel, edit_avgpool_stride, edit_avgpool_padding,
                edit_leaky_relu_slope, edit_elu_alpha
            ]
        )

        # Device info row
        with gr.Row():
            device_output = gr.Markdown(get_device_status())
            refresh_button = gr.Button("🔄 Refresh Device Status")
            refresh_button.click(fn=get_device_status, inputs=[], outputs=device_output)

    with gr.Tab("Train"):
        gr.Markdown("## Train the Model")

       
        loss_dropdown = gr.Dropdown(['MSELoss', 'CrossEntropyLoss'], label='Loss Function')
        opt_dropdown = gr.Dropdown(['SGD', 'Adam'], label='Optimizer')
        lr_box = gr.Textbox(value="0.01", label="Learning Rate")
        batch_box = gr.Textbox(value="32", label="Batch Size")
        size_box = gr.Textbox(value="28", label="Image Resize (e.g. 28x28)")
        file_box = gr.Textbox(label="Path to CSV or ZIP")
        custom_box = gr.Textbox(label="Custom Extraction Path (optional)")
        epochs_box = gr.Textbox(value="100", label="Epochs")
        generate_3d_checkbox = gr.Checkbox(label="Generate 3D Descent Animation (⚠️ Slower, CPU/RAM-intensive)", value=False)
        generate_3d_targetframes = gr.Textbox(value = "300", label = "Target Frames for Video")
        generate_3d_framerate = gr.Textbox(value="10", label="Frame Rate (Frames per Second)")
        channel_dropdown = gr.Dropdown([1, 3], label="Input Channels (1 = Grayscale, 3 = RGB)", value=3)

        # Outputs
        loss_curve = gr.Image(label="Loss Curve")
        animation_video = gr.Video(label="3D Descent Animation")
        model_file = gr.File(label="Download Trained Model")  
        log_box = gr.Markdown(label="Log")
        train_button = gr.Button("Start Training")
        train_button.click(
            fn=train_model_with_default_path,
            inputs=[
                loss_dropdown,
                opt_dropdown,
                lr_box,
                batch_box,
                size_box,
                file_box,
                custom_box,
                epochs_box,
                channel_dropdown,
                generate_3d_checkbox,
                generate_3d_targetframes,
                generate_3d_framerate

            ],
            outputs=[
                loss_curve,
                animation_video,
                model_file,
                builder_arch,
                log_box
            ]
        )

dashboard.queue()
dashboard.launch(share=True)

  from .autonotebook import tqdm as notebook_tqdm


* Running on local URL:  http://127.0.0.1:7861
