# Setup

In [None]:
import json

import sagemaker
import boto3
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.huggingface import get_huggingface_llm_image_uri

from concurrent.futures import ThreadPoolExecutor
import concurrent
from tqdm import tqdm


sess = sagemaker.Session()

sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker session region: {sess.boto_region_name}")

# Deploy

In [None]:
model_id = "model_id"

# sagemaker config
instance_type = "ml.g5.12xlarge"
number_of_gpu = 4
health_check_timeout = 300

config = {
  'HF_MODEL_ID': "/opt/ml/model", # path to where sagemaker stores the model
  'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
  'MAX_INPUT_LENGTH': json.dumps(2048), # Max length of input text
  'MAX_TOTAL_TOKENS': json.dumps(4096) # Max length of the generation (including input text)
}

# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(
    role=role,
    model_data={'S3DataSource':{'S3Uri': f's3://{sagemaker_session_bucket}/{model_id}/output/model/','S3DataType': 'S3Prefix','CompressionType': 'None'}},
    image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.2"),
    env=config
)

# Deploy model to an endpoint
# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy
llm = llm_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  container_startup_health_check_timeout=health_check_timeout, # 10 minutes to be able to load the model
)


# Prompts - Preparation

In [None]:
with open('./data/prompts_test.jsonl', 'r') as file:
    data = [json.loads(x) for x in file]

In [None]:
format = """### Instruction
In the speaker diarization transcript below, some words are potentially misplaced. Please correct those words and move them to the right speaker. Directly show the corrected transcript without explaining what changes were made or why you made those changes.:

{{ user_msg_1 }}

### Answer

"""

test_inputs = []

for _, test_entry in enumerate(data):
    id = test_entry['utterance_id']
    prompt = test_entry['prompt']
    payload = {"id": id,"inputs": format.replace("{{ user_msg_1 }}", prompt), "parameters": {"max_new_tokens":2048, "top_p":0.5, "temperature":0.2, "stop":["</s>", "###"]}}
    test_inputs.append(payload)


# Prompts - Execution

In [None]:

workers = 5
print(f"workers used for load test: {workers}")
responses = {}
max_retries = 3  # Maximum number of retries for each request

def submit_task(executor, index, payload):
    future = executor.submit(llm.predict, payload)
    future_to_index[future] = (index, payload, 0)  # Adding retry count 0 initially

with ThreadPoolExecutor(max_workers=workers) as executor:
    future_to_index = {}
    pbar = tqdm(total=len(test_inputs))  # Initialize the progress bar

    for i in range(len(test_inputs)):
        payload = {
            "inputs": test_inputs[i]['inputs'],
            "parameters": test_inputs[i]['parameters']
        }
        submit_task(executor, test_inputs[i]["id"], payload)

    while future_to_index:
        for future in concurrent.futures.as_completed(future_to_index):
            index, payload, retries = future_to_index.pop(future)
            try:
                result = future.result()  # This gets the result from the future
                responses[index] = result[0]["generated_text"]
                pbar.update(1)  # Update the progress bar when a result is successfully added

            except Exception as exc:
                print(f'Task {index} generated an exception: {exc}')
                if retries < max_retries:
                    print(f"Retrying task {index}, attempt {retries + 1}")
                    submit_task(executor, index, payload)  # Retry the task
                else:
                    print(f"Task {index} failed after {retries} retries")



In [None]:
with open('./results/model_predictions.json', 'w') as file:
    json.dump(responses, file)

# Cleanup

In [None]:
llm.delete_model()
llm.delete_endpoint()