### Initialize

In [9]:
# import csv
import json
import os
import random
import requests
import time

# import torch

from tqdm import tqdm
# from transformers import pipeline
# from transformers import AutoTokenizer, AutoModelForCausalLM

from constants import DIR_STORED_DATA
from crawler.utils import Logger
from news_with_rationale import NewsWithRationale
from rationale import Rationale
from summarized_news import SummarizedNews
from utils import *
from xai import XAI


logger = Logger(__name__)

# DIR_STORED_DATA = 'stored_data'

if not os.path.exists(DIR_STORED_DATA):
    os.makedirs(DIR_STORED_DATA)



In [2]:

# pipe = pipeline(
#     "text-generation",
#     model="google/gemma-2-2b",
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     device="cuda",
# )

# model_id = "meta-llama/Llama-3.2-1B"

# pipe = pipeline(
#     "text-generation", 
#     model="meta-llama/Llama-3.2-1B",
#     torch_dtype=torch.bfloat16, 
#     # device_map="auto"
#     device="cuda"
# )


## generate rationale

### generate rationale by local model

In [3]:
# load news data

from crawler.crawler_base import News, NewsCrawlerBase

# NEWS_DIR = os.path.join(os.path.dirname(__file__), "crawler/saved_news")
NEWS_DIR = "crawler/saved_news"

def get_news_data() -> list[News]:
    news: list[News] = []
    news_data: list[tuple[str, str]] = [] # (url, file_path)
    with open(os.path.join(NEWS_DIR, "crawled_urls.json"), "r") as f:
        news_data = json.load(f)
        # print(news_data)

    file_list = [file for _, file in news_data]
    for file in file_list:
        news.append(NewsCrawlerBase._parse_file(file))
    return news

# news = get_news_data()
# print(news[0])



In [None]:
# test rationale generation
import time
news = get_news_data()


doc = news[0].content
# print(f'{len(doc)=}')

doc = doc.replace("\n\n", "\n")
print(f'{len(doc)=}')
# print(doc)

# prompt = get_rationale_prompt_no_gt(doc)
prompt = get_rationale_prompt_chinese(doc, "test")
# print(f'{len(prompt)=}')


s_t = time.time()
output = pipe(prompt, max_length=2048)
e_t = time.time()
print(f'{e_t - s_t=}')
print(output[0]['generated_text'])

### convert row data to SummarizedNews

In [53]:
# convert row data to SummarizedNews


from opencc import OpenCC

cc = OpenCC('s2twp')

def process_str(s: str) -> str:
    return cc.convert(s.replace(" ", ""))  # translate to chinese traditional

summarized_data: list[SummarizedNews] = []

with open('CNewSum_v2/train.simple.label.jsonl', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    count = 0
    # for line in lines:
    for line in tqdm(lines):
        data = json.loads(line)

        article: list[str] = data['article']
        article = [process_str(a) for a in article]
        article_str = '\n'.join(article)

        summary: str = process_str(data['summary'])
        id = data['id']
        # assert len(data['label']) == 1
        # label = data['label'][0]
        label = data['label']
        summarized_data.append(SummarizedNews(article_str, summary, id, label))


print(f'read {len(summarized_data)} summarized news data')

SummarizedNews.save_all(summarized_data)


100%|██████████| 275596/275596 [13:57<00:00, 329.02it/s]

read 275596 summarized news data





### Load data and generate rationale

In [6]:
summarized_news = SummarizedNews.load_all()

x_ai = XAI()
responses_from_x_ai = XAI.get_responses()
rationales_from_x_ai = XAI.load_rationales_from_responses()


[SummarizedNews] [INFO] Loaded 275596 summarized news records
[XAI] [INFO] XAI instance created
[XAI] [INFO] loaded 207 responses from x_ai responses
[XAI] [INFO] loaded 207 rationales from x_ai responses


In [None]:

for i in range(210, 212):
    print(f"processing {i}")
    news = summarized_news[i]
    news_with_rationale = XAI.get_newsWithRationale(news)
    # print(news_with_rationale)
    news_with_rationale.save()



processing 210
processing 211


### sample test


In [24]:
response = {
    "id": "0f6f9b5a-ba7c-4b01-8054-aa28e855bbac",
    "object": "chat.completion",
    "created": 1735058331,
    "model": "grok-beta",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "核心要素：\n1. 海基會參訪大陸被叫停\n2. 國臺辦的回應\n3. 兩岸交流的影響\n\n三元組：\n- [海基會 | 參訪被叫停 | 大陸]\n- [國臺辦 | 回應 | 海基會參訪被叫停]\n- [臺灣方面 | 原因 | 海基會取消來訪]\n- [兩岸交流 | 受到干擾 | 海基會參訪被叫停]\n\n生成摘要：\n國臺辦回應海基會參訪大陸被叫停，指出是因為臺灣方面的原因，導致海基會取消了這次來訪。我們不希望兩岸正常的交流交往受到干擾。",
                "refusal": None
            },
            "finish_reason": "stop"
        }
    ],
    "usage": {
        "prompt_tokens": 657,
        "completion_tokens": 182,
        "total_tokens": 839,
        "prompt_tokens_details": {
            "text_tokens": 657,
            "audio_tokens": 0,
            "image_tokens": 0,
            "cached_tokens": 0
        }
    },
    "system_fingerprint": "fp_efe33e0791"
}

In [None]:

# news_with_rationale = NewsWithRationale(summarized_news[201], XAI.extract_rationale(XAI.response))
# # print(news_with_rationale.__str__())
# print(news_with_rationale.__dict__)
# news_with_rationale.save()

# XAI.response = {"id": "e33efe41-e762-4f4d-9174-b1c8ce276b42", "object": "chat.completion", "created": 1735313831, "model": "grok-beta", "choices": [{"index": 0, "message": {"role": "assistant", "content": "核心要素：\n1. 中國和印度的軍工行業發展\n2. 中國軍工行業的效率和先進水平\n3. 印度軍工行業的發展狀況\n\n三元組：\n\n[中國 | 軍工行業發展 | 遠超印度]\n[中國軍工行業 | 效率 | 先進水平的象徵]\n[印度軍工行業 | 發展狀況 | 效率低下]\n[中國 | 軍事研發 | 成功]\n[中國 | 國防預算 | 穩步增加]\n[中國軍工企業 | 參展 | 蘭卡威國際海事及航空航天展]\n\n生成摘要：\n美媒稱，中國軍工發展水平遠超印度。與印度相比，中國軍工行業是效率和先進水平的象徵。中國在軍事研發方面取得了成功，並通過穩步增加國防預算來支持軍工行業的現代化。相比之下，印度軍工行業的發展狀況顯得效率低下。中國軍工企業還積極參加了蘭卡威國際海事及航空航天展，展示其精良裝備。", "refusal": False}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 1259, "completion_tokens": 284, "total_tokens": 1543, "prompt_tokens_details": {"text_tokens": 1259, "audio_tokens": 0, "image_tokens": 0, "cached_tokens": 0}}, "system_fingerprint": "fp_e1b909a5cb"}

# # print(XAI.response)
# rationale = XAI.extract_rationale(XAI.response)
# print(rationale)

### test local rationale generation

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model_from_pyTorch(model_name: str):
    # load the model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    return tokenizer, model

def generate_text_from_model(tokenizer, model, prompt, max_length=1024):
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {name: tensor.to(model.device) for name, tensor in inputs.items()}
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

news = get_news_data()
doc = news[0].content

model_name = "google/gemma-2-2b"
tokenizer, model = load_model_from_pyTorch(model_name)
print(f'{model.device=}')
prompt = get_rationale_prompt_no_gt(doc)
print(f'{len(prompt)=}')
output = generate_text_from_model(tokenizer, model, prompt, max_length=3000)
print(f'{len(output)=}')