<a href="https://colab.research.google.com/github/anujdutt9/Talks_and_Presentations/blob/main/Decoding_the_Giants/Demo_1_Pretrained_LLM_Next_Word_Prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install Dependencies
!pip3 install -U transformers -q

In [2]:
# Import Dependencies
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [3]:
# Load pre-trained model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Ensure the model is in evaluation mode
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [4]:
# Define a prompt for the model
prompt = "Artificial intelligence is"

In [5]:
# Tokenize the input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")
input_ids

tensor([[8001, 9542, 4430,  318]])

In [6]:
# Decode each token to see what part of the text it represents
for token_id in input_ids.numpy()[0]:
    print(f"Token ID: {token_id}, Token: '{tokenizer.decode([token_id])}'")

Token ID: 8001, Token: 'Art'
Token ID: 9542, Token: 'ificial'
Token ID: 4430, Token: ' intelligence'
Token ID: 318, Token: ' is'


In [7]:
# Generate logits for the next token
with torch.no_grad():
    outputs = model(input_ids)
    logits = outputs.logits

# Extract logits for the last token
next_token_logits = logits[:, -1, :]

# Convert logits to probabilities using softmax
probs = F.softmax(next_token_logits, dim=-1)

In [8]:
# Get the top 5 tokens with the highest probabilities
top_k = 5
top_k_probs, top_k_indices = torch.topk(probs, top_k)
print(f'top_k_probs: {top_k_probs}, top_k_indices: {top_k_indices}')

top_k_probs: tensor([[0.1205, 0.0525, 0.0432, 0.0309, 0.0206]]), top_k_indices: tensor([[257, 262, 407, 281, 783]])


In [9]:
# Decode the top k tokens and their probabilities
for i in range(top_k):
    token = tokenizer.decode(top_k_indices[0, i])
    probability = top_k_probs[0, i].item()
    print(f"Token: {token}, Probability: {probability:.4f}")

Token:  a, Probability: 0.1205
Token:  the, Probability: 0.0525
Token:  not, Probability: 0.0432
Token:  an, Probability: 0.0309
Token:  now, Probability: 0.0206


In [10]:
# Display log probabilities
log_probs = torch.log(probs)
print("\nLog probabilities for the top 5 tokens:")
for i in range(top_k):
    token = tokenizer.decode(top_k_indices[0, i])
    log_probability = log_probs[0, top_k_indices[0, i]].item()
    print(f"Token: {token}, Log Probability: {log_probability:.4f}")


Log probabilities for the top 5 tokens:
Token:  a, Log Probability: -2.1165
Token:  the, Log Probability: -2.9462
Token:  not, Log Probability: -3.1411
Token:  an, Log Probability: -3.4763
Token:  now, Log Probability: -3.8815


# Continuously Generating the Next Word

In [11]:
# Define a prompt for the model
prompt = "Artificial intelligence is"
print(f"Initial Prompt: {prompt}")

# Number of words to generate
num_words_to_generate = 10 # @param {type:"slider", min:1, max:10, step:1}

# Lower temperature makes the model more deterministic
temperature = 0.9 # @param {type:"slider", min:0, max:1, step:0.1}

# Nucleus sampling threshold
top_p = 0.7 # @param {type:"slider", min:0, max:1, step:0.1}

for _ in range(num_words_to_generate):
    # Tokenize the input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # Generate logits for the next token
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    # Extract logits for the last token
    next_token_logits = logits[:, -1, :]

    # Apply temperature to the logits
    next_token_logits = next_token_logits / temperature

    # Convert logits to probabilities using softmax
    probs = F.softmax(next_token_logits, dim=-1)

    # Apply top-p (nucleus) sampling to filter the probabilities
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False
    sorted_probs = sorted_probs.masked_fill(sorted_indices_to_remove, 0.0)
    sorted_probs = sorted_probs / sorted_probs.sum()  # Renormalize probabilities

    # Sample from the filtered distribution
    next_token = torch.multinomial(sorted_probs, num_samples=1)
    next_token_id = sorted_indices[0, next_token.item()]

    # Decode the next token and append it to the prompt
    next_word = tokenizer.decode(next_token_id)
    prompt += f" {next_word}"

    # Print the updated prompt with the new word
    print(f"Updated Prompt: {prompt}\n")

    # Show the top 5 predictions for context
    top_k = 5
    top_k_probs, top_k_indices = torch.topk(probs, top_k)
    # print("Top 5 token predictions:")
    for i in range(top_k):
        token = tokenizer.decode(top_k_indices[0, i])
        probability = top_k_probs[0, i].item()
        # print(f"Token: '{token}', Probability: {probability:.4f}")

Initial Prompt: Artificial intelligence is
Updated Prompt: Artificial intelligence is  currently

Updated Prompt: Artificial intelligence is  currently  being

Updated Prompt: Artificial intelligence is  currently  being  used

Updated Prompt: Artificial intelligence is  currently  being  used  to

Updated Prompt: Artificial intelligence is  currently  being  used  to  help

Updated Prompt: Artificial intelligence is  currently  being  used  to  help  

Updated Prompt: Artificial intelligence is  currently  being  used  to  help    to

Updated Prompt: Artificial intelligence is  currently  being  used  to  help    to  

Updated Prompt: Artificial intelligence is  currently  being  used  to  help    to    control

Updated Prompt: Artificial intelligence is  currently  being  used  to  help    to    control  the

