In [None]:
!pip install rouge scikit-learn

In [None]:
!pip install huggingface_hub==0.33.1 transformers==4.52.4

In [None]:
import json
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModel
import torch
from rouge import Rouge
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from tqdm import tqdm
import re

tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
model = AutoModel.from_pretrained("vinai/phobert-base")
model.eval()

def get_phobert_embedding(text):
    """Get CLS embedding from PhoBERT - returns torch.Tensor"""
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=256, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[0][0].unsqueeze(0)

vietnamese_stopwords = set([
    'bị', 'bởi', 'cả', 'các', 'cái', 'cần', 'càng', 'chỉ', 'chiếc', 'cho', 'chứ',
    'chưa', 'chuyện', 'có', 'có_thể', 'cứ', 'của', 'cùng', 'cũng', 'đã', 'đang',
    'đây', 'để', 'đến_nỗi', 'đều', 'điều', 'do', 'đó', 'được', 'dưới', 'gì',
    'khi', 'không', 'là', 'lại', 'lên', 'lúc', 'mà', 'mỗi', 'này', 'nên', 'nếu',
    'ngay', 'nhiều', 'như', 'nhưng', 'những', 'nơi', 'nữa', 'phải', 'qua', 'ra',
    'rằng', 'rất', 'rồi', 'sau', 'sẽ', 'so', 'sự', 'tại', 'theo', 'thì', 'trên',
    'trước', 'từ', 'từng', 'và', 'vẫn', 'vào', 'vậy', 'vì', 'việc', 'với', 'vừa'
])

vietnamese_stopwords.update([c for c in '!"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~…""\'\''])

def remove_stopwords(text):
    text = text.lower()
    words = [w for w in text.split() if w not in vietnamese_stopwords and len(w) > 1]
    return ' '.join(words)

def train_lda_model(clusters, n_components=4, max_iter=3):
    paragraphs = []
    
    for cluster in clusters:
        for doc in cluster['single_documents']:
            text = doc.get('raw_text', '')
            if not text:
                continue
                
            paras = [p.strip() for p in text.split('\n\n') if p.strip()]
            
            if not paras:
                paras = [p.strip() for p in text.split('\n') if p.strip()]
            
            if not paras:
                paras = [s.strip() + '.' for s in text.split('.') if s.strip()]
            
            paragraphs.extend([remove_stopwords(p) for p in paras if p])
    
    tf = TfidfVectorizer(min_df=2, max_df=0.95, max_features=3000, sublinear_tf=True)
    X = tf.fit_transform(paragraphs)
    
    lda_model = LatentDirichletAllocation(
        n_components=n_components, 
        learning_method='online', 
        random_state=42, 
        max_iter=max_iter
    )
    lda_model.fit(X)
    
    return tf, lda_model

def divide_into_sections_lda(doc_texts, tf_model, lda_model):
    all_sents = []
    doc_sent_ranges = []
    doc_paragraphs = []
    
    for doc_text in doc_texts:
        start_idx = len(all_sents)
        sents = [s.strip() for s in doc_text.split('\n') if s.strip()]
        if not sents:
            sents = [s.strip() + '.' for s in doc_text.split('.') if s.strip()]
        all_sents.extend(sents)
        paras = []
        para_sent_indices = []
        for i in range(0, len(sents), 4):
            para = ' '.join(sents[i:i+4])
            paras.append(para)
            para_sent_indices.append(list(range(start_idx + i, start_idx + min(i+4, len(sents)))))
        doc_paragraphs.append((paras, para_sent_indices))
        doc_sent_ranges.append((start_idx, len(all_sents)))
    
    all_paras = []
    para_doc_map = []
    para_sent_map = []
    
    for doc_idx, (paras, sent_indices) in enumerate(doc_paragraphs):
        for para, sent_idx in zip(paras, sent_indices):
            all_paras.append(remove_stopwords(para))
            para_doc_map.append(doc_idx)
            para_sent_map.append(sent_idx)
    
    X = tf_model.transform(all_paras)
    lda_topics = lda_model.transform(X)
    
    para_topics = []
    for topic_dist in lda_topics:
        if np.min(topic_dist) == np.max(topic_dist):
            para_topics.append(0)
        else:
            para_topics.append(np.argmax(topic_dist))
    
    section_map = {}
    sect_id = 0
    
    for doc_idx in range(len(doc_texts)):
        doc_topics = set()
        for para_idx, (p_doc, p_topic) in enumerate(zip(para_doc_map, para_topics)):
            if p_doc == doc_idx:
                doc_topics.add(p_topic)
        
        for topic in sorted(doc_topics):
            section_map[(doc_idx, topic)] = sect_id
            sect_id += 1
    
    total_sects = sect_id
    total_sents = len(all_sents)
    total_docs = len(doc_texts)
    
    doc_sect_mask = np.zeros((total_docs, total_sects), dtype=int)
    sect_sent_mask = np.zeros((total_sects, total_sents), dtype=int)
    
    for para_idx, (p_doc, p_topic, sent_indices) in enumerate(zip(para_doc_map, para_topics, para_sent_map)):
        global_sect_id = section_map[(p_doc, p_topic)]
        doc_sect_mask[p_doc][global_sect_id] = 1
        for sent_idx in sent_indices:
            sect_sent_mask[global_sect_id][sent_idx] = 1
    
    return doc_sect_mask, sect_sent_mask, all_sents

def mask_to_adj(doc_sect_mask, sect_sent_mask):
    doc_sect_mask = np.array(doc_sect_mask)
    sect_sent_mask = np.array(sect_sent_mask)
    
    sent_num = sect_sent_mask.shape[1]
    sect_num = sect_sent_mask.shape[0]
    doc_num = doc_sect_mask.shape[0]
    total_nodes = sent_num + sect_num + doc_num + 1
    
    adj = np.zeros((total_nodes, total_nodes))
    
    adj[sent_num:sent_num+sect_num, 0:sent_num] = sect_sent_mask
    adj[0:sent_num, sent_num:sent_num+sect_num] = sect_sent_mask.T
    
    adj[sent_num+sect_num:sent_num+sect_num+doc_num, sent_num:sent_num+sect_num] = doc_sect_mask
    adj[sent_num:sent_num+sect_num, sent_num+sect_num:sent_num+sect_num+doc_num] = doc_sect_mask.T
    
    for i in range(sect_num):
        sect_mask = sect_sent_mask[i:i+1]
        adj[0:sent_num, 0:sent_num] += sect_mask.T @ sect_mask
    
    for i in range(doc_num):
        doc_mask = doc_sect_mask[i:i+1]
        adj[sent_num:sent_num+sect_num, sent_num:sent_num+sect_num] += doc_mask.T @ doc_mask
    
    root_idx = total_nodes - 1
    adj[root_idx, sent_num+sect_num:sent_num+sect_num+doc_num] = 1
    adj[sent_num+sect_num:sent_num+sect_num+doc_num, root_idx] = 1
    adj[root_idx, root_idx] = 1
    
    return adj

class Graph:
    def __init__(self, sents, sentVecs, scores, doc_sec_mask, sec_sen_mask, golden, threds=0.5):
        assert len(sentVecs) == len(scores) == len(sents), \
            f"Mismatch: {len(sentVecs)} vecs, {len(scores)} scores, {len(sents)} sents"
        
        self.docnum = len(doc_sec_mask)
        self.secnum = len(sec_sen_mask)
        self.adj = torch.from_numpy(mask_to_adj(doc_sec_mask, sec_sen_mask)).float()
        
        sentVecs_np = []
        for vec in sentVecs:
            if isinstance(vec, torch.Tensor):
                sentVecs_np.append(vec.squeeze().cpu().numpy())
            else:
                sentVecs_np.append(vec)
        
        vec_dim = sentVecs_np[0].shape[0]
        self.feature = np.concatenate((
            np.array(sentVecs_np), 
            np.zeros((self.secnum + self.docnum + 1, vec_dim))
        ))
        
        self.score = torch.from_numpy(np.array(scores)).float()
        self.score_onehot = (self.score >= threds).float()
        self.sents = np.array(sents)
        self.golden = golden
        
        golden_embedding = get_phobert_embedding(golden)
        self.goldenVec = golden_embedding.float()
        
        self.init_node_vec()
        self.feature = torch.from_numpy(self.feature).float()
    
    def init_node_vec(self):
        sent_num = len(self.sents)
        
        for i in range(sent_num, sent_num + self.secnum):
            mask = self.adj[i].clone()
            mask[sent_num:] = 0
            connected = mask.bool()
            if connected.any():
                connected_indices = connected.numpy()
                self.feature[i] = np.mean(self.feature[connected_indices], axis=0)
        
        for i in range(sent_num + self.secnum, sent_num + self.secnum + self.docnum):
            mask = self.adj[i].clone()
            mask[sent_num + self.secnum:] = 0
            connected = mask.bool()
            if connected.any():
                connected_indices = connected.numpy()
                self.feature[i] = np.mean(self.feature[connected_indices], axis=0)
        
        doc_start = sent_num + self.secnum
        doc_end = sent_num + self.secnum + self.docnum
        self.feature[-1] = np.mean(self.feature[doc_start:doc_end], axis=0)

def process_jsonl_to_graphs(jsonl_path, use_pretrained_lda=None, n_components=4, max_iter=3, 
                            start_line=0, end_line=None):
    clusters = []
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            if idx < start_line:
                continue
            if end_line is not None and idx >= end_line:
                break
            clusters.append(json.loads(line))
    
    if use_pretrained_lda is None:
        tf_model, lda_model = train_lda_model(clusters, n_components, max_iter)
    else:
        tf_model, lda_model = use_pretrained_lda
    
    rouge = Rouge()
    
    summary_types = ['summary_0', 'summary_1', 's3_summary_0', 's3_summary_1']
    all_graphs = {stype: [] for stype in summary_types}
    
    for cluster_idx, cluster in enumerate(tqdm(clusters, desc="Processing clusters")):
        try:
            doc_texts = [doc.get('raw_text', '') for doc in cluster['single_documents']]
            doc_texts = [text for text in doc_texts if text]
            
            if not doc_texts:
                continue
            
            doc_sect_mask, sect_sent_mask, sents = divide_into_sections_lda(
                doc_texts, tf_model, lda_model
            )
            
            if not sents:
                continue
            
            sentVecs = [get_phobert_embedding(sent) for sent in sents]
            
            for stype in summary_types:
                summary = cluster.get(stype, '')
                
                if not summary:
                    continue
                
                summary = summary.replace('– ', '').replace('- ', '').strip()
                
                if not summary:
                    continue
                
                scores = []
                for sent in sents:
                    if not sent.strip():
                        scores.append(0.0)
                        continue
                    
                    try:
                        rouge_scores = rouge.get_scores(sent, summary)[0]
                        score = rouge_scores['rouge-2']['p']
                        scores.append(score)
                    except:
                        scores.append(0.0)
                
                graph = Graph(sents, sentVecs, scores, doc_sect_mask, sect_sent_mask, summary)
                all_graphs[stype].append(graph)
        
        except Exception:
            continue
    
    return all_graphs, (tf_model, lda_model)

In [None]:
import math
import sys
import pickle

input_path = ''
lda_model_path = None
save_prefix = ''
samples_per_group = 500
group_to_run = 1

lda_models_loaded = None
if lda_model_path:
    try:
        with open(lda_model_path, 'rb') as f:
            lda_models_loaded = pickle.load(f)
    except Exception:
        pass

with open(input_path, 'r', encoding='utf-8') as f:
    total_lines = sum(1 for _ in f)

total_groups = math.ceil(total_lines / samples_per_group)

if not (1 <= group_to_run <= total_groups):
    sys.exit(1)

group_idx0 = group_to_run - 1
start_line = group_idx0 * samples_per_group
end_line = min((group_idx0 + 1) * samples_per_group, total_lines)

all_graphs, lda_models = process_jsonl_to_graphs(
    input_path,
    use_pretrained_lda=lda_models_loaded,
    n_components=4,
    max_iter=3,
    start_line=start_line,
    end_line=end_line
)

for stype, graphs in all_graphs.items():
    if graphs:
        output_name = f"{save_prefix}_{stype}_{group_to_run}_numsample{samples_per_group}_numgroup{total_groups}.pkl"
        with open(output_name, 'wb') as f:
            pickle.dump(graphs, f)

if group_to_run == 1 and lda_models_loaded is None:
    lda_save_path = f"{save_prefix}_lda_models.pkl"
    with open(lda_save_path, 'wb') as f:
        pickle.dump(lda_models, f)