# 05 Vector + Graph RAG (NAVER)

`langchain_naver` 기반으로 Vector + Graph RAG를 실험합니다.
- Embedding: `ClovaXEmbeddings(model='bge-m3')`
- Generation: `ChatClovaX(model='HCX-005')`
- Graph: Neo4j `Paragraph` 노드 + `REF` 관계

실행 순서:
1. 연결/모델 초기화
2. Paragraph 임베딩 저장 (`embedding` 속성)
3. vector index 생성
4. vector + ref 검색
5. Clova 답변 생성


In [None]:
import os
import json
import time
from pathlib import Path

from dotenv import load_dotenv
from neo4j import GraphDatabase
from langchain_naver import ClovaXEmbeddings, ChatClovaX

load_dotenv()

URI = os.getenv('NEO4J_URI', '').strip()
USER = os.getenv('NEO4J_USER', 'neo4j').strip()
PWD = os.getenv('NEO4J_PASSWORD', '').strip()
DB = os.getenv('NEO4J_DATABASE', 'neo4j').strip()

if not URI:
    raise ValueError('NEO4J_URI is empty')
if not PWD:
    raise ValueError('NEO4J_PASSWORD is empty')

driver = GraphDatabase.driver(URI, auth=(USER, PWD))
driver.verify_connectivity()
# print('neo4j connected:', URI, 'db:', DB)

embedder = ClovaXEmbeddings(model='bge-m3')
llm = ChatClovaX(model='HCX-005', temperature=0.0, max_tokens=1200)
# print('models ready: bge-m3 + HCX-005')


In [None]:
def run_query(cypher: str, **params):
    records, summary, keys = driver.execute_query(cypher, database_=DB, **params)
    return records, summary, keys


records, _, _ = run_query('MATCH (p:Paragraph) RETURN count(p) AS c')
print('Paragraph count:', records[0]['c'])


In [None]:
# Step 1) 임베딩 대상 조회 (resume-safe)

LIMIT_PARAGRAPHS = None   # 테스트: 100
ONLY_MISSING = True       # True면 embedding 없는 노드만 처리

# property key가 아직 없을 때 경고를 줄이기 위한 분기
prop_records, _, _ = run_query('CALL db.propertyKeys() YIELD propertyKey RETURN collect(propertyKey) AS keys')
prop_keys = set(prop_records[0]['keys'] or [])

base_q = 'MATCH (p:Paragraph) WHERE p.content IS NOT NULL AND trim(p.content) <> "" '
if ONLY_MISSING and 'embedding' in prop_keys:
    base_q += 'AND p.embedding IS NULL '
base_q += 'RETURN p.paragraph_key AS paragraph_key, p.content AS content '
if LIMIT_PARAGRAPHS is not None:
    base_q += 'LIMIT $n'

params = {'n': LIMIT_PARAGRAPHS} if LIMIT_PARAGRAPHS is not None else {}
records, summary, _ = run_query(base_q, **params)
rows = [r.data() for r in records]

print('embedding property exists:', 'embedding' in prop_keys)
print('target paragraphs:', len(rows), 'time(ms)=', summary.result_available_after)


In [None]:
# Optional) Graph reset cell (주의)
# 기본값 False: 실수 방지
RESET_GRAPH = False

if RESET_GRAPH:
    # 이 프로젝트 스키마(Document/Article/Paragraph + REF)만 삭제
    run_query('MATCH (a:Paragraph)-[r:REF]->(b:Paragraph) DELETE r')
    run_query('MATCH (n:Paragraph) DETACH DELETE n')
    run_query('MATCH (n:Article) DETACH DELETE n')
    run_query('MATCH (n:Document) DETACH DELETE n')

    # 인덱스 정리(있을 때만)
    try:
        run_query('DROP INDEX paragraph_embedding_index IF EXISTS')
    except Exception as e:
        print('index drop warning:', e)

    print('Graph reset done.')
else:
    print('RESET_GRAPH=False -> skip reset')


In [None]:
# Step 2) 임베딩 생성 + 저장 (429 retry + skip existing)

BATCH_SIZE = 8
SLEEP_SEC = 10
MAX_RETRIES = 3
BASE_BACKOFF = 1.0


def is_rate_limited(err: Exception) -> bool:
    msg = str(err)
    return ('429' in msg) or ('rate exceeded' in msg.lower()) or ('RateLimitError' in msg)


def missing_keys_in_db(keys: list[str]) -> set[str]:
    if not keys:
        return set()
    recs, _, _ = run_query(
        'UNWIND $keys AS k '
        'MATCH (p:Paragraph {paragraph_key:k}) '
        'WHERE p.embedding IS NULL AND p.content IS NOT NULL AND trim(p.content) <> "" '
        'RETURN p.paragraph_key AS paragraph_key',
        keys=keys,
    )
    return {r['paragraph_key'] for r in recs}


def upsert_embeddings(batch_rows, vectors):
    payload = []
    for r, v in zip(batch_rows, vectors):
        payload.append({'paragraph_key': r['paragraph_key'], 'embedding': v})

    cypher = (
        'UNWIND $rows AS r '
        'MATCH (p:Paragraph {paragraph_key: r.paragraph_key}) '
        'SET p.embedding = coalesce(p.embedding, r.embedding)'
    )

    last_err = None
    for attempt in range(MAX_RETRIES + 1):
        try:
            _, summary, _ = run_query(cypher, rows=payload)
            return summary
        except Exception as e:
            last_err = e
            if is_rate_limited(e) and attempt < MAX_RETRIES:
                wait = min(BASE_BACKOFF * (2 ** attempt), 30.0)
                print(f'upsert rate-limited, retry {attempt+1}/{MAX_RETRIES}, sleep={wait:.1f}s')
                time.sleep(wait)
                continue
            raise last_err


def embed_with_retry(texts: list[str]):
    last_err = None
    for attempt in range(MAX_RETRIES + 1):
        try:
            return embedder.embed_documents(texts)
        except Exception as e:
            last_err = e
            if is_rate_limited(e) and attempt < MAX_RETRIES:
                wait = min(BASE_BACKOFF * (2 ** attempt), 30.0)
                print(f'embed rate-limited, retry {attempt+1}/{MAX_RETRIES}, sleep={wait:.1f}s')
                time.sleep(wait)
                continue
            raise last_err


written = 0
skipped_existing = 0
failed_batches = []

for i in range(0, len(rows), BATCH_SIZE):
    batch = rows[i:i + BATCH_SIZE]
    keys = [str(x.get('paragraph_key', '')) for x in batch]

    # 배치마다 DB 재확인 -> 이미 저장된 건 임베딩 호출 전에 pass
    missing = missing_keys_in_db(keys)
    run_batch = [x for x in batch if x.get('paragraph_key') in missing]

    if not run_batch:
        skipped_existing += len(batch)
        print(f'batch {i//BATCH_SIZE+1}: skip all (already embedded)')
        continue

    texts = [str(x.get('content', '') or '') for x in run_batch]

    try:
        vectors = embed_with_retry(texts)
        summary = upsert_embeddings(run_batch, vectors)
        written += len(run_batch)
        skipped_existing += (len(batch) - len(run_batch))
        print(f'embedding write: +{len(run_batch)} (total_written={written}), skipped_existing={skipped_existing}, time(ms)={summary.result_available_after}')
    except Exception as e:
        failed_batches.append({'batch_index': i//BATCH_SIZE+1, 'error': str(e)[:500], 'size': len(run_batch)})
        print(f'[warn] batch {i//BATCH_SIZE+1} failed:', e)

    time.sleep(SLEEP_SEC)

print('embedding done')
print('written:', written)
print('skipped_existing:', skipped_existing)
print('failed_batches:', len(failed_batches))
if failed_batches:
    print('sample failed:', failed_batches[:3])


In [None]:
# Step 3) vector index 생성

INDEX_NAME = 'paragraph_embedding_index'
EMBEDDING_DIM = 1024
SIMILARITY = 'cosine'


def ensure_vector_index(index_name: str, dimensions: int, similarity: str):
    # Neo4j 5.x 문법 우선
    try:
        cypher = (
            f'CREATE VECTOR INDEX {index_name} IF NOT EXISTS '
            'FOR (p:Paragraph) ON (p.embedding) '
            'OPTIONS {indexConfig: {'
            '`vector.dimensions`: $dims, '
            '`vector.similarity_function`: $sim'
            '}}'
        )
        _, summary, _ = run_query(cypher, dims=dimensions, sim=similarity)
        print('vector index ensured (CREATE VECTOR INDEX), time(ms)=', summary.result_available_after)
        return
    except Exception as e:
        print('create vector index failed, fallback to procedure:', e)

    # 구버전 fallback
    try:
        _, summary, _ = run_query(
            'CALL db.index.vector.createNodeIndex($name, $label, $prop, $dims, $sim)',
            name=index_name,
            label='Paragraph',
            prop='embedding',
            dims=dimensions,
            sim=similarity,
        )
        print('vector index ensured (procedure), time(ms)=', summary.result_available_after)
    except Exception as e:
        print('fallback also failed:', e)
        raise


ensure_vector_index(INDEX_NAME, EMBEDDING_DIM, SIMILARITY)


In [None]:
# Step 4) Vector + Graph retrieval

QUERY = '이 조건에 맞는 건축선을 알려줘'
TOP_K = 5
MAX_REF_HOPS = 1   # 1~2 권장


def vector_ref_retrieve(query_text: str, top_k: int = 5, hops: int = 1):
    qvec = embedder.embed_query(query_text)

    # seed: paragraph vector search
    seed_q = (
        'CALL db.index.vector.queryNodes($index_name, $k, $query_vec) '
        'YIELD node, score '
        'RETURN node.paragraph_key AS paragraph_key, node.law_name AS law_name, '
        'node.article_num AS article_num, node.paragraph_num AS paragraph_num, '
        'node.content AS content, score'
    )
    seed_records, _, _ = run_query(seed_q, index_name=INDEX_NAME, k=top_k, query_vec=qvec)
    seeds = [r.data() for r in seed_records]
    if not seeds:
        return {'seeds': [], 'ref_edges': [], 'contexts': []}

    seed_para_keys = [s['paragraph_key'] for s in seeds]

    # seed paragraph -> parent article
    parent_q = (
        'UNWIND $keys AS k '
        'MATCH (a:Article)-[:HAS_PARAGRAPH]->(p:Paragraph {paragraph_key:k}) '
        'RETURN DISTINCT a.article_key AS article_key'
    )
    parent_records, _, _ = run_query(parent_q, keys=seed_para_keys)
    frontier = set([r['article_key'] for r in parent_records])
    seen_articles = set(frontier)

    ref_edges = []
    target_article_keys = set()
    target_paragraph_keys = set(seed_para_keys)

    for _ in range(hops):
        if not frontier:
            break
        ref_q = (
            'UNWIND $keys AS k '
            'MATCH (a:Article {article_key:k})-[r:REF]->(b) '
            'RETURN a.article_key AS from_article, labels(b) AS to_labels, '
            'coalesce(b.article_key, b.paragraph_key, b.law_key) AS to_key, '
            'r.scope AS scope, r.raw AS raw, r.target_level AS target_level'
        )
        recs, _, _ = run_query(ref_q, keys=list(frontier))

        next_frontier = set()
        for rec in recs:
            x = rec.data()
            ref_edges.append(x)
            labels = x.get('to_labels', [])
            key = x.get('to_key')
            if not key:
                continue
            if 'Article' in labels:
                target_article_keys.add(key)
                if key not in seen_articles:
                    seen_articles.add(key)
                    next_frontier.add(key)
            elif 'Paragraph' in labels:
                target_paragraph_keys.add(key)
        frontier = next_frontier

    # article targets -> paragraph context 확장
    if target_article_keys:
        recs, _, _ = run_query(
            'UNWIND $keys AS k '
            'MATCH (a:Article {article_key:k})-[:HAS_PARAGRAPH]->(p:Paragraph) '
            'RETURN p.paragraph_key AS paragraph_key',
            keys=list(target_article_keys),
        )
        for r in recs:
            target_paragraph_keys.add(r['paragraph_key'])

    # context fetch
    ctx_q = (
        'UNWIND $keys AS k '
        'MATCH (p:Paragraph {paragraph_key:k}) '
        'RETURN p.paragraph_key AS paragraph_key, p.law_name AS law_name, p.article_num AS article_num, '
        'p.paragraph_num AS paragraph_num, p.content AS content'
    )
    ctx_records, _, _ = run_query(ctx_q, keys=list(target_paragraph_keys))
    contexts = [r.data() for r in ctx_records if str(r.get('content','')).strip()]

    return {'seeds': seeds, 'ref_edges': ref_edges, 'contexts': contexts}


retrieved = vector_ref_retrieve(QUERY, top_k=TOP_K, hops=MAX_REF_HOPS)
print('seeds:', len(retrieved['seeds']))
print('ref_edges:', len(retrieved['ref_edges']))
print('contexts:', len(retrieved['contexts']))


In [None]:
# Step 5) Clova 답변 생성

def build_context_text(contexts, max_items=20):
    lines = []
    for c in contexts[:max_items]:
        lines.append(
            f"[{c.get('law_name','')} 제{c.get('article_num','')}조 {c.get('paragraph_num','')}항 | {c.get('paragraph_key','')}]\n{str(c.get('content',''))[:700]}"
        )
    return '\n\n'.join(lines)


context_text = build_context_text(retrieved['contexts'], max_items=20)

prompt = f'''당신은 건축법률 분석 어시스턴트입니다.
아래 질문에 대해, 제공된 법령 문맥만 근거로 한국어로 답하세요.

요구사항:
1) 근거 조문(법령명, 조, 항)을 2개 이상 명시
2) 불확실하면 불확실하다고 말하고 추가 필요 조건을 질문
3) 계산/판단이 있으면 단계적으로 작성

질문:
{QUERY}

문맥:
{context_text}
'''

resp = llm.invoke(prompt)
answer_text = getattr(resp, 'content', str(resp))
print(answer_text)


In [None]:
# Step 6) 디버깅 출력
print('[SEEDS]')
for s in retrieved['seeds']:
    print('-', s['paragraph_key'], '| score=', round(float(s.get('score', 0.0)), 4))

print('\n[REF EDGES]')
for e in retrieved['ref_edges'][:30]:
    print('-', e.get('from_key'), '--REF(', e.get('scope',''), ')->', e.get('to_key'), '|', e.get('raw',''))

print('\n[CONTEXT PREVIEW]')
for c in retrieved['contexts'][:5]:
    print(f"[{c.get('law_name','')} 제{c.get('article_num','')}조 {c.get('paragraph_num','')}항]")
    print(str(c.get('content',''))[:220])
    print('---')


In [None]:
# Optional) 스키마/충돌 진단
# 같은 라벨/키를 다른 데이터가 쓰는지 확인


labels, _, _ = run_query('CALL db.labels() YIELD label RETURN label ORDER BY label')
print('labels:', [r['label'] for r in labels])

for label, key in [('Document','law_id'), ('Article','article_key'), ('Paragraph','paragraph_key')]:
    rows, _, _ = run_query(f'MATCH (n:{label}) RETURN count(n) AS c')
    c = rows[0]['c'] if rows else 0
    dup, _, _ = run_query(f'MATCH (n:{label}) WITH n.{key} AS k, count(*) AS c WHERE k IS NOT NULL AND c > 1 RETURN count(*) AS d')
    d = dup[0]['d'] if dup else 0
    print(label, 'count=', c, 'dup_keys=', d)

scope, _, _ = run_query(
    "MATCH (p:Paragraph) RETURN count(*) AS total, "
    "count(CASE WHEN p.law_id IN ['001823','002118'] THEN 1 END) AS in_scope, "
    "count(CASE WHEN p.law_id IS NULL OR p.law_id NOT IN ['001823','002118'] THEN 1 END) AS out_scope"
)
print('paragraph scope:', scope[0] if scope else {})

outs, _, _ = run_query(
    "MATCH (p:Paragraph) WHERE p.law_id IS NULL OR p.law_id NOT IN ['001823','002118'] "
    "RETURN p.paragraph_key AS paragraph_key, p.law_id AS law_id, p.law_name AS law_name LIMIT 10"
)
print('out-scope samples:', [x.data() if hasattr(x, 'data') else x for x in outs])


In [None]:
# 종료
# driver.close()  # 필요 시 수동 종료
print('done')
