In [None]:
# %pip install evaluate

In [4]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import evaluate
metric = evaluate.load("accuracy")


Downloading builder script: 0.00B [00:00, ?B/s]

In [None]:
# 1. Load dữ liệu
data = pd.read_csv("/kaggle/input/vsoslcsum/contents.csv")
summary = pd.read_csv("/kaggle/input/vsoslcsum/summaries.csv")

data["label_binary"] = (data["label"] >= 3).atsype(int)

In [None]:
# 2. Sinh cặp pairwise
pairs = []
for pid, group in data.groupby("post_id"):
    sentences = group.to_dict("records")
    for i in range(len(sentences)):
        for j in range(len(sentences)):
            if sentences[i]["label"] > sentences[j]["label"]:
                pairs.append({
                    "post_id": pid,
                    "sent_u": sentences[i]["text"],
                    "sent_v": sentences[j]["text"],
                    "label": 1
                })
            elif sentences[i]["label"] < sentences[j]["label"]:
                pairs.append({
                    "post_id": pid,
                    "sent_u": sentences[i]["text"],
                    "sent_v": sentences[j]["text"],
                    "label": 0
                })

pairs_df = pd.DataFrame(pairs)

In [None]:
# 3. Dataset class
class PairwiseDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=128):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        encoded = self.tokenizer(
            row["sent_u"],
            row["sent_v"],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            "input_ids": encoded["input_ids"].squeeze(),
            "attention_mask": encoded["attention_mask"].squeeze(),
            "labels": torch.tensor(row["label"], dtype=torch.long)
        }

In [None]:
# 4. Load PhoBERT
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base", use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained("vinai/phobert-base", num_labels=2)

dataset = PairwiseDataset(pairs_df, tokenizer)

In [None]:
# 6. Training setup
training_args = TrainingArguments(
    output_dir="./phobert-ranking",
    evaluation_strategy="epoch",  
    eval_steps=500,               # tính accuracy mỗi 500 steps
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    save_total_limit=2
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset.sample(frac=0.1, random_state=42),  # dùng 10% để eval
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()

In [None]:
# 6. Sinh summary (inference)
def predict_summary(post_id, top_k=6):
    group = data[data["post_id"] == post_id]
    sentences = group["content"].tolist()
    scores = []

    for s in sentences:
        encoded = tokenizer(
            s,
            truncation=True,
            padding="max_length",
            max_length=128,
            return_tensors="pt"
        )
        with torch.no_grad():
            output = model(**encoded)
            score = torch.softmax(output.logits, dim=1)[0][1].item()
            scores.append((s, score))

    scores = sorted(scores, key=lambda x: x[1], reverse=True)
    return [s for s, _ in scores[:top_k]]

In [None]:
# 7. Đánh giá ROUGE
rouge = load_metric("rouge")

all_preds, all_refs = [], []
for pid, group in summary.groupby("post_id"):
    pred_summary = " ".join(predict_summary(pid, top_k=len(group)))
    gold_summary = " ".join(group["content"].tolist())

    all_preds.append(pred_summary)
    all_refs.append(gold_summary)

results = rouge.compute(predictions=all_preds, references=all_refs)
print(results)