In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import tqdm
import os
import pickle
import re
import glob

import numpy as np

def cosine_similarity(vec1, vec2):
    vec1 = vec1.reshape(-1)
    vec2 = vec2.reshape(-1)
    """Compute cosine similarity between two vectors."""
    similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    if similarity<0: similarity = 0
    return similarity

def similarity_to_probability(similarity):
    """Convert cosine similarity to a probability-like measure."""
    # return ((1 + similarity) / 2)  # Rescale to [0,1]
    return similarity

def joint_probability_two(A, B):
    """Compute the joint probability of two vectors."""
    sim_AB = cosine_similarity(A, B)
    return similarity_to_probability(sim_AB)

def joint_probability_three(A, B, C):
    """Compute the joint probability of three vectors."""
    P_AB = joint_probability_two(A, B)
    P_BC = joint_probability_two(B, C)
    P_CA = joint_probability_two(C, A)
    
    return P_AB * P_BC * P_CA 

def joint_vector(A, B):
    joint_vector = (A + B) / 2 * cosine_similarity(A, B)
    
    return joint_vector

In [None]:
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
device = torch.device("mps")
model = model.to(device)

In [None]:
def get_sentence_embeddings(sentence, model, tokenizer):
    tokens = tokenizer(sentence, return_tensors="pt", add_special_tokens=False)
    input_ids = tokens["input_ids"].to(device)

    with torch.no_grad():
        embeddings = model.model.embed_tokens(input_ids).squeeze(0)
    
    embeddings = np.mean(embeddings.cpu().numpy(), axis=0)

    return embeddings

In [None]:
with open("selected_sentences.pickle", "rb") as f:
    sentences = pickle.load(f)

In [None]:
selected_sentence_embeds = []

In [None]:
if os.path.isfile("overall_sentence_count.pickle"):
    with open("overall_sentence_count.pickle", "rb") as f:
        overall_sentence_count = pickle.load(f)

else:
    all_sentence_counts = []
    pbar = tqdm.tqdm(selected_sentence_embeds[:10000])
    for target_sentence, target_sentence_embed in pbar:
        target_sentence_count = 0
        for context_sentence, context_sentence_embed in selected_sentence_embeds:
            target_sentence_count+=joint_probability_two(target_sentence_embed, context_sentence_embed)

        all_sentence_counts+=[target_sentence_count]

    overall_sentence_count = np.mean(all_sentence_counts)*len(selected_sentence_embeds)
    with open("overall_sentence_count.pickle", "wb") as f:
        pickle.dump(overall_sentence_count, f)

In [None]:
term_lists = [
    (["兩造",], ["雙方"]),
    (["上揭", "上開", "前揭", "前開", "首揭"], ["前述", "上述"]),
    (["云云",], ["等陳述", "等語", "等等"]),
    (["可考", "可佐", "可按", "可稽", "可證", "足按", "足徵", "足稽", "足憑", "足證"], ["可以佐證", "可供證明", "可以證明", "足以佐證", "足以證明"]),
    (["迭",], ["接連", "多次"]),
    (["拘束",], ["限制",]),
    (["失所附麗",], ["失所依附",]),
    (["尚非無稽", "尚非無憑", "尚非無據", "尚非虛妄", "尚非臆造"], ["應可採信", "應屬事實", "並非全無依據", "並不是完全沒有依據"]),
    (["所載",], ["所記載",]),
    (["考諸", "徵諸", "觀諸", "稽之"], ["參考", "依照", "依據"]),
    (["質言之",], ["簡言之",]),
    (["相歧",], ["矛盾",]),
    (["即非法所不許", "依法即無不合",], ["符合法律規定",]),
    (["礙難採認",], ["難以認定", "難以採信", "不可採"]),
    (["矧",], ["況且",]),
    (["翻異"], ["推翻"]),
    (["乃",], ["於是",]),
    (["自白不諱", "供認不諱", "坦承不諱"], ["坦白承認"]),
    (["似無可採", "似屬無憑", "即無可採", "即屬無據", "尚無可採", "尚難憑採", "要非可信", "要屬虛言", "容非可採"], ["難以採信", "不可採信", "尚不足採信", "尚不足採證"]),
    (["顯有",], ["顯然有", "顯然屬於", "顯然是"])
]

for legal_terms, usual_terms in term_lists:

    terms = legal_terms+usual_terms
    print("|".join(terms))
    save_path = os.path.join("semantic_surprisals", "|".join(terms))

    if os.path.isfile(save_path): continue
    term_types = ["legal"]*len(legal_terms)+["usual"]*len(usual_terms)


    regex_str = "|".join(terms)
    idx = 0
    contexts = {term: [] for term in terms}

    non_context_sentences = []
    context_sentences = []
    term_list = []
    for sentence1, sentence2 in zip(sentences[:-1], sentences[1:]):

        matches = list(re.finditer(regex_str, sentence2))
        if matches:
            for match in matches:

                context = sentence1+"，"+sentence2[:match.start()]
                context_sentences.append(context)
                term_list.append(match.group())

        else:
            non_context_sentences+=[sentence1]

    non_context_embeds = []
    for non_context_sentence in tqdm.tqdm(non_context_sentences):
        non_context_embeds+=[get_sentence_embeddings(non_context_sentence, model, tokenizer)]
    
    context_embeds = []
    for context_sentence in tqdm.tqdm(context_sentences):
        context_embeds+=[get_sentence_embeddings(context_sentence, model, tokenizer)]

    target_context_probs = {term: [] for term in terms}
    term_target_context_probs = {term: [] for term in terms}

    for term, target_context_embed in tqdm.tqdm([(term, context_embed) for (term, context_embed) in zip(term_list, context_embeds)]):
        target_context_counts = []
        term_target_context_counts = []

        for non_context_embed in non_context_embeds:
            target_context_count = joint_probability_two(target_context_embed, non_context_embed)
            target_context_counts.append(target_context_count)

        for context_embed in context_embeds:
            target_context_count = joint_probability_two(target_context_embed, context_embed)
            target_context_counts.append(target_context_count)

            term_target_context_count = target_context_count
            term_target_context_counts.append(term_target_context_count)
        
            
        target_context_prob = np.sum(target_context_counts)/overall_sentence_count
        target_context_probs[term].append(target_context_prob)
        if np.isnan(target_context_prob) or target_context_prob>1 or target_context_prob<0: raise Exception
        

        term_target_context_prob = np.sum(term_target_context_counts)/overall_sentence_count
        term_target_context_probs[term].append(term_target_context_prob)
        if np.isnan(term_target_context_prob) or term_target_context_prob>1 or term_target_context_prob<0: raise Exception

    term_surprisals = {term: [] for term in terms}

    for term in contexts:
        for target_context_prob, term_target_context_prob in zip(target_context_probs[term], term_target_context_probs[term]):
            term_surprisal = -np.log(term_target_context_prob/target_context_prob)

            term_surprisals[term].append(term_surprisal)

    with open(save_path, "wb") as f:

        pickle.dump((term_surprisals, term_types), f)