# Load Data

In [33]:
import json
from sklearn.model_selection import train_test_split

file_path = '/user_data/text_summarization/chinese_summary_dataset.json'
with open(file_path, 'r', encoding='utf-8') as f:
    data = json.load(f)
    
articles = [entry['article'] for entry in data]
summaries = [entry['summary'] for entry in data]

train_articles, test_articles, train_summaries, test_summaries = train_test_split(
    articles, summaries, test_size=0.2, random_state=42
)

train_articles, val_articles, train_summaries, val_summaries = train_test_split(
    train_articles, train_summaries, test_size=0.2, random_state=42
)

In [34]:
print("training data 筆數：", len(train_articles))
print("testing data 筆數：", len(test_articles))
print("question：", train_articles[0])
print("taget：", train_summaries[0])

training data 筆數： 15662
testing data 筆數： 4895
question： 多知網4月9日訊息，今日，尚德召開了主題為“重生—一場志在炫耀的釋出會”，披露3月份營收數字：7840萬，同比增長100%。尚德還公佈了以下資料：通過網際網路學習的人數突破了6萬，每天線上學習人數2萬5千人，看回放總人次超過3萬人，老師授課時長超過3萬小時，單天刷題量超過56萬，學生打分好評率達到99.8%，通過率60%-65%，重複消費率破40%⋯⋯轉型後的尚德，一改過去幾年“崇尚黑森林法則”的論調，愈發高調起來。用尚德創始人歐蓬的話說，“終於可以出來見人了”。那麼，尚德任性＂炫耀＂背後的底氣是什麼？單月營收額7840萬是如何實現的？**首先要轉變思想**2008年-2010年，是歐蓬的理論建設期。從2008年開始，他不再混教育圈，而是去見網際網路行業的人。“那個時間點，網際網路行業的人談的一些事情我完全不懂，我嘗試努力懂得他們的思維方式，懂得他們的邏輯，懂得他們的玩法、戰術、戰略。”基於之前的理論建設，2011年，尚德做了第一代試錯產品嗨學網，2013年做出了第二代試錯產品對啊網，2014年6月6日全面轉到狐邏學院做直播。歐蓬認為，風格是價值觀的外延，內心的轉變很重要。“蘇寧一直還穿西服去做網際網路，這個風格是有問題的，所以說對於任何一個傳統企業的老闆，他最大的問題是老闆他到底有沒有變化。”**狼性文化，果斷捨棄**2013年底，尚德決定將線下店全部關閉。對於這一決定，一些校長是＂拒絕＂的。＂我相信沒好的未來，我相信線上是對的，但我接受不了兩年辛辛苦苦經營的校一夜之間關了。”很多校長從情感上無法接受，忍痛離開。另一關鍵群體是老師。很多的傳統的老師是很牴觸或者不適應上直播課，要調動學員氛圍，要產生粉絲和口碑。在老師轉型的過程中，尚德的原則是，對於心態非常封閉保守的老師，如果實在不能轉型，尚德會果斷捨棄。“這樣的老師的基因更加適合面授，我們更想帶更多願意跟我們活在未來的人一起走，我們的風格是狼性的。＂直播形式推動教學質量尚德認為，教育1.0時代是面授時代，相對價效比比較高的價格保證相對還不錯的通過率；2.0時代是錄播時代，學員很難戰勝看視訊的孤獨感，錄播的完成率不高於5%，高度依賴於一個人的自學能力和自控能力，教學質量蠻難控制。3.0時代是直播時代，最重要

# Tokenizer

In [35]:
from transformers import  BartTokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

tokenize_train_data=tokenizer(train_articles,text_target=train_summaries,padding=True,truncation=True)
tokenize_val_data=tokenizer(val_articles,text_target=val_summaries,padding=True,truncation=True)
tokenize_test_data=tokenizer(test_articles,text_target=test_summaries,padding=True,truncation=True)

In [36]:
tokenize_train_data.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [37]:
tokenize_train_data["input_ids"][0]

[0,
 47983,
 15113,
 36714,
 4333,
 8210,
 36714,
 19002,
 14292,
 306,
 44636,
 23133,
 466,
 47954,
 8210,
 36484,
 11423,
 27969,
 37127,
 10172,
 10965,
 43251,
 4394,
 14285,
 46890,
 27969,
 47954,
 8210,
 43251,
 4394,
 14285,
 47842,
 15113,
 48412,
 18400,
 45262,
 11582,
 49035,
 13859,
 46499,
 27819,
 46015,
 2023,
 41907,
 5543,
 14285,
 36714,
 9264,
 3070,
 17,
 48,
 47994,
 8384,
 48998,
 578,
 48105,
 42393,
 21402,
 20024,
 48765,
 6800,
 46537,
 11423,
 36714,
 9264,
 4958,
 48991,
 7471,
 44574,
 47994,
 13859,
 48635,
 3070,
 44636,
 862,
 17,
 46,
 43251,
 4394,
 14285,
 37127,
 27969,
 4958,
 41907,
 48,
 14292,
 246,
 44636,
 23133,
 46890,
 10809,
 36714,
 6382,
 4333,
 37127,
 10674,
 49363,
 15722,
 18537,
 48823,
 6800,
 43251,
 4394,
 15113,
 5479,
 1749,
 36484,
 16948,
 11582,
 43251,
 4394,
 14285,
 47504,
 14285,
 37127,
 10965,
 10674,
 42393,
 7258,
 17772,
 41907,
 15722,
 18400,
 1866,
 207,
 45682,
 47842,
 15113,
 48412,
 18400,
 41907,
 9264,
 11

# Define Dataset

In [38]:
import torch
class ChineseDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_data =ChineseDataset(tokenize_train_data)
val_data =ChineseDataset(tokenize_val_data)
test_data =ChineseDataset(tokenize_test_data)

In [39]:
train_data[0]

{'input_ids': tensor([    0, 47983, 15113,  ...,  8210,  9470,     2]),
 'attention_mask': tensor([1, 1, 1,  ..., 1, 1, 1]),
 'labels': tensor([    0, 47842, 15113, 48412, 18400, 45262, 11582, 49035, 13859, 46499,
         27819, 46015,  2023, 41907,  5543, 14285, 36714,  9264,  3070,    17,
            48, 47994,  8384, 48998,   578, 48105, 42393, 21402, 20024, 48765,
          6800, 46537, 11423, 36714,  9264,  4958, 48991,  7471, 44574, 47994,
         13859, 48635,  3070, 44636,   862,    17,    46, 43251,  4394, 14285,
         37127, 27969,  4958, 41907,    48, 14292,   246, 44636, 23133, 46890,
         10809, 36714,  6382,  4333, 37127, 10674, 49363, 15722, 18537, 48823,
          6800, 43251,  4394, 15113,  5479,  1749, 36484, 16948, 11582, 43251,
          4394, 14285, 47504, 14285, 37127, 10965, 10674, 42393,  7258, 17772,
         41907, 15722, 18400,  1866,   207, 45682,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,    

In [40]:
import logging
import datasets
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange
import math
import gc

import transformers
from accelerate import Accelerator
from transformers import (
    AdamW,
    AutoConfig,
    default_data_collator,
    get_scheduler
)

In [41]:
train_batch_size = 5      # 設定 training batch size
eval_batch_size = 5     # 設定 eval batch size
num_train_epochs = 12      # 設定 epoch

In [42]:
data_collator = default_data_collator
train_dataloader = DataLoader(train_data, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
eval_dataloader = DataLoader(val_data, collate_fn=data_collator, batch_size=eval_batch_size)

In [43]:
print(len(train_dataloader))
print(len(train_data))

3133
15662


# Prepare Training

In [44]:
from transformers import BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

# initialize optimizer
learning_rate=1e-5          # 設定 learning_rate
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

# Scheduler and math around the number of training steps.
gradient_accumulation_steps = 2   # 設定幾步後進行反向傳播

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
max_train_steps = num_train_epochs * num_update_steps_per_epoch
print('max_train_steps', max_train_steps)

# scheduler
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)



max_train_steps 18804


# Train

In [45]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state)
output_dir = 'model_bart/'


total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_articles)}")
logger.info(f"  Num Epochs = {num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {max_train_steps}")


best_epoch = {"epoch": 0, "rouge": 0 }

import evaluate
rouge_score = evaluate.load("rouge")
for epoch in trange(num_train_epochs, desc="Epoch"):
    model.train()
    for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if step % 100 == 0:
            print({'epoch': epoch, 'step': step, 'loss': loss.item()})


    logger.info("\n***** Running eval *****")
    model.eval()
    preds=[]
    answers=[]
    with torch.no_grad():
        for step, batch in enumerate(tqdm(eval_dataloader, desc="Eval Iteration")):
            outputs = model(**batch)
            output_ids = model.generate(batch['input_ids'], num_beams=2, max_length=150, early_stopping=True)
            preds.append(" ".join([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in output_ids]))
            answers.append(" ".join([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in batch['labels']]))
    
    
    scores = rouge_score.compute(
    predictions=preds, references=answers
    )
    print(scores)
    eval_precision=scores["rougeL"]
    logger.info(f"epoch :{epoch} \n rougeL:  {eval_precision}")
    if eval_precision > best_epoch["rouge"]:
        best_epoch['epoch'] = num_train_epochs
        best_epoch['rouge'] = eval_precision
    
    if output_dir is not None and epoch>5:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(output_dir + 'new_epoch_' + str(epoch) + '/', save_function=accelerator.save)

# print(f"best epoch:{best_epoch}")

Epoch:   0%|          | 0/12 [00:00<?, ?it/s]

{'epoch': 0, 'step': 0, 'loss': 12.889228820800781}




{'epoch': 0, 'step': 100, 'loss': 4.848992347717285}




{'epoch': 0, 'step': 200, 'loss': 3.045048475265503}




{'epoch': 0, 'step': 300, 'loss': 2.445564031600952}




{'epoch': 0, 'step': 400, 'loss': 1.097427487373352}




{'epoch': 0, 'step': 500, 'loss': 0.7416189312934875}




{'epoch': 0, 'step': 600, 'loss': 0.405283659696579}




{'epoch': 0, 'step': 700, 'loss': 0.6285397410392761}




{'epoch': 0, 'step': 800, 'loss': 0.5187814235687256}




{'epoch': 0, 'step': 900, 'loss': 0.4985881447792053}




{'epoch': 0, 'step': 1000, 'loss': 0.30311456322669983}




{'epoch': 0, 'step': 1100, 'loss': 0.13909849524497986}




{'epoch': 0, 'step': 1200, 'loss': 0.3101493716239929}




{'epoch': 0, 'step': 1300, 'loss': 0.17347228527069092}




{'epoch': 0, 'step': 1400, 'loss': 0.37995484471321106}




{'epoch': 0, 'step': 1500, 'loss': 0.3979204297065735}




{'epoch': 0, 'step': 1600, 'loss': 0.3643837869167328}




{'epoch': 0, 'step': 1700, 'loss': 0.29693394899368286}




{'epoch': 0, 'step': 1800, 'loss': 0.5244865417480469}




{'epoch': 0, 'step': 1900, 'loss': 0.15014050900936127}




{'epoch': 0, 'step': 2000, 'loss': 0.3011181354522705}




{'epoch': 0, 'step': 2100, 'loss': 0.3289421796798706}




{'epoch': 0, 'step': 2200, 'loss': 0.09713345021009445}




{'epoch': 0, 'step': 2300, 'loss': 0.43208181858062744}




{'epoch': 0, 'step': 2400, 'loss': 0.14751921594142914}




{'epoch': 0, 'step': 2500, 'loss': 0.06378649175167084}




{'epoch': 0, 'step': 2600, 'loss': 0.2023579478263855}




{'epoch': 0, 'step': 2700, 'loss': 0.21351391077041626}




{'epoch': 0, 'step': 2800, 'loss': 0.08588997274637222}




{'epoch': 0, 'step': 2900, 'loss': 0.16666048765182495}




{'epoch': 0, 'step': 3000, 'loss': 0.07759515196084976}




{'epoch': 0, 'step': 3100, 'loss': 0.08021101355552673}


Iteration: 100%|██████████| 3133/3133 [22:43<00:00,  2.30it/s]
Eval Iteration: 100%|██████████| 784/784 [13:46<00:00,  1.05s/it]
Epoch:   8%|▊         | 1/12 [36:30<6:41:31, 2190.11s/it]

{'rouge1': 0.3202266449909673, 'rouge2': 0.1731223274009348, 'rougeL': 0.31451807965172074, 'rougeLsum': 0.31562053997831235}




{'epoch': 1, 'step': 0, 'loss': 0.28170251846313477}




{'epoch': 1, 'step': 100, 'loss': 0.07619999349117279}




{'epoch': 1, 'step': 200, 'loss': 0.16184821724891663}




{'epoch': 1, 'step': 300, 'loss': 0.17447973787784576}




{'epoch': 1, 'step': 400, 'loss': 0.06714515388011932}




{'epoch': 1, 'step': 500, 'loss': 0.09379506856203079}




{'epoch': 1, 'step': 600, 'loss': 0.07512333244085312}




{'epoch': 1, 'step': 700, 'loss': 0.19536803662776947}




{'epoch': 1, 'step': 800, 'loss': 0.16478408873081207}




{'epoch': 1, 'step': 900, 'loss': 0.5126579999923706}




{'epoch': 1, 'step': 1000, 'loss': 0.11856036633253098}




{'epoch': 1, 'step': 1100, 'loss': 0.08785701543092728}




{'epoch': 1, 'step': 1200, 'loss': 0.08683167397975922}




{'epoch': 1, 'step': 1300, 'loss': 0.15304253995418549}




{'epoch': 1, 'step': 1400, 'loss': 0.03198062628507614}




{'epoch': 1, 'step': 1500, 'loss': 0.21564503014087677}




{'epoch': 1, 'step': 1600, 'loss': 0.06247692182660103}




{'epoch': 1, 'step': 1700, 'loss': 0.21011409163475037}




{'epoch': 1, 'step': 1800, 'loss': 0.28388023376464844}




{'epoch': 1, 'step': 1900, 'loss': 0.09377554804086685}




{'epoch': 1, 'step': 2000, 'loss': 0.18689505755901337}




{'epoch': 1, 'step': 2100, 'loss': 0.04069909825921059}




{'epoch': 1, 'step': 2200, 'loss': 0.3868943452835083}




{'epoch': 1, 'step': 2300, 'loss': 0.11891845613718033}




{'epoch': 1, 'step': 2400, 'loss': 0.10124387592077255}




{'epoch': 1, 'step': 2500, 'loss': 0.16596007347106934}




{'epoch': 1, 'step': 2600, 'loss': 0.2923198938369751}




{'epoch': 1, 'step': 2700, 'loss': 0.27554580569267273}




{'epoch': 1, 'step': 2800, 'loss': 0.564484179019928}




{'epoch': 1, 'step': 2900, 'loss': 0.13940295577049255}




{'epoch': 1, 'step': 3000, 'loss': 0.3687622547149658}




{'epoch': 1, 'step': 3100, 'loss': 0.05686414986848831}


Iteration: 100%|██████████| 3133/3133 [22:48<00:00,  2.29it/s]
Eval Iteration: 100%|██████████| 784/784 [12:35<00:00,  1.04it/s]
Epoch:  17%|█▋        | 2/12 [1:11:54<5:58:32, 2151.27s/it]

{'rouge1': 0.3338443567738659, 'rouge2': 0.17233873693385, 'rougeL': 0.3270126270520246, 'rougeLsum': 0.32736149574147755}




{'epoch': 2, 'step': 0, 'loss': 0.2810678780078888}




{'epoch': 2, 'step': 100, 'loss': 0.41227343678474426}




{'epoch': 2, 'step': 200, 'loss': 0.10326669365167618}




{'epoch': 2, 'step': 300, 'loss': 0.11005012691020966}




{'epoch': 2, 'step': 400, 'loss': 0.035861190408468246}




{'epoch': 2, 'step': 500, 'loss': 0.09151492267847061}




{'epoch': 2, 'step': 600, 'loss': 0.09536270797252655}




{'epoch': 2, 'step': 700, 'loss': 0.22248610854148865}




{'epoch': 2, 'step': 800, 'loss': 0.07962332665920258}




{'epoch': 2, 'step': 900, 'loss': 0.0494743175804615}




{'epoch': 2, 'step': 1000, 'loss': 0.574658989906311}




{'epoch': 2, 'step': 1100, 'loss': 0.24466300010681152}




{'epoch': 2, 'step': 1200, 'loss': 0.2402663677930832}




{'epoch': 2, 'step': 1300, 'loss': 0.6783084869384766}




{'epoch': 2, 'step': 1400, 'loss': 0.3384923040866852}




{'epoch': 2, 'step': 1500, 'loss': 0.3381820023059845}




{'epoch': 2, 'step': 1600, 'loss': 0.14836062490940094}




{'epoch': 2, 'step': 1700, 'loss': 0.1470152735710144}




{'epoch': 2, 'step': 1800, 'loss': 0.06603971123695374}




{'epoch': 2, 'step': 1900, 'loss': 0.25703078508377075}




{'epoch': 2, 'step': 2000, 'loss': 0.20344021916389465}




{'epoch': 2, 'step': 2100, 'loss': 0.05327630043029785}




{'epoch': 2, 'step': 2200, 'loss': 0.338457316160202}




{'epoch': 2, 'step': 2300, 'loss': 0.08273160457611084}




{'epoch': 2, 'step': 2400, 'loss': 0.08490206301212311}




{'epoch': 2, 'step': 2500, 'loss': 0.41270849108695984}




{'epoch': 2, 'step': 2600, 'loss': 0.07835432142019272}




{'epoch': 2, 'step': 2700, 'loss': 0.021220659837126732}




{'epoch': 2, 'step': 2800, 'loss': 0.17357461154460907}




{'epoch': 2, 'step': 2900, 'loss': 0.12679043412208557}




{'epoch': 2, 'step': 3000, 'loss': 0.42500177025794983}




{'epoch': 2, 'step': 3100, 'loss': 0.05095455050468445}


Iteration: 100%|██████████| 3133/3133 [22:49<00:00,  2.29it/s]
Eval Iteration: 100%|██████████| 784/784 [12:41<00:00,  1.03it/s]
Epoch:  25%|██▌       | 3/12 [1:47:26<5:21:21, 2142.40s/it]

{'rouge1': 0.3376609059307577, 'rouge2': 0.17264651556669386, 'rougeL': 0.3308700669037271, 'rougeLsum': 0.33097121712587346}




{'epoch': 3, 'step': 0, 'loss': 0.2272702306509018}




{'epoch': 3, 'step': 100, 'loss': 0.32647451758384705}




{'epoch': 3, 'step': 200, 'loss': 0.17710290849208832}




{'epoch': 3, 'step': 300, 'loss': 0.20669859647750854}




{'epoch': 3, 'step': 400, 'loss': 0.2304343432188034}




{'epoch': 3, 'step': 500, 'loss': 0.11456966400146484}




{'epoch': 3, 'step': 600, 'loss': 0.42438653111457825}




{'epoch': 3, 'step': 700, 'loss': 0.07273518294095993}




{'epoch': 3, 'step': 800, 'loss': 0.22723543643951416}




{'epoch': 3, 'step': 900, 'loss': 0.0780196338891983}




{'epoch': 3, 'step': 1000, 'loss': 0.06972237676382065}




{'epoch': 3, 'step': 1100, 'loss': 0.01895528845489025}




{'epoch': 3, 'step': 1200, 'loss': 0.29835838079452515}




{'epoch': 3, 'step': 1300, 'loss': 0.13243526220321655}




{'epoch': 3, 'step': 1400, 'loss': 0.20377250015735626}




{'epoch': 3, 'step': 1500, 'loss': 0.12104156613349915}




{'epoch': 3, 'step': 1600, 'loss': 0.14324980974197388}




{'epoch': 3, 'step': 1700, 'loss': 0.1473003476858139}




{'epoch': 3, 'step': 1800, 'loss': 0.10230182111263275}




{'epoch': 3, 'step': 1900, 'loss': 0.07654877752065659}




{'epoch': 3, 'step': 2000, 'loss': 0.1330135613679886}




{'epoch': 3, 'step': 2100, 'loss': 0.058559514582157135}




{'epoch': 3, 'step': 2200, 'loss': 0.22747811675071716}




{'epoch': 3, 'step': 2300, 'loss': 0.415569931268692}




{'epoch': 3, 'step': 2400, 'loss': 0.18141832947731018}




{'epoch': 3, 'step': 2500, 'loss': 0.07220377027988434}




{'epoch': 3, 'step': 2600, 'loss': 0.06304981559515}




{'epoch': 3, 'step': 2700, 'loss': 0.4319990277290344}




{'epoch': 3, 'step': 2800, 'loss': 0.36598941683769226}




{'epoch': 3, 'step': 2900, 'loss': 0.37332770228385925}




{'epoch': 3, 'step': 3000, 'loss': 0.21188116073608398}




{'epoch': 3, 'step': 3100, 'loss': 0.3323938250541687}


Iteration: 100%|██████████| 3133/3133 [22:53<00:00,  2.28it/s]
Eval Iteration: 100%|██████████| 784/784 [12:19<00:00,  1.06it/s]
Epoch:  33%|███▎      | 4/12 [2:22:39<4:44:07, 2130.94s/it]

{'rouge1': 0.34977738960064864, 'rouge2': 0.18246614831842536, 'rougeL': 0.3444768363990038, 'rougeLsum': 0.3448032465638634}




{'epoch': 4, 'step': 0, 'loss': 0.10007220506668091}




{'epoch': 4, 'step': 100, 'loss': 0.09779183566570282}




{'epoch': 4, 'step': 200, 'loss': 0.7881000638008118}




{'epoch': 4, 'step': 300, 'loss': 0.12749014794826508}




{'epoch': 4, 'step': 400, 'loss': 0.2343532145023346}




{'epoch': 4, 'step': 500, 'loss': 0.2968503534793854}




{'epoch': 4, 'step': 600, 'loss': 0.10090240836143494}




{'epoch': 4, 'step': 700, 'loss': 0.3399016559123993}




{'epoch': 4, 'step': 800, 'loss': 0.11228817701339722}




{'epoch': 4, 'step': 900, 'loss': 0.2692009508609772}




{'epoch': 4, 'step': 1000, 'loss': 0.1484474539756775}




{'epoch': 4, 'step': 1100, 'loss': 0.6076272130012512}




{'epoch': 4, 'step': 1200, 'loss': 0.09449080377817154}




{'epoch': 4, 'step': 1300, 'loss': 0.29226192831993103}




{'epoch': 4, 'step': 1400, 'loss': 0.0946866050362587}




{'epoch': 4, 'step': 1500, 'loss': 0.3113085627555847}




{'epoch': 4, 'step': 1600, 'loss': 0.16471029818058014}




{'epoch': 4, 'step': 1700, 'loss': 0.13222424685955048}




{'epoch': 4, 'step': 1800, 'loss': 0.1696138232946396}




{'epoch': 4, 'step': 1900, 'loss': 0.1617196798324585}




{'epoch': 4, 'step': 2000, 'loss': 0.11649040877819061}




{'epoch': 4, 'step': 2100, 'loss': 0.15478861331939697}




{'epoch': 4, 'step': 2200, 'loss': 0.1080678179860115}




{'epoch': 4, 'step': 2300, 'loss': 0.1514541059732437}




{'epoch': 4, 'step': 2400, 'loss': 0.3829285502433777}




{'epoch': 4, 'step': 2500, 'loss': 0.21576108038425446}




{'epoch': 4, 'step': 2600, 'loss': 0.24179181456565857}




{'epoch': 4, 'step': 2700, 'loss': 0.017901044338941574}




{'epoch': 4, 'step': 2800, 'loss': 0.024075962603092194}




{'epoch': 4, 'step': 2900, 'loss': 0.08580358326435089}




{'epoch': 4, 'step': 3000, 'loss': 0.05558772757649422}




{'epoch': 4, 'step': 3100, 'loss': 0.04003819078207016}


Iteration: 100%|██████████| 3133/3133 [22:52<00:00,  2.28it/s]
Eval Iteration: 100%|██████████| 784/784 [13:02<00:00,  1.00it/s]
Epoch:  42%|████▏     | 5/12 [2:58:34<4:09:38, 2139.76s/it]

{'rouge1': 0.37273158568614845, 'rouge2': 0.20115341823120608, 'rougeL': 0.36801143476842113, 'rougeLsum': 0.36828329987042563}




{'epoch': 5, 'step': 0, 'loss': 0.11311618238687515}




{'epoch': 5, 'step': 100, 'loss': 0.06562480330467224}




{'epoch': 5, 'step': 200, 'loss': 0.04704089090228081}




{'epoch': 5, 'step': 300, 'loss': 0.034588828682899475}




{'epoch': 5, 'step': 400, 'loss': 0.03715861216187477}




{'epoch': 5, 'step': 500, 'loss': 0.545900821685791}




{'epoch': 5, 'step': 600, 'loss': 0.0071481321938335896}




{'epoch': 5, 'step': 700, 'loss': 0.09671279788017273}




{'epoch': 5, 'step': 800, 'loss': 0.04199286550283432}




{'epoch': 5, 'step': 900, 'loss': 0.04029034078121185}




{'epoch': 5, 'step': 1000, 'loss': 0.3535187244415283}




{'epoch': 5, 'step': 1100, 'loss': 0.48773565888404846}




{'epoch': 5, 'step': 1200, 'loss': 0.15447388589382172}




{'epoch': 5, 'step': 1300, 'loss': 0.05225436016917229}




{'epoch': 5, 'step': 1400, 'loss': 0.17134901881217957}




{'epoch': 5, 'step': 1500, 'loss': 0.07126539945602417}




{'epoch': 5, 'step': 1600, 'loss': 0.27760955691337585}




{'epoch': 5, 'step': 1700, 'loss': 0.09076511859893799}




{'epoch': 5, 'step': 1800, 'loss': 0.021447692066431046}




{'epoch': 5, 'step': 1900, 'loss': 0.12552079558372498}




{'epoch': 5, 'step': 2000, 'loss': 0.0941518023610115}




{'epoch': 5, 'step': 2100, 'loss': 0.23126542568206787}




{'epoch': 5, 'step': 2200, 'loss': 0.18778786063194275}




{'epoch': 5, 'step': 2300, 'loss': 0.12806788086891174}




{'epoch': 5, 'step': 2400, 'loss': 0.08227572590112686}




{'epoch': 5, 'step': 2500, 'loss': 0.28860989212989807}




{'epoch': 5, 'step': 2600, 'loss': 0.3261863589286804}




{'epoch': 5, 'step': 2700, 'loss': 0.14407669007778168}




{'epoch': 5, 'step': 2800, 'loss': 0.7586187124252319}




{'epoch': 5, 'step': 2900, 'loss': 0.30822381377220154}




{'epoch': 5, 'step': 3000, 'loss': 0.05927598476409912}




{'epoch': 5, 'step': 3100, 'loss': 0.07916800677776337}


Iteration: 100%|██████████| 3133/3133 [22:51<00:00,  2.28it/s]
Eval Iteration: 100%|██████████| 784/784 [12:57<00:00,  1.01it/s]
Epoch:  50%|█████     | 6/12 [3:34:23<3:34:17, 2142.92s/it]

{'rouge1': 0.37864351505747285, 'rouge2': 0.2103563331041365, 'rougeL': 0.3730174215363994, 'rougeLsum': 0.37308553319130766}




{'epoch': 6, 'step': 0, 'loss': 0.13046637177467346}




{'epoch': 6, 'step': 100, 'loss': 0.1458544284105301}




{'epoch': 6, 'step': 200, 'loss': 0.027803804725408554}




{'epoch': 6, 'step': 300, 'loss': 0.05549287796020508}




{'epoch': 6, 'step': 400, 'loss': 0.22771483659744263}




{'epoch': 6, 'step': 500, 'loss': 0.07787086069583893}




{'epoch': 6, 'step': 600, 'loss': 0.09776908159255981}




{'epoch': 6, 'step': 700, 'loss': 0.1110314279794693}




{'epoch': 6, 'step': 800, 'loss': 0.2570797801017761}




{'epoch': 6, 'step': 900, 'loss': 0.14705049991607666}




{'epoch': 6, 'step': 1000, 'loss': 0.32149967551231384}




{'epoch': 6, 'step': 1100, 'loss': 0.3119416832923889}




{'epoch': 6, 'step': 1200, 'loss': 0.15134170651435852}




{'epoch': 6, 'step': 1300, 'loss': 0.02920793555676937}




{'epoch': 6, 'step': 1400, 'loss': 0.08037473261356354}




{'epoch': 6, 'step': 1500, 'loss': 0.07737533748149872}




{'epoch': 6, 'step': 1600, 'loss': 0.1455351859331131}




{'epoch': 6, 'step': 1700, 'loss': 0.3174206018447876}




{'epoch': 6, 'step': 1800, 'loss': 0.13904038071632385}




{'epoch': 6, 'step': 1900, 'loss': 0.14677686989307404}




{'epoch': 6, 'step': 2000, 'loss': 0.1099444180727005}




{'epoch': 6, 'step': 2100, 'loss': 0.08184224367141724}




{'epoch': 6, 'step': 2200, 'loss': 0.28059834241867065}




{'epoch': 6, 'step': 2300, 'loss': 0.12967969477176666}




{'epoch': 6, 'step': 2400, 'loss': 0.5232438445091248}




{'epoch': 6, 'step': 2500, 'loss': 0.01894252933561802}




{'epoch': 6, 'step': 2600, 'loss': 0.16345682740211487}




{'epoch': 6, 'step': 2700, 'loss': 0.1960647702217102}




{'epoch': 6, 'step': 2800, 'loss': 0.15069624781608582}




{'epoch': 6, 'step': 2900, 'loss': 0.22282221913337708}




{'epoch': 6, 'step': 3000, 'loss': 0.24546101689338684}




{'epoch': 6, 'step': 3100, 'loss': 0.19134339690208435}


Iteration: 100%|██████████| 3133/3133 [22:49<00:00,  2.29it/s]
Eval Iteration: 100%|██████████| 784/784 [12:51<00:00,  1.02it/s]


{'rouge1': 0.37864351505747285, 'rouge2': 0.2103563331041365, 'rougeL': 0.3730174215363994, 'rougeLsum': 0.37308553319130766}


Epoch:  58%|█████▊    | 7/12 [4:10:12<2:58:43, 2144.73s/it]

{'epoch': 7, 'step': 0, 'loss': 0.07034608721733093}




{'epoch': 7, 'step': 100, 'loss': 0.06936048716306686}




{'epoch': 7, 'step': 200, 'loss': 0.3558166027069092}




{'epoch': 7, 'step': 300, 'loss': 0.205229714512825}




{'epoch': 7, 'step': 400, 'loss': 0.16962802410125732}




{'epoch': 7, 'step': 500, 'loss': 0.19602124392986298}




{'epoch': 7, 'step': 600, 'loss': 0.3901159167289734}




{'epoch': 7, 'step': 700, 'loss': 0.21695683896541595}




{'epoch': 7, 'step': 800, 'loss': 0.2397952824831009}




{'epoch': 7, 'step': 900, 'loss': 0.10553165525197983}




{'epoch': 7, 'step': 1000, 'loss': 0.1046813502907753}




{'epoch': 7, 'step': 1100, 'loss': 0.08302082121372223}




{'epoch': 7, 'step': 1200, 'loss': 0.059668418020009995}




{'epoch': 7, 'step': 1300, 'loss': 0.14146234095096588}




{'epoch': 7, 'step': 1400, 'loss': 0.1576349288225174}




{'epoch': 7, 'step': 1500, 'loss': 0.3132774829864502}




{'epoch': 7, 'step': 1600, 'loss': 0.129167839884758}




{'epoch': 7, 'step': 1700, 'loss': 0.27413052320480347}




{'epoch': 7, 'step': 1800, 'loss': 0.16408702731132507}




{'epoch': 7, 'step': 1900, 'loss': 0.17399056255817413}




{'epoch': 7, 'step': 2000, 'loss': 0.05288201570510864}




{'epoch': 7, 'step': 2100, 'loss': 0.10895847529172897}




{'epoch': 7, 'step': 2200, 'loss': 0.09831514209508896}




{'epoch': 7, 'step': 2300, 'loss': 0.01664067432284355}




{'epoch': 7, 'step': 2400, 'loss': 0.31888216733932495}




{'epoch': 7, 'step': 2500, 'loss': 0.07260478287935257}




{'epoch': 7, 'step': 2600, 'loss': 0.09539788961410522}




{'epoch': 7, 'step': 2700, 'loss': 0.06405709683895111}




{'epoch': 7, 'step': 2800, 'loss': 0.14988528192043304}




{'epoch': 7, 'step': 2900, 'loss': 0.2871689200401306}




{'epoch': 7, 'step': 3000, 'loss': 0.14934749901294708}




{'epoch': 7, 'step': 3100, 'loss': 0.09113061428070068}


Iteration: 100%|██████████| 3133/3133 [22:40<00:00,  2.30it/s]
Eval Iteration: 100%|██████████| 784/784 [12:54<00:00,  1.01it/s]


{'rouge1': 0.37864351505747285, 'rouge2': 0.2103563331041365, 'rougeL': 0.3730174215363994, 'rougeLsum': 0.37308553319130766}


Epoch:  67%|██████▋   | 8/12 [4:45:53<2:22:54, 2143.66s/it]

{'epoch': 8, 'step': 0, 'loss': 0.09128756821155548}




{'epoch': 8, 'step': 100, 'loss': 0.1092446818947792}




{'epoch': 8, 'step': 200, 'loss': 0.04242466762661934}




{'epoch': 8, 'step': 300, 'loss': 0.1266295313835144}




{'epoch': 8, 'step': 400, 'loss': 0.12225741147994995}




{'epoch': 8, 'step': 500, 'loss': 0.20009391009807587}




{'epoch': 8, 'step': 600, 'loss': 0.051290515810251236}




{'epoch': 8, 'step': 700, 'loss': 0.39788582921028137}




{'epoch': 8, 'step': 800, 'loss': 0.1483735293149948}




{'epoch': 8, 'step': 900, 'loss': 0.29253941774368286}




{'epoch': 8, 'step': 1000, 'loss': 0.5014610290527344}




{'epoch': 8, 'step': 1100, 'loss': 0.016470571979880333}




{'epoch': 8, 'step': 1200, 'loss': 0.04947686195373535}




{'epoch': 8, 'step': 1300, 'loss': 0.13360914587974548}




{'epoch': 8, 'step': 1400, 'loss': 0.0951535627245903}




{'epoch': 8, 'step': 1500, 'loss': 0.13516369462013245}




{'epoch': 8, 'step': 1600, 'loss': 0.21236975491046906}




{'epoch': 8, 'step': 1700, 'loss': 0.3434731662273407}




{'epoch': 8, 'step': 1800, 'loss': 0.08816797286272049}




{'epoch': 8, 'step': 1900, 'loss': 0.171311154961586}




{'epoch': 8, 'step': 2000, 'loss': 0.2316046953201294}




{'epoch': 8, 'step': 2100, 'loss': 0.2805195748806}




{'epoch': 8, 'step': 2200, 'loss': 0.38593411445617676}




{'epoch': 8, 'step': 2300, 'loss': 0.534032940864563}




{'epoch': 8, 'step': 2400, 'loss': 0.026824580505490303}




{'epoch': 8, 'step': 2500, 'loss': 0.11533966660499573}




{'epoch': 8, 'step': 2600, 'loss': 0.120765320956707}




{'epoch': 8, 'step': 2700, 'loss': 0.19434499740600586}




{'epoch': 8, 'step': 2800, 'loss': 0.2065795212984085}




{'epoch': 8, 'step': 2900, 'loss': 0.05719532072544098}




{'epoch': 8, 'step': 3000, 'loss': 0.09900042414665222}




{'epoch': 8, 'step': 3100, 'loss': 0.14217552542686462}


Iteration: 100%|██████████| 3133/3133 [22:39<00:00,  2.31it/s]
Eval Iteration: 100%|██████████| 784/784 [12:53<00:00,  1.01it/s]


{'rouge1': 0.37864351505747285, 'rouge2': 0.2103563331041365, 'rougeL': 0.3730174215363994, 'rougeLsum': 0.37308553319130766}


Epoch:  75%|███████▌  | 9/12 [5:21:32<1:47:06, 2142.21s/it]

{'epoch': 9, 'step': 0, 'loss': 0.3296232223510742}




{'epoch': 9, 'step': 100, 'loss': 0.10217148065567017}




{'epoch': 9, 'step': 200, 'loss': 0.18546509742736816}




{'epoch': 9, 'step': 300, 'loss': 0.22796395421028137}




{'epoch': 9, 'step': 400, 'loss': 0.017187345772981644}




{'epoch': 9, 'step': 500, 'loss': 0.10194041579961777}




{'epoch': 9, 'step': 600, 'loss': 0.11842207610607147}




{'epoch': 9, 'step': 700, 'loss': 0.1267170011997223}




{'epoch': 9, 'step': 800, 'loss': 0.0956677496433258}




{'epoch': 9, 'step': 900, 'loss': 0.07317019253969193}




{'epoch': 9, 'step': 1000, 'loss': 0.327668696641922}




{'epoch': 9, 'step': 1100, 'loss': 0.10578616708517075}




{'epoch': 9, 'step': 1200, 'loss': 0.11260948330163956}




{'epoch': 9, 'step': 1300, 'loss': 0.3232831060886383}




{'epoch': 9, 'step': 1400, 'loss': 0.11540959775447845}




{'epoch': 9, 'step': 1500, 'loss': 0.0996178537607193}




{'epoch': 9, 'step': 1600, 'loss': 0.4793059229850769}




{'epoch': 9, 'step': 1700, 'loss': 0.019216788932681084}




{'epoch': 9, 'step': 1800, 'loss': 0.12269196659326553}




{'epoch': 9, 'step': 1900, 'loss': 0.3356671929359436}




{'epoch': 9, 'step': 2000, 'loss': 0.11338835209608078}




{'epoch': 9, 'step': 2100, 'loss': 0.3380527198314667}




{'epoch': 9, 'step': 2200, 'loss': 0.15992727875709534}




{'epoch': 9, 'step': 2300, 'loss': 0.18534179031848907}




{'epoch': 9, 'step': 2400, 'loss': 0.03757226839661598}




{'epoch': 9, 'step': 2500, 'loss': 0.17842812836170197}




{'epoch': 9, 'step': 2600, 'loss': 0.09059816598892212}




{'epoch': 9, 'step': 2700, 'loss': 0.1288367062807083}




{'epoch': 9, 'step': 2800, 'loss': 0.32323941588401794}




{'epoch': 9, 'step': 2900, 'loss': 0.053478922694921494}




{'epoch': 9, 'step': 3000, 'loss': 0.1964823603630066}




{'epoch': 9, 'step': 3100, 'loss': 0.02756020799279213}


Iteration: 100%|██████████| 3133/3133 [22:39<00:00,  2.31it/s]
Eval Iteration: 100%|██████████| 784/784 [12:49<00:00,  1.02it/s]


{'rouge1': 0.37864351505747285, 'rouge2': 0.2103563331041365, 'rougeL': 0.3730174215363994, 'rougeLsum': 0.37308553319130766}


Epoch:  83%|████████▎ | 10/12 [5:57:07<1:11:19, 2139.99s/it]

{'epoch': 10, 'step': 0, 'loss': 0.14962710440158844}




{'epoch': 10, 'step': 100, 'loss': 0.2836413085460663}




{'epoch': 10, 'step': 200, 'loss': 0.06143704801797867}




{'epoch': 10, 'step': 300, 'loss': 0.11750835925340652}




{'epoch': 10, 'step': 400, 'loss': 0.1875179260969162}




{'epoch': 10, 'step': 500, 'loss': 0.05627031996846199}




{'epoch': 10, 'step': 600, 'loss': 0.15802554786205292}




{'epoch': 10, 'step': 700, 'loss': 0.07149188220500946}




{'epoch': 10, 'step': 800, 'loss': 0.08709468692541122}




{'epoch': 10, 'step': 900, 'loss': 0.3212118148803711}




{'epoch': 10, 'step': 1000, 'loss': 0.09255184233188629}




{'epoch': 10, 'step': 1100, 'loss': 0.24303753674030304}




{'epoch': 10, 'step': 1200, 'loss': 0.09189732372760773}




{'epoch': 10, 'step': 1300, 'loss': 0.07502859830856323}




{'epoch': 10, 'step': 1400, 'loss': 0.054900746792554855}




{'epoch': 10, 'step': 1500, 'loss': 0.24858155846595764}




{'epoch': 10, 'step': 1600, 'loss': 0.043185800313949585}




{'epoch': 10, 'step': 1700, 'loss': 0.43352484703063965}




{'epoch': 10, 'step': 1800, 'loss': 0.05389132350683212}




{'epoch': 10, 'step': 1900, 'loss': 0.0462263748049736}




{'epoch': 10, 'step': 2000, 'loss': 0.2887672185897827}




{'epoch': 10, 'step': 2100, 'loss': 0.07673341780900955}




{'epoch': 10, 'step': 2200, 'loss': 0.2031835913658142}




{'epoch': 10, 'step': 2300, 'loss': 0.06927520781755447}




{'epoch': 10, 'step': 2400, 'loss': 0.16037926077842712}




{'epoch': 10, 'step': 2500, 'loss': 0.1878155618906021}




{'epoch': 10, 'step': 2600, 'loss': 0.1462363451719284}




{'epoch': 10, 'step': 2700, 'loss': 0.078933946788311}




{'epoch': 10, 'step': 2800, 'loss': 0.10949352383613586}




{'epoch': 10, 'step': 2900, 'loss': 0.12258151173591614}




{'epoch': 10, 'step': 3000, 'loss': 0.07569026201963425}




{'epoch': 10, 'step': 3100, 'loss': 0.26778337359428406}


Iteration: 100%|██████████| 3133/3133 [22:38<00:00,  2.31it/s]
Eval Iteration: 100%|██████████| 784/784 [12:52<00:00,  1.01it/s]


{'rouge1': 0.37864351505747285, 'rouge2': 0.2103563331041365, 'rougeL': 0.3730174215363994, 'rougeLsum': 0.37308553319130766}


Epoch:  92%|█████████▏| 11/12 [6:32:45<35:39, 2139.24s/it]  

{'epoch': 11, 'step': 0, 'loss': 0.06531437486410141}




{'epoch': 11, 'step': 100, 'loss': 0.20427097380161285}




{'epoch': 11, 'step': 200, 'loss': 0.09906531870365143}




{'epoch': 11, 'step': 300, 'loss': 0.0653161108493805}




{'epoch': 11, 'step': 400, 'loss': 0.05994882807135582}




{'epoch': 11, 'step': 500, 'loss': 0.0795726552605629}




{'epoch': 11, 'step': 600, 'loss': 0.14565376937389374}




{'epoch': 11, 'step': 700, 'loss': 0.07344433665275574}




{'epoch': 11, 'step': 800, 'loss': 0.10476371645927429}




{'epoch': 11, 'step': 900, 'loss': 0.08474799990653992}




{'epoch': 11, 'step': 1000, 'loss': 0.06301656365394592}




{'epoch': 11, 'step': 1100, 'loss': 0.049801748245954514}




{'epoch': 11, 'step': 1200, 'loss': 0.043753642588853836}




{'epoch': 11, 'step': 1300, 'loss': 0.2638190984725952}




{'epoch': 11, 'step': 1400, 'loss': 0.23056364059448242}




{'epoch': 11, 'step': 1500, 'loss': 0.10152462124824524}




{'epoch': 11, 'step': 1600, 'loss': 0.06498575210571289}




{'epoch': 11, 'step': 1700, 'loss': 0.15279459953308105}




{'epoch': 11, 'step': 1800, 'loss': 0.17894504964351654}




{'epoch': 11, 'step': 1900, 'loss': 0.07056690007448196}




{'epoch': 11, 'step': 2000, 'loss': 0.16970187425613403}




{'epoch': 11, 'step': 2100, 'loss': 0.07987633347511292}




{'epoch': 11, 'step': 2200, 'loss': 0.10932459682226181}




{'epoch': 11, 'step': 2300, 'loss': 0.06768417358398438}




{'epoch': 11, 'step': 2400, 'loss': 0.21936504542827606}




{'epoch': 11, 'step': 2500, 'loss': 0.11308334767818451}




{'epoch': 11, 'step': 2600, 'loss': 0.14592231810092926}




{'epoch': 11, 'step': 2700, 'loss': 0.005872334819287062}




{'epoch': 11, 'step': 2800, 'loss': 0.1788068264722824}




{'epoch': 11, 'step': 2900, 'loss': 0.05936203524470329}




{'epoch': 11, 'step': 3000, 'loss': 0.22672352194786072}




{'epoch': 11, 'step': 3100, 'loss': 0.05551273375749588}


Iteration: 100%|██████████| 3133/3133 [22:38<00:00,  2.31it/s]
Eval Iteration: 100%|██████████| 784/784 [12:48<00:00,  1.02it/s]


{'rouge1': 0.37864351505747285, 'rouge2': 0.2103563331041365, 'rougeL': 0.3730174215363994, 'rougeLsum': 0.37308553319130766}


Epoch: 100%|██████████| 12/12 [7:08:18<00:00, 2141.58s/it]


In [46]:
print(f"best epoch:{best_epoch}")

best epoch:{'epoch': 12, 'rouge': 0.3730174215363994}


# Test

In [47]:
test_dataloader = DataLoader(test_data, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)

In [48]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [49]:
from transformers import BartTokenizer, BartForConditionalGeneration ,BartConfig
from tqdm.auto import tqdm



config = BartConfig.from_pretrained("/user_data/text_summarization/model_bart/new_epoch_11/config.json")
model = BartForConditionalGeneration.from_pretrained("/user_data/text_summarization/model_bart/new_epoch_11/model.safetensors", config=config)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

model = model.to(device)
def Summarize_model(model, text):

    input_encodings = tokenizer(
                                    [text],
                                    max_length=1024,
                                    return_tensors="pt",
                                    truncation=True
                                ).to(device)


    output_ids = model.generate(
        **input_encodings, 
        num_beams=1, 
        max_length=150, 
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=False,
        no_repeat_ngram_size= 5, 
        num_return_sequences=1
    )

    return output_ids

In [50]:
# 假設您已經有 test_articles 和 test_summaries 變數，它們是待測試的文章和摘要
# 還有您之前定義好的 Summarize_model 函式

model.eval()  # 設置模型為評估模式

# 存儲模型生成的摘要和參考摘要
generated_summaries = []
reference_summaries = []

with torch.no_grad():  # 確保在測試過程中不會計算梯度
    for i, (article, reference_summary) in enumerate(tqdm(zip(test_articles, test_summaries), total=len(test_articles))):
        # 生成摘要
        output_ids = Summarize_model(model, article)
        
        # 解碼模型生成的 token IDs 為文本
        generated_summary = [tokenizer.decode(output_id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for output_id in output_ids]
        
        # 儲存生成的摘要和參考摘要
        generated_summaries.extend(generated_summary)
        reference_summaries.append(reference_summary)  # 假設參考摘要是一個字符串

        # 每測完 100 筆資料顯示一個檢查點
        if (i + 1) % 500 == 0:
            print(f"Processed {i + 1} samples")

import evaluate
rouge_score = evaluate.load("rouge")

scores = rouge_score.compute(
predictions=generated_summaries, references=reference_summaries
)
print(scores)
eval_precision=scores["rougeL"]
print(f'ROUGE-L score: {eval_precision}')

 10%|█         | 500/4895 [05:21<47:40,  1.54it/s]

Processed 500 samples


 20%|██        | 1000/4895 [10:43<40:45,  1.59it/s]

Processed 1000 samples


 31%|███       | 1500/4895 [16:04<37:40,  1.50it/s]

Processed 1500 samples


 41%|████      | 2000/4895 [21:28<32:37,  1.48it/s]

Processed 2000 samples


 51%|█████     | 2500/4895 [26:50<25:55,  1.54it/s]

Processed 2500 samples


 61%|██████▏   | 3000/4895 [32:13<20:08,  1.57it/s]

Processed 3000 samples


 72%|███████▏  | 3500/4895 [37:35<15:10,  1.53it/s]

Processed 3500 samples


 82%|████████▏ | 4000/4895 [42:56<09:33,  1.56it/s]

Processed 4000 samples


 92%|█████████▏| 4500/4895 [48:18<04:15,  1.54it/s]

Processed 4500 samples


100%|██████████| 4895/4895 [52:31<00:00,  1.55it/s]


{'rouge1': 0.14017096609634533, 'rouge2': 0.0437449856022438, 'rougeL': 0.13874884411076938, 'rougeLsum': 0.13864855563719203}
ROUGE-L score: 0.13874884411076938


In [51]:
import evaluate
rouge_score = evaluate.load("rouge")

scores = rouge_score.compute(
predictions=generated_summaries, references=reference_summaries
)
print(scores)
eval_precision=scores["rougeL"]
print(f'ROUGE-L score: {eval_precision}')

{'rouge1': 0.14017096609634533, 'rouge2': 0.0437449856022438, 'rougeL': 0.13874884411076938, 'rougeLsum': 0.13864855563719203}
ROUGE-L score: 0.13874884411076938


# Inference

In [53]:
from transformers import BartTokenizer, BartForConditionalGeneration ,BartConfig


config = BartConfig.from_pretrained("/user_data/text_summarization/model_bart/new_epoch_11/config.json")
model = BartForConditionalGeneration.from_pretrained("/user_data/text_summarization/model_bart/new_epoch_11/model.safetensors", config=config)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")


def Summarize_model(model, text):

    input_encodings = tokenizer(
                                    [text],
                                    max_length=1024,
                                    return_tensors="pt",
                                    truncation=True
                                )


    output_ids = model.generate(
        **input_encodings, 
        num_beams=5, 
        max_length=150, 
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=False,
        no_repeat_ngram_size= 5, 
        num_return_sequences=1
    )

    return output_ids

context="Hugging Face是一家美國公司，專門開發用於構建機器學習應用的工具。\
        該公司的代表產品是其為自然語言處理應用構建的transformers庫，\
        以及允許使用者共享機器學習模型和資料集的平台。"

summarize = Summarize_model(model, context)

print("Output:\n" + 100 * '-')
# for i, beam_output in enumerate(summarize):
#   print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

print(tokenizer.decode(summarize[0], skip_special_tokens=True))


Output:
----------------------------------------------------------------------------------------------------
KST）， 台灣樂團。成立於2013年，團名來自希臘哲學家赫拉克利特（Heraclitus）的名言 「Change is the only constant in life.」「唯一不變的就是永遠在變化」，也是佛說「無常」。
