In [45]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import math

import itertools

In [14]:
# Load pre-trained Pythia model and tokenizer
model_name = "EleutherAI/pythia-70m-deduped"  # Adjust as necessary
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# Define a function to generate text until a period is encountered
def generate_until_period(input_text, temperature = 0.5, max_length=50):
    # Tokenize the input text
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    input_ids = input_ids.to(device)

    # Generate tokens until a period is encountered
    generated_ids = input_ids
    while True:
        # Get model logits for the input sequence
        logits = model(generated_ids).logits[:, -1, :] / temperature
        
        # Apply softmax to get probabilities for the next token
        probs = F.softmax(logits, dim=-1)
        
        # Sample a token from the probability distribution
        next_token_id = torch.multinomial(probs, num_samples=1).squeeze().item()
        
        # Append the generated token to the sequence
        generated_ids = torch.cat([generated_ids, torch.tensor([[next_token_id]], device=device)], dim=-1)
        
        # Decode the generated token
        next_token = tokenizer.decode(next_token_id)
        
        # Check if the token ends with sentence-ending punctuation
        if next_token[-1] in ['.', '!', '?']:
            break
        
        # Stop if the sequence length exceeds max_length
        if generated_ids.shape[1] > max_length:
            break
    
    # Decode the full generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text

input_text = "The number of bones in the human body is"
gen_sens = []

for i in range(100):
    generated_text = generate_until_period(input_text, temperature=0.25)
    print(generated_text)
    gen_sens.append(generated_text)

The number of bones in the human body is significant.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is the number of bones in the human body.
The number of bones in the human body is about 1.
The number of bones in the human body is the number of bones in the body.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is the number of bones in the body.
The number of bones in the human body is estimated to be 1.
The number of bones in the human body is not the number of bones in the body.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is estimated to be around 1.
The number of bones in the human body is the number of bones in the body.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is a measure of the number of bones in 

In [56]:
# Load pre-trained Pythia model and tokenizer
model_name = "EleutherAI/pythia-1.4b-deduped"  # Adjust as necessary
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate_until_period2(input_text, n, max_length=50):
    # Tokenize the input text
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    input_ids = input_ids.to(device)

    # Generate tokens until a period is encountered
    generated_ids = input_ids
    prob = []
    while True:
        # Get model logits for the input sequence
        logits = model(generated_ids).logits[:, -1, :]
        
        # Apply softmax to get probabilities for the next token
        probs = F.softmax(logits, dim=-1)

        # Get the indices of the two highest logits
        top2_indices = torch.topk(logits, n, dim=-1).indices.squeeze(dim=0)
        
        # Extract the probabilities for the top 2 tokens
        top2_probs = probs[0, top2_indices]
        
        # Normalize the probabilities (they should sum to 1)
        top2_probs = top2_probs / top2_probs.sum()

        # Sample a token from the top 2 probabilities
        next_token_id = torch.multinomial(top2_probs, num_samples=1).item()
        prob.append(top2_probs[next_token_id].item())
        
        # Append the generated token to the sequence
        generated_ids = torch.cat([generated_ids, torch.tensor([[top2_indices[next_token_id]]], device=device)], dim=-1)
        
        # Decode the generated token
        next_token = tokenizer.decode(top2_indices[next_token_id].item())
        
        # Check if the token ends with sentence-ending punctuation
        if next_token[-1] in ['.', '!', '?']:
            break
        
        # Stop if the sequence length exceeds max_length
        if generated_ids.shape[1] > max_length:
            break
    
    # Decode the full generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text, prob

In [57]:
input_text = "The spinal cord is"
num_traj = 1000
n=5

gen_sens = []

p_vals = []

for i in range(num_traj):
    generated_text, p = generate_until_period2(input_text, n)
    print(generated_text)
    val = np.prod(p)
    p_vals.append(val)
    gen_sens.append(generated_text)

The spinal cord is a complex system of nerves that controls all movement of the entire body, but is not involved in sensory perception or motor control.


In [49]:
def entropy(p, base=2):
    """Compute Shannon entropy of distribution p (list of floats summing to 1)."""
    log_fn = math.log if base == math.e else (lambda x: math.log(x, base))
    H = 0.0
    for pi in p:
        if pi > 0:
            H -= pi * log_fn(pi)
    return H

def mutual_information(samples, keyword1, keyword2):

    count1 = 0
    count2 = 0

    count_joint1 = 0
    count_joint2 = 0
    count_joint3 = 0
    count_joint4 = 0


    for sen in samples:
        if keyword1 in sen:
            count1 += 1
        
        if keyword2 in sen:
            count2 += 1

        if keyword1 in sen and keyword2 in sen:
            count_joint1 += 1
        
        if keyword1 not in sen and keyword2 not in sen:
            count_joint2 += 1

        if keyword1 in sen and keyword2 not in sen:
            count_joint3 += 1
        
        if keyword1 not in sen and keyword2 in sen:
            count_joint4 += 1


    prob1 = count1 / num_traj
    prob2 = count2 / num_traj
    prob1and2 = count_joint1 / num_traj
    prob1not2 = count_joint3 / num_traj
    prob2not1 = count_joint4 / num_traj
    probno1no2 = count_joint2 / num_traj

    joint_prob = [prob1and2, prob1not2, prob2not1, probno1no2]


    proba = [prob1, 1-prob1]
    probb = [prob2, 1-prob2]


    joint_entropy = entropy(joint_prob)
    h_a = entropy(proba)
    h_b = entropy(probb)

    mi = h_a + h_b - joint_entropy

    return mi, h_a, h_b, joint_entropy



In [50]:
terms = ['cortex', 'frontal', 'occipital', 'temporal',
         'cerebellum', 'basal ganglia', 'nuclei', 'thinking', 
         'hippocampus', 'membrane', 'spinal cord']

In [54]:
miis = []

for w1, w2 in itertools.combinations(terms, 2):

    mi, h_a, h_b, h_joint = mutual_information(gen_sens, w1, w2)
    miis.append(((w1,w2),mi))

sorted_miis = sorted(miis, key=lambda x: x[1], reverse=True)  # reverse=True for descending

# Print sorted results
for (w1, w2), mi in sorted_miis:
    print(f"{w1} & {w2}: {mi}")


frontal & temporal: 0.8711705334550619
cortex & temporal: 0.6484066440162966
cortex & frontal: 0.6371206852362756
occipital & temporal: 0.5825237870552784
frontal & occipital: 0.5482914806687367
frontal & cerebellum: 0.5234368691943054
temporal & cerebellum: 0.51991980606729
temporal & basal ganglia: 0.5182527419439857
frontal & basal ganglia: 0.5092904023082714
cortex & basal ganglia: 0.5057593087109642
cortex & cerebellum: 0.5047964458016523
cortex & occipital: 0.4477457147782786
occipital & cerebellum: 0.36683246944333314
occipital & basal ganglia: 0.35620882143846044
cerebellum & basal ganglia: 0.3106398309551628
temporal & hippocampus: 0.2053523450931365
frontal & hippocampus: 0.20230369636418155
cortex & hippocampus: 0.20109580216580314
cerebellum & hippocampus: 0.16409866522642047
cerebellum & spinal cord: 0.1567251642743639
occipital & hippocampus: 0.14681204624140598
temporal & spinal cord: 0.13134787408777937
frontal & spinal cord: 0.12927540895836098
cortex & spinal cord: 0.