In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import tqdm
import os
import pickle
import re
import glob
import torch.nn.functional as F
import numpy as np

In [None]:
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
device = torch.device("mps")
model = model.to(device)

In [None]:
def get_word_probability(context, word, model, tokenizer):
    model.to(torch.device("mps"))
    # Tokenize the context and the word
    context_tokens = tokenizer.encode(context, add_special_tokens=True)
    word_tokens = tokenizer.encode(word, add_special_tokens=False)[1:]
        
    # Combine context with word tokens for probability calculation
    combined_tokens = context_tokens + word_tokens
        
    total_prob = 1.0
    with torch.no_grad():
        for i in range(len(context_tokens), len(combined_tokens)):
            # Get the input tokens up to the current token
            input_tokens = combined_tokens[:i]
            target_token = combined_tokens[i]
                
            # Convert to tensor
            input_tensor = torch.tensor([input_tokens]).to(torch.device("mps"))
                
            # Get model output and probabilities
            outputs = model(input_tensor)
            logits = outputs.logits  # (batch_size, seq_len, vocab_size)
            last_token_logits = logits[0, -1, :]  # Get logits for the last token
            probs = F.softmax(last_token_logits, dim=-1)

            # Get the probability of the target token
            token_prob = probs[target_token].item()
            total_prob *= token_prob  # Multiply probabilities

    return total_prob

In [None]:
with open("selected_sentences.pickle", "rb") as f:
    sentences = pickle.load(f)

In [None]:
term_lists = [
    (["兩造",], ["雙方"]),
    (["上揭", "上開", "前揭", "前開", "首揭"], ["前述", "上述"]),
    (["云云",], ["等陳述", "等語", "等等"]),
    (["可考", "可佐", "可按", "可稽", "可證", "足按", "足徵", "足稽", "足憑", "足證"], ["可以佐證", "可供證明", "可以證明", "足以佐證", "足以證明"]),
    (["迭",], ["接連", "多次"]),
    (["拘束",], ["限制",]),
    (["失所附麗",], ["失所依附",]),
    (["尚非無稽", "尚非無憑", "尚非無據", "尚非虛妄", "尚非臆造"], ["應可採信", "應屬事實", "並非全無依據", "並不是完全沒有依據"]),
    (["所載",], ["所記載",]),
    (["考諸", "徵諸", "觀諸", "稽之"], ["參考", "依照", "依據"]),
    (["質言之",], ["簡言之",]),
    (["相歧",], ["矛盾",]),
    (["即非法所不許", "依法即無不合",], ["符合法律規定",]),
    (["礙難採認",], ["難以認定", "難以採信", "不可採"]),
    (["矧",], ["況且",]),
    (["翻異"], ["推翻"]),
    (["乃",], ["於是",]),
    (["自白不諱", "供認不諱", "坦承不諱"], ["坦白承認"]),
    (["似無可採", "似屬無憑", "即無可採", "即屬無據", "尚無可採", "尚難憑採", "要非可信", "要屬虛言", "容非可採"], ["難以採信", "不可採信", "尚不足採信", "尚不足採證"]),
    (["顯有",], ["顯然有", "顯然屬於", "顯然是"])
]

for legal_terms, usual_terms in term_lists:

    terms = legal_terms+usual_terms
    print("|".join(terms))
    save_path = os.path.join("llm_surprisals", "|".join(terms))

    if os.path.isfile(save_path): continue
    term_types = ["legal"]*len(legal_terms)+["usual"]*len(usual_terms)


    regex_str = "|".join(terms)
    idx = 0
    contexts = {term: [] for term in terms}

    context_sentences = []
    term_list = []
    for sentence1, sentence2 in zip(sentences[:-1], sentences[1:]):

        matches = list(re.finditer(regex_str, sentence2))
        if matches:
            for match in matches:

                context = sentence1+"，"+sentence2[:match.start()]
                context_sentences.append(context)
                term_list.append(match.group())

    term_surprisals = {term: [] for term in terms}
    for context_sentence, target_term in tqdm.tqdm([p for p in zip(context_sentences, term_list)]):
        term_prob = 0
        for term in set(terms):

            term_prob+=get_word_probability(context_sentence, term, model, tokenizer)

        term_surprisal = -np.log(term_prob)
        term_surprisals[target_term].append(term_surprisal)

    with open(save_path, "wb") as f:
        pickle.dump((term_surprisals, term_types), f)