In [1]:
import pandas as pd
import re
import torch
from transformers import AutoTokenizer, AutoModel
from kiwipiepy import Kiwi
from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection
import time # 시간 측정을 위해 추가 (선택적)

# --- 1. 설정값 ---
CSV_FILE_PATH = 'naver_news_data.csv'  # <<<< CSV 파일 경로를 지정하세요
TEXT_COLUMN = 'content'             # 임베딩할 텍스트가 포함된 컬럼 이름
METADATA_COLUMNS = ['title', 'datetime', 'summary', 'url'] # 함께 저장할 메타데이터 컬럼들

# 임베딩 모델 설정
MODEL_NAME = "klue/bert-base"

# 청킹 설정
MAX_TOKENS_PER_CHUNK = 400 # 청크당 최대 토큰 수 (모델 최대 길이보다 작게)
# (kiwipiepy는 문장 단위로 나누므로 overlap 개념은 직접 적용 안 함)

# Milvus 설정
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
COLLECTION_NAME = "news_embeddings" # 새 컬렉션 이름 지정 권장
VECTOR_DIM = 768

# --- 2. 모델 및 토크나이저, Kiwi 로드 ---
print("Loading models and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # 평가 모드
print(f"Using device: {device}")

print("Initializing Kiwi...")
kiwi = Kiwi()

# --- 3. Milvus 연결 및 컬렉션 준비 ---
print(f"Connecting to Milvus ({MILVUS_HOST}:{MILVUS_PORT})...")
try:
    connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
    print("Successfully connected to Milvus.")
except Exception as e:
    print(f"Failed to connect to Milvus: {e}")
    exit()

# 새 스키마 정의 (메타데이터 필드 추가)
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="chunk_text", dtype=DataType.VARCHAR, max_length=65535), # 청크 텍스트
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=VECTOR_DIM),
    FieldSchema(name="original_article_id", dtype=DataType.INT64), # 원본 CSV의 행 index 등
    FieldSchema(name="chunk_seq_id", dtype=DataType.INT64),       # 기사 내 청크 순서
    # 추가 메타데이터 필드들
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=1024), # 길이 조절 가능
    FieldSchema(name="datetime", dtype=DataType.VARCHAR, max_length=64), # 또는 날짜 타입
    FieldSchema(name="summary", dtype=DataType.VARCHAR, max_length=2048), # 길이 조절 가능
    FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=2048)      # 길이 조절 가능
]
schema = CollectionSchema(fields, description="Naver News Embeddings with Metadata")

# 컬렉션 생성 또는 가져오기
if utility.has_collection(COLLECTION_NAME):
    print(f"Collection '{COLLECTION_NAME}' already exists.")
    collection = Collection(COLLECTION_NAME)
else:
    print(f"Creating collection '{COLLECTION_NAME}'...")
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    print(f"Collection '{COLLECTION_NAME}' created successfully.")

# --- 4. 데이터 처리 및 Milvus 삽입 준비 ---
print(f"Reading CSV file: {CSV_FILE_PATH}")
try:
    # CSV 인코딩 주의: 'utf-8' 또는 'cp949' 등을 시도
    df = pd.read_csv(CSV_FILE_PATH, encoding='utf-8')
    print(f"Successfully read {len(df)} articles from CSV.")
except FileNotFoundError:
    print(f"Error: CSV file not found at {CSV_FILE_PATH}")
    exit()
except Exception as e:
    print(f"Error reading CSV file: {e}")
    exit()

# Milvus에 삽입할 데이터 리스트들 초기화
all_chunk_texts = []
all_embeddings = []
all_original_article_ids = []
all_chunk_seq_ids = []
all_titles = []
all_datetimes = []
all_summaries = []
all_urls = []

start_time = time.time()
print("Starting data processing (Chunking, Embedding)...")

# DataFrame의 각 행(기사) 처리
for idx, row in df.iterrows():
    article_content = row[TEXT_COLUMN]
    article_title = row['title'] # 메타데이터 가져오기
    article_datetime = str(row['datetime']) # 문자열로 변환 (필요시)
    article_summary = row['summary']
    article_url = row['url']

    # content가 비어있거나 NaN인 경우 건너뛰기
    if not isinstance(article_content, str) or pd.isna(article_content) or not article_content.strip():
        print(f"Skipping article index {idx} due to empty content.")
        continue

    # 간단한 전처리 (연속 공백 제거 등)
    cleaned_content = re.sub(r'\s+', ' ', article_content).strip()

    # Kiwi를 이용한 문장 분할
    sentences = [s.text for s in kiwi.split_into_sents(cleaned_content)]

    # 문장을 그룹화하여 청크 생성
    current_chunk_sentences = []
    current_chunk_length = 0
    article_chunks = []

    for sentence in sentences:
        sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
        sentence_length = len(sentence_tokens)

        if current_chunk_length + sentence_length <= MAX_TOKENS_PER_CHUNK:
            current_chunk_sentences.append(sentence)
            current_chunk_length += sentence_length + 1
        else:
            if current_chunk_sentences:
                article_chunks.append(" ".join(current_chunk_sentences))
            if sentence_length <= MAX_TOKENS_PER_CHUNK:
                current_chunk_sentences = [sentence]
                current_chunk_length = sentence_length
            else:
                # 문장 자체가 너무 긴 경우, 일단 잘릴 것을 감수하고 청크에 넣음
                article_chunks.append(sentence)
                current_chunk_sentences = []
                current_chunk_length = 0
    if current_chunk_sentences:
        article_chunks.append(" ".join(current_chunk_sentences))

    # 생성된 청크들을 임베딩하고 데이터 준비
    with torch.no_grad():
        for seq_id, chunk_text in enumerate(article_chunks):
            if not chunk_text: continue # 빈 청크 스킵

            encoded_input = tokenizer(chunk_text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)
            try:
                model_output = model(**encoded_input)
                # Mean Pooling 적용
                token_embeddings = model_output[0]
                input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                embedding_vector = (sum_embeddings / sum_mask).cpu().squeeze().tolist()

                # 삽입할 데이터 리스트에 추가
                all_chunk_texts.append(chunk_text)
                all_embeddings.append(embedding_vector)
                all_original_article_ids.append(idx) # 원본 DataFrame 인덱스 사용
                all_chunk_seq_ids.append(seq_id)
                all_titles.append(article_title)
                all_datetimes.append(article_datetime)
                all_summaries.append(article_summary if isinstance(article_summary, str) else "") # NaN 처리
                all_urls.append(article_url)

            except Exception as e:
                print(f"Error embedding chunk {seq_id} for article {idx}: {e}")

    if (idx + 1) % 50 == 0: # 50개 기사 처리마다 로그 출력
        print(f"Processed {idx + 1}/{len(df)} articles...")

processing_time = time.time() - start_time
print(f"Data processing finished. Time taken: {processing_time:.2f} seconds.")
print(f"Total chunks to insert: {len(all_chunk_texts)}")

# --- 5. Milvus 데이터 삽입 ---
if not all_chunk_texts:
    print("No valid data to insert into Milvus.")
else:
    # 삽입 형식 준비
    entities_to_insert = [
        all_chunk_texts,
        all_embeddings,
        all_original_article_ids,
        all_chunk_seq_ids,
        all_titles,
        all_datetimes,
        all_summaries,
        all_urls
    ]

    try:
        print(f"Inserting {len(all_chunk_texts)} entities into '{COLLECTION_NAME}'...")
        insert_result = collection.insert(entities_to_insert)
        print(f"Insertion successful. Primary keys count: {len(insert_result.primary_keys)}")

        print("Flushing data...")
        collection.flush()
        print("Data flushed.")

    except Exception as e:
        print(f"Failed to insert data into Milvus: {e}")

    # --- 6. 인덱스 생성 및 로드 ---
    INDEX_PARAM = {
        "metric_type": "L2",
        "index_type": "IVF_FLAT",
        "params": {"nlist": 128} # 데이터 양에 따라 조절
    }

    if not collection.has_index():
        print(f"Creating index ({INDEX_PARAM['index_type']}) for embedding field...")
        try:
            collection.create_index(field_name="embedding", index_params=INDEX_PARAM)
            print("Index created successfully.")
            print("Waiting for index to build...")
            utility.wait_for_index_building_complete(COLLECTION_NAME)
            print("Index building complete.")
        except Exception as e:
            print(f"Failed to create index: {e}")
    else:
        print("Index already exists.")

    print("Loading collection into memory...")
    try:
        collection.load()
        print(f"Collection '{COLLECTION_NAME}' loaded successfully.")
    except Exception as e:
        print(f"Failed to load collection: {e}")

# --- 7. 연결 종료 (선택적) ---
# connections.disconnect("default")

print("\n--- Milvus processing finished ---")

Loading models and tokenizer...
Using device: cpu
Initializing Kiwi...
Connecting to Milvus (localhost:19530)...
Successfully connected to Milvus.
Creating collection 'news_embeddings'...
Collection 'news_embeddings' created successfully.
Reading CSV file: naver_news_data.csv
Successfully read 100 articles from CSV.
Starting data processing (Chunking, Embedding)...
Processed 50/100 articles...
Processed 100/100 articles...
Data processing finished. Time taken: 91.87 seconds.
Total chunks to insert: 215
Inserting 215 entities into 'news_embeddings'...
Insertion successful. Primary keys count: 215
Flushing data...
Data flushed.
Creating index (IVF_FLAT) for embedding field...
Index created successfully.
Waiting for index to build...
Index building complete.
Loading collection into memory...
Collection 'news_embeddings' loaded successfully.

--- Milvus processing finished ---
