In [None]:
# Gradio UI definition
import gradio as gr
from pathlib import Path

def summarize_interface(uploaded_files, raw_text, selected_indices,
                        max_length, min_length, num_return_sequences, temperature, num_beams, use_fp16):
    """Generator that yields (display_text, download_path|None) to stream progress in Gradio.
    """
    # Collect texts
    inputs = []
    names = []
    if uploaded_files:
        for f in uploaded_files:
            try:
                txt = load_text_from_uploaded(f)
            except Exception as e:
                txt = f'<<ERROR reading file: {e}>>'
            inputs.append(clean_text(txt))
            names.append(getattr(f, 'name', 'uploaded'))
    if raw_text and raw_text.strip():
        inputs.append(clean_text(raw_text))
        names.append('pasted_text')

    if not inputs:
        yield ('No inputs provided. Upload files or paste text.', None)
        return

    # Filter by selected indices like '0,2'
    if selected_indices and selected_indices.strip():
        try:
            parts = [p.strip() for p in selected_indices.split(',') if p.strip()!='']
            idxs = [int(p) for p in parts]
            inputs = [inputs[i] for i in idxs if 0 <= i < len(inputs)]
            names = [names[i] for i in idxs if 0 <= i < len(names)]
        except Exception:
            # ignore selection errors and continue with all inputs
            pass

    try:
        load_model(use_fp16_on_cuda=use_fp16)
    except Exception as e:
        yield (f'Error loading model: {e}', None)
        return

    summaries = []
    tmp_dir = '/tmp'
    # Process each input and yield partial results so Gradio shows progress
    for i, text in enumerate(inputs):
        try:
            res = summarize_texts([text], max_length=max_length, min_length=min_length, num_return_sequences=num_return_sequences, temperature=temperature, num_beams=num_beams)
            summary_text = '\n\n'.join(res[0])
            summaries.append(summary_text)
        except Exception as e:
            summaries.append(f'<<Error during summarization: {e}>>')
        # Build current display and yield
        display = []
        for n, s in zip(names, summaries):
            display.append(f'--- {n} ---\n{ s }')
        display_text = '\n\n'.join(display)
        yield (display_text, None)

    # After all done, prepare download file(s)
    if len(summaries) == 1:
        out_path = f'{tmp_dir}/summary.txt'
        with open(out_path, 'w', encoding='utf-8') as fh:
            fh.write(summaries[0])
        yield ('All done', out_path)
    else:
        zip_path = f'{tmp_dir}/summaries.zip'
        make_zip_from_texts(summaries, names, zip_path)
        yield ('All done', zip_path)

# Build Gradio components
with gr.Blocks() as demo:
    gr.Markdown('Upload notes (.txt, .md, .pdf) or paste text. Select files to summarize and press Summarize.')
    with gr.Row():
        file_input = gr.File(file_count='multiple', label='Upload note files')
        text_input = gr.Textbox(lines=8, placeholder='Paste note text here', label='Raw text')
    with gr.Row():
        max_length = gr.Slider(16, 1024, value=128, step=8, label='max_length')
        min_length = gr.Slider(8, 512, value=30, step=1, label='min_length')
    with gr.Row():
        num_return_sequences = gr.Slider(1, 5, value=1, step=1, label='num_return_sequences')
        temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label='temperature')
        num_beams = gr.Slider(1, 8, value=4, step=1, label='num_beams')
    use_fp16 = gr.Checkbox(value=True, label='use_fp16_on_cuda (if GPU)')
    selected_indices = gr.Textbox(lines=1, placeholder='e.g. 0,2 to pick first and third uploaded files (or leave empty)', label='selected_indices')
    summarize_btn = gr.Button('Summarize')
    output = gr.Textbox(label='Summaries (display)')
    download_output = gr.File(label='Download results')

    summarize_btn.click(fn=summarize_interface, inputs=[file_input, text_input, selected_indices, max_length, min_length, num_return_sequences, temperature, num_beams, use_fp16], outputs=[output, download_output])

demo