# Gwen3-1.7B EAGLE head

This notebook demonstrates how to:
1. Train EAGLE3 head for [Gwen3-1.7b](https://huggingface.co/Qwen/Qwen3-1.7B) model using Amazon SageMaker AI optimization job
2. Deploy `Gwen3-1.7B` model with EAGLE3-based speculative decoding


For more information, please refer to this [blog](https://aws.amazon.com/blogs/machine-learning/amazon-sagemaker-ai-introduces-eagle-based-adaptive-speculative-decoding-to-accelerate-generative-ai-inference/)


In [20]:
%pip install --upgrade --quiet --no-warn-conflicts boto3

Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import time
import re
import json
import boto3
from IPython.display import display, Markdown, clear_output

boto_session = boto3.Session()
region = boto_session.region_name

sm = boto3.client("sagemaker")  # client to intreract with SageMaker
sm_runtime = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints
s3 = boto3.client("s3")

In [None]:
#
# Helper functions to remove dependency on SageMaker Python SDK
#
def get_sagemaker_role():
    sts = boto3.client('sts')
    response = sts.get_caller_identity()
    assumed_role = response['Arn']
    role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)
    return role


def wait_for_endpoint(endpoint_name: str, sleep_time: int = 60):
    ind = "."
    progress = f"Waiting for '{endpoint_name}': "
    print(progress)

    status = sm.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]

    while status == "Creating":
        time.sleep(sleep_time)

        status = sm.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]

        clear_output(wait=True)
        progress += ind
        print(progress)

    print(f"Endpoint: '{endpoint_name}', Status: '{status}'")

def wait_for_ic(ic_name: str, sleep_time: int = 60):
    ind = "."
    progress = f"Waiting for '{ic_name}': "
    print(progress)

    status = sm.describe_inference_component(InferenceComponentName = ic_name)["InferenceComponentStatus"]

    while status == "Creating":
        time.sleep(sleep_time)

        status = sm.describe_inference_component(InferenceComponentName = ic_name)["InferenceComponentStatus"]

        clear_output(wait=True)
        progress += ind
        print(progress)

    print(f"IC: '{ic_name}', Status: '{status}'")

In [None]:
#
# Overwrite with your role ARN if you are running this notebook outside of SageMaker Studio
#
role = None

if role == None:
    role = get_sagemaker_role()
print(role)

## Model and dataset preparation

In [20]:
model_id = "Qwen/Qwen3-1.7B"
bucket = "<YOUR_BUCKET>"
model_s3_key = f"model/{model_id}"
dataset_file = "gsm8k_200.jsonl"
dataset_s3_key = f"training-data/gsm8k200/{dataset_file}"

We will download the model (`Qwen3-1.7B`) from the HuggingFace hub and upload the model weights to Amazon S3 bucket.

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path

local_model_path = Path("./data")
local_model_path.mkdir(exist_ok=True)

snapshot_download(repo_id=model_id, local_dir=local_model_path)

In [None]:
# enumerate local files recursively
for root, dirs, files in os.walk(local_model_path):
    for filename in files:
        local_path = os.path.join(root, filename)

        relative_path = os.path.relpath(local_path, local_model_path)
        s3_path = os.path.join(model_s3_key, relative_path)

        print("Uploading %s..." % s3_path)
        s3.upload_file(local_path, bucket, s3_path)

In [11]:
s3.upload_file(dataset_file, bucket, dataset_s3_key)

## Optimization Job

In [21]:
base_model = f"s3://{bucket}/{model_s3_key}/"
training_data = f"s3://{bucket}/training-data/gsm8k200/"
output_data = f"s3://{bucket}/{model_s3_key}-gsm200-eagle/"

In [None]:
job_name = f"opt-{time.strftime('%y%m%d-%H%M%S')}"

opt_job = sm.create_optimization_job(
    OptimizationJobName=job_name,
    RoleArn=role,
    ModelSource={
        'S3': {
            'S3Uri': base_model,
        }
    },
    DeploymentInstanceType='ml.g6.24xlarge',
    MaxInstanceCount=1,
    OptimizationConfigs=[
        {
            'ModelSpeculativeDecodingConfig': {
                'Technique': 'EAGLE',
                'TrainingDataSource': {
                    'S3Uri': training_data,
                    'S3DataType': 'S3Prefix'
                }
            }
        },
    ],
    OutputConfig={
        'S3OutputLocation': output_data,
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 432000,
    },
)

print(json.dumps(opt_job, indent=2))

**PLEASE NOTE THE JOB WILL TAKE ABOUT 2 HOURS TO COMPLETE**
---

In [None]:
job_status = sm.describe_optimization_job(
    OptimizationJobName=job_name
)
job_status


**After the optimization job completes, you can access evaluation results and draft model weights in the output S3 path**
---

In [51]:
results = f"{output_data}results/"
draft_model_path = f"{output_data}draft/"
!aws s3 ls $output_data
!echo '------'
!aws s3 ls $draft_model_path
!echo '------'
!aws s3 ls $results

                           PRE draft/
                           PRE opt-260123-164926-48bf6945-5af6-42ea-92e1-61172fe33d42-4-of-4/
                           PRE results/
------
2026-01-23 16:49:37        862 config.json
2026-01-23 18:29:31  286614232 model.safetensors
------
2026-01-23 18:58:03       1380 benchmark_no_eagle_conc1.json
2026-01-23 18:58:03       1375 benchmark_no_eagle_conc16.json
2026-01-23 18:58:03       1376 benchmark_no_eagle_conc2.json
2026-01-23 18:58:03       1374 benchmark_no_eagle_conc4.json
2026-01-23 18:58:03       1376 benchmark_no_eagle_conc8.json
2026-01-23 18:58:03     250508 benchmark_report.html
2026-01-23 18:58:03       1384 benchmark_trained_eagle_conc1.json
2026-01-23 18:58:03       1374 benchmark_trained_eagle_conc16.json
2026-01-23 18:58:03       1382 benchmark_trained_eagle_conc2.json
2026-01-23 18:58:03       1372 benchmark_trained_eagle_conc4.json
2026-01-23 18:58:03       1381 benchmark_trained_eagle_conc8.json


**Here is snippet of the report**
---

![title](results.png)

## Deployment

We are going to use LMIv18 container, see [this](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers) for more info

In [None]:
CONTAINER_VERSION = "0.36.0-lmi18.0.0-cu128"
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"

instance = {"type": "ml.g5.2xlarge", "num_gpu": 1}

model_id = base_model
channel_name = "eagle"

model_name = f"model-{time.strftime('%y%m%d-%H%M%S')}"
endpoint_name = model_name
endpoint_config_name = model_name
timeout = 600

variant_name = "main"

spec_config = {
    "method": "eagle3",
    "model": f"/opt/ml/additional-model-data-sources/{channel_name}",
    "draft_tensor_parallel_size": 1,
    "num_speculative_tokens": 5
}

lmi_env = {
    "OPTION_MODEL": "/opt/ml/model",
    "SERVING_FAIL_FAST": "true",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_TENSOR_PARALLEL_DEGREE": json.dumps(instance["num_gpu"]),
    "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
    "OPTION_MAX_MODEL_LEN": "16384",
    "OPTION_SPECULATIVE_CONFIG": json.dumps(spec_config),
}
env = lmi_env

model_data_source = {
    'S3DataSource': {
        'S3Uri': model_id,
        'S3DataType': 'S3Prefix',
        'CompressionType': 'None',
    }
}

add_data_source = {
    'ChannelName': channel_name,
    'S3DataSource': {
        'S3Uri': draft_model_path,
        'S3DataType': 'S3Prefix',
        'CompressionType': 'None',
    }
}

In [None]:
model_res = sm.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image,
        "Environment": env,
        "ModelDataSource": model_data_source,
        "AdditionalModelDataSources": [add_data_source],
    },
)
print(json.dumps(model_res, indent=2))

In [None]:
config_res = sm.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [
        {
            "VariantName": variant_name,
            "ModelName": model_name,
            "InstanceType": instance["type"],
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": timeout,
        },
    ],
)

endpoint_res = sm.create_endpoint(EndpointName = endpoint_name,
                                  EndpointConfigName = endpoint_config_name)

_ = wait_for_endpoint(endpoint_name)

Waiting for 'model-260123-202533': ..........
Endpoint: 'model-260123-202533', Status: 'InService'


### Test inference

In [None]:
payload={
    "messages": [
        {"role": "user", "content": "Name popular places to visit in London?"}
    ],
}

start_time = time.time()
res = sm_runtime.invoke_endpoint(EndpointName = endpoint_name,
                                 Body = json.dumps(payload),
                                 ContentType = "application/json")
response = json.loads(res["Body"].read().decode("utf8"))
end_time = time.time()

print(f"âœ… Response time: {end_time-start_time:.2f}s\n")
display(Markdown(response["choices"][0]["message"]["content"]))

usage = response["usage"]
print(f'-----------------------\n{usage}')

âœ… Response time: 32.00s



<think>
Okay, the user is asking for popular places to visit in London. Let me start by recalling the main attractions. London is a big city with a lot of iconic spots. First, the Tower of London comes to mind. It's a historic site with the Crown Jewels and the Tower Museum. Then there's the British Museum, which is a must-see for its vast collection of art and artifacts.

I should also mention the Westminster Abbey, which is a famous religious site and a symbol of the UK's monarchy. The London Eye is another big one, offering a panoramic view of the city. The British Museum is definitely a top spot, but maybe the National Gallery is also popular. Oh, and the Houses of Parliament and Big Ben are iconic landmarks.

Wait, the user might be interested in both historical and modern attractions. So I need to balance between old and new. The Shard is a modern skyscraper, and the London Eye is a modern landmark. Also, the Thames River is a big part of London's identity, so mentioning places along the river like the Tate Modern or the London Bridge could be useful.

I should also think about tourist attractions that are family-friendly or have unique experiences. The London Zoo is a good one, but maybe the Science Museum is another option. Oh, and the Tower of London is a bit of a must-visit, but sometimes people might not know about the nearby places like the National Gallery. Let me check if there are any other major attractions I'm missing. The Globe Theatre, the Natural History Museum, and the Victoria and Albert Museum are also popular. 

I need to make sure the list is comprehensive but not too long. Maybe group them into categories like historical sites, museums, modern landmarks, and natural attractions. Also, include some popular areas like the West End for theaters and shopping. Oh, and the City of London and the East End have different attractions. The East End has places like the South Bank and the Barbican. 

Wait, the user might not be familiar with the East End, so maybe mention that as a separate section. Also, the British Museum is in the South Bank, so that's a good point. I should also note that some places are closed during certain times, like the National Gallery during exhibitions, so maybe mention that. 

I should structure the answer in a way that's easy to follow, maybe with headings for each category. Make sure the names are correct and the descriptions are accurate. Avoid any outdated information. Check if the London Eye is considered a popular place, and if the Tower of London is still a top attraction. Also, mention the Royal Albert Hall for music events. 

Okay, putting it all together, the answer should list the main attractions with brief descriptions, maybe a few bullet points, and ensure it's clear and helpful for the user.
</think>

London is a vibrant city with a rich history, stunning architecture, and a wealth of cultural attractions. Here are some of the most popular places to visit:

### **Historical & Cultural Sites**  
1. **Tower of London**  
   - A medieval fortress with the Crown Jewels, the Crown Jewels Museum, and the Royal Mint.  
   - Explore its history, including the Black Death and the Tower's role in the English monarchy.  

2. **British Museum**  
   - One of the worldâ€™s largest and most comprehensive art museums, housing artifacts from around the globe.  
   - Must-see: The Rosetta Stone, the Egyptian collection, and the Hall of Bulls.  

3. **Westminster Abbey**  
   - A Gothic masterpiece, home to the Crown Jewels and the Abbeyâ€™s famous choir.  
   - A symbol of the UKâ€™s monarchy and a site for weddings and royal events.  

4. **The Houses of Parliament**  
   - The iconic **Big Ben** (the clock tower) and the **Palace of Westminster**.  
   - Visit the National Gallery and the British Library nearby.  

### **Modern & Iconic Landmarks**  
5. **London Eye**  
   - A 135-meter observation wheel offering panoramic views of the city.  
   - A popular spot for photos and night views.  

6. **The Shard**  
   - A 310-meter skyscraper with a rooftop restaurant and a 360Â° view of London.  

7. **The London Eye**  
   - A 135-meter observation wheel offering panoramic views of the city.  

8. **The National Gallery**  
   - A world-renowned art museum with masterpieces by Van Gogh, Rembrandt, and Leonardo da Vinci.  

### **Natural & Leisure Attractions**  
9. **Tate Modern**  
   - A contemporary art museum located on the banks of the Thames, featuring works by contemporary artists.  

10. **The Thames River**  
    - Explore the River Thames, with landmarks like **London Bridge**, **Tower Bridge**, and **The South Bank**.  

11. **The London Eye**  
    - A 135-meter observation wheel offering panoramic views of the city.  

12. **The London Zoo**  
    - A wildlife sanctuary with a variety of animals and a scenic garden.  

### **Other Must-Visit Areas**  
- **The East End** (e.g., **South Bank**, **Barbican**, **The Royal Academy**)  
- **The Science Museum** (for interactive exhibits)  
- **The Royal Albert Hall** (for concerts and music events)  
- **The National Archives** (for historical documents)  

### Tips:  
- Check opening hours and ticket availability for major attractions.  
- Use apps like **Google Maps** or **Walking Tours** for guided itineraries.  

Londonâ€™s mix of history, art, and modernity makes it a must-visit for travelers! ðŸŒŸ

-----------------------
{'prompt_tokens': 16, 'total_tokens': 1228, 'completion_tokens': 1212, 'prompt_tokens_details': None}


In [None]:
import io
import json
import time
import boto3
from IPython.display import clear_output

class LineIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord("\n"):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if "PayloadPart" not in chunk:
                print("Unknown event type:" + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk["PayloadPart"]["Bytes"])

def stream_response(endpoint_name, inputs, max_tokens=8189, temperature=0.7, top_p=0.9):
    body = {
      "messages": [
        {"role": "user", "content": [{"type": "text", "text": inputs}]}
        ],
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "stream": True,
    }

    resp = sm_runtime.invoke_endpoint_with_response_stream(
        EndpointName = endpoint_name,
        Body = json.dumps(body),
        ContentType = "application/json",
    )

    event_stream = resp["Body"]
    start_json = b"{"
    full_response = ""
    start_time = time.time()
    token_count = 0

    for line in LineIterator(event_stream):
        if line != b"" and start_json in line:
            data = json.loads(line[line.find(start_json):].decode("utf-8"))
            token_text = data['choices'][0]['delta'].get('content', '')
            full_response += token_text
            token_count += 1

            # Calculate tokens per second
            elapsed_time = time.time() - start_time
            tps = token_count / elapsed_time if elapsed_time > 0 else 0

            # Clear the output and reprint everything
            clear_output(wait=True)
            print(full_response)
            print(f"\nTokens per Second: {tps:.2f}", end="")

    print("\n") # Add a newline after response is complete

    return full_response

In [62]:
inputs = "What is greater 9.11 or 9.8?"
output = stream_response(endpoint_name, inputs, max_tokens=8000)

<think>
Okay, so I need to figure out whether 9.11 is greater than 9.8 or not. Let me think. Both numbers are decimals, right? So they have whole numbers and decimal parts. Let me write them down again to visualize better: 9.11 and 9.8. 

First, I remember that when comparing decimals, you start by looking at the digits from left to right. So, both numbers have the same whole number part, which is 9. That's easy. Now, the next part is the decimal part. 

For 9.11, the decimal part is 0.11, and for 9.8, it's 0.8. So, comparing the decimal parts. Let's break it down. The first decimal place is tenths. 9.11 has 1 in the tenths place, and 9.8 has 8 in the tenths place. Since 1 is less than 8, that would mean that 0.11 is less than 0.8. Therefore, 9.11 is less than 9.8. 

Wait, but maybe I should check if there's any other way to compare them. For example, converting them into fractions. Let me try that. 9.11 is the same as 9 + 0.11, and 9.8 is 9 + 0.8. So, subtracting the whole numbers, th

## Cleanup

In [63]:
_ = sm.delete_endpoint(EndpointName=endpoint_name)
_ = sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
_ = sm.delete_model(ModelName=model_name)