#### Run All to Initialize Dashboard

#### Dashboard code

##### Libraries

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import io
import torch.nn.functional as F

In [None]:
# --- SETUP PATHS & DATA ---
project_root = Path.cwd().parent
model_dir = project_root / "models"
data_dir = project_root / "data" / "processed" / "resized224p"
class_names = sorted([f.name for f in data_dir.iterdir() if f.is_dir()])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- MODEL LOADER ---
def load_selected_model(model_name, num_classes):
    if "resnet50" in model_name:
        model = models.resnet50()
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        target_layer = [model.layer4[-1]]
    elif "efficientnet_b0" in model_name:
        model = models.efficientnet_b0()
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        target_layer = [model.features[-1]]
    elif "convnext_tiny" in model_name:
        model = models.convnext_tiny()
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
        target_layer = [model.features[-1]]
    
    path = model_dir / model_name
    model.load_state_dict(torch.load(path, map_location=device))
    return model.to(device).eval(), target_layer

# --- 3. UI STYLING & WIDGETS ---
header_style = {'description_width': 'initial'}
layout_padding = widgets.Layout(margin='10px 0', width='400px')

model_dropdown = widgets.Dropdown(
    options=[f.name for f in model_dir.glob("*.pth")],
    description='Select Model',
    style=header_style,
    layout=layout_padding
)

# CHANGED: multiple=True
file_upload = widgets.FileUpload(
    accept='image/*', 
    multiple=True, 
    description='Upload Artwork(s)',
    layout=layout_padding
)

file_status = widgets.HTML(
    value="<span style='color: #7f8c8d;'>No files uploaded</span>",
    layout=widgets.Layout(margin='-5px 0 10px 0')
)

analyze_btn = widgets.Button(
    description="Analyze All Paintings", 
    button_style='success', 
    icon='search',
    layout=widgets.Layout(width='400px', height='40px', margin='20px 0')
)

out = widgets.Output()

# --- LOGIC & OBSERVERS ---
def update_status(change):
    num_files = len(change['new'])
    if num_files == 0:
        file_status.value = "<span style='color: #7f8c8d;'>No files uploaded</span>"
    elif num_files == 1:
        fname = change['new'][0]['name']
        file_status.value = f"<span style='color: #27ae60;'>1 file uploaded: <b>{fname}</b></span>"
    else:
        file_status.value = f"<span style='color: #27ae60;'><b>{num_files} files</b> uploaded</span>"

file_upload.observe(update_status, names='value')

def on_analyze_clicked(b):
    with out:
        clear_output()
        if not file_upload.value:
            display(widgets.HTML("<h4 style='color: #e74c3c;'>Please upload images first!</h4>"))
            return
        
        # Load Model once for all images
        model, target_layers = load_selected_model(model_dropdown.value, len(class_names))
        
        # Preprocessing setup
        transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
        norm_transform = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        # --- LOOP THROUGH ALL UPLOADED FILES ---
        for file_info in file_upload.value:
            img_name = file_info['name']
            img_bytes = file_info['content']
            img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
            
            # Predict
            input_tensor = transform(img).unsqueeze(0).to(device)
            with torch.no_grad():
                outputs = model(norm_transform(input_tensor))
                probs = F.softmax(outputs, dim=1)
            
            top_probs, top_idxs = torch.topk(probs, 3)
            best_idx = top_idxs[0][0].item()
            best_artist = class_names[best_idx].replace('_', ' ')
            best_conf = top_probs[0][0].item() * 100

            # --- INDIVIDUAL RESULT DISPLAY ---
            display(widgets.HTML(f"""
                <hr style="border: 1px solid #eee; margin: 40px 0;">
                <div style="background-color: #2c3e50; color: white; padding: 15px; border-radius: 10px; text-align: center;">
                    <h2 style="margin: 0; font-size: 22px;">File: {img_name}</h2>
                    <h1 style="margin: 5px 0; font-size: 28px;">Prediction: {best_artist}</h1>
                    <p style="font-size: 16px; opacity: 0.9;">Confidence: {best_conf:.2f}%</p>
                </div>
            """))

            # Grad-CAM
            cam = GradCAM(model=model, target_layers=target_layers)
            targets = [ClassifierOutputTarget(best_idx)]
            grayscale_cam = cam(input_tensor=norm_transform(input_tensor), targets=targets)[0, :]
            rgb_img = np.float32(img.resize((224, 224))) / 255
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

            # Visualization
            fig, ax = plt.subplots(1, 2, figsize=(12, 5))
            ax[0].imshow(img)
            ax[0].set_title("Original", fontsize=10)
            ax[0].axis('off')
            ax[1].imshow(cam_image)
            ax[1].set_title("Explainability (Grad-CAM)", fontsize=10)
            ax[1].axis('off')
            plt.show()

            # Probabilities
            list_items = []
            for i in range(3):
                name = class_names[top_idxs[0][i].item()].replace('_', ' ')
                score = top_probs[0][i].item() * 100
                list_items.append(f"<li><b>{name}</b>: {score:.1f}%</li>")
            
            display(widgets.HTML(f"<ul>{''.join(list_items)}</ul>"))

analyze_btn.on_click(on_analyze_clicked)

# --- 5. RENDER DASHBOARD ---
dashboard_box = widgets.VBox([
    widgets.HTML("<h2>ðŸŽ¨ Tim! (Art Classifier with Grad-Cam Explainability)</h2>"),
    model_dropdown, 
    file_upload, 
    file_status,
    analyze_btn, 
    out
], layout=widgets.Layout(align_items='center', padding='20px', border_radius='15px'))

# Display Tim's Dashboard

In [None]:
display(dashboard_box)