In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import tqdm
import os
import pickle
import re
import glob
import torch.nn.functional as F
import numpy as np

In [None]:
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
device = torch.device("mps")
model = model.to(device)

In [None]:
def get_word_probability(context, word, model, tokenizer):
    model.to(torch.device("mps"))
    # Tokenize the context and the word
    context_tokens = tokenizer.encode(context, add_special_tokens=True)
    word_tokens = tokenizer.encode(word, add_special_tokens=False)[1:]
        
    # Combine context with word tokens for probability calculation
    combined_tokens = context_tokens + word_tokens
        
    total_prob = 1.0
    with torch.no_grad():
        for i in range(len(context_tokens), len(combined_tokens)):
            # Get the input tokens up to the current token
            input_tokens = combined_tokens[:i]
            target_token = combined_tokens[i]
                
            # Convert to tensor
            input_tensor = torch.tensor([input_tokens]).to(torch.device("mps"))
                
            # Get model output and probabilities
            outputs = model(input_tensor)
            logits = outputs.logits  # (batch_size, seq_len, vocab_size)
            last_token_logits = logits[0, -1, :]  # Get logits for the last token
            probs = F.softmax(last_token_logits, dim=-1)

            # Get the probability of the target token
            token_prob = probs[target_token].item()
            total_prob *= token_prob  # Multiply probabilities

    return total_prob

In [None]:
import glob, re

sentences = []
for f in glob.glob("PMC000xxxxxx/*.txt"):
    with open(f, "r", encoding="latin-1") as file:
        content = file.read()
    for sentence in content.split("\n"):
        if sentence.strip() and not sentence.startswith("==== "):
            sentence = sentence.strip()

            sentences+=re.split(r"[{. }!?;]\s+", sentence.replace("\t", " "))
sentences = [s.lower() for s in sentences]

In [None]:
term_lists = [
[[["bulla", "bullae"]], [["blister", "blisters"],]],
[[["candida albicans"], ["candidiasis"]], [["thrush",]]],
[[["carbohydrate", "carbohydrates"]], [["carb", "carbs"]]],
[[["chemotherapy", "chemotherapies"]], [["chemo", "chemoes"]]],
[[["chronic pain", "chronic pains"]], [["persistent pain", "persistent pains"]]],
[[["comedo", "comedos"]], [["whitehead", "whiteheads"]]],
[[["dermis",], ["epidermis",]], [["skin"]]],
[[["dyspepsia"]], [["indigestion"]]],
[[["erythrocyte", "erythrocytes"]], [["red blood cell", "red blood cells"]]],
[[["febrile"]], [["feverish"]]],
[[["haemorrhage"]], [["heavy bleeding"]]],
[[["herpes zoster"]], [["chickenpox"]]],
[[["hypertension"]], [["high blood pressure", "high blood pressures"]]],
[[["hypotension"]], [["low blood pressure", "low blood pressures"]]],
[[["influenza"]], [["flu"]]],
[[["inhaler"]], [["puffer"]]],
[[["intestine"]], [["guts"]]],
[[["lethargy"]], [["tiredness"]]],
[[["leukocyte", "leukocytes"]], [["white blood cell", "white blood cells"]]],
[[["myocardial infarction"]], [["heart attack"]]],
[[["pneumonia"]], [["lung infection"]]],
[[["renal failure"]], [["kidney failure"]]],
[[["thrombocytopenia"]], [["low platelet count", "low platelet counts"]]],
[[["liposuction"]], [["lipo"]]],
[[["melanoma"]], [["skin cancer"]]],
]

for legal_terms, usual_terms in term_lists:

    terms = legal_terms+usual_terms

    save_path = os.path.join("llm_surprisals", "|".join([term[0] for term in terms]))

    if os.path.isfile(save_path): continue
    term_types = ["legal"]*len(legal_terms)+["usual"]*len(usual_terms)


    term_surprisals = {term[0]: [] for term in terms}
    for term in terms:
        if type(term)!=list: raise ValueError("Term must be a list")
        regex_str = fr'\b(?:{"|".join(term)})\b'
        idx = 0

        context_sentences = []
        for sentence1, sentence2 in zip(sentences[:-1], sentences[1:]):

            matches = list(re.finditer(regex_str, sentence2))
            if matches:
                for match in matches:

                    context = sentence1+"; "+sentence2[:match.start()]
                    context_sentences.append(context)

        
        for context_sentence in tqdm.tqdm(context_sentences):
            term_prob = 0
            for t in term:

                term_prob+=get_word_probability(context_sentence, t, model, tokenizer)

            term_surprisal = -np.log(term_prob)
            term_surprisals[term[0]].append(term_surprisal)

    with open(save_path, "wb") as f:
        pickle.dump((term_surprisals, term_types), f)