## Text Generation using Different Decoding Strategies with Amazon SageMaker JumpStart SDK and Falcon 40B Instruct Language Model

---
This Amazon SageMaker Studio Notebook demonstrates how to use the SageMaker Python SDK with very little efforts to firstly deploy Falcon-40B-Instruct Large Language Model and then generate text using different decoding methods. 

This notebook has the following prerequisites:
- Select an AWS region where [Amazon SageMaker JumpStart](https://aws.amazon.com/sagemaker/jumpstart) is available. 
- [Setup Amazon SageMaker Domain](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html).
- [Available service queta for "ml.g5.12xlarge for endpoint usage"](https://docs.aws.amazon.com/general/latest/gr/sagemaker.html).
- Less than $10 per hour to spend on Amazon SageMaker JumpStart model deployment and Amazon SageMaker Studio notebook usage.  

This notebook is based on the following references:
- [Amazon SageMaker JumpStart SDK](https://sagemaker.readthedocs.io/en/v2.82.0/overview.html#use-prebuilt-models-with-sagemaker-jumpstart), providing pretrained models for a wide range of problem types to help you get started with machine learning.
- [Falcon-40B-Instruct](https://huggingface.co/tiiuae/falcon-40b-instruct), a top performing open source model with 40B parameters causal decoder-only model built by TII.
- Public articles ([Link 1](https://huggingface.co/blog/how-to-generate), [Link 2](https://huggingface.co/docs/transformers/generation_strategies), [Link 3](https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model)) published on Hugging Face, an open source community and data science platform for machine learning models and datasets. 
---

In [2]:
!pip install --upgrade pip --quiet --disable-pip-version-check --root-user-action=ignore
!pip install --upgrade sagemaker --quiet --root-user-action=ignore

In [3]:
from sagemaker.jumpstart.model import JumpStartModel

# Define SageMaker JumpStart Model using model id, instance type, and endpoint timeout
my_model = JumpStartModel(model_id="huggingface-llm-falcon-40b-instruct-bf16",
                          instance_type="ml.g5.12xlarge",
                          env={'ENDPOINT_SERVER_TIMEOUT':'300'})

# Take a look at the JumpStart Model parameters printed by this cell
print("Model id =", my_model.model_id)
print("Model name =", my_model.name)
print("Model version =", my_model.model_version)
print("Instance type =", my_model.instance_type)
print("Instance number of GPUs =", my_model.env["SM_NUM_GPUS"])
print("Model maximum input length =", my_model.env["MAX_INPUT_LENGTH"])
print("Model maximum total tokens =", my_model.env["MAX_TOTAL_TOKENS"])
print("Server endpoint timeout =", my_model.env["ENDPOINT_SERVER_TIMEOUT"], "seconds")

Model id = huggingface-llm-falcon-40b-instruct-bf16
Model name = hf-llm-falcon-40b-instruct-bf16-2023-07-03-00-44-29-841
Model version = *
Instance type = ml.g5.12xlarge
Instance number of GPUs = 4
Model maximum input length = 1024
Model maximum total tokens = 2048
Server endpoint timeout = 300 seconds


In [4]:
%%time

# Host the model on the instance and deploy an inference endpoint
# Because the model size is >80GB, expecy deploy() to take 15 min!
predictor = my_model.deploy()

--------------------------!CPU times: user 178 ms, sys: 18.1 ms, total: 196 ms
Wall time: 13min 34s


---
**Decoding Strategies**

Large Language Models are designed to "guess" the next token having read all the previous ones based on a predefined sampling method. There are several methods (decoding strategies) we can configure for picking this output token, such as Greedy Search, Beam Search, and Contrastive Search.

We start by defining the promote which will enable us to test different decoding strategies for the same input text. We will also fix the values of some hyperparameters across all decoding strategies to make easy for us to compare the generated text:
- *Temperature* is used to control the randomness of predictions by scaling the logits before applying softmax. The softmax layer in the transformer architecture turns the logits into probabilities (between 0 and 1). A low temperature (below 1) sharpens the probabilities of the predicted words resulting in more conservative and predictable text. A high temperature (above 1) makes the model generates more creative and diverse text resulting in unusual or unexpected words.
- *stop* provides a list of input tokens to the model to stop the generation. The generation will stop when one of the tokens is generated.
- *max_new_tokens* defines the maximum number of tokens to be generated by the model.

In [5]:
prompt = "Why didn't the chicken cross the road? Was it afraid of cars?"
stop_keywords = ["<|endoftext|>", "</s>"]
max_new_tokens = 120
temperature = 1.05

**Strategy 1: Greedy Search**

Greedy search is a deterministic method that simply selects the word with the highest probability as its next word. To configure greedy decoding, we set the *do_sample* hyperparameter to false and make sure the *num_beams* hyperparameter is set to 1. 

In [6]:
greedy_search_payload = {
    "inputs": prompt,
    "parameters": {
        "stop": stop_keywords,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "do_sample": False,
        "num_beams": 1,
    }
}
greedy_search_response = predictor.predict(greedy_search_payload)

print(prompt)
print("Greedy Search Response:", ">"*40, "\033[95m")
print(greedy_search_response[0]["generated_text"])

Why didn't the chicken cross the road? Was it afraid of cars?
Greedy Search Response: >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> [95m

We may never know why the chicken crossed the road. Some suggest it was to get to the other side, while others say it is a popular joke with no real answer. Regardless, we can be sure the chicken was not afraid of cars, as they did not exist at the time when the expression originated around the 21st century.</s>


**Strategy 2: Beam Search**

Beam search is another deterministic method that reduces the risk of missing hidden high probability word sequences by keeping a fixed number (beam) of active candidates at each time step and eventually choosing the hypothesis that has the overall highest probability. To configure beam decoding, we set the *do_sample* hyperparameter to false and set the *num_beams* hyperparameter to a value above 1. 

In [7]:
beam_search_payload = {
    "inputs": prompt,
    "parameters": {
        "stop": stop_keywords,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "do_sample": False,
        "num_beams": 4,
    }
}
beam_search_response = predictor.predict(beam_search_payload)

print(prompt)
print("Beam Search Response:", ">"*40, "\033[95m")
print(beam_search_response[0]["generated_text"])

Why didn't the chicken cross the road? Was it afraid of cars?
Beam Search Response: >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> [95m
 did a fox or other predator attack it? or did the chicken have an unknown phobia towards roads?
As an AI language model, I cannot determine the reason why the chicken did not cross the road, as it could be a multitude of reasons based on its behavior or the environment.


**Strategy 3: Multinomial Sampling**

Solely maximizing the output probability in deterministic methods can lead to dullness and repetitions. Conversely, stochastic methods try to solve the problem by introducing randomness to the sampling process. 

Multinomial sampling randomly selects the next token based on the probability distribution over the entire vocabulary given by the model. Every token with a non-zero probability has a chance of being selected, thus reducing the risk of repetition. To configure multinomial sampling, we set the *do_sample* hyperparameter to true and set the *num_beams* hyperparameter to 1. 

In [8]:
multinomial_sampling_payload = {
    "inputs": prompt,
    "parameters": {
        "stop": stop_keywords,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
         "do_sample": True,
        "num_beams": 1
    }
}
multinomial_sampling_response = predictor.predict(multinomial_sampling_payload)

print(prompt)
print("Multinomial Sampling Response:", ">"*40, "\033[95m")
print(multinomial_sampling_response[0]["generated_text"])

Why didn't the chicken cross the road? Was it afraid of cars?
Multinomial Sampling Response: >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> [95m

The reason behind the chicken crossing the road might vary depending on the situation. It could be due to the attraction towards the other side of the road, distraction, or fear of cars.


**Strategy 4: Top-k Sampling**

Top- k sampling means sorting by probability and zeroing out the probabilities for anything below the k'th token. To configure Top-k sampling, we set the *do_sample* hyperparameter to true and set the *top_k* hyperparameter to a value above 1. 

In [9]:
top_k_sampling_payload = {
    "inputs": prompt,
    "parameters": {
        "stop": stop_keywords,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "top_k": 6
    }
}
top_k_sampling_response = predictor.predict(top_k_sampling_payload)

print(prompt)
print("Top-k Search Response:", ">"*40, "\033[95m")
print(top_k_sampling_response[0]["generated_text"])

Why didn't the chicken cross the road? Was it afraid of cars?
Top-k Search Response: >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> [95m

There is actually no definitive answer to why the chicken didn't cross the road. The joke has been told many ways, with different punchlines. Some people say the chicken was too afraid to cross the road, while others say it didn't want to get hit by a car. Still others say it was simply a chicken and didn't feel like crossing the road.


**Strategy 5: Top-p Sampling**

Top k sampling (or nucleus sampling) chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. To configure Top p sampling, we set the *do_sample* hyperparameter to true and set the *top_p* hyperparameter to a value less than 1. 

In [10]:
nucleus_sampling_payload = {
    "inputs": prompt,
    "parameters": {
        "stop": stop_keywords,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "top_p": 0.8
    }
}
nucleus_sampling_response = predictor.predict(nucleus_sampling_payload)

print(prompt)
print("Nucleus search Response:", ">"*40, "\033[95m")
print(nucleus_sampling_response[0]["generated_text"])

Why didn't the chicken cross the road? Was it afraid of cars?
Nucleus search Response: >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> [95m

I'm sorry, I cannot answer that question as it is a well-known joke that cannot be accurately interpreted.


**Strategy 6: Contrastive Sampling**

Contrastive search selects from the most probable candidates predicted by the model while taking into account the degeneration penalty computed from the previous context. This decoding strategy tries to maintain the semantic coherence in the generated text while reducing repetitions. To configure Contrastive search, we set the *top_k* hyperparameter to a value above 1 and set the *repetition_penalty* hyperparameter to a value between 0 and 1. When *repetition_penalty* is close to zero, contrastive search degenerates to the greedy search method.

In [11]:
contrastive_search_payload = {
    "inputs": prompt,
    "parameters": {
        "stop": stop_keywords,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "repetition_penalty": 0.6,
        "top_k": 6
    }
}
contrastive_search_response = predictor.predict(contrastive_search_payload)

print(prompt)
print("Contrastive Search Response:", ">"*40, "\033[95m")
print(contrastive_search_response[0]["generated_text"])

Why didn't the chicken cross the road? Was it afraid of cars?
Contrastive Search Response: >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> [95m
 Was it afraid of the road? Was it afraid of the chicken's chicken? Was it afraid of the chicken's chicken's chicken? Was it afraid of the chicken's chicken's chicken's chicken? Was it afraid of the chicken's chicken's chicken's chicken's chicken? Was it afraid of the chicken's chicken's chicken's chicken's chicken's chicken? Was it afraid of the chicken's chicken's chicken's chicken's chicken's chicken's chicken? Was it afraid of the chicken's


**Strategy 7: Combining Multiple Methods**

Combining multiple methods (such as top-k and top-p sampling) can sometimes improve the diversity and fluency of the generated text. Try to configure your own decoding strategy including changing the temperature hyperparameter and see the results. 

In [12]:
multiple_strategies_payload = {
    "inputs": prompt,
    "parameters": {
        "stop": stop_keywords,
        "temperature": 1,
        "max_new_tokens": 120,
        "do_sample": True,
        "top_p": 0.8,
        "top_k": 6
    }
}
multiple_strategies_response = predictor.predict(multiple_strategies_payload)

print(prompt)
print("Multiple Strategies Response:", ">"*40, "\033[95m")
print(multiple_strategies_response[0]["generated_text"])

Why didn't the chicken cross the road? Was it afraid of cars?
Multiple Strategies Response: >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> [95m

The chicken crossed the road to get to the other side. There is no indication that it was afraid of cars.


### SageMaker Clean up 

In [13]:
# Delete the SageMaker endpoint
predictor.delete_endpoint()