# STAC Detection on Google Colab

这个 notebook 用于在 Colab 上运行 STAC (Situation, Task, Action, Consequence) 检测。

## 使用说明

1. **挂载 Google Drive**：确保你的 `llm_model` 和 `pre_data_process` 文件夹已经在 Google Drive 中（在项目根目录下）
2. **选择模型**：支持 Hugging Face 模型（如 Qwen3）
3. **配置参数**：设置输入输出路径和模型参数
4. **运行检测**：对单个文件或批量文件进行 STAC 检测

## 支持的模型

- `Qwen/Qwen2.5-7B-Instruct` （推荐，需要较大内存）
- `Qwen/Qwen2.5-3B-Instruct` （较小模型，适合 Colab 免费版）
- `Qwen/Qwen2-7B-Instruct`
- 其他 Hugging Face 聊天模型

## 第一步：挂载 Google Drive

In [None]:
# 挂载 Google Drive
from google.colab import drive
drive.mount('/content/drive')

print("✓ Google Drive mounted")

## 第二步：安装依赖包

In [None]:
# 安装必要的依赖包
%pip install -q transformers torch accelerate sentencepiece

print("✓ Dependencies installed")

## 第三步：配置路径和模型

In [None]:
import os
import sys
from pathlib import Path

# ========== 配置路径 ==========
# 修改为你 Google Drive 中项目根目录的路径
# 例如：/content/drive/MyDrive/fairytales_resarch
# 项目结构应该是：
#   fairytales_resarch/
#     llm_model/
#     pre_data_process/
PROJECT_ROOT = "/content/drive/MyDrive/fairytales_resarch"

# 验证路径
project_path = Path(PROJECT_ROOT)
llm_model_path = project_path / "llm_model"
pre_data_process_path = project_path / "pre_data_process"

# 添加项目根目录到 Python 路径（这样 llm_model 和 pre_data_process 都能被导入）
if project_path.exists():
    sys.path.insert(0, str(project_path))
    print(f"✓ Added project root to path: {project_path}")
    
    # 验证子目录是否存在
    if llm_model_path.exists():
        print(f"✓ Found llm_model: {llm_model_path}")
    else:
        print(f"⚠ Warning: {llm_model_path} not found")
    
    if pre_data_process_path.exists():
        print(f"✓ Found pre_data_process: {pre_data_process_path}")
    else:
        print(f"⚠ Warning: {pre_data_process_path} not found")
else:
    print(f"⚠ Warning: {PROJECT_ROOT} not found. Please check the path.")

# ========== 配置模型 ==========
# 选择 Hugging Face 模型
# 推荐使用较小的模型（如 3B）如果 Colab 内存不足
HF_MODEL = "Qwen/Qwen2.5-3B-Instruct"  # 或 "Qwen/Qwen2.5-7B-Instruct"
HF_DEVICE = "auto"  # "cuda", "cpu", 或 "auto"（自动检测）

# ========== STAC 分析配置 ==========
USE_CONTEXT = True  # 使用完整故事上下文
USE_NEIGHBORING_SENTENCES = False  # 使用相邻句子作为辅助上下文

# 设置环境变量
os.environ["LLM_PROVIDER"] = "huggingface"
os.environ["HF_MODEL"] = HF_MODEL
os.environ["HF_DEVICE"] = HF_DEVICE
os.environ["HF_TEMPERATURE"] = "0.2"
os.environ["HF_TOP_P"] = "0.9"
os.environ["HF_MAX_NEW_TOKENS"] = "2048"

print(f"✓ Configuration set")
print(f"  Model: {HF_MODEL}")
print(f"  Device: {HF_DEVICE}")
print(f"  Use Context: {USE_CONTEXT}")
print(f"  Use Neighboring Sentences: {USE_NEIGHBORING_SENTENCES}")

## 第四步：定义句子分割函数（内嵌代码）

In [None]:
# 句子分割函数（从 pre_data_process 复制，避免导入问题）
import re
from typing import List

def split_sentences(text: str) -> List[str]:
    """将文本按句子切分，正确处理引号内的对话"""
    if not text or not text.strip():
        return []
    
    text = re.sub(r'[ \t]+', ' ', text)
    text = re.sub(r'\n{3,}', '\n\n', text)
    
    sentences = []
    paragraphs = re.split(r'\n\s*\n', text)
    
    for para in paragraphs:
        para = para.strip()
        if not para:
            continue
        para = re.sub(r'\n+', ' ', para)
        para = para.strip()
        if not para:
            continue
        para_sentences = split_sentences_with_quotes(para)
        sentences.extend(para_sentences)
    
    cleaned_sentences = []
    for sent in sentences:
        sent = sent.strip()
        sent = re.sub(r'^[\s\u3000]+|[\s\u3000]+$', '', sent)
        if sent and len(sent) > 0:
            cleaned_sentences.append(sent)
    
    return cleaned_sentences


def split_sentences_with_quotes(text: str) -> List[str]:
    """智能句子切分，正确处理引号内的对话和破折号"""
    if not text:
        return []
    
    double_quotes = {'\u0022', '\u201C', '\u201D'}
    single_quotes = {'\u0027', '\u2018', '\u2019'}
    chinese_quote_pairs = {'「': '」', '『': '』'}
    
    sentences = []
    i = 0
    current_sentence = []
    quote_stack = []
    sentence_end_chars = ['。', '！', '？', '.', '!', '?', '؟']
    closing_chars = ['」', '』', '"', "'", '"', "'", ')', '）']
    dash_chars = ['—', '–']
    dash_pattern = '——'
    
    while i < len(text):
        char = text[i]
        
        if char in chinese_quote_pairs:
            quote_stack.append(char)
            current_sentence.append(char)
            i += 1
            continue
        
        if char in chinese_quote_pairs.values():
            if quote_stack and quote_stack[-1] in chinese_quote_pairs:
                expected_close = chinese_quote_pairs[quote_stack[-1]]
                if char == expected_close:
                    quote_stack.pop()
                    current_sentence.append(char)
            i += 1
            continue
        
        if char in double_quotes:
            if quote_stack and quote_stack[-1] == 'double':
                quote_stack.pop()
                current_sentence.append(char)
            else:
                quote_stack.append('double')
                current_sentence.append(char)
            i += 1
            continue
        
        if char in single_quotes:
            if quote_stack and quote_stack[-1] == 'single':
                quote_stack.pop()
                current_sentence.append(char)
            else:
                quote_stack.append('single')
                current_sentence.append(char)
            i += 1
            continue
        
        if char in sentence_end_chars:
            current_sentence.append(char)
            if not quote_stack:
                j = i + 1
                while j < len(text) and text[j] in closing_chars:
                    current_sentence.append(text[j])
                    j += 1
                while j < len(text) and text[j] in ' \t\n':
                    j += 1
                
                if j < len(text):
                    sentence = ''.join(current_sentence).strip()
                    if sentence:
                        sentences.append(sentence)
                    current_sentence = []
                    i = j
                    continue
            i += 1
            continue
        
        current_sentence.append(char)
        i += 1
    
    if current_sentence:
        sentence = ''.join(current_sentence).strip()
        if sentence:
            sentences.append(sentence)
    
    return sentences


def split_sentences_advanced(text: str) -> List[str]:
    """高级句子切分方法，处理数字、网址、缩写等"""
    if not text or not text.strip():
        return []
    
    protected_map = {}
    protected_counter = 0
    
    def protect(match):
        nonlocal protected_counter
        placeholder = f"<PROTECTED_{protected_counter}>"
        protected_map[placeholder] = match.group(0)
        protected_counter += 1
        return placeholder
    
    text = re.sub(r'\d+\.\d+', protect, text)
    text = re.sub(r'https?://[^\s]+', protect, text)
    text = re.sub(r'www\.[^\s]+', protect, text)
    text = re.sub(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', protect, text)
    
    abbreviations = [
        r'\bMr\.', r'\bMrs\.', r'\bMs\.', r'\bDr\.', r'\bProf\.',
        r'\bSr\.', r'\bJr\.', r'\bInc\.', r'\bLtd\.', r'\bCo\.',
        r'\betc\.', r'\bi\.e\.', r'\be\.g\.', r'\bvs\.', r'\bU\.S\.',
    ]
    for abbr in abbreviations:
        text = re.sub(abbr, protect, text, flags=re.IGNORECASE)
    
    sentences = split_sentences(text)
    
    restored_sentences = []
    for sent in sentences:
        for placeholder, original in protected_map.items():
            sent = sent.replace(placeholder, original)
        restored_sentences.append(sent)
    
    return restored_sentences

print("✓ Sentence splitting functions defined")

## 第五步：导入 STAC 分析模块

In [None]:
# 导入必要的模块（只需要 llm_model，句子分割函数已在上一步定义）
try:
    from llm_model.llm_router import LLMConfig
    from llm_model.huggingface_client import HuggingFaceConfig
    from llm_model.stac_analyzer import STACAnalyzerConfig, analyze_stac
    
    print("✓ All modules imported successfully")
    print("  - LLMConfig, HuggingFaceConfig")
    print("  - STACAnalyzerConfig, analyze_stac")
    print("  - split_sentences_advanced (defined in previous cell)")
except ModuleNotFoundError as e:
    print(f"✗ Import error: {e}")
    print()
    print("Troubleshooting:")
    print("1. Make sure you ran the 'Configuration' cell (step 3) before this one")
    print("2. Check that PROJECT_ROOT points to the correct directory")
    print("3. Verify 'llm_model' folder exists in Google Drive")
    print()
    print("You can manually add the path with:")
    print(f"   import sys")
    print(f"   sys.path.insert(0, '{project_path}')")
    raise

## 第六步：初始化 STAC 分析器

In [None]:
# 创建配置
hf_config = HuggingFaceConfig(
    model=HF_MODEL,
    device=HF_DEVICE,
    temperature=0.2,
    top_p=0.9,
    max_new_tokens=2048,
)

# 检查 LLMConfig 是否支持 huggingface 参数
try:
    # 检查 LLMConfig 是否有 huggingface 字段（检查 dataclass 字段）
    has_huggingface = hasattr(LLMConfig, '__dataclass_fields__') and 'huggingface' in LLMConfig.__dataclass_fields__
    
    if has_huggingface:
        llm_config = LLMConfig(
            provider="huggingface",
            huggingface=hf_config,
        )
        print("✓ LLMConfig created with huggingface support")
    else:
        raise TypeError(
            "LLMConfig does not support 'huggingface' parameter. "
            "Please update your llm_model/llm_router.py file in Google Drive. "
            "The file should have 'huggingface: HuggingFaceConfig = HuggingFaceConfig()' in the LLMConfig class."
        )
        
except TypeError as e:
    print(f"✗ {e}")
    print()
    print("Solution:")
    print("1. Open llm_model/llm_router.py in your Google Drive")
    print("2. Make sure the LLMConfig class includes:")
    print("   huggingface: HuggingFaceConfig = HuggingFaceConfig()")
    print("3. Re-upload the file to Google Drive if needed")
    raise
except Exception as e:
    print(f"✗ Error creating LLMConfig: {e}")
    import traceback
    traceback.print_exc()
    raise

stac_config = STACAnalyzerConfig(llm=llm_config)

print("✓ STAC analyzer configured")
print("  Note: Model will be downloaded on first use (this may take a few minutes)")

## 使用方式

### 方式一：分析单个句子（测试用）

In [None]:
# 示例：分析单个句子
test_sentence = "王子来到了森林里。"
story_context = "从前有一个王子，他非常勇敢。王子来到了森林里。他在那里遇到了一个仙女。"

result = analyze_stac(
    sentence=test_sentence,
    story_context=story_context if USE_CONTEXT else None,
    use_context=USE_CONTEXT,
    config=stac_config,
)

print("Sentence:", test_sentence)
print("Result:")
import json
print(json.dumps(result, ensure_ascii=False, indent=2))

### 方式二：分析整个故事文件

In [None]:
# ========== 配置输入输出路径 ==========
# 修改为你的故事文件路径和输出路径
STORY_FILE = "/content/drive/MyDrive/path/to/your/story.txt"
OUTPUT_FILE = "/content/drive/MyDrive/path/to/output/story_stac.json"

# 读取故事文件
story_content = Path(STORY_FILE).read_text(encoding="utf-8")
print(f"✓ Loaded story file: {STORY_FILE}")
print(f"  Length: {len(story_content)} characters")

# 分割句子
sentences = split_sentences_advanced(story_content)
print(f"✓ Split into {len(sentences)} sentences")

# 分析每个句子
results = []
for idx, sentence in enumerate(sentences, start=1):
    print(f"Processing sentence {idx}/{len(sentences)}: {sentence[:50]}...")
    
    try:
        # 获取邻近句子（如果启用）
        previous_sentence = None
        next_sentence = None
        if USE_NEIGHBORING_SENTENCES:
            sent_idx = idx - 1
            if sent_idx > 0:
                previous_sentence = sentences[sent_idx - 1]
            if sent_idx < len(sentences) - 1:
                next_sentence = sentences[sent_idx + 1]
        
        # 执行 STAC 分析
        analysis = analyze_stac(
            sentence=sentence,
            story_context=story_content if USE_CONTEXT else None,
            use_context=USE_CONTEXT,
            previous_sentence=previous_sentence,
            next_sentence=next_sentence,
            use_neighboring_sentences=USE_NEIGHBORING_SENTENCES,
            config=stac_config,
        )
        
        results.append({
            "sentence_index": idx,
            "sentence": sentence,
            "analysis": analysis,
        })
        
        # 每10个句子显示进度
        if idx % 10 == 0:
            print(f"  ✓ Analyzed {idx}/{len(sentences)} sentences")
            
    except Exception as e:
        print(f"  ✗ Error analyzing sentence {idx}: {e}")
        results.append({
            "sentence_index": idx,
            "sentence": sentence,
            "analysis": None,
            "error": str(e),
        })

# 保存结果
output_data = {
    "source_file": STORY_FILE,
    "use_context": USE_CONTEXT,
    "use_neighboring_sentences": USE_NEIGHBORING_SENTENCES,
    "model": HF_MODEL,
    "total_sentences": len(sentences),
    "analyzed_sentences": len(results),
    "sentences": results,
}

Path(OUTPUT_FILE).parent.mkdir(parents=True, exist_ok=True)
Path(OUTPUT_FILE).write_text(
    json.dumps(output_data, ensure_ascii=False, indent=2),
    encoding="utf-8"
)

print(f"\n✓ Analysis complete!")
print(f"  Output saved to: {OUTPUT_FILE}")
print(f"  Analyzed: {len(results)}/{len(sentences)} sentences")

### 方式三：批量处理多个文件

In [None]:
# ========== 批量处理配置 ==========
INPUT_DIR = "/content/drive/MyDrive/path/to/story_files"  # 输入文件夹（包含 .txt 文件）
OUTPUT_DIR = "/content/drive/MyDrive/path/to/stac_output"  # 输出文件夹

# 查找所有文本文件
input_path = Path(INPUT_DIR)
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

story_files = list(input_path.glob("*.txt"))
print(f"Found {len(story_files)} story files")

# 批量处理
for file_idx, story_file in enumerate(story_files, 1):
    print(f"\n[{file_idx}/{len(story_files)}] Processing: {story_file.name}")
    
    try:
        # 读取文件
        story_content = story_file.read_text(encoding="utf-8")
        sentences = split_sentences_advanced(story_content)
        print(f"  Split into {len(sentences)} sentences")
        
        # 分析每个句子
        results = []
        for idx, sentence in enumerate(sentences, start=1):
            try:
                # 获取邻近句子
                previous_sentence = None
                next_sentence = None
                if USE_NEIGHBORING_SENTENCES:
                    sent_idx = idx - 1
                    if sent_idx > 0:
                        previous_sentence = sentences[sent_idx - 1]
                    if sent_idx < len(sentences) - 1:
                        next_sentence = sentences[sent_idx + 1]
                
                # STAC 分析
                analysis = analyze_stac(
                    sentence=sentence,
                    story_context=story_content if USE_CONTEXT else None,
                    use_context=USE_CONTEXT,
                    previous_sentence=previous_sentence,
                    next_sentence=next_sentence,
                    use_neighboring_sentences=USE_NEIGHBORING_SENTENCES,
                    config=stac_config,
                )
                
                results.append({
                    "sentence_index": idx,
                    "sentence": sentence,
                    "analysis": analysis,
                })
                
                # 进度显示
                if idx % 10 == 0:
                    print(f"    Analyzed {idx}/{len(sentences)} sentences...")
                    
            except Exception as e:
                print(f"    ✗ Error at sentence {idx}: {e}")
                results.append({
                    "sentence_index": idx,
                    "sentence": sentence,
                    "analysis": None,
                    "error": str(e),
                })
        
        # 保存结果
        output_file = output_path / f"{story_file.stem}_stac.json"
        output_data = {
            "source_file": str(story_file),
            "use_context": USE_CONTEXT,
            "use_neighboring_sentences": USE_NEIGHBORING_SENTENCES,
            "model": HF_MODEL,
            "total_sentences": len(sentences),
            "analyzed_sentences": len(results),
            "sentences": results,
        }
        
        output_file.write_text(
            json.dumps(output_data, ensure_ascii=False, indent=2),
            encoding="utf-8"
        )
        
        print(f"  ✓ Saved to: {output_file.name}")
        
    except Exception as e:
        print(f"  ✗ Error processing {story_file.name}: {e}")

print(f"\n✓ Batch processing complete!")
print(f"  Output directory: {OUTPUT_DIR}")

## 注意事项

1. **内存限制**：Colab 免费版内存有限，建议使用较小的模型（如 `Qwen2.5-3B-Instruct`）
2. **首次运行**：模型首次加载时需要从 Hugging Face 下载，可能需要几分钟
3. **GPU 使用**：如果有 GPU，会自动使用 CUDA 加速
4. **处理时间**：每个句子的分析需要几秒钟，大批量处理需要较长时间
5. **中断恢复**：如果中断，可以修改代码跳过已处理的文件继续处理