In [None]:
import os
import time
import re
import json
import google.generativeai as genai
from PIL import Image, ImageDraw, ImageFont
import typing_extensions as typing
from google.api_core import exceptions

# ---------------------------------------------------------
# 設定
# ---------------------------------------------------------

PRIMARY_MODEL = "models/gemini-flash-latest"
FALLBACK_MODEL = "gemini-1.5-flash"

INPUT_DIR = "./test_images"
OUTPUT_DIR = "./results_final_filtered"

target_images_list = [
    "sample2.png",
    "sample.png",
    "sample1.png",
    "IMG_4579.jpg",
    "IMG_4578.jpg",
    "IMG_4577.jpg"
]

# ---------------------------------------------------------

def clean_json_text(text):
    """Markdownタグなどを除去して純粋なJSON文字列にする"""
    text = text.replace("```json", "").replace("```", "")
    # 文頭・文末の余計な文字を削除
    start = text.find("[")
    end = text.rfind("]")
    if start != -1 and end != -1:
        return text[start:end+1]
    return text

def extract_marks_robust(image_path, output_dir, api_key):
    file_name = os.path.basename(image_path)
    base_name = os.path.splitext(file_name)[0]
    
    genai.configure(api_key=api_key)
    
    # 画像読み込み
    if not os.path.exists(image_path):
        print(f"[{file_name}] スキップ: ファイルなし")
        return

    try:
        img = Image.open(image_path)
        if img.mode != 'RGB':
            img = img.convert('RGB')
    except Exception as e:
        print(f"[{file_name}] 画像エラー: {e}")
        return

    # ★戦略変更: 「禁止」せず「分類」させるプロンプト
    prompt = """
    Analyze this image and detect ALL handwritten red/colored ink elements.
    Do NOT filter anything out yet. Instead, CLASSIFY each element into one of these types:

    - "TARGET": For grading marks like Circle (丸), Cross (バツ), Triangle (三角), Checkmark (チェック).
    - "IGNORE": For numbers (scores like 10, 50), text comments, underlines, or long corrections.

    Output a JSON list with these fields:
    - "type": "TARGET" or "IGNORE"
    - "mark_name": "Circle", "Cross", "Score", "Text", etc.
    - "description": Description in Japanese.
    - "box_2d": [ymin, xmin, ymax, xmax] (0-1000 scale)
    """

    print(f"--- [{file_name}] 解析開始 ---")

    models_to_try = [PRIMARY_MODEL, FALLBACK_MODEL]
    raw_data = []
    
    # --- 実行ループ ---
    for model_name in models_to_try:
        if raw_data: break
        
        model = genai.GenerativeModel(
            model_name=model_name,
            generation_config={"temperature": 0.0} # JSONモードを強制しない（柔軟性優先）
        )

        for attempt in range(3):
            try:
                response = model.generate_content([prompt, img])
                text_response = response.text
                
                # JSON抽出・パース
                try:
                    json_str = clean_json_text(text_response)
                    raw_data = json.loads(json_str)
                    print(f"   -> 成功！ (モデル: {model_name})")
                    break
                except json.JSONDecodeError:
                    print(f"   [リトライ] JSON形式エラー。再試行... ({attempt+1})")
                    continue
                
            except exceptions.ResourceExhausted:
                print(f"   [待機] 429エラー。10秒待機...")
                time.sleep(10)
            except Exception as e:
                print(f"   [エラー] {e}")
                break # 致命的なエラーは次のモデルへ

    if not raw_data:
        print(f"   [失敗] データ取得不可")
        return

    # --- Python側でフィルタリング (ここで「点数」や「文字」を捨てる) ---
    final_marks = []
    skipped_count = 0
    
    for item in raw_data:
        # 1. AIが "IGNORE" と判定したものを捨てる
        if item.get("type") == "IGNORE":
            skipped_count += 1
            continue
            
        # 2. キーワードで念押しフィルタリング
        name = item.get("mark_name", "").lower()
        desc = item.get("description", "").lower()
        exclude_keywords = ["score", "number", "text", "digit", "点", "文字", "数字"]
        
        if any(x in name for x in exclude_keywords) or any(x in desc for x in exclude_keywords):
            skipped_count += 1
            continue

        # 3. 座標変換と格納
        box = item.get("box_2d")
        if not box: continue
        
        # 形式が [ymin, xmin, ymax, xmax] のリストか辞書か確認
        if isinstance(box, list) and len(box) == 4:
            ymin, xmin, ymax, xmax = box
        elif isinstance(box, dict):
            ymin = box.get('ymin', 0)
            xmin = box.get('xmin', 0)
            ymax = box.get('ymax', 0)
            xmax = box.get('xmax', 0)
        else:
            continue

        final_marks.append({
            "mark_type": item.get("mark_name"),
            "description": item.get("description"),
            "box": [ymin, xmin, ymax, xmax]
        })

    print(f"   -> 検出合計: {len(raw_data)} -> フィルタ後: {len(final_marks)}個 (除外: {skipped_count})")

    # 可視化処理
    if final_marks:
        draw = ImageDraw.Draw(img)
        width, height = img.size
        
        # フォント設定
        font = None
        try:
            # 汎用的なフォントパス
            font_path = "arial.ttf" 
            # 日本語フォントがあればそちらを優先（パスは環境に合わせてください）
            # font_path = "C:\\Windows\\Fonts\\msgothic.ttc" 
            font = ImageFont.truetype(font_path, 20)
        except:
            pass

        for item in final_marks:
            ymin, xmin, ymax, xmax = item['box']
            label = item['mark_type']
            
            abs_ymin = (ymin / 1000) * height
            abs_xmin = (xmin / 1000) * width
            abs_ymax = (ymax / 1000) * height
            abs_xmax = (xmax / 1000) * width

            # 枠線: 鮮やかな緑
            draw.rectangle([(abs_xmin, abs_ymin), (abs_xmax, abs_ymax)], outline="#00FF00", width=3)
            
            if font:
                text_pos = (abs_xmin, max(0, abs_ymin - 25))
                bbox = draw.textbbox(text_pos, label, font=font)
                draw.rectangle(bbox, fill="#00FF00")
                draw.text(text_pos, label, fill="black", font=font)

    # 保存
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    out_img_path = os.path.join(output_dir, f"{base_name}_robust.jpg")
    img.save(out_img_path, "JPEG", quality=95)
    
    out_json_path = os.path.join(output_dir, f"{base_name}_robust.json")
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(final_marks, f, ensure_ascii=False, indent=2)

    print(f"   -> 保存完了")

if __name__ == '__main__':
    API_KEY = "あなたのAPIキー" 

    print(f"対象枚数: {len(target_images_list)} 枚")
    print("------------------------------------------------")
        
    for filename in target_images_list:
        full_path = os.path.join(INPUT_DIR, filename)
        extract_marks_robust(full_path, OUTPUT_DIR, API_KEY)
        time.sleep(3) # 安定のため少し待機
            
    print("\nすべての処理が完了しました。")