In [2]:
import warnings
warnings.filterwarnings("ignore")

import re
import json
import torch
import spacy
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset 
from transformers import BertTokenizer, BartForConditionalGeneration, BartTokenizer

In [3]:
torch.__version__

'2.5.0+cu124'

In [4]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    # "mps"  if torch.backends.mps.is_available() and torch.backends.mps.is_built()else
    "cpu"
)
print(f"Using device: {device}")

Using device: cuda


In [5]:
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 4090'

In [6]:
def split_text_with_spacy(text, max_tokens=512):
    """
    這是利用 spacy 的方式將句子切割放入 chunk，讓每一個 chunk 的總 token 數 <= max_tokens。
    """
    doc = nlp(text)
    chunks = []
    current_chunk = ""
    
    for sent in doc.sents:
        sentence = sent.text
        if len(tokenizer(current_chunk + sentence)['input_ids']) > max_tokens:
            chunks.append(current_chunk.strip())
            current_chunk = sentence
        else:
            current_chunk += " " + sentence
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    return chunks

In [7]:
def fix_summary_punctuation(summary):
    """
    這是用來處理，如果結尾沒有標點符號，則加入標點符號。
    """
    if summary and summary[-1] not in ".!?":
        return summary + "."
    return summary

### **這裡是用來載入模型的**

In [8]:
nlp = spacy.load("en_core_web_sm")
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')

In [9]:
def generate_summary(article, min_summary_length=130, max_summary_length=512):
    """
    這是寫的，這裡一開始encode的max_length用1024，因為這是 bart 最大輸入token 數限制。
    """
    input_ids = tokenizer.encode(article, max_length=1024, return_tensors="pt", truncation=True).to(device)
    # inputs = tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True).to(device)['input_ids']
    summary_ids = model.generate(input_ids,
                                 length_penalty=2.0, num_beams=4, early_stopping=True,
                                 min_length=min_summary_length,
                                 max_length=max_summary_length)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

def count_token(text):
    return len(tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)[0]) # goal_len

In [10]:
def merge_summaries(summaries, max_tokens=512):
    merged_chunks = []
    current_chunk = ""
    
    for summary in summaries:
        if count_token(current_chunk + " " + summary) > max_tokens:
            merged_chunks.append(current_chunk.strip())
            current_chunk = summary
        else:
            current_chunk += " " + summary
    if current_chunk:
        merged_chunks.append(current_chunk.strip())
    
    return merged_chunks

In [11]:
def bart_summarize_article(
        article,
        max_bart_input_length=512,
        min_summary_length=128,
        max_summary_length=512,
        final_min_summary_length=128,
        final_max_summary_length=256,
        log=False
    ):
    """使用 BART 模型對給定的長文本文章進行摘要。

    這個函數將文章分割成多個部分，並對每個部分生成摘要，然後合併這些摘要，直到只剩下一個摘要為止。

    Parameters
    ----------
    article (str)-------------------- 要摘要的文章內容。
    max_bart_input_length (int)------ BART 模型輸入的最大長度（以 token 計算），默認為 512。
    min_summary_length (int)--------- 生成摘要的最小長度（以 token 計算），默認為 128。
    max_summary_length (int)--------- 生成摘要的最大長度（以 token 計算），默認為 512。
    final_min_summary_length (int)--- 最終摘要的最小長度（以 token 計算），默認為 128。
    final_max_summary_length (int)--- 最終摘要的最大長度（以 token 計算），默認為 256。
    log (bool)----------------------- 是否打印每一層的摘要過程和 token 數的日誌，默認為 False。

    Returns
    -------
    summarized_article (str): 生成的最終摘要內容。
    """
    # --- 將原始 article 按照 max_tokens 分割成好幾個 parts。（原因：因為 bart 輸入的 token 數只能是 1024。）
    article_parts = split_text_with_spacy(article, max_tokens=max_bart_input_length)
    
    # --- 開始使用 while 迴圈，一層一層不斷切割，並摘要文章。
    round = 1 # 用於記錄 層數/圈數，只會在 log 裡面出現，如果要保守點預防無限迴圈，應該要把他加入到 while 判斷，如果 round > 某個數，就 break。
    while len(article_parts) > 1: # 
        # --- 這裡是用來 print 出 token 數做紀錄的
        if log:
            print(f"\noriginal layer-{round:<2}({len(article_parts):<2} parts): {' '*round*3}|", end='')
            for idx, article_part in enumerate(article_parts): print(f" {count_token(article_part):<4}|", end='')
            print(f"\n summary layer-{round:<2}({len(article_parts):<2} parts): {' '*round*3}|", end='')
        # --- 這裡開始用 for 迴圈跑，對於每一個文章，我都要進行一次摘要。
        summaries = [] # 用於記錄所有的摘要結果
        for idx, article_part in enumerate(article_parts):
            summary = generate_summary(article_part, # 對於每一個文章根據限制數量進行摘要
                                       min_summary_length=min_summary_length, # 限制摘要常最短為 min_summary_length
                                       max_summary_length=max_summary_length  # 限制摘要常最長為 max_summary_length
                                      )
            summary = fix_summary_punctuation(summary) # 由於有可能生成的結果沒有結尾符號，所以這裡加上去。
            summaries.append(summary) # 將輸出的摘要記錄。
            # --- 這裡是用來 print 出紀錄的
            if log: print(f" {count_token(summary):<4}|", end='')
            # ---

        # --- 將 article_parts 按照 max_tokens 組合成好幾個 parts。（原因：因為 bart 輸入的 token 數只能是 1024。）
        article_parts = merge_summaries(summaries, max_tokens=max_bart_input_length)
        round += 1
    # --- 生成對後一次的摘要
    final_summary = generate_summary(article_parts[0], # 因為前面 while 迴圈的關係，這裡的 article_parts 長度一定會是 1，所以只取第一個當作輸入。
                                     min_summary_length=final_min_summary_length, # 限制摘要常最短為 final_min_summary_length
                                     max_summary_length=final_max_summary_length  # 限制摘要常最長為 final_max_summary_length
                                    )
    # --- 這裡是用來 print 出 token 數做紀錄的
    if log:
        print(f"\noriginal layer-{round:<3}({1:<2} parts): {' '*round*3}| {count_token(article_parts[0]):<4}|", end='')
        print(f"\nsummary  layer-{round:<3}({1:<2} parts): {' '*round*3}| {count_token(final_summary):<4}|", end='')
    # ---
    return final_summary

## 這裡用 arXiv 抓下來的文章做測試

In [12]:
test_article = open("./test/article.md").read()
test_abstract = open("./test/abstract.md").read()
goal_token_count = count_token(test_abstract)
bart_summarize_article(test_article,
                  max_bart_input_length    = 512,
                  min_summary_length       = 230, # ~ (512/2) * 0.9
                  max_summary_length       = 256, # ~ (512/2) 
                  final_min_summary_length = int(goal_token_count*0.9),
                  final_max_summary_length = int(goal_token_count*1.2),
                  log=True)


original layer-1 (13 parts):    | 500 | 473 | 489 | 500 | 505 | 508 | 497 | 505 | 459 | 507 | 461 | 411 | 478 |
 summary layer-1 (13 parts):    | 245 | 255 | 245 | 246 | 238 | 241 | 242 | 256 | 237 | 240 | 231 | 239 | 256 |
original layer-2 (7  parts):       | 498 | 489 | 477 | 495 | 475 | 468 | 256 |
 summary layer-2 (7  parts):       | 239 | 245 | 240 | 241 | 234 | 253 | 256 |
original layer-3 (4  parts):          | 482 | 479 | 485 | 256 |
 summary layer-3 (4  parts):          | 249 | 254 | 238 | 256 |
original layer-4 (2  parts):             | 501 | 492 |
 summary layer-4 (2  parts):             | 248 | 248 |
original layer-5  (1  parts):                | 494 |
summary  layer-5  (1  parts):                | 286 |

'Coarse maps of the distribution of galaxies can be constructed by measuring redshifts. Results of the study will be published in the forthcoming issue of The Astrophysics of Light and Space, published by The Astronomical Society of America. The entire Universe can be considered a patchwork of abutting BoAs, just as the terrestrial landscape is separated into watersheds. A BoA is generally not gravitationally bound, as the relative motion of distant points within it are usually dominated by the cosmic expansion. Any source position in a BoA leads via a streamline to a "sink" near the potential minimum within the BoA. Streamlines diverge out of the local maxima of the velocity potential and converge onto its local minima - namely they \'stream\' away from the underdense to the dense regions of the Universe. Overall, there is reasonable coverage outside the zone of Milky Way obscuration across the sky within $z=0.05$ with a slight deficiency south of the Milky Way in the celestial north.

In [13]:
def bart_summarize_article(
        article,
        max_bart_input_length=512,
        min_summary_length=128,
        max_summary_length=512,
        final_min_summary_length=128,
        final_max_summary_length=256
    ):
    
    article_parts = split_text_with_spacy(article, max_tokens=max_bart_input_length)
    
    while len(article_parts) > 1: 
        summaries = [] 
        for idx, article_part in enumerate(article_parts):
            summary = generate_summary(article_part, 
                                       min_summary_length=min_summary_length, 
                                       max_summary_length=max_summary_length  
                                      )
            summary = fix_summary_punctuation(summary) 
            summaries.append(summary) 
        article_parts = merge_summaries(summaries, max_tokens=max_bart_input_length)
    
    final_summary = generate_summary(article_parts[0], 
                                     min_summary_length=final_min_summary_length, 
                                     max_summary_length=final_max_summary_length  
                                    )
    return final_summary

## 這裡使用 ollama 生成的文章做 bart 摘要

In [15]:
# task_name = 'original'
# task_name = 'highlighted'
# task_name = 'summarized'
# task_name = 'compressed'
# task_name = 'abstracted'
# task_name = 'ollama_highlight_token_ratio'
# task_name = 'ollama_highlight_words'
task_name = 'ollama_highlight_words_ratio'
# ---
if task_name=='original':
    df_by_llama = pd.read_csv("./dat/cnn_dailymail_test.csv")
else:
    df_by_llama = pd.read_csv(f"./dat/cnn_dailymail_test-{task_name}.csv")
# ---
df_token_count = pd.read_csv(f"./dat/cnn_dailymail_test-token_count.csv")
df_by_llama.shape[0], df_token_count.shape[0]

(11490, 11490)

In [16]:
article_ids = df_by_llama['id']
article_need_summarize = df_by_llama['article']
highlight_token_counts = df_token_count['highlights']

## **這裡使用的 min length 和 max length 是 CNN daily mail 中 highlights 的 token 數量 $\pm 30\%$**

In [15]:
# import time
# time.sleep(600)

In [17]:
SAVE_RESULT_PATH = f'./dat/cnn_dailymail_test-{task_name}-bart.dat'
SAVE_EVERRY_NUM = 5 # 每跑完幾筆資料就儲存
TEMPORARY_SAVE_STRING = "" # 用於儲存要放入文件的字串

processed_index = 0
summarized_articles = []
for article_id, goal_token_count, article in tqdm(zip(article_ids, highlight_token_counts, article_need_summarize), total=len(article_ids), desc='BART summarize'):
    try:
        summarized_article = bart_summarize_article(article,
                  max_bart_input_length    = 512,
                  min_summary_length       = 230, # ~ (512/2) * 0.9
                  max_summary_length       = 256, # ~ (512/2) 
                  final_min_summary_length = goal_token_count,
                  final_max_summary_length = int(goal_token_count*1.2))
        TEMPORARY_SAVE_STRING += f"{article_id},{json.dumps(summarized_article)}\n"
        if not processed_index%SAVE_EVERRY_NUM :
            with open(SAVE_RESULT_PATH, "a") as file:
                file.write(TEMPORARY_SAVE_STRING)
            TEMPORARY_SAVE_STRING = ""
    except Exception as e:
        print(f"Error processing row {processed_index}: {e}")
    processed_index += 1
if TEMPORARY_SAVE_STRING:
    with open(SAVE_RESULT_PATH, "a") as file:
        file.write(TEMPORARY_SAVE_STRING)
print("summarize complete => save to csv complete")

BART summarize: 100%|██████████████████████████████████████████████████████████| 11490/11490 [4:35:08<00:00,  1.44s/it]

summarize complete => save to csv complete





In [17]:
if TEMPORARY_SAVE_STRING:
    with open(SAVE_RESULT_PATH, "a") as file:
        file.write(TEMPORARY_SAVE_STRING)