<a href="https://colab.research.google.com/github/880121andy/CTG/blob/main/CTG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import stanza
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
import matplotlib.patches as patches
import numpy as np
import io
import base64
from collections import defaultdict
import customtkinter as ctk
from PIL import Image, ImageTk, ImageDraw, ImageOps # 新增 ImageDraw 和 ImageOps
import threading
import os
import sys
from tkinter import filedialog
import tkinter as tk

# --- ConstituencyTreeGenerator 類保持不變 ---
class ConstituencyTreeGenerator:
    def __init__(self):
        """初始化 Stanza NLP 管道，包含成分句法分析"""
        print("正在初始化 Stanza 英文模型...")
        model_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'stanza_models')
        if not os.path.exists(model_dir):
            print(f"模型目錄不存在: {model_dir}，嘗試下載 Stanza 模型...")
            try:
                stanza.download('en', model_dir=model_dir)
                print(f"Stanza 英文模型下載完成到: {model_dir}")
            except Exception as e:
                print(f"Stanza 模型下載失敗: {e}")
                raise

        self.nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,constituency', dir=model_dir, download_method=None)
        print("Stanza NLP 管道初始化完成！")

        # 定義顏色字典，使其在類初始化時可用
        self.phrase_colors = {
            'S': '#ff6b6b', 'NP': '#4ecdc4', 'VP': '#45b7d1',
            'PP': '#96ceb4', 'AP': '#feca57', 'ADJP': '#feca57',
            'ADVP': '#ff9ff3', 'SBAR': '#5f27cd', 'WHNP': '#54a0ff',
            'WHPP': '#00d2d3', 'default_phrase': '#ddd',
        }

        self.pos_colors = {
            'NN': '#e74c3c', 'NNS': '#e74c3c', 'NNP': '#e74c3c', 'NNPS': '#e74c3c',
            'VB': '#3498db', 'VBD': '#3498db', 'VBG': '#3498db', 'VBN': '#3498db', 'VBP': '#3498db', 'VBZ': '#3498db',
            'JJ': '#2ecc71', 'JJR': '#2ecc71', 'JJS': '#2ecc71',
            'RB': '#9b59b6', 'RBR': '#9b59b6', 'RBS': '#9b59b6',
            'DT': '#f39c12', 'PRP': '#e67e22', 'PRP$': '#e67e22',
            'IN': '#1abc9c', 'CC': '#34495e', 'CD': '#16a085',
            'default': '#95a5a6'
        }
        self.all_colors_map = {**self.phrase_colors, **self.pos_colors}
        # 移除default，因為它不是一個實際的標籤
        self.all_colors_map.pop('default_phrase', None)
        self.all_colors_map.pop('default', None)

        # 新增一個字典來儲存縮寫的全名
        self.full_names = {
            # 句法類別
            'S': 'Sentence',
            'NP': 'Noun Phrase',
            'VP': 'Verb Phrase',
            'PP': 'Prepositional Phrase',
            'ADJP': 'Adjective Phrase',
            'ADVP': 'Adverb Phrase',
            'SBAR': 'Subordinate Clause',
            'WHNP': 'Wh- Noun Phrase',
            'WHPP': 'Wh- Prepositional Phrase',
            'PRN': 'Parenthetical',
            'QP': 'Quantifier Phrase',
            'FRAG': 'Fragment',
            'INTJ': 'Interjection', # 通常也視為一個短語類別
            'AP': 'Adjective Phrase', # <--- 修正：新增 AP 的全名

            # 詞性標籤 (常見的)
            'NN': 'Noun, singular or mass',
            'NNS': 'Noun, plural',
            'NNP': 'Proper noun, singular',
            'NNPS': 'Proper noun, plural',
            'VB': 'Verb, base form',
            'VBD': 'Verb, past tense',
            'VBG': 'Verb, gerund or present participle',
            'VBN': 'Verb, past participle',
            'VBP': 'Verb, non-3rd person singular present',
            'VBZ': 'Verb, 3rd person singular present',
            'JJ': 'Adjective',
            'JJR': 'Adjective, comparative',
            'JJS': 'Adjective, superlative',
            'RB': 'Adverb',
            'RBR': 'Adverb, comparative',
            'RBS': 'Adverb, superlative',
            'DT': 'Determiner',
            'PRP': 'Personal pronoun',
            'PRP$': 'Possessive pronoun',
            'IN': 'Preposition or subordinating conjunction',
            'CC': 'Coordinating conjunction',
            'CD': 'Cardinal number',
            'MD': 'Modal',
            'TO': 'to',
            'UH': 'Interjection',
            'EX': 'Existential there',
            'FW': 'Foreign word',
            'LS': 'List item marker',
            'PDT': 'Predeterminer',
            'POS': 'Possessive ending',
            'RP': 'Particle',
            'SYM': 'Symbol',
            'WDT': 'Wh-determiner',
            'WP': 'Wh-pronoun',
            'WP$': 'Possessive wh-pronoun',
            'WRB': 'Wh-adverb',
            '.': 'Sentence-final punctuation',
            ',': 'Comma',
            ':': 'Colon or semicolon',
            '(': 'Left parenthesis',
            ')': 'Right parenthesis',
            '``': 'Opening quotation mark',
            "''": 'Closing quotation mark',
            '$': 'Dollar sign',
            '#': 'Pound sign',
        }


    def parse_sentence(self, sentence):
        """使用 Stanza 分析句子，獲取成分句法樹"""
        doc = self.nlp(sentence)
        if not doc.sentences:
            raise ValueError("Stanza 無法解析此句子。")
        return doc.sentences[0]

    def parse_tree_string(self, tree_string):
        """解析 Stanza 的成分句法樹字符串格式"""
        tree_string = tree_string.strip()
        if tree_string.startswith('(ROOT'):
            tree_string = tree_string[5:-1].strip()

        return self._parse_subtree(tree_string)

    def _parse_subtree(self, s):
        """遞歸解析子樹"""
        s = s.strip()
        if not s.startswith('('):
            return s  # 葉節點（詞彙）

        label_end = s.find(' ')
        if label_end == -1:
            label_end = s.find('(', 1)
        if label_end == -1:
            label_end = len(s) - 1

        label = s[1:label_end]
        children = []

        remaining = s[label_end:].strip()
        if remaining.endswith(')'):
            remaining = remaining[:-1]

        pos = 0
        paren_count = 0
        start = 0

        while pos < len(remaining):
            char = remaining[pos]
            if char == '(':
                if paren_count == 0:
                    start = pos
                paren_count += 1
            elif char == ')':
                paren_count -= 1
                if paren_count == 0:
                    child_str = remaining[start:pos+1]
                    children.append(self._parse_subtree(child_str))
            elif char == ' ' and paren_count == 0:
                if pos > start:
                    word = remaining[start:pos].strip()
                    if word and not word.startswith('('):
                        children.append(word)
                start = pos + 1
            pos += 1

        if start < len(remaining):
            word = remaining[start:].strip()
            if word and not word.startswith('('):
                children.append(word)

        return {'label': label, 'children': children}

    def calculate_tree_layout(self, tree):
        """計算樹的佈局位置"""
        positions = {}

        def _get_leaf_positions(node):
            """獲取所有葉節點的位置"""
            if isinstance(node, str):
                return [node]

            leaves = []
            for child in node['children']:
                leaves.extend(_get_leaf_positions(child))
            return leaves

        def _assign_positions(node, x_offset=0, y=0, x_spacing=2):
            """遞歸分配位置"""
            if isinstance(node, str):
                positions[id(node)] = (x_offset, y)
                return x_offset, x_offset

            children_positions = []
            current_x = x_offset

            for child in node['children']:
                if isinstance(child, str):
                    positions[id(child)] = (current_x, y - 1)
                    children_positions.append((current_x, current_x))
                    current_x += x_spacing
                else:
                    left_x, right_x = _assign_positions(child, current_x, y - 1, x_spacing)
                    children_positions.append((left_x, right_x))
                    current_x = right_x + x_spacing

            if children_positions:
                leftmost = min(pos[0] for pos in children_positions)
                rightmost = max(pos[1] for pos in children_positions)
                center_x = (leftmost + rightmost) / 2
            else:
                center_x = x_offset

            positions[id(node)] = (center_x, y)
            return leftmost if children_positions else x_offset, rightmost if children_positions else x_offset

        _assign_positions(tree, 0, 0, 2)
        return positions

    def create_constituency_tree_plot(self, sentence):
        """創建成分句法樹圖片"""
        constituency_tree = sentence.constituency
        tree_data = self.parse_tree_string(str(constituency_tree))
        positions = self.calculate_tree_layout(tree_data)

        fig, ax = plt.subplots(1, 1, figsize=(18, 12))
        ax.set_aspect('equal')

        fig.patch.set_facecolor('#f8f9fa')
        ax.set_facecolor('#ffffff')

        # 使用實例中的顏色字典
        phrase_colors = self.phrase_colors
        pos_colors = self.pos_colors

        def draw_tree(node, parent_pos=None):
            if isinstance(node, str):
                pos = positions[id(node)]

                bbox = FancyBboxPatch((pos[0]-0.5, pos[1]-0.25), 1.0, 0.5,
                                      boxstyle="round,pad=0.05",
                                      facecolor='white', edgecolor='#333',
                                      linewidth=1, alpha=0.9, zorder=3)
                ax.add_patch(bbox)

                ax.text(pos[0], pos[1], node, fontsize=11, fontweight='bold',
                        ha='center', va='center', color='#333', zorder=4)

                if parent_pos:
                    ax.plot([parent_pos[0], pos[0]], [parent_pos[1], pos[1]],
                            'k-', linewidth=2, alpha=0.7, zorder=1)

                return

            pos = positions[id(node)]
            label = node['label']

            if label in phrase_colors:
                color = phrase_colors[label]
            elif label in pos_colors:
                color = pos_colors[label]
            elif len(label) <= 3 and label.isupper():
                color = phrase_colors.get('default_phrase', '#ddd')
            else:
                color = pos_colors.get('default', '#95a5a6')

            bbox = FancyBboxPatch((pos[0]-0.4, pos[1]-0.2), 0.8, 0.4,
                                  boxstyle="round,pad=0.05",
                                  facecolor=color, edgecolor='white',
                                  linewidth=2, alpha=0.9, zorder=2)
            ax.add_patch(bbox)

            ax.text(pos[0], pos[1], label, fontsize=10, fontweight='bold',
                    ha='center', va='center', color='white', zorder=4)

            if parent_pos:
                ax.plot([parent_pos[0], pos[0]], [parent_pos[1], pos[1]],
                        'k-', linewidth=2, alpha=0.7, zorder=1)

            for child in node['children']:
                draw_tree(child, pos)

        draw_tree(tree_data)

        all_x = [pos[0] for pos in positions.values()]
        all_y = [pos[1] for pos in positions.values()]

        margin = 1.0
        ax.set_xlim(min(all_x) - margin, max(all_x) + margin)
        ax.set_ylim(min(all_y) - margin, max(all_y) + margin)

        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_visible(False)

        original_sentence = ' '.join([token.text for token in sentence.tokens])
        plt.title(f'Constituency Parse Tree\n"{original_sentence}"',
                  fontsize=16, fontweight='bold', pad=20, color='#333')

        plt.tight_layout()

        return fig, self.all_colors_map # 返回圖形和所有顏色映射

    def generate_tree_image_base64(self, sentence_text):
        """生成成分句法樹並返回 base64 編碼的圖片"""
        sentence = self.parse_sentence(sentence_text)
        fig, all_colors_map = self.create_constituency_tree_plot(sentence) # 接收顏色映射

        buffer = io.BytesIO()
        fig.savefig(buffer, format='png', dpi=300, bbox_inches='tight',
                        facecolor='#f8f9fa', edgecolor='none')
        buffer.seek(0)
        image_base64 = base64.b64encode(buffer.getvalue()).decode()

        plt.close(fig)
        return image_base64, buffer, all_colors_map # 返回 base64, buffer 和顏色映射

# --- ConstituencyTreeGenerator 類結束 ---

class ScrollableFrame(ctk.CTkScrollableFrame):
    def __init__(self, master, **kwargs):
        super().__init__(master, **kwargs)
        self.grid_columnconfigure(0, weight=1) # 確保內部內容可以伸縮

    def add_label_pair(self, text, color):
        frame = ctk.CTkFrame(self, fg_color="transparent") # 這裡保持透明，因為背景色由 ScrollableFrame 決定
        frame.pack(fill="x", padx=5, pady=2)
        frame.grid_columnconfigure(0, weight=0) # Color block
        frame.grid_columnconfigure(1, weight=1) # Text label

        color_block = ctk.CTkLabel(frame, text="", width=20, height=20, fg_color=color, corner_radius=5)
        color_block.grid(row=0, column=0, padx=5, pady=2)

        label = ctk.CTkLabel(frame, text=text, font=("Times New Roman", 12), anchor="w")
        label.grid(row=0, column=1, sticky="ew")


class App(ctk.CTk):
    def __init__(self):
        super().__init__()

        self.title("CTG")
        self.geometry("1200x800")
        #set_appearance_mode("dark") # 保持註釋或刪除

        # --- 設定主視窗的背景色 ---
        self.configure(fg_color="#eda6b9") # 將主視窗背景色設定為 #eda6b9

        # 網格配置保持不變
        self.grid_rowconfigure(0, weight=0)
        self.grid_rowconfigure(1, weight=1)
        self.grid_columnconfigure(0, weight=1)

        # 頂部控制框架
        self.control_frame = ctk.CTkFrame(self, corner_radius=10, fg_color="#f0f0f0") # 設定為淺灰色
        self.control_frame.grid(row=0, column=0, padx=20, pady=20, sticky="ew")
        self.control_frame.grid_columnconfigure(0, weight=1)
        self.control_frame.grid_columnconfigure(1, weight=0)
        self.control_frame.grid_columnconfigure(2, weight=0)
        self.control_frame.grid_columnconfigure(3, weight=0) # For Legend button

        self.sentence_input = ctk.CTkEntry(self.control_frame,
                                            placeholder_text="Enter a grammatical English sentence or phrase.",
                                            width=500, height=40, font=("Times New Roman", 16))
        self.sentence_input.grid(row=0, column=0, padx=10, pady=10, sticky="ew")

        # 按鈕顏色
        self.analyze_button = ctk.CTkButton(self.control_frame,
                                            text="Generate",
                                            command=self.start_analysis_thread,
                                            width=120, height=40, font=("Times New Roman", 16),
                                            fg_color="#2e6653",
                                            hover_color="#1d4236")
        self.analyze_button.grid(row=0, column=1, padx=10, pady=10, sticky="e")

        self.download_button = ctk.CTkButton(self.control_frame,
                                             text="Save Image",
                                             command=self.download_image,
                                             width=120, height=40, font=("Times New Roman", 16),
                                             fg_color="#2e6653", hover_color="#1d4236",
                                             state="disabled")
        self.download_button.grid(row=0, column=2, padx=10, pady=10, sticky="e")

        self.legend_button = ctk.CTkButton(self.control_frame,
                                           text="Categories",
                                           command=self.show_legend_window,
                                           width=120, height=40, font=("Times New Roman", 16),
                                           fg_color="#2e6653", hover_color="#1d4236",
                                           state="disabled")
        self.legend_button.grid(row=0, column=3, padx=10, pady=10, sticky="e")


        # 狀態標籤的字體顏色為黑色 (已是黑色，保持)
        self.status_label = ctk.CTkLabel(self.control_frame, text="Initializing...", font=("Times New Roman", 14), text_color="black")
        self.status_label.grid(row=1, column=0, columnspan=4, padx=10, pady=5, sticky="ew")

        # image_frame 的背景色設定為與 control_frame 相同的淺灰色
        # 這裡依然設定為圓角，但重點是 Label 的處理
        self.image_frame = ctk.CTkFrame(self, corner_radius=10, fg_color="#f0f0f0")
        self.image_frame.grid(row=1, column=0, padx=20, pady=0, sticky="nsew")
        self.image_frame.grid_columnconfigure(0, weight=1)
        self.image_frame.grid_rowconfigure(0, weight=1)

        # 這裡的 image_label 設置為與 frame 相同的背景色和圓角，以確保整體圓角視覺效果
        self.image_label = ctk.CTkLabel(self.image_frame, text="", fg_color="#f0f0f0", corner_radius=10) # 修正：設定背景色和圓角
        self.image_label.grid(row=0, column=0, sticky="nsew")

        self.generator = None
        self.current_image_buffer = None
        self.all_categories_colors = {}
        self.initialize_generator_thread()

    def initialize_generator_thread(self):
        self.status_label.configure(text="Initializing Stanza model (downloading on first run may take some time)...", text_color="black")
        self.analyze_button.configure(state="disabled")
        self.download_button.configure(state="disabled")
        self.legend_button.configure(state="disabled")
        threading.Thread(target=self._initialize_generator_task, daemon=True).start()

    def _initialize_generator_task(self):
        try:
            self.generator = ConstituencyTreeGenerator()
            self.status_label.configure(text="Stanza model successfully initialized! Please enter a sentence.", text_color="black")
            self.analyze_button.configure(state="normal")
            self.legend_button.configure(state="normal")
        except Exception as e:
            self.status_label.configure(text=f"Error: Failed to initialize the Stanza model! {e}", text_color="red")
            print(f"Failed to initialize the Stanza model: {e}")
            self.analyze_button.configure(state="disabled")
            self.download_button.configure(state="disabled")
            self.legend_button.configure(state="disabled")

    def start_analysis_thread(self):
        sentence_text = self.sentence_input.get()
        if not sentence_text.strip():
            self.status_label.configure(text="Error: Please enter a sentence or phrase", text_color="red")
            return

        if self.generator is None:
            self.status_label.configure(text="Error: The Stanza model is still initializing. Please wait a moment.", text_color="red")
            return

        self.status_label.configure(text="Parsing and generating...", text_color="black")
        self.analyze_button.configure(state="disabled")
        self.download_button.configure(state="disabled")
        threading.Thread(target=self._perform_analysis_task, args=(sentence_text,), daemon=True).start()

    def _perform_analysis_task(self, sentence_text):
        try:
            image_base64, buffer, all_colors_map = self.generator.generate_tree_image_base64(sentence_text)
            self.current_image_buffer = buffer
            self.all_categories_colors = all_colors_map

            image_data = base64.b64decode(image_base64)
            img = Image.open(io.BytesIO(image_data))

            self.image_frame.update_idletasks() # 確保獲取正確的寬高
            target_width = self.image_frame.winfo_width()
            target_height = self.image_frame.winfo_height()

            if target_width <= 1 or target_height <= 1:
                target_width = 1000 # Fallback 預設值
                target_height = 700 # Fallback 預設值

            img.thumbnail((target_width, target_height), Image.LANCZOS)

            # --- 修正：圖片圓角剪裁邏輯開始 ---
            radius = 10 # 圓角半徑，與 CTkFrame 的 corner_radius 保持一致
            # 創建一個透明的背景圖片
            mask = Image.new('L', img.size, 0)
            draw = ImageDraw.Draw(mask)
            # 繪製圓角矩形作為遮罩
            draw.rounded_rectangle((0, 0, img.width, img.height), radius=radius, fill=255)
            # 將圖片應用遮罩
            img = ImageOps.fit(img, mask.size, centering=(0.5, 0.5))
            img.putalpha(mask)
            # --- 圖片圓角剪裁邏輯結束 ---

            ctk_img = ctk.CTkImage(light_image=img, dark_image=img,
                                    size=(img.width, img.height))

            self.image_label.configure(image=ctk_img)
            self.image_label.image = ctk_img # 必須保持對圖片的引用，否則會被垃圾回收

            self.status_label.configure(text="Constituency tree successfully generated", text_color="black")
            self.download_button.configure(state="normal")
            self.legend_button.configure(state="normal")

        except ValueError as ve:
            self.status_label.configure(text=f"Error: {ve}", text_color="red")
            print(f"Parsing error: {ve}")
            self.download_button.configure(state="disabled")
        except Exception as e:
            self.status_label.configure(text=f"Unknown error while parsing: {e}", text_color="red")
            print(f"Error while parsing: {e}")
            import traceback
            traceback.print_exc()
            self.download_button.configure(state="disabled")
        finally:
            self.analyze_button.configure(state="normal")

    def download_image(self):
        if self.current_image_buffer:
            try:
                file_path = filedialog.asksaveasfilename(
                    defaultextension=".png",
                    filetypes=[("PNG files", "*.png"), ("All files", "*.*")],
                    title="Save image"
                )
                if file_path:
                    # 注意：如果圖片已經被圓角剪裁並保存到 buffer，直接寫入即可。
                    # 如果 buffer 還是原始矩形圖，則需要在這裡再次處理圖片
                    # 為了簡化，假設 current_image_buffer 已經包含了圓角處理後的數據
                    # 但更安全的做法是，如果需要圓角保存，則在生成圖片時就處理好圓角再存入 buffer
                    # 或者在這裡重新對 buffer 中的圖片進行圓角處理（更複雜）
                    # 這裡直接寫入原始 buffer 的內容
                    self.current_image_buffer.seek(0) # 重置讀取位置
                    with open(file_path, 'wb') as f:
                        f.write(self.current_image_buffer.getvalue())
                    self.status_label.configure(text=f"Image has been saved: {os.path.basename(file_path)}", text_color="black")
                else:
                    self.status_label.configure(text="Image saving cancelled", text_color="orange")
            except Exception as e:
                self.status_label.configure(text=f"Error while saving image: {e}", text_color="red")
                print(f"Image saving error: {e}")
        else:
            self.status_label.configure(text="No image available for saving", text_color="red")

    def show_legend_window(self):
        """顯示所有類別及其顏色的彈出視窗"""
        if not self.all_categories_colors:
            self.status_label.configure(text="Error: Please generate at least one sentence first", text_color="red")
            return

        legend_window = ctk.CTkToplevel(self)
        legend_window.title("All syntactic categories")
        legend_window.geometry("400x600")
        legend_window.transient(self)
        legend_window.grab_set()

        legend_window.configure(fg_color="#f0f0f0") # 淺灰色

        # 獲取排序後的類別
        sorted_categories = sorted(self.all_categories_colors.items(), key=lambda item: item[0])

        scrollable_frame = ScrollableFrame(legend_window, width=380, height=550)
        scrollable_frame.configure(fg_color="transparent") # 確保內部 ScrollableFrame 背景透明
        scrollable_frame.pack(padx=10, pady=10, fill="both", expand=True)

        for category, color in sorted_categories:
            # 從 self.generator.full_names 獲取全名，如果沒有則使用縮寫
            full_name = self.generator.full_names.get(category, None)
            display_text = f"{category}"
            if full_name:
                display_text += f" ({full_name})"
            scrollable_frame.add_label_pair(display_text, color)

        close_button = ctk.CTkButton(legend_window, text="Close", command=legend_window.destroy)
        close_button.pack(pady=10)

        self.wait_window(legend_window)

if __name__ == "__main__":
    app = App()
    app.mainloop()
