<a href="https://colab.research.google.com/github/Mrbold8/mongolian_text_summarization/blob/main/text_summarization_cp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



```
# Setup
```

In [None]:
!nvidia-smi # runtime T4 GPU ашиглана

# --- Гол libraries суулгах
!pip -q install -U "transformers[torch]" datasets accelerate peft sentencepiece evaluate rouge_score bitsandbytes wandb

# --- Import test хийх, version -г хэвлэх
import torch, transformers, datasets, peft, evaluate
print("Torch:", torch.__version__)
print("Transformers:", transformers.__version__)
print("Datasets:", datasets.__version__)
print("PEFT:", peft.__version__)

In [None]:
# --- Hugginface -руу нэтврэх
from huggingface_hub import login
login()

# preprocess хийхийн өмнө schema -г бататгах
Load & inspect dataset

In [None]:
from datasets import load_dataset

# 1) Load all splits
ds = load_dataset("amaraaa/mn_translated_cnn")
print(ds)

# 2) Train split -ээс хэдэн мөрүүдийг шалгая
for i in range(2):
  row = ds["train"][i]
  print(f"nRow {i} keys:", row.keys())
  print("id:", row["id"])
  print("article snippet:", (row["article"] or "")[:200].replace("\n"," "))
  print("highlights (type/len):", type(row["highlights"]).__name__, len(row["highlights"]))
  print("first highlight:", row["highlights"][0] if row["highlights"] else "-")

# 3) Sanity checks - Алдаанаас сэргийлж шалгалт хийх (data sanity checks)
def summarize_split(name):
  d = ds[name]
  # Хоосон article байгаа эсэхийг шалгах
  empty_articles = sum(1 for r in d if not r["article"] or not r["article"].strip())
  # Хоосон highlights байгаа эсэхийг шалгах
  empty_highlights = sum(1 for r in d if not r"highlights" or not r["highlights".strip()])
  return {
      "rows": len(d),
      "empty_articles": empty_articles,
      "empty_highlights": empty_highlights,
  }

print("\nSanity:")
for split in ds.keys():
  print(split, summarize_split(split))


# Build preprocessing - (model training -д зориулж өгөгдлийг бэлдэх)
* mT5 нь single target string хүлээж авдаг
* highlights -д мөр бүрт string list байгаа

1) Иймээс list -ийн агуулгыг newline -тай нэг text болгоно
2) inputs/targets -ийг tokenize хийнэ (consistent max lengths)

In [None]:
from datasets import DatasetDict
from transformers import AutoTokenizer

# --- Тохируулга хийх
model_name = "google/mt5-small"
task_prefix = "summarize: " # t5 model -д ямар таск хийхийг prefix зааж өгнө
max_source_length = 768
max_target_length = 128

# 1) Highlight -уудын list -ийг single string болгож join хийх
def join_highlights(example): # (example -> single record)
  bullets = example.get("highlights") or []
  # keep only non-empty strings and strip whitespace
  bullets = [str(x).strip() for x in bullets if isinstance(x, str) and str(x).strip()]
  example["highlights_str"] = "\n".join(bullets)
  # print(example)
  return example

# Функцийг split бүрийн record дээр ажилуулна
ds = ds.map(join_highlights, desc="Joining highlight bullets")

# 2) Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 3) Tokenize inputs & targets - (article, highlights)
def preprocess_batch(batch):
  # input text + task hint
  inputs = [task_prefix + (article or "") for article in batch["article"]]
  targets = batch["highlights_str"]

  # tokenize sources
  model_inputs = tokenizer(
      inputs,
      max_length=max_source_length,
      truncation=True,
  )

  # tokenize targets (labels)
  labels = tokenizer(
      text_target=targets,
      max_length=max_target_length,
      truncation=True,
  )

  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

# cols_to_remove = list(set(ds["train"].column_names) - {"id"})
# Remove original columns and the joined highlights string, including 'id'
cols_to_remove = ["id", "article", "highlights", "highlights_str"]

tokenized_ds = ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=cols_to_remove,
    desc="Tokenizing"
)

print(tokenized_ds)
print("Example ids:", tokenized_ds["train"][0]["input_ids"][:20], "...")
print("Example labels:", tokenized_ds["train"][0]["labels"][:20], "...")

# --- Tokenized хослолыг шалгах
sample = tokenized_ds["train"][0]
print(tokenizer.decode(sample["input_ids"][:128], skip_special_tokens=True))

# Safe label preview: (replace -100 with pad id before decoding)
safe_labels = [(tid if tid != -100 else tokenizer.pad_token_id) for tid in sample["labels"]]
print(tokenizer.decode(safe_labels[:64], skip_special_tokens=True))



# Load mT5-small & attach LoRa adapters
LoRa гэх мэт PEFT method -оор model сургахдаа бэлдэх (peft configuration)

LoRa -> суурь model -ийг бүхэлд нь биш зөвхөн тодорхой цөөн weight -үүдийг сургана.
*   Resrource бага шаардана
*   Илүү efficient байх боломжтой
*   Model -ийн цөөн тооны parameter -ийг үр дүнтэй fine-tune хийнэ.
*   Computational болон storage cost багасгана.


In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model, TaskType

#1) Load base model
base_model_name = "google/mt5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)

#2) (recommended for memory) - enable gradient checkpointing later during training
model.config.use_cache = False
# model.config.decoder_start_token_id = tokenizer.pad_token_id

#3) LoRa config тохируулах
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=16, # rank (capacity of adapters)
    lora_alpha=32, #scaling
    lora_dropout=0.1, # regularization
    target_modules=["q", "v"], # T5 attention projections commonly adapted
)

#4) Wrap model with LoRa - (mt5 -д loRa technique -ийг apply хийнэ)
model = get_peft_model(model, peft_config)

# Enable input gradients for gradient checkpointing
model.enable_input_require_grads()

# Move the model to the GPU
model.to("cuda")


#5) Quick sanity: хэчнээн params сургахаа харах
model.print_trainable_parameters()

In [None]:
# 2) Sentinel blocklist үүсгэх (T5 extra_id 0..99) -> for generation
vsize = tokenizer.vocab_size
extra_ids = list(range(vsize - 100, vsize))
bad_words_ids = [[i] for i in extra_ids]
print(f"Blocking {len(bad_words_ids)} sentinel IDs: {extra_ids[0]}..{extra_ids[-1]}")

# Trainer тохируулга (metrics, collator, hyperparams)
Rouge metrics, batching/padding, hyperparameters тохируулах


In [None]:
import numpy as np
import evaluate
from transformers import(
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback,
)

#1) Metric: ROUGE
rouge = evaluate.load("rouge")

def postprocess_text(preds, labels):
  # strip хийгээд sentence хоороod newline нэмэх
  preds = [p.strip() for p in preds]
  labels = [l.strip() for l in labels]
  return preds, labels

def compute_metrics(eval_pred):
  preds, labels = eval_pred

  # Some trainers return (sequences)
  if isinstance(preds, tuple):
      preds = preds[0] #take token IDs matrix

  # IMPORTANT: decode хийхээс өмнө ignore index (-100) -г real pad token ID -аар солих
  #preds
  if isinstance(preds, np.ndarray):
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
  else:
    # fallback in case preds is a list
    preds = [[(tok if tok != -100 else tokenizer.pad_token_id) for tok in seq] for seq in preds]
  #labels
  if isinstance(labels, np.ndarray):
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  else:
    # fallback in case labels is a list
    labels = [[(tok if tok != -100 else tokenizer.pad_token_id) for tok in seq] for seq in labels]

  # print("Debugging: (computer_metrics) ")
  # print("Type of preds:", type(preds))
  # print("Preds -> first 50 values", preds[:50])
  # print("---------------------------------")

  # decode predictions
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
  print(decoded_preds)

  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  # light cleanup for ROUGE
  decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

  # Compute ROUGE
  result = rouge.compute(
      predictions=decoded_preds,
      references=decoded_labels,
      use_stemmer=True,
  )

  # # add a smiple generation length metric
  # gen_lens = [np.count_nonzero(p != tokenizer.pad_token_id) for p in preds]
  # result["gen_len"] = float(np.mean(gen_lens))
  # # focus metric alias
  # result["rougeLsum"] = result.get("rougeLsum", result.get("rougeL", 0.0))
  # return {k: round(v * 100, 4) if k.startswith("rouge") else round(v, 2) for k, v in result.items()}

  return {k: round(v * 100, 4) for k, v in result.items()}

#2) dynamic padding & label masking (-100)
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8 if torch.cuda.is_available() else None,
)

#3) Enable memory savers
model.gradient_checkpointing_enable() # reduces VRAM

#4) Training arguments (tuned for small data + LoRA)
training_args = Seq2SeqTrainingArguments(
    output_dir="mt5_mncnn_lora",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-4,
    num_train_epochs=8,
    eval_strategy="epoch",
    save_strategy="epoch",
    predict_with_generate=True,
    generation_max_length=128,
    generation_num_beams=4,
    warmup_ratio=0.03,
    # lr_scheduler_type="cosine",
    weight_decay=0.01,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="eval_rougeLsum",
    greater_is_better=True,
    fp16=False,  #AMP on Colab GPUs
    # label_smoothing_factor=0.1,
    # max_grad_norm=0.5,
    # remove_unused_columns=True,
    report_to="none",

    # Decode hygiene
    # no_repeat_ngram_size=3,
    # length_penalty=1.0,
    # early_stopping=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Add early stopping (prevents overfitting on 369 rows)
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=2))

print("Ready to train.")
print("Train rows:", len(tokenized_ds["train"]), "| Val rows:", len(tokenized_ds["validation"]))

# Sanity check before full training

In [None]:
# Small slice of eval to test metrics & decoding path
small_eval = tokenized_ds["validation"].select(range(20))
eval_result = trainer.evaluate(eval_dataset=small_eval)

print(eval_result)


In [None]:
# Small slice of eval to test metrics & decoding path
small_eval = tokenized_ds["train"].select(range(20))
eval_result = trainer.evaluate(eval_dataset=small_eval)

print(eval_result)


# Training эхлүүлэх - (LoRa on mT5-small)

In [None]:
# # ──────────────────────────────────────────────────────────────
# # 8) Evaluate → Train → Evaluate
# # ──────────────────────────────────────────────────────────────
print("Eval (before training):")
metrics_before = trainer.evaluate()
print(metrics_before)

# Training
train_result = trainer.train()

# Save the best checkpoint (load_best_model_at_the_end=True)
trainer.save_model("mt5_mncnn_lora/best") # saves LoRa adapter weights + config
tokenizer.save_pretrained("mt5_mncnn_lora/best")

# see a short summary of training
print(train_result)

# Eval after training
print("Eval (before training):")
metrics_after = trainer.evaluate()
print(metrics_after)

# Lock the best checkpoint, evaluate on test & inspect generations (with sentinel-blocking)

*   Best checkpoint -> хадгалах
*   Test dataset дээр evaluate хийх
*   Generations -ийг шалгах (block **sentinel tokens**)


In [None]:
# 1) Best checkpoint -ийг хадгалах
import torch, numpy as np
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_ckpt = getattr(trainer.state, "best_model_checkpoint", None)
print("Best checkpoint:", best_ckpt)

# If None, fall back to trainer.model as-is; else reload clean
if best_ckpt:
  base_eval = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
  model_eval = PeftModel.from_pretrained(base_eval, best_ckpt).to(device)
else:
  model_eval = trainer.model.to(device)

model_eval.eval()

# 2) Sentinel blocklist үүсгэх (T5 extra_id 0..99) -> for generation
vsize = tokenizer.vocab_size
extra_ids = list(range(vsize - 100, vsize))
bad_words_ids = [[i] for i in extra_ids]
print(f"Blocking {len(bad_words_ids)} sentinel IDs: {extra_ids[0]}..{extra_ids[-1]}")

# 3) DataLoader for test split
test_loader = DataLoader(
    tokenized_ds["test"],
    batch_size=4,
    shuffle=False,
    collate_fn=data_collator,
)

# 4) Generate with constraints & compute ROUGE on the test set
rouge = evaluate.load("rouge")
all_preds, all_refs = [], []

for batch in test_loader:
  labels = batch["labels"].numpy()
  refs = tokenizer.batch_decode(
    np.where(labels != -100, labels, tokenizer.pad_token_id),
    skip_special_tokens=True,
  )
  with torch.no_grad():
    gen = model_eval.generate(
        input_ids=batch["input_ids"].to(device),
        attention_mask=batch["attention_mask"].to(device),
        num_beams=4,
        max_new_tokens=128,
        no_repeat_ngram_size=3, #reduce loops
        repetition_penalty=1.15,
        bad_words_ids=bad_words_ids, #block sentinel tokens
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
  preds = tokenizer.batch_decode(gen, skip_special_tokens=True)

  all_refs.extend([r.strip() for r in refs])
  all_preds.extend([p.strip() for p in preds])

test_metrics = rouge.compute(
    predictions=all_preds,
    references=all_refs,
    use_stemmer=False,
)
test_metrics = {k: round(v * 100, 2) for k, v in test_metrics.items()}
print("TEST ROUGE:", test_metrics)

# 5) Show a few qualitative examples
def show_pairs(n=5):
  for i in range(n):
    print(f"\n--- Example{i+1} ---")
    print("Pred:", all_preds[i][:400])
    print("Ref:", all_refs[i][:400])

show_pairs(5)




In [None]:
# Сургасан model -оо хадгалах

#1) Google Drive -руу хадгалах
SAVE_DIR = "/content/drive/MyDrive/mt5_mncnn_lora_best_v1"

from google.colab import drive
drive.mount('/content/drive')

trainer.save_model(SAVE_DIR) # with PEFT: saves LoRA adapter + peft_config.json
tokenizer.save_pretrained(SAVE_DIR) # tokenizer files
model.config.save_pretrained(SAVE_DIR)
print("Saved to:", SAVE_DIR)


In [None]:
#Сургасан model -оо ачааллах.
!pip -q install transformers peft accelerate sentencepiece

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import torch, numpy as np

ADAPTER_DIR = "/content/drive/MyDrive/mt5_mncnn_lora_best_v1"
BASE = "google/mt5-small"

tokenizer = AutoTokenizer.from_pretrained(ADAPTER_DIR)
base = AutoModelForSeq2SeqLM.from_pretrained(BASE, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None)
model = PeftModel.from_pretrained(base, ADAPTER_DIR)

# Safety: pad/decoder ids
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token or "<pad>"
pad_id = tokenizer.pad_token_id
model.config.pad_token_id = pad_id
if getattr(model.config, "decoder_start_token_id", None) is None:
    model.config.decoder_start_token_id = pad_id

model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()


# Input оруулж model -ийн үр дүн хураангуйг шалгах

In [None]:
# === Single-input inference helper ===
# import torch, numpy as np
# from transformers import AutoModelForSeq2SeqLM
# from peft import PeftModel

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

text = "Сүүлийн өдрүүдэд нийслэл болон аймгуудын шатахуун түгээх станцууд дээр автомашиндаа бензин авах гэсэн иргэдийн урт дараалал үүсч, зарим газарт шатахуун дууссанаа мэдэгдэх болсон. Иргэд машинынхаа банкийг дүүртэл шатахуун авч байгаа нь дотоодын хэрэглээний зохицуулалтад хүндрэл үүсгэж, богино хугацаанд нийлүүлэлтийн ачааллыг нэмэгдүүлж байна гэж холбогдох албаны хүмүүс тайлбарласан. Аж үйлдвэр, эрдэс баялгийн сайд Г.Дамдинням энэ талаар өөрийн “Х” хуудастаа “130 вагон АИ-92 шатахуун ОХУ-аас ачигдан манай улс руу тээвэрлэгдэж байна. Үүнээс өнөөдөр 62 вагон нь хилээр нэвтэрч ирэх хуваарьтай бөгөөд энэ нь нийтдээ 4000 гаруй тонн шатахуун гэсэн үг. Улаанбаатар хотын ердийн өдрийн хэрэглээ ойролцоогоор 1000 тонн байдаг тул ирж буй нийлүүлэлт дөрөв дахин илүү хэмжээгээр хангалт үзүүлэх боломжтой” гэжээ. Мөн тэрбээр,“Орой гэхэд нөхцөл байдал тогтворжиж, шатахуун түгээх станцуудын үйл ажиллагаа хэвийн горимд шилжинэ гэж үзэж байна” хэмээн мэдээлсэн байна."

# 2) Build robust blocklist
#    - T5/mT5 sentinel tokens are the last 100 vocab IDs
vsize = tokenizer.vocab_size
bad_words_ids = [[i] for i in range(vsize - 100, vsize)]

#    - Also block the stray "langsung" if it tokenizes cleanly
# bad_lang = tokenizer.encode("langsung", add_special_tokens=False)
# if isinstance(bad_lang, list) and len(bad_lang) > 0:
#     bad_words_ids.append(bad_lang)

def mn_summarize(
    text: str,
    num_beams: int = 4,
    max_new_tokens: int = 128,
    no_repeat_ngram_size: int = 4,
    repetition_penalty: float = 1.2,
    length_penalty: float = 1.0,
):
    # Prefix helps T5-style models
    prompt = "summarize: " + (text or "")
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=768,
    ).to(device)

    with torch.no_grad():
        gen = model_eval.generate(
            **inputs,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            no_repeat_ngram_size=no_repeat_ngram_size,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            bad_words_ids=bad_words_ids,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    return tokenizer.decode(gen[0], skip_special_tokens=True)

# 3) Try it on a test row (or paste your own Mongolian text)
# print("Summary: ")
# print(mn_summarize(text))

# Test дээр турших
sample_text = ds["test"][0]["article"]
print("INPUT (snippet):", sample_text[:300].replace("\n"," "))
print("\nSUMMARY:", mn_summarize(sample_text))


In [None]:
# print(ds)
# print(ds["train"])
# print(ds["train"][1]["highlights"])

# def square(number):
#   return number * number

# number = [1, 2, 3, 4 , 5]
# square_nums = map(square, number)
# print(list(square_nums))