In [9]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

In [14]:
# Load pre-trained Pythia model and tokenizer
model_name = "EleutherAI/pythia-70m-deduped"  # Adjust as necessary
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# Define a function to generate text until a period is encountered
def generate_until_period(input_text, temperature = 0.5, max_length=50):
    # Tokenize the input text
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    input_ids = input_ids.to(device)

    # Generate tokens until a period is encountered
    generated_ids = input_ids
    while True:
        # Get model logits for the input sequence
        logits = model(generated_ids).logits[:, -1, :] / temperature
        
        # Apply softmax to get probabilities for the next token
        probs = F.softmax(logits, dim=-1)
        
        # Sample a token from the probability distribution
        next_token_id = torch.multinomial(probs, num_samples=1).squeeze().item()
        
        # Append the generated token to the sequence
        generated_ids = torch.cat([generated_ids, torch.tensor([[next_token_id]], device=device)], dim=-1)
        
        # Decode the generated token
        next_token = tokenizer.decode(next_token_id)
        
        # Check if the token ends with sentence-ending punctuation
        if next_token[-1] in ['.', '!', '?']:
            break
        
        # Stop if the sequence length exceeds max_length
        if generated_ids.shape[1] > max_length:
            break
    
    # Decode the full generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text

input_text = "The number of bones in the human body is"
gen_sens = []

for i in range(100):
    generated_text = generate_until_period(input_text, temperature=0.25)
    print(generated_text)
    gen_sens.append(generated_text)

The number of bones in the human body is significant.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is the number of bones in the human body.
The number of bones in the human body is about 1.
The number of bones in the human body is the number of bones in the body.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is the number of bones in the body.
The number of bones in the human body is estimated to be 1.
The number of bones in the human body is not the number of bones in the body.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is estimated to be around 1.
The number of bones in the human body is the number of bones in the body.
The number of bones in the human body is the number of the bones in the body.
The number of bones in the human body is a measure of the number of bones in 

In [19]:
# Load pre-trained Pythia model and tokenizer
model_name = "EleutherAI/pythia-1b-deduped"  # Adjust as necessary
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate_until_period2(input_text, max_length=50):
    # Tokenize the input text
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    input_ids = input_ids.to(device)

    # Generate tokens until a period is encountered
    generated_ids = input_ids
    prob = []
    while True:
        # Get model logits for the input sequence
        logits = model(generated_ids).logits[:, -1, :]
        
        # Apply softmax to get probabilities for the next token
        probs = F.softmax(logits, dim=-1)

        # Get the indices of the two highest logits
        top2_indices = torch.topk(logits, 2, dim=-1).indices.squeeze(dim=0)
        
        # Extract the probabilities for the top 2 tokens
        top2_probs = probs[0, top2_indices]
        
        # Normalize the probabilities (they should sum to 1)
        top2_probs = top2_probs / top2_probs.sum()

        # Sample a token from the top 2 probabilities
        next_token_id = torch.multinomial(top2_probs, num_samples=1).item()
        prob.append(top2_probs[next_token_id].item())
        
        # Append the generated token to the sequence
        generated_ids = torch.cat([generated_ids, torch.tensor([[top2_indices[next_token_id]]], device=device)], dim=-1)
        
        # Decode the generated token
        next_token = tokenizer.decode(top2_indices[next_token_id].item())
        
        # Check if the token ends with sentence-ending punctuation
        if next_token[-1] in ['.', '!', '?']:
            break
        
        # Stop if the sequence length exceeds max_length
        if generated_ids.shape[1] > max_length:
            break
    
    # Decode the full generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text, prob

input_text = "The Basal Ganglia are a part of"
gen_sens = []

for i in range(100):
    generated_text, p = generate_until_period2(input_text)
    print(generated_text)
    print(p)
    gen_sens.append(generated_text)

The Basal Ganglia are a part of the limbic system of the brain.
[0.9738022685050964, 0.5706159472465515, 0.9991954565048218, 0.9889147877693176, 0.48343124985694885, 0.9863533973693848, 0.9778152108192444, 0.560214102268219]
The Basal Ganglia are a part of the brain that is responsible for the control of our body's movements and the ability to feel pain.
[0.9738022685050964, 0.4293840229511261, 0.7471218109130859, 0.35000166296958923, 0.7737951278686523, 0.9965533018112183, 0.551062285900116, 0.6155555844306946, 0.9497830271720886, 0.28369009494781494, 0.7914279103279114, 0.40423086285591125, 0.48300230503082275, 0.2859199345111847, 0.5254173278808594, 0.43943604826927185, 0.9733124375343323, 0.5802935361862183, 0.6452164053916931, 0.8700284957885742]
The Basal Ganglia are a part of the brain that is involved in controlling movement, emotion, and perception.
[0.9738022685050964, 0.4293840229511261, 0.7471218109130859, 0.35000166296958923, 0.22620487213134766, 0.8429732918739319, 0.2978

KeyboardInterrupt: 