In [None]:
!pip install datasets evaluate spacy transformers -q
!python -m spacy download en_core_web_sm -q
!pip install bert-score
!pip install rouge_score

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m72.5 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting bert-score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bert-score
Successfully installed bert-score-0.3.13
Collecting rouge_score
  Downloading rouge_

In [None]:
# Cell 1: 라이브러리 설치 및 기본 세팅

import torch
import torch.nn as nn
import torch.nn.functional as F
import spacy
import pandas as pd
import gc
import numpy as np
import re

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import evaluate
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# -----------------------------------------------
# 1) Sentence-level embedding model (SBERT)
#    → 문장/segment 임베딩 & 문장/청크 중요도 계산
# -----------------------------------------------
sbert = SentenceTransformer("all-MiniLM-L6-v2", device=device)

# -----------------------------------------------
# 2) SpaCy (영어)
# -----------------------------------------------
# 사전에 아래 명령으로 모델 다운로드 필요:
# !python -m spacy download en_core_web_sm
nlp = spacy.load("en_core_web_sm")

# -----------------------------------------------
# 3) n 값 리스트 (원하는 대로 수정)
# -----------------------------------------------
N_LIST = [7, 9, 11, 13]

# -----------------------------------------------
# 4) 사용할 샘플 개수 및 split
# -----------------------------------------------
TRAIN_SPLIT = "train"
TEST_SPLIT = "test"

NUM_TRAIN = 200   # RL 학습에 사용할 기사 개수
NUM_TEST = 200    # 최종 평가에 사용할 기사 개수

# -----------------------------------------------
# 5) 평가 지표: BERTScore / ROUGE
# -----------------------------------------------
print("평가 지표 로딩 중...")
metric_bert = evaluate.load("bertscore")
metric_rouge = evaluate.load("rouge")
print("준비 완료!")

# -----------------------------------------------
# 6) 공용 BERTScore 함수 (항상 동일 설정 사용)
# -----------------------------------------------
def compute_bertscore(pred: str, ref: str) -> float:
    """
    BERTScore F1을 공통 설정으로 계산.
    - model_type, device를 고정해서 모든 실험을 동일 조건으로 비교.
    """
    try:
        score = metric_bert.compute(
            predictions=[pred],
            references=[ref],
            lang="en",
            model_type="bert-base-uncased",
            device=device,
        )["f1"][0]
    except Exception as e:
        print(f"[WARN] BERTScore 실패: {repr(e)}")
        score = 0.0
    return float(score)


Device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

평가 지표 로딩 중...


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

준비 완료!


In [None]:
# Cell 2: SpaCy 기반 chunk / POS / density 유틸 함수

from typing import List, Dict, Any
from collections import Counter


def split_sentences(text: str) -> List[str]:
    """
    전체 기사 텍스트를 spaCy로 문장 단위로 분리.
    """
    if not isinstance(text, str) or not text.strip():
        return []
    doc = nlp(text)
    sents = [s.text.strip() for s in doc.sents if s.text.strip()]
    return sents


def get_sentence_embedding(text: str):
    """
    SBERT로 한 문장/텍스트를 임베딩 (L2 normalize 포함).
    """
    emb = sbert.encode(
        text,
        convert_to_tensor=True,
        normalize_embeddings=True
    )
    return emb  # (d,) 또는 (N, d)


def get_chunks_spacy(sentence: str):
    """
    spaCy로 청크(supervector) 단위 만들기.
    - noun_chunks (명사구)
    - ROOT 동사 하나
    - 나머지 content 토큰들(NOUN/PROPN/VERB/ADJ/ADV)
    """
    doc = nlp(sentence)
    chunks = []
    used_token_idxs = set()

    # 1) noun_chunks 우선
    for nc in doc.noun_chunks:
        chunks.append(nc)
        used_token_idxs.update(range(nc.start, nc.end))

    # 2) ROOT 동사 추가
    roots = [t for t in doc if t.dep_ == "ROOT"]
    if roots:
        root = roots[0]
        if root.i not in used_token_idxs:
            span = doc[root.i:root.i+1]
            chunks.append(span)
            used_token_idxs.add(root.i)

    # 3) 나머지 content 토큰들(NOUN/PROPN/VERB/ADJ/ADV)
    for token in doc:
        if token.i in used_token_idxs:
            continue
        if token.pos_ in {"NOUN", "PROPN", "VERB", "ADJ", "ADV"}:
            span = doc[token.i:token.i+1]
            chunks.append(span)
            used_token_idxs.add(token.i)

    return doc, chunks


def build_pos_dict_from_doc(doc) -> Dict[str, str]:
    """
    spaCy doc에서 word.lower() -> 대표 POS 매핑
    (완벽한 매칭은 아니지만, 기능어 필터링 용도로 충분)
    """
    bucket = {}
    for token in doc:
        w = token.text.strip()
        if not w:
            continue
        key = w.lower()
        bucket.setdefault(key, []).append(token.pos_)

    pos_dict = {}
    for w, poses in bucket.items():
        most_common_pos = Counter(poses).most_common(1)[0][0]
        pos_dict[w] = most_common_pos
    return pos_dict


# 기능어로 보고 점수를 0에 가깝게 줄 POS
FUNCTION_POS = {
    "ADP",    # 전치사
    "CCONJ",  # 등위접속사
    "SCONJ",  # 종속접속사
    "PART",   # to, 's 등
    "PUNCT",
    "SPACE",
    "SYM",
    "DET",    # 관사
    "PRON",   # 대명사
}


def fix_possessives(words: List[str], original_sentence: str) -> List[str]:
    """
    ['court', 's', 'treaty'] 같은 리스트를
    ['court's', 'treaty']로 정리.
    원문 안에 실제로 "court's"가 있는지 간단히 체크.
    """
    fixed = []
    i = 0
    lower_orig = original_sentence.lower()

    while i < len(words):
        if i + 1 < len(words) and words[i+1].lower() == "s":
            candidate = words[i] + "'s"
            if candidate.lower() in lower_orig:
                fixed.append(candidate)
                i += 2
                continue
        fixed.append(words[i])
        i += 1

    return fixed


def simple_tokenize(text: str) -> List[str]:
    if not isinstance(text, str):
        return []
    return re.findall(r"\w+", text.lower())


def compute_density(pred: str, original: str) -> float:
    """
    pred     : 모델이 뽑은 요약 문자열
    original : 원문 기사 전체
    density  : pred 단어들 중 original에도 등장하는 단어 비율
    """
    pred_tokens = simple_tokenize(pred)
    orig_tokens = simple_tokenize(original)

    if not pred_tokens:
        return 0.0

    pred_set = set(pred_tokens)
    orig_set = set(orig_tokens)

    overlap = pred_set & orig_set
    return len(overlap) / len(pred_tokens)


In [None]:
# Cell 3: chunk-level 점수를 word-level로 전파 + span 추출 함수 (Baseline / RL 공용)

def build_word_scores_from_chunks(
    sentence: str,
    doc,
    chunks,
    chunk_scores: np.ndarray
):
    """
    chunk_scores (num_chunks,)를 단어 단위 점수로 전파.
    - 각 청크에 포함된 token에 해당 청크 score를 부여
    - FUNCTION_POS에 해당하는 기능어는 0점 처리
    """
    if len(chunks) == 0:
        words = sentence.split()
        return words, [0.0] * len(words)

    if len(chunk_scores) != len(chunks):
        raise ValueError("chunk_scores 길이와 chunks 개수가 다릅니다.")

    # word 리스트 및 word별 점수, chunk id 기록
    word_texts = []
    word_scores = []
    word_chunk_ids = []

    for cid, ch in enumerate(chunks):
        score = float(chunk_scores[cid])
        for token in ch:
            w = token.text
            if not w.strip():
                continue
            word_texts.append(w)
            word_scores.append(score)
            word_chunk_ids.append(cid)

    if not word_texts:
        words = sentence.split()
        return words, [0.0] * len(words)

    pos_dict = build_pos_dict_from_doc(doc)
    final_scores = []

    for w, sc, cid in zip(word_texts, word_scores, word_chunk_ids):
        # 숫자/기호는 0점
        if not any(ch.isalnum() for ch in w):
            final_scores.append(0.0)
            continue

        lower = w.lower()
        pos = pos_dict.get(lower, None)
        if pos in FUNCTION_POS:
            final_scores.append(0.0)
        else:
            final_scores.append(sc)

    return word_texts, final_scores


def select_best_span(words: List[str], scores: List[float], n: int, sentence: str) -> str:
    """
    단어 리스트와 점수 리스트가 주어졌을 때,
    길이 n짜리 연속 구간 중 점수 합이 최대인 span 선택.
    """
    total_words = len(words)

    if total_words == 0:
        return ""

    if total_words <= n:
        selected = words
    else:
        best_sum = None
        best_start = 0
        for start in range(0, total_words - n + 1):
            window_score = sum(scores[start:start + n])
            if (best_sum is None) or (window_score > best_sum):
                best_sum = window_score
                best_start = start
        selected = words[best_start:best_start + n]

    selected = fix_possessives(selected, sentence)
    return " ".join(selected)


def _get_keywords_chunk_span_n_core(
    sentence: str,
    n: int,
    policy_model=None,
    use_rl: bool = False
) -> str:
    """
    SBERT 기반 청크(supervector) + 연속 span 방식으로
    문장 sentence 안에서 길이 n의 연속 단어 span을 선택하는 공용 코어.

    - SBERT로 문장 임베딩, 청크 임베딩
    - cos(sent, chunk)로 chunk-level base score 계산
    - (옵션) policy_model이 있으면 chunk-level score를 보정(RL)
    - build_word_scores_from_chunks로 word-level score 전파
    - select_best_span으로 길이 n 연속 구간 선택
    """
    # 1) spaCy 청크 추출
    doc, chunks = get_chunks_spacy(sentence)
    if not chunks:
        words = sentence.split()
        return " ".join(words[:n])

    # 2) SBERT 문장 임베딩
    sent_emb = get_sentence_embedding(sentence)  # (d,)
    sent_emb_batch = sent_emb.unsqueeze(0)       # (1, d)

    # 3) SBERT 청크 임베딩
    chunk_texts = [ch.text for ch in chunks]
    chunk_embs = sbert.encode(
        chunk_texts,
        convert_to_tensor=True,
        normalize_embeddings=True
    )  # (num_chunks, d)

    # 4) chunk 기본 중요도: cos(sent_emb, chunk_emb)
    base_scores = torch.matmul(chunk_embs, sent_emb_batch.T).squeeze(1)  # (num_chunks,)

    # 4-1) (옵션) RL policy로 chunk score 보정
    # policy_model은 chunk_embs -> (num_chunks,) 형태의 보정값(delta)을 준다고 가정
    if use_rl and (policy_model is not None):
        policy_model.eval()
        with torch.no_grad():
            delta = policy_model(chunk_embs)  # (num_chunks,) 또는 (num_chunks, 1)
            if delta.dim() > 1:
                delta = delta.squeeze(-1)
        chunk_scores = base_scores + delta
    else:
        chunk_scores = base_scores

    # numpy로 변환
    chunk_scores_np = chunk_scores.detach().cpu().numpy()  # (num_chunks,)

    # 5) word-level 점수로 전파
    words, scores = build_word_scores_from_chunks(
        sentence,
        doc,
        chunks,
        chunk_scores_np
    )

    # 6) n단어 연속 span 선택
    return select_best_span(words, scores, n, sentence)


def get_keywords_chunk_span_n_baseline(sentence: str, n: int) -> str:
    """
    [Baseline] policy_model 없이 SBERT cosine 기반 chunk score만 사용하는 버전.
    """
    return _get_keywords_chunk_span_n_core(
        sentence=sentence,
        n=n,
        policy_model=None,
        use_rl=False
    )


def get_keywords_chunk_span_n_rl(
    sentence: str,
    n: int,
    policy_model
) -> str:
    """
    [RL 버전] 학습된 policy_model을 사용해 chunk-level score를 보정하는 버전.
    - policy_model: chunk_embs -> (num_chunks,) 형태의 점수 보정 네트워크
    """
    return _get_keywords_chunk_span_n_core(
        sentence=sentence,
        n=n,
        policy_model=policy_model,
        use_rl=True
    )


In [None]:
# Cell 4: TextRank 기반 문장 중요도 + sentence-level policy + generic span 추출기

from typing import List, Dict, Any, Optional
import torch.nn as nn


def build_similarity_matrix(
    sentences: List[str],
    sim_threshold: float = 0.1,
    return_embs: bool = False,
):
    """
    SBERT 문장 임베딩 기반 cosine similarity 행렬 생성.
    - sentences: 문장 리스트
    - sim_threshold: 너무 낮은 유사도는 0으로 잘라 노이즈 감소
    - return_embs: True이면 (sim_matrix, sent_embs) 튜플 반환
    """
    if len(sentences) == 0:
        if return_embs:
            return np.zeros((0, 0), dtype=float), None
        return np.zeros((0, 0), dtype=float)

    embs = sbert.encode(
        sentences,
        convert_to_tensor=True,
        normalize_embeddings=True
    )  # (N, d)

    with torch.no_grad():
        sim = torch.matmul(embs, embs.T)  # (N, N) cosine similarity

    sim = sim.cpu().numpy()
    N = sim.shape[0]

    # 자기 자신과의 유사도는 0으로
    for i in range(N):
        sim[i, i] = 0.0

    # threshold 이하 값은 0으로 컷
    sim[sim < sim_threshold] = 0.0

    if return_embs:
        return sim, embs
    return sim


def textrank_scores(sim_matrix: np.ndarray,
                    damping: float = 0.85,
                    max_iter: int = 50,
                    tol: float = 1e-4) -> np.ndarray:
    """
    TextRank / PageRank 스타일 문장 중요도 계산.
    - sim_matrix: (N, N) 비음수 similarity 행렬
    반환: 각 문장 중요도 점수 (합=1이 되도록 정규화된 벡터)
    """
    N = sim_matrix.shape[0]
    if N == 0:
        return np.array([])

    row_sums = sim_matrix.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    P = sim_matrix / row_sums  # (N, N)

    scores = np.ones(N, dtype=float) / N

    for _ in range(max_iter):
        new_scores = (1 - damping) / N + damping * P.T.dot(scores)
        if np.abs(new_scores - scores).sum() < tol:
            scores = new_scores
            break
        scores = new_scores

    scores = np.maximum(scores, 0.0)
    if scores.sum() > 0:
        scores = scores / scores.sum()
    return scores


def extract_spans_textrank_generic(
    text: str,
    target_n: int,
    span_fn,
    top_k_sents: int = 3,
    min_words_per_sent: int = 1,
    span_kwargs: Dict[str, Any] = None,
    sent_policy_model: Optional[nn.Module] = None,
):
    """
    TextRank 기반 문장 중요도 + sentence-level policy + span_fn을 이용한 문장 내부 n단어 추출.

    - 기본값: TextRank 점수만으로 문장 선택
    - sent_policy_model이 주어지면:
        * SBERT 문장 임베딩 + TextRank 점수를 feature로 concat
        * policy_model이 TextRank 점수에 대한 delta(logits)를 예측
        * 최종 문장 점수 logits = TextRank_score + delta
        * greedy로 상위 top_k_sents 선택
        * 선택된 문장들의 log_prob_sentence를 계산해서 meta에 저장

    span_fn: (sentence: str, n: int, **span_kwargs) -> str
    """
    if span_kwargs is None:
        span_kwargs = {}

    sentences = split_sentences(text)
    if not sentences:
        return "", [], {"sentences": [], "scores": []}

    # 1) 유사도 행렬 + TextRank 점수 + 문장 임베딩
    sim_mat, sent_embs = build_similarity_matrix(
        sentences,
        sim_threshold=0.1,
        return_embs=True,
    )
    tr_scores = textrank_scores(sim_mat)  # (N,)
    N = len(sentences)
    top_k = min(top_k_sents, N)

    # 2) sentence-level policy 적용
    log_prob_sentence = None  # RL 학습용 (없으면 None)

    if sent_policy_model is None or sent_embs is None or N == 0:
        # 순수 TextRank만 사용
        final_scores = tr_scores
        ranked_idx = np.argsort(-final_scores)          # 내림차순
        selected_idx = ranked_idx[:top_k]
    else:
        # TextRank + sentence policy (hybrid)
        tr_tensor = torch.tensor(tr_scores, dtype=torch.float32, device=device).unsqueeze(1)  # (N, 1)
        feats = torch.cat([sent_embs.to(device), tr_tensor], dim=1)  # (N, emb_dim + 1)

        # 정책 모델이 delta(logits)를 예측
        sent_policy_model.train()  # RL 학습 시 gradient 필요
        delta = sent_policy_model(feats).squeeze(-1)  # (N,)

        # TextRank 점수와 합쳐서 최종 logits
        logits = tr_tensor.squeeze(1) + delta         # (N,)

        # greedy로 상위 top_k 문장 선택 (eval용과 동일)
        ranked_idx_t = torch.argsort(logits, descending=True)
        selected_idx_t = ranked_idx_t[:top_k]                     # (top_k,)
        selected_idx = selected_idx_t.detach().cpu().numpy()      # numpy로 span 분배에 사용

        # softmax 확률에서 선택된 문장들의 log_prob 계산 (REINFORCE용)
        probs = torch.softmax(logits, dim=0)                      # (N,)
        log_prob_sentence = probs[selected_idx_t].log().sum()     # scalar tensor

        # final_scores는 numpy로 span 분배에 사용
        final_scores = logits.detach().cpu().numpy()

    # 3) 선택된 문장들에 target_n 분배
    sel_scores = final_scores[selected_idx]
    if sel_scores.sum() == 0:
        sel_scores = np.ones_like(sel_scores) / len(sel_scores)
    else:
        sel_scores = sel_scores / sel_scores.sum()

    float_ns = sel_scores * target_n
    base_ns = np.floor(float_ns).astype(int)
    remainder = target_n - base_ns.sum()

    # remainder > 0이면 점수가 높은 문장부터 한 단어씩 추가
    order_for_remainder = np.argsort(-sel_scores)
    for i in range(remainder):
        base_ns[order_for_remainder[i % top_k]] += 1

    # min_words_per_sent보다 작은 건 0으로 처리
    for i in range(len(base_ns)):
        if base_ns[i] < min_words_per_sent:
            base_ns[i] = 0

    # 4) 각 선택 문장에서 n_i 단어 span 추출
    spans = []
    used_word_count = 0

    for idx_sel, n_words in zip(selected_idx, base_ns):
        if n_words <= 0:
            continue
        sent = sentences[idx_sel]

        span = span_fn(sent, n_words, **span_kwargs)
        span = (span or "").strip()
        if not span:
            continue

        spans.append(span)
        used_word_count += len(span.split())

    pred_text = "; ".join(spans)

    meta = {
        "sentences": sentences,
        "textrank_scores": tr_scores,
        "final_scores": final_scores,
        "selected_idx": selected_idx,
        "allocated_words": base_ns,
        "used_word_count": used_word_count,
    }

    # sentence policy가 있으면 log_prob_sentence도 meta에 넣어줌
    if log_prob_sentence is not None:
        meta["log_prob_sentence"] = log_prob_sentence

    return pred_text, spans, meta


In [None]:
# Cell 5: RL 정책 모델 정의 + TextRank span 추출 래퍼

import torch.nn as nn


class ChunkPolicyModel(nn.Module):
    """
    chunk-level 정책 모델.
    - 입력: SBERT chunk embedding (num_chunks, d)
    - 출력: chunk별 보정 점수 (num_chunks,)
    실제 RL에서는 이 모델 파라미터를 PPO 등으로 학습.
    """
    def __init__(self, emb_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(emb_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, chunk_embs: torch.Tensor) -> torch.Tensor:
        """
        chunk_embs: (num_chunks, d)
        반환: (num_chunks,)
        """
        x = self.fc1(chunk_embs)
        x = torch.tanh(x)
        x = self.fc2(x).squeeze(-1)
        return x


class SentencePolicyModel(nn.Module):
    """
    sentence-level 정책 모델 (TextRank 하이브리드).
    - 입력: [SBERT 문장 임베딩 ⊕ TextRank 점수]  (차원: emb_dim + 1)
    - 출력: 각 문장에 대해 TextRank 점수에 더해줄 delta (num_sents,)
    """
    def __init__(self, emb_plus_tr_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(emb_plus_tr_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        """
        feats: (num_sents, emb_dim + 1)
        반환: (num_sents,) — TextRank 점수에 더해줄 보정값(delta)
        """
        x = self.fc1(feats)
        x = torch.tanh(x)
        x = self.fc2(x).squeeze(-1)
        return x


# SBERT 임베딩 차원에 맞춰 정책 모델 생성
emb_dim = sbert.get_sentence_embedding_dimension()

# 1) 청크 정책 (chunk-level RL)
chunk_policy_model = ChunkPolicyModel(emb_dim=emb_dim).to(device)

# 2) 문장 정책 (sentence-level RL, TextRank 하이브리드)
#    입력 차원 = SBERT 문장 임베딩 차원 + TextRank 점수(1)
sentence_policy_model = SentencePolicyModel(emb_plus_tr_dim=emb_dim + 1).to(device)

# (옵션) 옛 코드 호환용 alias
policy_model = chunk_policy_model

# TODO:
# 실제 RL 학습 후에는 아래와 같이 weight를 로드:
# chunk_policy_model.load_state_dict(torch.load("chunk_policy.pt", map_location=device))
# sentence_policy_model.load_state_dict(torch.load("sent_policy.pt", map_location=device))


def extract_spans_textrank_baseline(
    text: str,
    target_n: int,
    top_k_sents: int = 3,
    min_words_per_sent: int = 1,
):
    """
    Baseline 버전:
    - 문장 선택: 순수 TextRank
    - 문장 내부: get_keywords_chunk_span_n_baseline 사용
    - chunk / sentence RL 모두 사용하지 않음
    """
    return extract_spans_textrank_generic(
        text=text,
        target_n=target_n,
        span_fn=get_keywords_chunk_span_n_baseline,
        top_k_sents=top_k_sents,
        min_words_per_sent=min_words_per_sent,
        span_kwargs=None,
        sent_policy_model=None,   # ← sentence-level policy 비활성화
    )


def extract_spans_textrank_rl(
    text: str,
    target_n: int,
    policy_model: nn.Module,
    top_k_sents: int = 3,
    min_words_per_sent: int = 1,
):
    """
    RL 버전 (하이브리드):
    - 문장 선택: TextRank + SentencePolicyModel 보정
    - 문장 내부: get_keywords_chunk_span_n_rl(policy_model)로 chunk-level score 보정

    인자:
        policy_model: chunk-level policy 모델 (예: chunk_policy_model)
                      (옛 코드와 호환을 위해 이름은 그대로 둠)
    sentence_policy_model은 전역 변수로 사용.
    """
    return extract_spans_textrank_generic(
        text=text,
        target_n=target_n,
        span_fn=get_keywords_chunk_span_n_rl,            # chunk-level RL span
        top_k_sents=top_k_sents,
        min_words_per_sent=min_words_per_sent,
        span_kwargs={"policy_model": policy_model},      # 청크 정책 전달
        sent_policy_model=sentence_policy_model,         # 문장 정책 하이브리드
    )


In [None]:
# Cell 6: cnn_dailymail에서 Train/Test 샘플 로드 + 불량 데이터 필터링 (비스트리밍 버전)

def load_cnn_data_fast(
    split: str = "test",
    target_count: int = 300,      # 최종 사용할 샘플 수
    buffer_size: int = 600,       # 그보다 조금 더 많이 가져와서 필터링
    min_article_len: int = 200,   # 기사 본문 최소 길이 (문자 수 기준)
    min_summary_len: int = 20     # 하이라이트 최소 길이 (문자 수 기준)
) -> pd.DataFrame:
    """
    cnn_dailymail에서 앞 buffer_size개만 로드해서
    - article / highlights None 여부
    - 너무 짧은 샘플
    등을 미리 필터링해서 DataFrame으로 반환.

    (streaming=False + 캐시 재사용 → 429 에러 줄이기)
    반환 컬럼:
        - article   : 원문 기사 텍스트
        - reference : highlights(요약) 텍스트
    """
    print(f"\n--- cnn_dailymail {split} split 로드 시작 ---")
    print(f"목표: {target_count}개 / 버퍼: {buffer_size}개")

    try:
        ds = load_dataset(
            "cnn_dailymail",
            "3.0.0",
            split=f"{split}[:{buffer_size}]",
            download_mode="reuse_cache_if_exists",
        )
    except Exception as e:
        print("[Fatal] cnn_dailymail 로드 중 오류 발생:")
        print(repr(e))
        return pd.DataFrame(columns=["article", "reference"])

    print("원시 샘플 수:", len(ds))

    valid_items = []
    for item in ds:
        original = item.get("article", None)
        ref = item.get("highlights", None)

        # 1) None / 타입 체크
        if not original or not ref:
            continue
        if not isinstance(original, str) or not isinstance(ref, str):
            continue

        # 2) 길이 체크 (문자 길이 기준)
        if len(original.strip()) < min_article_len:
            continue
        if len(ref.strip()) < min_summary_len:
            continue

        valid_items.append({
            "article": original,
            "reference": ref,
        })

        if len(valid_items) >= target_count:
            break

    print(f"유효 샘플 수: {len(valid_items)}개")

    if not valid_items:
        print("유효 데이터가 없습니다. 필터 기준을 완화해보세요.")
        return pd.DataFrame(columns=["article", "reference"])

    return pd.DataFrame(valid_items)


# === Train 데이터 로드 ===
df_train = load_cnn_data_fast(
    split=TRAIN_SPLIT,          # 예: "train"
    target_count=NUM_TRAIN,     # 예: 200
    buffer_size=NUM_TRAIN * 2,
    min_article_len=200,
    min_summary_len=20,
)

print("\n[Train] 앞부분 샘플:")
display(df_train.head())
print("[Train] DataFrame 크기:", df_train.shape)

# === Test 데이터 로드 ===
df_test = load_cnn_data_fast(
    split=TEST_SPLIT,           # 예: "test"
    target_count=NUM_TEST,      # 예: 200
    buffer_size=NUM_TEST * 2,
    min_article_len=200,
    min_summary_len=20,
)

print("\n[Test] 앞부분 샘플:")
display(df_test.head())
print("[Test] DataFrame 크기:", df_test.shape)



--- cnn_dailymail train split 로드 시작 ---
목표: 200개 / 버퍼: 400개


README.md: 0.00B [00:00, ?B/s]

3.0.0/train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

3.0.0/validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

3.0.0/test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

원시 샘플 수: 400
유효 샘플 수: 200개

[Train] 앞부분 샘플:


Unnamed: 0,article,reference
0,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gets £20M f...
1,Editor's note: In our Behind the Scenes series...,Mentally ill inmates in Miami are housed on th...
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who we...","NEW: ""I thought I was going to die,"" driver sa..."
3,WASHINGTON (CNN) -- Doctors removed five small...,"Five small polyps found during procedure; ""non..."
4,(CNN) -- The National Football League has ind...,"NEW: NFL chief, Atlanta Falcons owner critical..."


[Train] DataFrame 크기: (200, 2)

--- cnn_dailymail test split 로드 시작 ---
목표: 200개 / 버퍼: 400개


Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

원시 샘플 수: 400
유효 샘플 수: 200개

[Test] 앞부분 샘플:


Unnamed: 0,article,reference
0,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...
1,(CNN)Never mind cats having nine lives. A stra...,"Theia, a bully breed mix, was apparently hit b..."
2,"(CNN)If you've been following the news lately,...",Mohammad Javad Zarif has spent more time with ...
3,(CNN)Five Americans who were monitored for thr...,17 Americans were exposed to the Ebola virus w...
4,(CNN)A Duke student has admitted to hanging a ...,Student is no longer on Duke University campus...


[Test] DataFrame 크기: (200, 2)


In [None]:
# Cell 6.5: RL 학습 (df_train 사용) — sentence-only REINFORCE

from torch.optim import AdamW

# BERTScore 기반 보상 함수 (항상 동일한 설정으로 사용)
def compute_bertscore_reward(pred: str, reference: str) -> float:
    if not isinstance(pred, str) or not pred.strip():
        return 0.0
    if not isinstance(reference, str) or not reference.strip():
        return 0.0

    out = metric_bert.compute(
        predictions=[pred],
        references=[reference],
        lang="en",
        model_type="bert-base-uncased",
        device=device,
    )
    # F1 기준 (scalar)
    return float(out["f1"][0])


# sentence policy만 학습
optimizer = AdamW(sentence_policy_model.parameters(), lr=1e-4)

NUM_EPOCHS_RL = 3  # 예시 — 네가 원하는 만큼 조정
baseline_reward = 0.0
baseline_momentum = 0.9  # EMA 계수 (0.9 ~ 0.99 사이 추천)

for epoch in range(NUM_EPOCHS_RL):
    print(f"\n[RL Train] Epoch {epoch+1}/{NUM_EPOCHS_RL}")

    for idx, row in tqdm(df_train.iterrows(), total=len(df_train)):
        article = row["article"]
        reference = row["reference"]

        if not isinstance(article, str) or not isinstance(reference, str):
            continue

        # 1) 현재 sentence_policy_model + (고정된 chunk/baseline)으로 요약 생성
        pred_rl, spans_rl, meta_rl = extract_spans_textrank_rl(
            text=article,
            target_n=9,   # 예: 특정 n에 대해서만 학습 (원하면 N_LIST로 loop 가능)
            policy_model=chunk_policy_model,  # chunk 쪽은 지금은 사실상 fixed head 역할
            top_k_sents=3,
            min_words_per_sent=1,
        )

        if not pred_rl or not pred_rl.strip():
            continue

        log_prob_sentence = meta_rl.get("log_prob_sentence", None)
        if log_prob_sentence is None:
            # sentence_policy_model이 없거나, meta에 log_prob가 없으면 학습 불가
            continue

        # 2) 보상 계산 (BERTScore 기반)
        reward = compute_bertscore_reward(pred_rl, reference)

        # 3) EMA baseline 업데이트
        baseline_reward = baseline_momentum * baseline_reward + (1 - baseline_momentum) * reward

        advantage = reward - baseline_reward  # scalar float

        # 4) REINFORCE: loss = - log_prob_sentence * advantage
        optimizer.zero_grad()

        # advantage는 상수 scalar로 취급 → gradient는 log_prob_sentence 쪽으로만 흐름
        loss = - log_prob_sentence * advantage

        loss.backward()
        optimizer.step()

    # 필요하면 epoch마다 sentence policy 모델 저장
    # torch.save(sentence_policy_model.state_dict(), f"sent_policy_epoch_{epoch+1}.pt")

print("RL 학습 완료 (sentence-only REINFORCE).")



[RL Train] Epoch 1/3


  0%|          | 0/200 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]


[RL Train] Epoch 2/3


  0%|          | 0/200 [00:00<?, ?it/s]


[RL Train] Epoch 3/3


  0%|          | 0/200 [00:00<?, ?it/s]

RL 학습 완료 (sentence-only REINFORCE).


In [None]:
# Cell 7: Baseline vs RL 평행 실험 루프

gc.collect()

results_baseline = []
results_rl = []

if not df_test.empty:
    print("병렬 실험 시작 (Baseline vs RL)...")

    for idx, row in tqdm(df_test.iterrows(), total=len(df_test)):
        article = row["article"]
        reference = row["reference"]

        if not isinstance(article, str) or not isinstance(reference, str):
            continue

        ref_len = len(reference.split())

        for n in N_LIST:
            if n <= 0:
                continue

            # ===== 1) Baseline 버전 =====
            pred_b, spans_b, meta_b = extract_spans_textrank_baseline(
                text=article,
                target_n=n,
                top_k_sents=3,
                min_words_per_sent=1,
            )
            pred_b = (pred_b or "").strip()

            if pred_b:
                pred_len_b = len(pred_b.split())

                # BERTScore (helper 함수 사용)
                b_b = compute_bertscore_reward(pred_b, reference)

                # ROUGE-1, ROUGE-2
                try:
                    rouge_scores_b = metric_rouge.compute(
                        predictions=[pred_b],
                        references=[reference],
                    )
                    r1_b = rouge_scores_b["rouge1"]
                    r2_b = rouge_scores_b["rouge2"]
                except Exception as e:
                    print(f"[WARN] ROUGE(baseline) 실패 idx={idx}, n={n}: {repr(e)}")
                    r1_b, r2_b = 0.0, 0.0

                eff_b = b_b / max(1, pred_len_b)
                density_b = compute_density(pred_b, article)

                results_baseline.append({
                    "idx": idx,
                    "target_n": n,
                    "article": article,
                    "reference": reference,
                    "pred": pred_b,
                    "pred_len": pred_len_b,
                    "ref_len": ref_len,
                    "bert_f1": b_b,
                    "rouge1": r1_b,
                    "rouge2": r2_b,
                    "efficiency": eff_b,
                    "density": density_b,
                    "used_word_count": meta_b.get("used_word_count", pred_len_b),
                    "model_type": "baseline",
                })

            # ===== 2) RL 버전 =====
            pred_r, spans_r, meta_r = extract_spans_textrank_rl(
                text=article,
                target_n=n,
                policy_model=chunk_policy_model,  # ← 청크 정책 명시적으로 사용
                top_k_sents=3,
                min_words_per_sent=1,
            )
            pred_r = (pred_r or "").strip()

            if pred_r:
                pred_len_r = len(pred_r.split())

                # BERTScore (helper 함수 사용)
                b_r = compute_bertscore_reward(pred_r, reference)

                # ROUGE-1, ROUGE-2
                try:
                    rouge_scores_r = metric_rouge.compute(
                        predictions=[pred_r],
                        references=[reference],
                    )
                    r1_r = rouge_scores_r["rouge1"]
                    r2_r = rouge_scores_r["rouge2"]
                except Exception as e:
                    print(f"[WARN] ROUGE(rl) 실패 idx={idx}, n={n}: {repr(e)}")
                    r1_r, r2_r = 0.0, 0.0

                eff_r = b_r / max(1, pred_len_r)
                density_r = compute_density(pred_r, article)

                results_rl.append({
                    "idx": idx,
                    "target_n": n,
                    "article": article,
                    "reference": reference,
                    "pred": pred_r,
                    "pred_len": pred_len_r,
                    "ref_len": ref_len,
                    "bert_f1": b_r,
                    "rouge1": r1_r,
                    "rouge2": r2_r,
                    "efficiency": eff_r,
                    "density": density_r,
                    "used_word_count": meta_r.get("used_word_count", pred_len_r),
                    "model_type": "rl",
                })

        # 주기적 메모리 정리
        if idx % 50 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

df_res_baseline = pd.DataFrame(results_baseline)
df_res_rl = pd.DataFrame(results_rl)

print("Baseline 결과 크기:", df_res_baseline.shape)
print("RL 결과 크기      :", df_res_rl.shape)

display(df_res_baseline.head())
display(df_res_rl.head())


병렬 실험 시작 (Baseline vs RL)...


  0%|          | 0/200 [00:00<?, ?it/s]

Baseline 결과 크기: (800, 14)
RL 결과 크기      : (800, 14)


Unnamed: 0,idx,target_n,article,reference,pred,pred_len,ref_len,bert_f1,rouge1,rouge2,efficiency,density,used_word_count,model_type
0,0,7,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,Palestine 's decision; oppose counterproductiv...,7,36,0.394865,0.0,0.0,0.056409,0.857143,7,baseline
1,0,9,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,justice Palestine 's decision; continue oppose...,9,36,0.397334,0.0,0.0,0.044148,0.888889,9,baseline
2,0,11,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,justice Palestine 's decision; said continue o...,11,36,0.399341,0.0,0.0,0.036304,0.909091,11,baseline
3,0,13,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,international justice Palestine 's decision; s...,13,36,0.397228,0.0,0.0,0.030556,1.0,13,baseline
4,1,7,(CNN)Never mind cats having nine lives. A stra...,"Theia, a bully breed mix, was apparently hit b...",her apparent death; medical attention; true mi...,7,44,0.375221,0.08,0.041667,0.053603,1.0,7,baseline


Unnamed: 0,idx,target_n,article,reference,pred,pred_len,ref_len,bert_f1,rouge1,rouge2,efficiency,density,used_word_count,model_type
0,0,7,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,International Criminal Court; Hamas Gaza; Pale...,7,36,0.408985,0.0,0.0,0.058426,1.0,7,rl
1,0,9,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,International Criminal Court genocide; Hamas G...,9,36,0.412161,0.0,0.0,0.045796,1.0,9,rl
2,0,11,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,International Criminal Court genocide; Israel ...,11,36,0.444405,0.044444,0.0,0.0404,1.0,11,rl
3,0,13,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,International Criminal Court genocide crimes; ...,13,36,0.465259,0.085106,0.0,0.035789,1.0,13,rl
4,1,7,(CNN)Never mind cats having nine lives. A stra...,"Theia, a bully breed mix, was apparently hit b...","stray pooch Washington; only animal; 10,000 ta...",7,44,0.352276,0.0,0.0,0.050325,1.0,7,rl


In [None]:
# Cell 8: 통합 CSV 저장 + n(target_n)별 / model_type별 평균 통계 (TEST SET 기준)

# ==========================================================
# 1) 통합 DataFrame (Baseline + RL) - TEST SET 결과만
# ==========================================================
df_res_all = pd.concat(
    [df_res_baseline, df_res_rl],
    ignore_index=True
)

print("=== [TEST SET] 전체 결과 통합 DataFrame ===")
print("크기:", df_res_all.shape)
display(df_res_all.head())

# ==========================================================
# 2) 전체 테스트셋 raw 비교 CSV 저장
# ==========================================================
csv_compare = f"cnn_textrank_chunk_rl_compare_{TEST_SPLIT}.csv"
df_res_all.to_csv(csv_compare, index=False)

print(f"\n[저장 완료] Baseline vs RL 비교 원본 CSV → '{csv_compare}'")


# ==========================================================
# 3) model_type + target_n별 평균 성능 통계
# ==========================================================
def summarize_by_n_and_model(df_res: pd.DataFrame, split=TEST_SPLIT):
    if df_res.empty:
        print("[오류] df_res가 비어 있습니다.")
        return None

    # 숫자 컬럼 자동 선택
    numeric_cols = df_res.select_dtypes(include="number").columns.tolist()

    # 평균 계산에서 제외할 컬럼
    exclude_cols = ["idx", "target_n"]
    metric_cols = [c for c in numeric_cols if c not in exclude_cols]

    # 그룹 평균 계산
    df_avg = (
        df_res
        .groupby(["model_type", "target_n"])[metric_cols]
        .mean()
        .reset_index()
    )

    print("\n=== [TEST SET] Baseline vs RL 평균 성능 ===")
    display(df_avg)

    csv_name = f"cnn_textrank_chunk_rl_summary_{split}.csv"
    df_avg.to_csv(csv_name, index=False)

    print(f"\n[저장 완료] 요약 통계 CSV → '{csv_name}'")

    return df_avg


# 실행
df_avg_all = summarize_by_n_and_model(df_res_all, split=TEST_SPLIT)


=== [TEST SET] 전체 결과 통합 DataFrame ===
크기: (1600, 14)


Unnamed: 0,idx,target_n,article,reference,pred,pred_len,ref_len,bert_f1,rouge1,rouge2,efficiency,density,used_word_count,model_type
0,0,7,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,Palestine 's decision; oppose counterproductiv...,7,36,0.394865,0.0,0.0,0.056409,0.857143,7,baseline
1,0,9,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,justice Palestine 's decision; continue oppose...,9,36,0.397334,0.0,0.0,0.044148,0.888889,9,baseline
2,0,11,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,justice Palestine 's decision; said continue o...,11,36,0.399341,0.0,0.0,0.036304,0.909091,11,baseline
3,0,13,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,international justice Palestine 's decision; s...,13,36,0.397228,0.0,0.0,0.030556,1.0,13,baseline
4,1,7,(CNN)Never mind cats having nine lives. A stra...,"Theia, a bully breed mix, was apparently hit b...",her apparent death; medical attention; true mi...,7,44,0.375221,0.08,0.041667,0.053603,1.0,7,baseline



[저장 완료] Baseline vs RL 비교 원본 CSV → 'cnn_textrank_chunk_rl_compare_test.csv'

=== [TEST SET] Baseline vs RL 평균 성능 ===


Unnamed: 0,model_type,target_n,pred_len,ref_len,bert_f1,rouge1,rouge2,efficiency,density,used_word_count
0,baseline,7,6.975,34.735,0.431118,0.113954,0.030014,0.061895,0.901722,6.975
1,baseline,9,8.935,34.735,0.440931,0.128633,0.032468,0.049506,0.907192,8.935
2,baseline,11,10.915,34.735,0.450382,0.146661,0.03675,0.0414,0.899828,10.915
3,baseline,13,12.88,34.735,0.456271,0.159754,0.040736,0.035548,0.89787,12.88
4,rl,7,7.0,34.735,0.419212,0.09256,0.023982,0.059887,0.950298,7.0
5,rl,9,8.995,34.735,0.426878,0.10233,0.026963,0.04746,0.937336,8.995
6,rl,11,10.985,34.735,0.435951,0.119509,0.030269,0.039695,0.937863,10.985
7,rl,13,12.96,34.735,0.44037,0.134414,0.031742,0.033988,0.93067,12.96



[저장 완료] 요약 통계 CSV → 'cnn_textrank_chunk_rl_summary_test.csv'
