In [None]:
import os
import random
import json
from transformers import AutoTokenizer
from datasets import load_from_disk, load_dataset, DatasetDict
from Libraries import Summarizer_Trainer as Trainer

In [None]:
# ============================================================
# 1Ô∏è‚É£  C·∫§U H√åNH BI·∫æN TO√ÄN C·ª§C
# ============================================================
DATABASE = "Database"
DATASET = "Datasets/vietnews"
MODELS = "Models"
Service = "Summarizer"
DataFile = "TrainData.jsonl"
ModelName = "bartpho-summarizer"
Checkpoint = "vinai/bartpho-syllable"

DataPath = f"{DATABASE}/{Service}/{DataFile}"
ModelPath = f"{MODELS}/{Service}/{ModelName}"

# ‚öôÔ∏è TH√îNG S·ªê HU·∫§N LUY·ªÜN ‚Äî T·ªêI ∆ØU CHO RTX 4050 8GB
MAX_INPUT_LENGTH = 512          # r√∫t ng·∫Øn ƒë·ªÉ gi·∫£m b·ªô nh·ªõ & tƒÉng t·ªëc
MAX_TARGET_LENGTH = 128
NUM_TRAIN_EPOCHS = 3            # kho·∫£ng 8‚Äì10h v·ªõi c·∫•u h√¨nh n√†y
LEARNING_RATE = 3e-5
WEIGHT_DECAY = 0.01
BATCH_SIZE = 2                  # v·ª´a kh√≠t GPU 8GB khi fp16 + grad_accum=4

In [None]:
if not os.path.exists(DATASET):
    ds = load_dataset("nam194/vietnews")
    ds.save_to_disk("Datasets/vietnews")

In [None]:
# ============================================================
# 2Ô∏è‚É£  LOAD DATASET
# ============================================================
print("üîπ ƒêang t·∫£i dataset ...")
if os.path.isdir(DATASET):
    dataset = load_from_disk(DATASET)
elif os.path.exists(DataPath):
    from datasets import load_dataset
    dataset = load_dataset("json", data_files=DataPath)
else:
    raise FileNotFoundError(f"‚ùå Kh√¥ng t√¨m th·∫•y dataset t·∫°i {DATASET} ho·∫∑c {DataPath}")

print(f"‚úÖ Dataset loaded th√†nh c√¥ng: {dataset}")
if isinstance(dataset, DatasetDict):
    print(f"üì¶ C√°c split: {list(dataset.keys())}")

In [None]:
# ============================================================
# 3Ô∏è‚É£  TOKENIZER & TRAINER KH·ªûI T·∫†O
# ============================================================
print(f"üîπ ƒêang t·∫£i tokenizer t·ª´ checkpoint: {Checkpoint}")
tokenizer = AutoTokenizer.from_pretrained(Checkpoint)

summarizer_trainer = Trainer.SummarizationTrainer(
    # ===== D·ªÆ LI·ªÜU =====
    Max_Input_Length=MAX_INPUT_LENGTH,
    Max_Target_Length=MAX_TARGET_LENGTH,
    prefix="",                          # BARTPho kh√¥ng c·∫ßn ti·ªÅn t·ªë
    input_column="article",
    target_column="abstract",           # VietNews d√πng c·ªôt "abstract"

    # ===== HU·∫§N LUY·ªÜN =====
    Learning_Rate=LEARNING_RATE,
    Weight_Decay=WEIGHT_DECAY,
    Batch_Size=BATCH_SIZE,
    Num_Train_Epochs=NUM_TRAIN_EPOCHS,
    gradient_accumulation_steps=4,      # ‚Üí effective batch = 2√ó4 = 8
    warmup_ratio=0.1,                   # tƒÉng d·∫ßn LR trong 10% ƒë·∫ßu
    lr_scheduler_type="linear",
    seed=42,

    # ===== SUY DI·ªÑN / SINH =====
    num_beams=4,                        # trade-off gi·ªØa ch·∫•t l∆∞·ª£ng & t·ªëc ƒë·ªô
    fp16=True,                          # gi·∫£m VRAM, tƒÉng t·ªëc
    early_stopping_patience=1,          # d·ª´ng s·ªõm n·∫øu kh√¥ng c·∫£i thi·ªán
    logging_steps=500,
    report_to="none",                   # t·∫Øt logging ngo√†i
)

In [None]:
# ============================================================
# 4Ô∏è‚É£  CH·∫†Y HU·∫§N LUY·ªÜN
# ============================================================
print("\nüöÄ B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán ...\n")
trainer = summarizer_trainer.run(
    Checkpoint=Checkpoint,
    ModelPath=ModelPath,
    DataPath=DATASET,
    tokenizer=tokenizer,
)
print("\n‚úÖ Hu·∫•n luy·ªán ho√†n t·∫•t.")

In [None]:
# ============================================================
# 5Ô∏è‚É£  KI·ªÇM TH·ª¨ NHANH
# ============================================================
if "test" in dataset:
    sample = random.choice(dataset["test"])
elif "validation" in dataset:
    sample = random.choice(dataset["validation"])
else:
    sample = random.choice(dataset["train"])

article_text = sample["article"]
print("\nüì∞ B√†i b√°o g·ªëc:\n")
print(article_text[:700], "...")

summary = summarizer_trainer.generate(article_text, max_new_tokens=160)
print("\nüßæ T√≥m t·∫Øt m√¥ h√¨nh sinh ra:\n")
print(summary)

In [None]:
# ============================================================
# 6Ô∏è‚É£  ƒê√ÅNH GI√Å & L∆ØU K·∫æT QU·∫¢
# ============================================================
print("\nüìä ƒêang ƒë√°nh gi√° m√¥ h√¨nh ...")
eval_results = trainer.evaluate()
results_path = os.path.join(ModelPath, "eval_results.json")

with open(results_path, "w", encoding="utf-8") as f:
    json.dump(eval_results, f, ensure_ascii=False, indent=2)

print(f"üíæ K·∫øt qu·∫£ ROUGE ƒë√£ l∆∞u t·∫°i: {results_path}")
print("\nüéØ Pipeline ho√†n t·∫•t.")