In [None]:
!pip install gradio

In [6]:
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
# from my_projects.modeling.train import main as train_main
# from my_projects.modeling.inference import run_inference
from pathlib import Path


calibration_path = Path("../reports/figures/calibration.png")
confusion_path = Path("../reports/figures/confusion_matrix.png")

def train_and_visualize(dataset_choice, model_choice, seed, epochs):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    calib_img = mpimg.imread(calibration_path)
    axes[0].imshow(calib_img)
    axes[0].set_title("Calibration Plot")
    axes[0].axis('off')
    
    cm_img = mpimg.imread(confusion_path)
    axes[1].imshow(cm_img)
    axes[1].set_title("Confusion Matrix")
    axes[1].axis('off')
    
    metrics_text = (
        f"**Dataset:** {dataset_choice}\n"
        f"**Model:** {model_choice}\n"
        f"**Seed:** {seed}\n"
        f"**Epochs:** {epochs}\n"
        f"**Test Accuracy:** 0.85"
    )
    
    status = "✅ Training simulated"
    
    return fig, metrics_text, status

with gr.Blocks(title="Animal Classification Dashboard") as demo:
    with gr.Tab("Data"):
        dataset_choice = gr.Dropdown(["Full dataset", "Reduced dataset"], value="Full dataset", label="Dataset")
        seed = gr.Number(value=123, label="Random Seed")
    
    with gr.Tab("Model"):
        model_choice = gr.Dropdown(["VGG16", "VGG11"], value="VGG16", label="Model")
        epochs = gr.Number(value=5, label="Epochs")
    
    with gr.Tab("Results"):
        plot_output = gr.Plot()
        metrics_output = gr.Markdown()
        status_output = gr.Markdown()
    
    inputs = [dataset_choice, model_choice, seed, epochs]
    outputs = [plot_output, metrics_output, status_output]
    
    for input_component in inputs:
        input_component.change(
            fn=train_and_visualize,
            inputs=inputs,
            outputs=outputs
        )
    
    demo.load(fn=train_and_visualize, inputs=inputs, outputs=outputs)

demo.launch()


* Running on local URL:  http://127.0.0.1:7862
* To create a public link, set `share=True` in `launch()`.


