In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import tqdm
import os
import pickle
import re
import glob

import numpy as np

def cosine_similarity(vec1, vec2):
    vec1 = vec1.reshape(-1)
    vec2 = vec2.reshape(-1)
    """Compute cosine similarity between two vectors."""
    similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    # if similarity<0: similarity = 0
    return similarity

def similarity_to_probability(similarity):
    """Convert cosine similarity to a probability-like measure."""
    return ((1 + similarity) / 2)  # Rescale to [0,1]
    return similarity

def joint_probability_two(A, B):
    """Compute the joint probability of two vectors."""
    sim_AB = cosine_similarity(A, B)
    return similarity_to_probability(sim_AB)

def joint_probability_three(A, B, C):
    """Compute the joint probability of three vectors."""
    P_AB = joint_probability_two(A, B)
    P_BC = joint_probability_two(B, C)
    P_CA = joint_probability_two(C, A)
    
    return P_AB * P_BC * P_CA 

def joint_vector(A, B):
    joint_vector = (A + B) / 2 * cosine_similarity(A, B)
    
    return joint_vector

In [None]:
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
device = torch.device("mps")
model = model.to(device)

In [None]:
def get_sentence_embeddings(sentence, model, tokenizer):
    tokens = tokenizer(sentence, return_tensors="pt", add_special_tokens=False)
    input_ids = tokens["input_ids"].to(device)

    with torch.no_grad():
        embeddings = model.model.embed_tokens(input_ids).squeeze(0)
    
    embeddings = np.mean(embeddings.cpu().numpy(), axis=0)

    return embeddings

In [None]:
import glob, re

sentences = []
for f in glob.glob("PMC000xxxxxx/*.txt"):
    with open(f, "r", encoding="latin-1") as file:
        content = file.read()
    for sentence in content.split("\n"):
        if sentence.strip() and not sentence.startswith("==== "):
            sentence = sentence.strip()

            sentences+=re.split(r"[{. }!?;]\s+", sentence.replace("\t", " "))
sentences = [s for s in sentences if len(s)>0]
sentences = [s.lower() for s in sentences]

In [None]:
selected_sentence_embeds = []

for f in tqdm.tqdm(glob.glob("sentence_embeddings/*")):
    with open(f, "rb") as file:
        selected_sentence_embeds+=[pickle.load(file)]

In [None]:
all_embeds = {}

In [None]:
term_lists = [
[[["bulla", "bullae"]], [["blister", "blisters"],]],
[[["candida albicans"], ["candidiasis"]], [["thrush",]]],
[[["carbohydrate", "carbohydrates"]], [["carb", "carbs"]]],
[[["chemotherapy", "chemotherapies"]], [["chemo", "chemoes"]]],
[[["chronic pain", "chronic pains"]], [["persistent pain", "persistent pains"]]],
[[["comedo", "comedos"]], [["whitehead", "whiteheads"]]],
[[["dermis",], ["epidermis",]], [["skin"]]],
[[["dyspepsia"]], [["indigestion"]]],
[[["erythrocyte", "erythrocytes"]], [["red blood cell", "red blood cells"]]],
[[["febrile"]], [["feverish"]]],
[[["haemorrhage"]], [["heavy bleeding"]]],
[[["herpes zoster"]], [["chickenpox"]]],
[[["hypertension"]], [["high blood pressure", "high blood pressures"]]],
[[["hypotension"]], [["low blood pressure", "low blood pressures"]]],
[[["influenza"]], [["flu"]]],
[[["inhaler"]], [["puffer"]]],
[[["intestine"]], [["guts"]]],
[[["lethargy"]], [["tiredness"]]],
[[["leukocyte", "leukocytes"]], [["white blood cell", "white blood cells"]]],
[[["myocardial infarction"]], [["heart attack"]]],
[[["pneumonia"]], [["lung infection"]]],
[[["renal failure"]], [["kidney failure"]]],
[[["thrombocytopenia"]], [["low platelet count", "low platelet counts"]]],
[[["liposuction"]], [["lipo"]]],
[[["melanoma"]], [["skin cancer"]]],
]

for legal_terms, usual_terms in term_lists:

    terms = legal_terms+usual_terms
    print(terms)

    save_path = os.path.join("semantic_surprisals_v3", "|".join([term[0] for term in terms]))

    if os.path.isfile(save_path): continue
    term_types = ["legal"]*len(legal_terms)+["usual"]*len(usual_terms)


    term_target_context_probs = {term[0]: [] for term in terms}
    
    for term in terms:

        assert type(term)==list, "Term must be a list"
        regex_str = fr'\b(?:{"|".join(term)})\b'
        idx = 0

        non_context_sentences = []
        non_context_sentence_following_terms = []
        context_sentences = []
        term_list = []
        for sentence1, sentence2 in zip(sentences[:-1], sentences[1:]):

            matches = list(re.finditer(regex_str, sentence2))
            if matches:
                for match in matches:

                    context = sentence1+"; "+sentence2[:match.start()]
                    context_sentences.append(context)
                    term_list.append(term[0])

            else:
                non_context_sentences+=[sentence1]
                non_context_sentence_following_terms+=[sentence2.split(" ")[0]]
        non_context_embeds = []
        non_context_following_term_embeds = []
        for non_context_sentence in tqdm.tqdm(non_context_sentences, desc="Non-context sentences"):
            if non_context_sentence not in all_embeds:
                all_embeds[non_context_sentence] = get_sentence_embeddings(non_context_sentence, model, tokenizer)
            
            non_context_embeds.append(all_embeds[non_context_sentence])

        for non_context_sentence_following_term in tqdm.tqdm(non_context_sentence_following_terms, desc="Non-context sentence following terms"):
            if non_context_sentence_following_term not in all_embeds:
                all_embeds[non_context_sentence_following_term] = get_sentence_embeddings(non_context_sentence_following_term, model, tokenizer)
            
            non_context_following_term_embeds.append(all_embeds[non_context_sentence_following_term])
        
        context_embeds = []
        for context_sentence in tqdm.tqdm(context_sentences, desc="Context sentences"):
            if context_sentence not in all_embeds:
                all_embeds[context_sentence] = get_sentence_embeddings(context_sentence, model, tokenizer)

            context_embeds.append(all_embeds[context_sentence])

        for t in term:
            print(t)
            if t not in all_embeds:
                all_embeds[t] = get_sentence_embeddings(t, model, tokenizer)

        pbar = tqdm.tqdm([(term, context_embed) for (term, context_embed) in zip(term_list, context_embeds)], desc="Counts")
        for this_term, target_context_embed in pbar:
            target_context_counts = []
            term_target_context_counts = []

            for non_context_embed, non_context_following_term_embed in zip(non_context_embeds, non_context_following_term_embeds):
                target_context_count = joint_probability_two(target_context_embed, non_context_embed)
                target_context_counts.append(target_context_count)

                term_counts = []
                for t in term:
                    t_embed = all_embeds[t]
                    term_counts+=[joint_probability_two(non_context_following_term_embed, t_embed)]

                term_count = sum(term_counts)

                term_target_context_count = target_context_count * term_count
                term_target_context_counts.append(term_target_context_count)
                if target_context_count < term_target_context_count: raise ValueError

            for context_embed in context_embeds:
                target_context_count = joint_probability_two(target_context_embed, context_embed)
                target_context_counts.append(target_context_count)

                term_target_context_count = target_context_count * 1
                term_target_context_counts.append(term_target_context_count)
                if target_context_count < term_target_context_count: raise ValueError
            
            prob = sum(term_target_context_counts)/sum(target_context_counts)

            term_target_context_probs[this_term].append(prob)

            pbar.set_description(f"Prob: {prob:.8f}")

    term_surprisals = {term[0]: [] for term in terms}

    for term in term_target_context_probs:
        for term_target_context_prob in term_target_context_probs[term]:
            term_surprisal = -np.log(term_target_context_prob)
            if term_surprisal < 0:
                raise ValueError("Surprisal cannot be negative")
            term_surprisals[term].append(term_surprisal)

    with open(save_path, "wb") as f:

        pickle.dump((term_surprisals, term_types), f)

In [None]:
if os.path.isfile("overall_sentence_count.pickle"):
    with open("overall_sentence_count.pickle", "rb") as f:
        overall_sentence_count = pickle.load(f)

else:
    all_sentence_counts = []
    pbar = tqdm.tqdm(selected_sentence_embeds[:10000])
    for target_sentence_embed in pbar:
        target_sentence_count = 0
        for context_sentence_embed in selected_sentence_embeds:
            target_sentence_count+=joint_probability_two(target_sentence_embed, context_sentence_embed)

        all_sentence_counts+=[target_sentence_count]

    overall_sentence_count = np.mean(all_sentence_counts)*len(selected_sentence_embeds)
    with open("overall_sentence_count.pickle", "wb") as f:
        pickle.dump(overall_sentence_count, f)