In [1]:
!pip install transformers



In [2]:
!pip install -U sagemaker



In [3]:
!pip install s3fs



In [4]:
!pip install loguru



In [5]:
import time
import json
import os
import sys
import sagemaker
import boto3
import s3fs
from sagemaker.huggingface import HuggingFaceModel

import pandas as pd
from datasets import Dataset

# Add the project root directory to the Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
from src.utils.data_generation import nested_split_dataset, generate_responses_concurrently_deployed
from src.prompts.llama_prompts import MathQAPrompt, ContextualQAPrompt

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


In [6]:
PROC_NUM = 1 # number of processes to use for data generation
DATA_SPLIT = "train" # "train" or "test"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"

### Get execution role

In [7]:
try:
	role = sagemaker.get_execution_role()
    # role = "arn:aws:iam::551529993308:role/service-role/AmazonSageMaker-ExecutionRole-20250711T075198"
except ValueError:
	iam = boto3.client('iam')
	role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']


### Model setup

In [8]:
# env variables for model creation
env = {
    "LOGLEVEL": "INFO"
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data="s3://self-corrective-llm-data/initial_model/model.tar.gz",
    role=role,
    env=env,
    transformers_version="4.49",
    pytorch_version="2.6",
    py_version='py312',
)


In [9]:
# deploy model to SageMaker Inference
model = huggingface_model.deploy(
	initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
	container_startup_health_check_timeout=300,
  )

--

### SQUAD

In [None]:
dataset_name = "rajpurkar_squad"
num_major_chunks = 8
s3 = s3fs.S3FileSystem()
path = f"s3://self-corrective-llm-data/dataset/raw_data/{dataset_name}/{DATA_SPLIT}.parquet"

try:
    print(f"Loading Parquet file from: {path}")
    with s3.open(path, 'rb') as f:
        df = pd.read_parquet(f)
    print("File loaded successfully into pandas DataFrame.")
except Exception as e:
    print(f"Failed to read Parquet file. Error: {e}")

data = Dataset.from_pandas(df)
# data = data.shuffle(seed=42).select(range(100))
# data_chunks = split_dataset(data, PROC_NUM)
nested_data_split = nested_split_dataset(data, num_major_chunks=num_major_chunks, num_minor_chunks=PROC_NUM)


In [None]:
response_dict = {
    "task_info": {
        "type": "Contextual QA",
        "dataset": dataset_name,
    },
    "additional_info": {
        "model": MODEL_NAME,
    }
}

In [None]:
def process_data_chunk(data_chunk: dict) -> tuple[list[dict], list[dict]]:
    model_input = []
    additional_info = []
    for i in range(len(data_chunk["question"])):
        model_input.append({
            "query": data_chunk["question"][i],
            "context": data_chunk["context"][i]
        })
        additional_info.append({
            "question": data_chunk["question"][i],
            "context": data_chunk["context"][i],
            "answer": data_chunk["answers"][i]["text"],
            "title": data_chunk["title"][i],
        })
    return model_input, additional_info

In [None]:
final_results = []
start_time = time.time()
for i, data_chunks in enumerate(nested_data_split):
        print(f"Processing chunk {i+1} of {len(nested_data_split)}")
        all_results = await generate_responses_concurrently_deployed(
            model=model,
            prompt_class=ContextualQAPrompt,
            data_chunks=data_chunks,
            response_dict_format=response_dict,
            data_processing_function=process_data_chunk,
            prompt_repetitions=10,
        )
        final_results.extend(all_results)
        
end_time = time.time()
print(f"Time taken: {end_time - start_time}")

In [None]:
final_results[:2]

In [None]:
output_path = f"s3://self-corrective-llm-data/dataset/raw_model_responses/{DATA_SPLIT}/{DATA_SPLIT}_{dataset_name}.json"
json_string = json.dumps(final_results, indent=4)
s3 = s3fs.S3FileSystem()

print(f"Saving file to: {output_path}")
with s3.open(output_path, 'w') as f:
    f.write(json_string)

print("File saved successfully to S3!")

### UMWP

In [None]:
dataset_name = "UMWP"
num_major_chunks = 8
s3 = s3fs.S3FileSystem()
path = f"s3://self-corrective-llm-data/dataset/raw_data/{dataset_name}/{DATA_SPLIT}.json"

try:
    with s3.open(path, 'r') as f:
        df = pd.read_json(f, lines=True)
except Exception as e:
    print(f"Failed to read as JSONL, trying as regular JSON. Error: {e}")
    with s3.open(path, 'r') as f:
        df = pd.read_json(f)


data = Dataset.from_pandas(df)
# data = data.shuffle(seed=42).select(range(100)) # take smaller sample for testing
# data_chunks = split_dataset(data, PROC_NUM)
nested_data_split = nested_split_dataset(data, num_major_chunks=num_major_chunks, num_minor_chunks=PROC_NUM)

In [None]:
response_dict = {
    "task_info": {
        "type": "QA",
        "dataset": dataset_name,
    },
    "additional_info": {
        "model": MODEL_NAME,
        "domain": "Math"
    }
}

In [None]:
def process_data_chunk(data_chunk: dict) -> tuple[list[dict], list[dict]]:
    model_input = []
    additional_info = []
    for i in range(len(data_chunk["question"])):
        model_input.append({
            "query": data_chunk["question"][i]
        })
        additional_info.append({
            "question": data_chunk["question"][i],
            "answer": data_chunk["answer"][i],
            "answerable": data_chunk["answerable"][i],
            "source": data_chunk["source"][i]
        })
    return model_input, additional_info

In [None]:
final_results = []
start_time = time.time()
for i, data_chunks in enumerate(nested_data_split):
        print(f"Processing chunk {i+1} of {len(nested_data_split)}")
        all_results = await generate_responses_concurrently_deployed(
            model=model,
            prompt_class=MathQAPrompt,
            data_chunks=data_chunks,
            response_dict_format=response_dict,
            data_processing_function=process_data_chunk,
            prompt_repetitions=10,
        )

        final_results.extend(all_results)

end_time = time.time()
print(f"Time taken: {end_time - start_time}")

In [None]:
final_results[:2]

In [None]:
output_path = f"s3://self-corrective-llm-data/dataset/raw_model_responses/{DATA_SPLIT}/{DATA_SPLIT}_{dataset_name}.json"
json_string = json.dumps(final_results, indent=4)
s3 = s3fs.S3FileSystem()

print(f"Saving file to: {output_path}")
with s3.open(output_path, 'w') as f:
    f.write(json_string)

print("File saved successfully to S3!")

In [None]:
model.delete_model()
model.delete_endpoint()

In [None]:
# prompt_1 = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a specialized question-answering AI. Your task is to give a concise answer to the question using *only* the provided context.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nContext:\n'''\nThe Pew Forum on Religion & Public Life ranks Egypt as the fifth worst country in the world for religious freedom. The United States Commission on International Religious Freedom, a bipartisan independent agency of the US government, has placed Egypt on its watch list of countries that require close monitoring due to the nature and extent of violations of religious freedom engaged in or tolerated by the government. According to a 2010 Pew Global Attitudes survey, 84% of Egyptians polled supported the death penalty for those who leave Islam; 77% supported whippings and cutting off of hands for theft and robbery; and 82% support stoning a person who commits adultery.\n'''\n\nQuestion: What percentage of Egyptians polled support death penalty for those leaving Islam?<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

In [None]:
# # send request
# response = predictor.predict({"inputs": [prompt_1]*10, "parameters": {"temperature": 0.7, "max_new_tokens": 256}})

# # print(response["responses"])
# for response in response["responses"]:
#     print(response)