# Novel Writer Studio - Web UI (Colab)

Run a full story-writing web UI on Colab with GPU acceleration.

**Workflow:**
1. **Cloud AI** (Gemini / GPT) develops your story idea into a detailed plot outline
2. **Your fine-tuned LoRA model** writes each chapter in the trained literary style

### Requirements
- Colab with GPU (A100 recommended for 32B model, T4 works for 8B)
- A trained LoRA adapter (from the training notebook)
- A Gemini or OpenAI API key (for plot generation)

### How to use
1. Run all cells below
2. Upload your LoRA adapter when prompted
3. Click the Gradio public link to open the UI
4. Enter your API key, write your story idea, and generate!

In [None]:
#@title Configuration { display-mode: "form" }

#@markdown ### Model Selection
#@markdown Choose the base model that matches your LoRA adapter.
MODEL_CHOICE = "qwen3_32b" #@param ["qwen3_4b", "qwen3_8b", "qwen3_14b", "qwen3_32b", "llama31_8b", "gemma2_9b", "mistral_nemo_12b"]

#@markdown ### LoRA Upload Mode
#@markdown - **upload_zip**: Upload your LoRA adapter as a .zip file
#@markdown - **google_drive**: Load from Google Drive path
LORA_MODE = "upload_zip" #@param ["upload_zip", "google_drive"]

#@markdown ### Google Drive path (only if LORA_MODE = google_drive)
DRIVE_LORA_PATH = "/content/drive/MyDrive/qwen3_32b_novel_lora" #@param {type:"string"}

MODEL_CONFIGS = {
    'qwen3_4b': 'unsloth/Qwen3-4B',
    'qwen3_8b': 'unsloth/Qwen3-8B',
    'qwen3_14b': 'unsloth/Qwen3-14B',
    'qwen3_32b': 'unsloth/Qwen3-32B',
    'llama31_8b': 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit',
    'gemma2_9b': 'unsloth/gemma-2-9b-it-bnb-4bit',
    'mistral_nemo_12b': 'unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit',
}

BASE_MODEL = MODEL_CONFIGS[MODEL_CHOICE]
print(f'Base model: {BASE_MODEL}')
print(f'LoRA mode: {LORA_MODE}')

In [None]:
#@title Install dependencies
!pip install -q unsloth
!pip install -q --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git
!pip install -q gradio google-genai openai

import torch
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
print('Setup complete!')

In [None]:
#@title Upload / locate LoRA adapter
import os, zipfile
from pathlib import Path

LORA_PATH = None

if LORA_MODE == 'upload_zip':
    from google.colab import files as colab_files
    print('Upload your LoRA adapter zip file:')
    uploaded = colab_files.upload()
    for name in uploaded:
        if name.endswith('.zip'):
            with zipfile.ZipFile(name, 'r') as z:
                z.extractall('/content/')
            # Find the extracted adapter directory
            for d in Path('/content').iterdir():
                if d.is_dir() and (d / 'adapter_config.json').exists():
                    LORA_PATH = str(d)
                    break
        else:
            # Maybe they uploaded the raw files
            os.makedirs('/content/lora_adapter', exist_ok=True)
            os.rename(name, f'/content/lora_adapter/{name}')
            LORA_PATH = '/content/lora_adapter'

elif LORA_MODE == 'google_drive':
    from google.colab import drive
    drive.mount('/content/drive')
    LORA_PATH = DRIVE_LORA_PATH

if LORA_PATH and Path(LORA_PATH).exists():
    print(f'LoRA adapter found: {LORA_PATH}')
    for f in sorted(Path(LORA_PATH).iterdir()):
        print(f'  {f.name} ({f.stat().st_size / 1024:.0f} KB)')
else:
    print(f'ERROR: LoRA adapter not found at {LORA_PATH}')

In [None]:
#@title Load model + LoRA
from unsloth import FastLanguageModel
import torch

print(f'Loading base model: {BASE_MODEL}')
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=4096,
    dtype=None,
    load_in_4bit=True,
)

print(f'Applying LoRA from: {LORA_PATH}')
from peft import PeftModel
model = PeftModel.from_pretrained(model, LORA_PATH)
FastLanguageModel.for_inference(model)

vram = torch.cuda.memory_allocated() / 1e9
print(f'\nModel loaded! VRAM used: {vram:.1f} GB')

In [None]:
#@title Launch Web UI
import gradio as gr
import re, time, json
from pathlib import Path

# ---- System prompts (matching training) ----
ZH_SYSTEM = (
    '你是一位经验丰富的中文小说作家，擅长构建沉浸式的叙事场景。请根据给定的上下文续写故事，要求：\n'
    '1. 保持与原文一致的叙事视角和文风\n'
    '2. 通过具体的动作、对话和环境描写推动情节发展\n'
    '3. 角色的言行应符合其性格特征和当前情境\n'
    '4. 善用感官细节（视觉、听觉、触觉、嗅觉）营造氛围\n'
    '5. 对话要自然生动，符合角色身份和说话习惯\n'
    '6. 避免空洞的心理独白，用行动和细节展现人物内心'
)

EN_SYSTEM = (
    'You are an accomplished fiction author with a gift for immersive storytelling. '
    'Continue the narrative following these principles:\n'
    '1. Maintain the established point of view, voice, and tonal register\n'
    '2. Advance the plot through concrete action, dialogue, and environmental detail\n'
    '3. Show character emotion through behavior, body language, and subtext — not exposition\n'
    '4. Engage multiple senses (sight, sound, touch, smell, taste) to ground scenes\n'
    '5. Write dialogue that reveals character, creates tension, and sounds natural\n'
    '6. Vary sentence rhythm — mix short punchy lines with longer flowing passages'
)


def detect_language(text):
    cjk = sum(1 for c in text[:300] if '\u4e00' <= c <= '\u9fff')
    return 'zh' if cjk > len(text[:300]) * 0.15 else 'en'


# ---- Cloud API plot generation ----
def _build_plot_prompt(idea, num_chapters, lang):
    if lang == 'zh':
        system = (
            '你是一位资深的小说策划编辑和故事架构师。你擅长从简单的故事构思中发展出完整、'
            '引人入胜的小说大纲。你的大纲应该包含极其丰富的细节，足以直接指导AI模型逐章生成高质量的小说内容。'
            '每个章节的大纲都应该详细到可以独立作为写作指南。'
        )
        prompt = f'''请基于以下故事构思，创作一个非常详细的小说大纲：\n\n故事构思：{idea}\n\n请严格按照以下格式输出（使用中文）：\n\n## 小说标题\n[一个引人入胜的标题]\n\n## 故事背景\n[详细的世界观和背景设定，至少200字。包括：时代背景、地理环境、社会体制、文化风俗、特殊设定]\n\n## 主要人物\n（至少4个主要角色，每个角色需包含详细信息）\n- **[角色全名]**（[年龄/外貌简述]）：[性格特点——至少3个性格关键词]，[身份背景]，[核心动机]，[人物弧光]，[与其他角色的关键关系]\n\n## 核心冲突\n[故事的主要矛盾和驱动力，包括外部冲突和内部冲突，至少100字]\n\n## 章节大纲\n（共{num_chapters}章，每章需要非常具体的情节描述）\n'''
        for i in range(1, num_chapters + 1):
            prompt += f'''\n### 第{i}章：[章节标题]\n- **开场场景**：[具体的时间、地点、氛围描写]\n- **主要事件**：[本章发生的1-3个关键事件]\n- **人物互动**：[哪些角色出场，对话和冲突要点]\n- **情感节奏**：[情感基调变化]\n- **关键细节**：[需要着重描写的场景细节]\n- **章末转折**：[悬念、伏笔或转折点]\n'''
        prompt += '\n## 伏笔与线索\n[列出3-5个贯穿全文的伏笔和线索]\n\n## 写作风格指导\n[对本小说整体风格的建议]\n\n请确保章节之间有清晰的因果关系，整体故事有完整的起承转合，每章大纲足够详细。'
    else:
        system = (
            'You are a senior fiction editor and story architect. You excel at developing '
            'simple story concepts into complete, compelling novel outlines with rich detail.'
        )
        prompt = f'''Based on the following story idea, create a very detailed novel outline:\n\nStory idea: {idea}\n\n## Title\n[A compelling title]\n\n## Setting\n[Detailed world-building, at least 200 words]\n\n## Main Characters\n(At least 4, each with detailed profiles)\n\n## Central Conflict\n[Main tension, at least 100 words]\n\n## Chapter Outline\n({num_chapters} chapters, each with specific plot details)\n'''
        for i in range(1, num_chapters + 1):
            prompt += f'''\n### Chapter {i}: [Title]\n- **Opening scene**: [Time, place, atmosphere]\n- **Key events**: [1-3 major events]\n- **Character interactions**: [Dialogue and conflict points]\n- **Emotional rhythm**: [Tone arc]\n- **Key details**: [Important elements to emphasize]\n- **Chapter-end hook**: [Cliffhanger or turning point]\n'''
        prompt += '\n## Foreshadowing & Threads\n[3-5 narrative threads]\n\n## Style Guide\n[Recommendations for style]\n\nEnsure clear progression and complete narrative arc.'
    return system, prompt


def generate_plot_gemini(api_key, idea, num_chapters, lang):
    from google import genai
    from google.genai import types
    client = genai.Client(api_key=api_key)
    system, prompt = _build_plot_prompt(idea, num_chapters, lang)
    response = client.models.generate_content(
        model='gemini-3-pro-preview', contents=prompt,
        config=types.GenerateContentConfig(system_instruction=system, temperature=0.9, max_output_tokens=8192),
    )
    return response.text


def generate_plot_gpt(api_key, idea, num_chapters, lang):
    from openai import OpenAI
    client = OpenAI(api_key=api_key)
    system, prompt = _build_plot_prompt(idea, num_chapters, lang)
    response = client.chat.completions.create(
        model='gpt-4o', messages=[{'role': 'system', 'content': system}, {'role': 'user', 'content': prompt}],
        temperature=0.9, max_tokens=8192,
    )
    return response.choices[0].message.content


def develop_plot_api(idea, num_chapters, provider, api_key):
    if not idea.strip():
        return 'Please enter a story idea first.'
    if not api_key.strip():
        return f'Please enter your {provider} API key above.'
    lang = detect_language(idea)
    try:
        if provider == 'Gemini':
            return generate_plot_gemini(api_key, idea, int(num_chapters), lang)
        else:
            return generate_plot_gpt(api_key, idea, int(num_chapters), lang)
    except Exception as e:
        return f'API Error ({provider}): {e}'


# ---- Local model generation ----
def generate_text(prompt, system_prompt='', max_new_tokens=2048, temperature=0.8,
                  top_p=0.9, top_k=50, repetition_penalty=1.1):
    if not system_prompt:
        system_prompt = ZH_SYSTEM if detect_language(prompt) == 'zh' else EN_SYSTEM

    messages = [
        {'role': 'system', 'content': system_prompt},
        {'role': 'user', 'content': prompt},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=max_new_tokens,
            temperature=max(temperature, 0.01), top_p=top_p, top_k=top_k,
            repetition_penalty=repetition_penalty, do_sample=True,
        )
    new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)


def generate_chapter(plot_outline, chapter_num, previous_text, style_notes,
                     max_tokens, temperature, top_p, rep_penalty):
    lang = detect_language(plot_outline)
    ch = int(chapter_num)

    pattern = rf'(?:###?\s*(?:第{ch}章|Chapter\s*{ch})[：:\s]*)(.+?)(?=(?:###?\s*(?:第\d+章|Chapter\s*\d+))|$)'
    match = re.search(pattern, plot_outline, re.DOTALL)
    chapter_outline = match.group(0).strip() if match else f'Chapter {ch}'

    if lang == 'zh':
        prompt = f'## 小说大纲\n{plot_outline}\n\n'
        if previous_text:
            ctx = previous_text[-2000:] if len(previous_text) > 2000 else previous_text
            prompt += f'## 上一章结尾\n{ctx}\n\n'
        prompt += f'## 当前任务\n请根据以上大纲，撰写第{ch}章的完整内容。\n本章大纲：{chapter_outline}\n\n'
        if style_notes:
            prompt += f'风格要求：{style_notes}\n\n'
        prompt += f'要求：\n1. 以具体的场景描写开头\n2. 通过对话和动作推动情节\n3. 注意人物性格的一致性\n4. 章节结尾要有悬念或转折\n5. 写出完整的章节内容\n\n第{ch}章正文：\n'
    else:
        prompt = f'## Novel Outline\n{plot_outline}\n\n'
        if previous_text:
            ctx = previous_text[-2000:] if len(previous_text) > 2000 else previous_text
            prompt += f'## End of Previous Chapter\n{ctx}\n\n'
        prompt += f'## Current Task\nWrite the complete text of Chapter {ch}.\nChapter outline: {chapter_outline}\n\n'
        if style_notes:
            prompt += f'Style notes: {style_notes}\n\n'
        prompt += f'Requirements:\n1. Open with vivid scene-setting\n2. Drive the plot through dialogue and action\n3. Maintain consistent characterization\n4. End with a hook or turning point\n5. Write the complete chapter\n\nChapter {ch}:\n'

    system = ZH_SYSTEM if lang == 'zh' else EN_SYSTEM
    return generate_text(prompt, system_prompt=system, max_new_tokens=int(max_tokens),
                         temperature=temperature, top_p=top_p, repetition_penalty=rep_penalty)


def continue_writing(existing_text, instruction, max_tokens, temperature, top_p, rep_penalty):
    lang = detect_language(existing_text)
    if not instruction:
        instruction = '续写这段叙事，保持原文的风格和节奏。' if lang == 'zh' else 'Continue the narrative in the established style.'
    prompt = instruction + '\n\n' + existing_text
    system = ZH_SYSTEM if lang == 'zh' else EN_SYSTEM
    return generate_text(prompt, system_prompt=system, max_new_tokens=int(max_tokens),
                         temperature=temperature, top_p=top_p, repetition_penalty=rep_penalty)


# ---- Build Gradio UI ----
with gr.Blocks(title='Novel Writer Studio') as app:
    gr.Markdown('# Novel Writer Studio (Colab)')
    gr.Markdown('*Cloud AI develops your plot outline — your fine-tuned model writes the chapters*')

    # API Settings
    with gr.Accordion('API Settings (for Plot Generation)', open=True):
        gr.Markdown('Enter your Gemini or OpenAI API key. Keys are only used in-memory.')
        with gr.Row():
            api_provider = gr.Radio(['Gemini', 'GPT'], value='Gemini', label='Provider')
            api_key_input = gr.Textbox(label='API Key', type='password', placeholder='Paste your API key...', scale=3)

    # Generation settings
    with gr.Accordion('Chapter Generation Settings', open=False):
        with gr.Row():
            max_tokens_slider = gr.Slider(128, 4096, value=2048, step=128, label='Max New Tokens')
            temperature_slider = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label='Temperature')
        with gr.Row():
            top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label='Top-P')
            rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label='Repetition Penalty')

    with gr.Tabs():
        # Tab 1: Story Workshop
        with gr.Tab('Story Workshop'):
            gr.Markdown('### Step 1: Your Story Idea')
            with gr.Row():
                with gr.Column(scale=3):
                    story_idea = gr.Textbox(
                        label='Story Idea', lines=4,
                        placeholder='例：在一个武林高手辈出的时代，一个失忆的少年在雪山醒来...\n\nOr: In a world where magic is fueled by music...',
                    )
                with gr.Column(scale=1):
                    num_chapters = gr.Slider(3, 20, value=8, step=1, label='Chapters')
                    develop_btn = gr.Button('Develop Plot (Cloud AI)', variant='primary', size='lg')

            gr.Markdown('### Step 2: Plot Outline')
            gr.Markdown('Generated by Gemini/GPT. Review and edit freely.')
            plot_outline = gr.Textbox(label='Plot Outline (editable)', lines=25,
                                     placeholder='Click Develop Plot to generate...')
            develop_btn.click(develop_plot_api, [story_idea, num_chapters, api_provider, api_key_input], plot_outline)

            gr.Markdown('---')
            gr.Markdown('### Step 3: Generate Chapters (Local Model)')
            with gr.Row():
                with gr.Column(scale=1):
                    chapter_num = gr.Slider(1, 20, value=1, step=1, label='Chapter Number')
                    style_notes = gr.Textbox(label='Style Notes (optional)', lines=2, placeholder='e.g., 多用对话')
                    gen_chapter_btn = gr.Button('Generate Chapter', variant='primary', size='lg')
                with gr.Column(scale=3):
                    chapter_output = gr.Textbox(label='Generated Chapter', lines=25)

            all_chapters = gr.State('')

            def gen_and_accumulate(outline, ch_num, prev, style, max_t, temp, tp, rp):
                result = generate_chapter(outline, ch_num, prev, style, max_t, temp, tp, rp)
                sep = f'\n\n{"="*40}\n第{int(ch_num)}章 / Chapter {int(ch_num)}\n{"="*40}\n\n'
                new_acc = prev + sep + result if prev else result
                return result, new_acc

            gen_chapter_btn.click(
                gen_and_accumulate,
                [plot_outline, chapter_num, all_chapters, style_notes,
                 max_tokens_slider, temperature_slider, top_p_slider, rep_penalty_slider],
                [chapter_output, all_chapters],
            )

            with gr.Accordion('All Generated Chapters', open=False):
                all_chapters_display = gr.Textbox(label='Full Story So Far', lines=30, interactive=False)
                refresh_btn = gr.Button('Refresh')
                refresh_btn.click(lambda x: x, all_chapters, all_chapters_display)

                export_btn = gr.Button('Export Story to File')
                export_file = gr.File(label='Download')

                def export_story(text):
                    if not text:
                        return None
                    fpath = f'/content/story_{time.strftime("%Y%m%d_%H%M%S")}.txt'
                    Path(fpath).write_text(text, encoding='utf-8')
                    return fpath

                export_btn.click(export_story, all_chapters, export_file)

        # Tab 2: Free Write
        with gr.Tab('Free Write'):
            gr.Markdown('### Direct Generation')
            with gr.Row():
                with gr.Column():
                    free_context = gr.Textbox(label='Context / Previous Text', lines=10,
                                             placeholder='Paste existing text for continuation...')
                    free_instruction = gr.Textbox(label='Instruction (optional)', lines=2,
                                                 placeholder='e.g., 续写这段战斗场景')
                    free_gen_btn = gr.Button('Generate', variant='primary', size='lg')
                with gr.Column():
                    free_output = gr.Textbox(label='Generated Text', lines=20)
                    append_btn = gr.Button('Append to Context')

            free_gen_btn.click(
                continue_writing,
                [free_context, free_instruction, max_tokens_slider, temperature_slider, top_p_slider, rep_penalty_slider],
                free_output,
            )
            append_btn.click(lambda ctx, out: ctx + '\n' + out if ctx else out, [free_context, free_output], free_context)

        # Tab 3: Quick Test
        with gr.Tab('Quick Test'):
            gr.Markdown('### Test the model with a single prompt')
            test_prompt = gr.Textbox(label='Prompt', lines=4,
                                    placeholder='月色如霜，照在悬崖边两道对峙的身影上...')
            test_btn = gr.Button('Generate', variant='primary')
            test_output = gr.Textbox(label='Output', lines=15)

            def run_test(prompt, max_t, temp, tp, rp):
                return generate_text(prompt, max_new_tokens=int(max_t), temperature=temp, top_p=tp, repetition_penalty=rp)

            test_btn.click(run_test, [test_prompt, max_tokens_slider, temperature_slider, top_p_slider, rep_penalty_slider], test_output)

print('Launching Web UI...')
app.launch(share=True)