<a href="https://colab.research.google.com/github/Shun0212/CodeSearch-Crow/blob/main/CodeCrow_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade transformers
!pip install flash-attn lizard faiss-cpu

In [None]:
import os
import subprocess
import json
import lizard
import faiss
import torch
import numpy as np
from sentence_transformers import SentenceTransformer

# ===========================================
# Settings
# ===========================================
GITHUB_REPO_URL = "https://github.com/google-research/bert.git"  # 🔧 Change to any GitHub repo you want
MODEL_NAME = "Shuu12121/CodeSearch-ModernBERT-Crow-Plus"
MIN_FUNCTION_LENGTH = 3  # Only include functions/cells with 3+ lines
SAVE_DIR = "./cloned_repos"

# ファイル名定義
FUNCTIONS_FILE = "functions.json"
INDEX_FILE = "faiss_index.bin"


# ===========================================
# Helper Functions
# ===========================================

def clone_repository(repo_url, clone_dir):
    """
    Clone the GitHub repository if not already cloned.
    """
    if not os.path.exists(clone_dir):
        subprocess.run(["git", "clone", repo_url, clone_dir], check=True)
        print(f"✅ Repository cloned to {clone_dir}")
    else:
        print(f"ℹ️ Repository already exists at {clone_dir}. Skipping clone.")


def extract_functions(repo_path):
    """
    Extract functions from .py and .ipynb files.
    Uses lizard's long_name to include class names if available.
    """
    functions = []
    print("📥 Extracting functions...")
    for root, _, files in os.walk(repo_path):
        # .gitなどの隠しディレクトリや不要なファイルはスキップ (前回の修正で追加した要素)
        if ".git" in root or ".ipynb_checkpoints" in root:
             continue

        files.sort()  # Sort files for stable order
        for file in files:
            file_path = os.path.join(root, file)
            try:
                if file.endswith(".py"):
                    analysis = lizard.analyze_file(file_path)
                    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                        lines = f.readlines()
                    for func in analysis.function_list:
                        if hasattr(func, 'start_line') and hasattr(func, 'end_line'):
                            start, end = max(func.start_line - 1, 0), func.end_line
                            code = "".join(lines[start:end]) # コードを文字列として結合
                            if len(code.strip().splitlines()) >= MIN_FUNCTION_LENGTH:
                                functions.append({
                                    "file_path": file_path,
                                    "function_name": func.long_name,  # Use long_name (with class if exists)
                                    "code": code
                                })
                elif file.endswith(".ipynb"):
                    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                        data = json.load(f)
                    for idx, cell in enumerate(data.get("cells", [])):
                        if cell.get("cell_type") == "code":
                            code = "".join(cell.get("source", []))
                            if len(code.strip().splitlines()) >= MIN_FUNCTION_LENGTH:
                                functions.append({
                                    "file_path": file_path,
                                    "function_name": f"cell_{idx}",
                                    "code": code
                                })
            except Exception as e:
                print(f"⚠️ Warning: Could not process {file_path}: {e}")

    print(f"✅ Extracted {len(functions)} functions.")
    return functions


def embed_codes(codes, model):
    """
    Embed code snippets into dense vectors.
    """
    print("\n📈 Encoding function codes...")
    return model.encode(codes, batch_size=32, show_progress_bar=True, device="cuda" if torch.cuda.is_available() else "cpu")


def build_faiss_index(embeddings):
    """
    Build a FAISS index from embeddings.
    """
    print("\n🏗️ Building FAISS index...")
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return index


def load_or_build_data_and_index(clone_path, model):
    """
    Load existing functions data and FAISS index, or build them if not found.
    """
    functions_path = os.path.join(clone_path, FUNCTIONS_FILE)
    index_path = os.path.join(clone_path, INDEX_FILE)

    # Check if both data and index files exist
    if os.path.exists(functions_path) and os.path.exists(index_path):
        print(f"\n🔄 Loading existing data and index from {clone_path}...")
        try:
            # Load functions data
            with open(functions_path, 'r', encoding='utf-8') as f:
                functions = json.load(f)
            # Load FAISS index
            index = faiss.read_index(index_path)
            print("✅ Successfully loaded existing data and index.")
            return functions, index
        except Exception as e:
            print(f"⚠️ Error loading existing data or index: {e}. Rebuilding...")
            # If loading fails, proceed to rebuild

    # If data or index files do not exist, or loading failed, build them
    print(f"\n🏗️ No existing data or index found (or failed to load). Building a new one...")

    # Extract functions
    functions = extract_functions(clone_path)
    if not functions:
        print("❌ Error: No functions found. Cannot build index. Exiting.")
        return [], None # Return empty list and None for main to handle

    # Save functions data
    try:
        with open(functions_path, 'w', encoding='utf-8') as f:
            # json.dumpはデフォルトで非ASCII文字をエスケープするので、ensure_ascii=Falseで日本語などをそのまま保存
            json.dump(functions, f, indent=4, ensure_ascii=False)
        print(f"💾 Functions data saved at: {functions_path}")
    except Exception as e:
        print(f"⚠️ Warning: Could not save functions data to {functions_path}: {e}")

    # Embed codes
    codes = [func["code"] for func in functions]
    embeddings = embed_codes(codes, model)

    # Build FAISS index
    index = build_faiss_index(embeddings)

    # Save FAISS index
    try:
        faiss.write_index(index, index_path)
        print(f"💾 FAISS index saved at: {index_path}")
    except Exception as e:
         print(f"⚠️ Warning: Could not save FAISS index to {index_path}: {e}")

    return functions, index


def search_functions(index, model, query, functions, top_k=5):
    """
    Search for top-k most relevant functions given a natural language query.
    """
    # Check if the model has a device attribute, if not, default to cpu
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if hasattr(model, 'device'):
         device = model.device

    query_emb = model.encode([query], device=device)

    D, I = index.search(np.array(query_emb).astype('float32'), top_k) # Embeddings might need float32

    results = []
    for idx in I[0]:
         if 0 <= idx < len(functions): # Check if the index is within bounds
              results.append(functions[idx])
         else:
              print(f"⚠️ Warning: Invalid index {idx} returned from FAISS search. Skipping result.")

    return results


def pretty_print_results(results):
    """
    Display search results in a clean format.
    """
    print("\n🔍 Search Results:")
    if not results:
        print("No relevant functions found.")
        return

    for idx, res in enumerate(results, start=1):
        print(f"\n=== Result {idx} ===")
        print(f"📄 File: {res['file_path']}")
        print(f"🔧 Function: {res['function_name']}")
        print(f"🧩 Code Preview:")
        lines = res['code'].splitlines()
        # Limit preview lines
        preview_lines = 100
        for line in lines[:preview_lines]:
            print(line)
        if len(lines) > preview_lines:
            print(f"... ({len(lines) - preview_lines} more lines truncated) ...")


def get_repo_name(repo_url):
    """
    Extract the repository name from the GitHub URL.
    """
    return repo_url.rstrip("/").split("/")[-1].replace(".git", "")

# ===========================================
# Main Execution
# ===========================================
if __name__ == "__main__":
    try:
        # 1. Clone Repository
        repo_name = get_repo_name(GITHUB_REPO_URL)
        clone_path = os.path.join(SAVE_DIR, repo_name)
        os.makedirs(SAVE_DIR, exist_ok=True)
        clone_repository(GITHUB_REPO_URL, clone_path)

        # 2-6. Load or Build Data and Index
        # This step handles extraction, embedding, building, and saving/loading
        # of both functions data and the FAISS index.
        print("\n📦 Loading embedding model...")
        model = SentenceTransformer(MODEL_NAME)

        # Load or build functions data and FAISS index
        functions, index = load_or_build_data_and_index(clone_path, model)

        if not functions or index is None:
             print("❌ Could not load or build data/index. Exiting.")
             exit(3)

        # 7. Search
        while True: # Loop for multiple searches until empty query is entered
            query = input("\n💬 Enter your search query (in English, or any language the embedding model handles well - press Enter only to quit): ")
            if not query.strip():
                print("👋 Exiting search.")
                break

            try:
                results = search_functions(index, model, query, functions)
                pretty_print_results(results)
            except Exception as search_err:
                print(f"❗ An error occurred during search: {search_err}")

    except Exception as e:
        print(f"❗ Unexpected error occurred: {e}")
        exit(99)

In [None]:
# ===========================================
# ❗ ご注意
# このスクリプトは「Google Colab（L4 GPU推奨）」での実行を想定しています。
#
# ① GitHubリポジトリをクローンし、
# ② .py, .ipynbファイルから関数を抽出し（初回のみ）、
# ③ コードを埋め込み（Embedding）し（初回のみ）、
# ④ FAISSインデックスを作成して保存し（初回のみ）、
# ⑤ 日本語クエリをQwen3-8B-FP8で英訳してから検索します。
#
# 日本語で質問しても英語に翻訳して高精度に検索できる仕組みです！
# 初回実行時やコード更新時には抽出・埋め込み・インデックス構築が行われますが、
# それ以降は保存されたファイルを読み込むため高速に検索できます！
# ===========================================

import os
import subprocess
import json
import lizard
import faiss
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

# ===========================================
# Settings
# ===========================================
GITHUB_REPO_URL = "https://github.com/google-research/bert.git"  # 🔧 Change here
SAVE_DIR = "./cloned_repos"
MODEL_NAME = "Shuu12121/CodeSearch-ModernBERT-Crow-Plus"
QWEN_MODEL = "Qwen/Qwen3-8B-FP8"
MIN_FUNCTION_LENGTH = 3  # Minimum lines for function

# ファイル名定義
FUNCTIONS_FILE = "functions.json"
INDEX_FILE = "faiss_index.bin"

# ===========================================
# Helper Functions
# ===========================================

def clone_repository(repo_url, clone_dir):
    if not os.path.exists(clone_dir):
        subprocess.run(["git", "clone", repo_url, clone_dir], check=True)
        print(f"✅ Repository cloned to {clone_dir}")
    else:
        print(f"ℹ️ Repository already exists at {clone_dir}. Skipping clone.")

def extract_functions(repo_path):
    functions = []
    print("📥 Extracting functions...")
    for root, _, files in os.walk(repo_path):
        files.sort()
        for file in files:
            file_path = os.path.join(root, file)
            # .gitなどの隠しディレクトリや不要なファイルはスキップ
            if ".git" in file_path or ".ipynb_checkpoints" in file_path:
                 continue
            try:
                if file.endswith(".py"):
                    analysis = lizard.analyze_file(file_path)
                    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                        lines = f.readlines()
                    for func in analysis.function_list:
                        if hasattr(func, 'start_line') and hasattr(func, 'end_line'):
                            start, end = max(func.start_line - 1, 0), func.end_line
                            code = "".join(lines[start:end]) # joinで文字列にする
                            if len(code.strip().splitlines()) >= MIN_FUNCTION_LENGTH:
                                functions.append({
                                    "file_path": file_path,
                                    "function_name": func.long_name,
                                    "code": code
                                })
                elif file.endswith(".ipynb"):
                    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                        data = json.load(f)
                    for idx, cell in enumerate(data.get("cells", [])):
                        if cell.get("cell_type") == "code":
                            code = "".join(cell.get("source", []))
                            if len(code.strip().splitlines()) >= MIN_FUNCTION_LENGTH:
                                functions.append({
                                    "file_path": file_path,
                                    "function_name": f"cell_{idx}",
                                    "code": code
                                })
            except Exception as e:
                print(f"⚠️ Warning: Could not process {file_path}: {e}")
    print(f"✅ Extracted {len(functions)} functions.")
    return functions

def embed_codes(codes, model):
    print("\n📈 Encoding function codes...")
    return model.encode(codes, batch_size=32, show_progress_bar=True, device="cuda" if torch.cuda.is_available() else "cpu")

def build_faiss_index(embeddings):
    print("\n🏗️ Building FAISS index...")
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return index

def load_or_build_data_and_index(clone_path, model):
    """
    既存のfunctionsデータとFAISSインデックスをロードするか、無ければ新しく作成する。
    """
    functions_path = os.path.join(clone_path, FUNCTIONS_FILE)
    index_path = os.path.join(clone_path, INDEX_FILE)

    # データとインデックスの両方が存在するかチェック
    if os.path.exists(functions_path) and os.path.exists(index_path):
        print(f"\n🔄 Loading existing data and index from {clone_path}...")
        try:
            # functionsデータをロード
            with open(functions_path, 'r', encoding='utf-8') as f:
                functions = json.load(f)
            # FAISSインデックスをロード
            index = faiss.read_index(index_path)
            print("✅ Successfully loaded existing data and index.")
            return functions, index
        except Exception as e:
            print(f"⚠️ Error loading existing data or index: {e}. Rebuilding...")
            # ロードに失敗した場合は再構築へ進む

    # データまたはインデックスが存在しない場合、またはロードに失敗した場合
    print(f"\n🏗️ No existing data or index found (or failed to load). Building a new one...")

    # functionsを抽出
    functions = extract_functions(clone_path)
    if not functions:
        print("❌ Error: No functions found. Cannot build index. Exiting.")
        return [], None # 空のリストとNoneを返してメインでエラー処理させる

    # functionsデータを保存
    try:
        with open(functions_path, 'w', encoding='utf-8') as f:
            json.dump(functions, f, indent=4)
        print(f"💾 Functions data saved at: {functions_path}")
    except Exception as e:
        print(f"⚠️ Warning: Could not save functions data to {functions_path}: {e}")

    # コードを埋め込み
    codes = [func["code"] for func in functions]
    embeddings = embed_codes(codes, model)

    # FAISSインデックスを構築
    index = build_faiss_index(embeddings)

    # FAISSインデックスを保存
    try:
        faiss.write_index(index, index_path)
        print(f"💾 FAISS index saved at: {index_path}")
    except Exception as e:
         print(f"⚠️ Warning: Could not save FAISS index to {index_path}: {e}")


    return functions, index


def search_functions(index, model, query, functions, top_k=5):
    query_emb = model.encode([query], device="cuda" if torch.cuda.is_available() else "cpu")
    D, I = index.search(np.array(query_emb), top_k)
    results = []
    for idx in I[0]:
         if 0 <= idx < len(functions): # 念のためインデックスの範囲チェック
              results.append(functions[idx])
         else:
              print(f"⚠️ Warning: Invalid index {idx} returned from FAISS search.")
    return results


def pretty_print_results(results):
    print("\n🔍 Search Results:")
    if not results:
        print("No relevant functions found.")
        return
    for idx, res in enumerate(results, start=1):
        print(f"\n=== Result {idx} ===")
        print(f"📄 File: {res['file_path']}")
        print(f"🔧 Function: {res['function_name']}")
        print(f"🧩 Code Preview:")
        lines = res['code'].splitlines()
        # コードが長すぎる場合は一部だけ表示
        preview_lines = 100
        for line in lines[:preview_lines]:
            print(line)
        if len(lines) > preview_lines:
            print(f"... ({len(lines) - preview_lines} more lines truncated) ...")

def translate_to_english(qwen_model, qwen_tokenizer, japanese_text):
    """
    Qwen3-8B-FP8を使って、技術文書向けに自然な英語へ翻訳する。
    """
    prompt_translate = f"""
    以下の日本語の内容を、自然な英語に翻訳してください。
    ・専門用語やコードの変数名はそのままにしてください。
    ・正確かつ自然な英語にしてください。
    ・翻訳対象:
    ---
    {japanese_text}
    ---
    英訳:
    """
    messages = [{"role": "user", "content": prompt_translate.strip()}]

    # 推論設定を明示的に指定
    generation_config = {
        "max_new_tokens": 256,
        "do_sample": True,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 5,
        "min_p": 0,
        "pad_token_id": qwen_tokenizer.eos_token_id,
        "eos_token_id": qwen_tokenizer.eos_token_id
    }

    text = qwen_tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False # 簡潔に翻訳だけさせる
    )

    inputs = qwen_tokenizer([text], return_tensors="pt").to(qwen_model.device)

    generated_ids = qwen_model.generate(
        **inputs,
        **generation_config # 設定を渡す
    )

    # 入力部分を除去し、リストに変換
    output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()

    # EOSトークンで切断
    try:
        # Qwenモデルは複数のEOSトークンを持つ場合があるので、リストで指定する
        # またはtokenizer.eos_token_idがリストの場合はそれを使う
        eos_ids = qwen_tokenizer.eos_token_id
        if not isinstance(eos_ids, list):
             eos_ids = [eos_ids]

        min_eos_index = len(output_ids)
        for eos_id in eos_ids:
             try:
                  idx = output_ids.index(eos_id)
                  min_eos_index = min(min_eos_index, idx)
             except ValueError:
                  pass # EOSトークンが見つからない場合は続行

        output_ids = output_ids[:min_eos_index]

    except Exception as e:
        # 万が一の例外時もデコードを試みる
        print(f"⚠️ Warning during EOS token handling: {e}")
        pass


    translated_text = qwen_tokenizer.decode(output_ids, skip_special_tokens=True).strip()

    # Qwenの出力の最後にたまに不要な文字がつく場合があるのでクリーンアップ
    # 例: <|im_end|> やそれに類するもの
    # skip_special_tokens=True で大抵は除去されますが、念のため
    # Qwenの特定の出力形式に合わせて調整が必要かもしれません
    # ここでは一般的なクリーンアップは行わず、decodeの結果を信頼します。
    # 必要であれば translated_text = translated_text.split('<|im_end|>')[0].strip() などを追加

    return translated_text


def get_repo_name(repo_url):
    return repo_url.rstrip("/").split("/")[-1].replace(".git", "")

# ===========================================
# Main Execution
# ===========================================
if __name__ == "__main__":
    try:
        # 1. Clone repository
        repo_name = get_repo_name(GITHUB_REPO_URL)
        clone_path = os.path.join(SAVE_DIR, repo_name)
        os.makedirs(SAVE_DIR, exist_ok=True)
        clone_repository(GITHUB_REPO_URL, clone_path)

        # 2-6. Load or Build Data and Index
        # ここでfunctionsの抽出、埋め込み、インデックス構築・保存、または読み込みが行われる
        print("\n📦 Loading embedding model...")
        model = SentenceTransformer(MODEL_NAME)

        print("\n📦 Loading translation model (Qwen3-8B-FP8)...")
        qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL, trust_remote_code=True)
        # device_map="auto"を使用するとモデルが自動的にデバイスに配置される
        qwen_model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL, torch_dtype="auto", device_map="auto", trust_remote_code=True)


        # functionsデータとFAISSインデックスをロードまたは新しく構築
        functions, index = load_or_build_data_and_index(clone_path, model)

        if not functions or index is None:
             print("❌ Could not load or build data/index. Exiting.")
             exit(3)

        # 7. Search
        while True: # ユーザーが空行を入力するまで検索を繰り返すループ
            japanese_query = input("\n💬 日本語で検索クエリを入力してください (終了するにはEnterキーのみを押す): ")
            if not japanese_query.strip():
                print("👋 検索を終了します。")
                break

            print("\n🔄 Translating query to English...")
            try:
                english_query = translate_to_english(qwen_model, qwen_tokenizer, japanese_query)
                print(f"🌎 English Query: {english_query}")

                results = search_functions(index, model, english_query, functions)
                pretty_print_results(results)
            except Exception as search_err:
                print(f"❗ An error occurred during search or translation: {search_err}")


    except Exception as e:
        print(f"❗ Unexpected error occurred: {e}")
        exit(99)