## Utilizing Llama 3.2 90B for data extraction on images for distillation to fine tune Llama 3.1 8B (Build-your-own OCR Solution)

In this notebook, we will walk you through the process of utilizing a Llama 3.2 90B model to leverage its vision and text extraction capabilities to create an instruction fine-tuning dataset to perform distillation by fine-tuning a Llama 3.2 11B model.

By leveraging Llama 3.2 90B's vision capabilities, we can leverage advanced vision models for text extraction from images, improving the process of creating a robust OCR solution. Here are a few significant advantages to leveraging vision models for structured documents:

#### **Structured Information Extraction**

Unlike traditional OCR, where solutions capture text as flat string, vision models can identify specific field like "Employer Name" or "Wages" and export them in a labeled format (e.g JSON) directly. This avoids additional processing steps needed to interpret unstructured OCR text.

#### **Enhanced Field-Level Accuracy**

Vision models can be fine-tined to recognize specific fields with high precision. For instance, a fine-tuned vision model can reliably differentiate between 2 fields within a bounding box, whereas OCR might capture them as generic text that requires additional parsing.

#### **Improved Handling of Variability**

Vision models are better at handling document variability (e.g different layouts, fonts, or minor distortions) without sacrificing accuracy.

#### **Semantic Understanding / Contextual Clustering**

Vision models trained to extract and label fields can group information meaningfully, making it easier to convert it into structured datasets for analysis or fine-tuning text-based models. Traditional OCR lacks this semantic understanding and would require post-processing pipelines to achieve similar structured output.

### In this notebook, we perform the following high level steps: 

1. We deploy a `Llama3.2-90b instruct` model on Amazon Bedrock and extract the data in a W2 images dataset in JSON format to be perform distillation by fine tuning a Llama 3.1 8b model.

1. Fine-tune `Llama3.2-11b instruct` model to improve the JSON extraction capabilities.

In [1]:
!pip install boto3==1.35.18



In [2]:
import boto3
import json
from PIL import Image 
import re

In [3]:
boto3.__version__

'1.35.18'

## Comparing Llama 3.2 models

In the next few slides we will take a look at both Llama 3.2 models on Amazon Bedrock and the output for data extraction on both models. Prior to extracting the data from the models, let's take a look at the pricing comparison for both Llama 3.2 9b and 11b, the newest Meta models for vision use cases for extracting data out of images.

| Meta Models              | Price per 1,000 input tokens | Price per 1,000 output tokens |
|--------------------------|------------------------------|-------------------------------|
| Llama 3.2 Instruct (11B) | 0.00035 USD                     | 0.00035 USD                      |
| Llama 3.2 Instruct (90B) | 0.002 USD                       | 0.002 USD       

Here we can see that 90B is significantly higher per input / output tokens. Per 1 million input / output for 90B is $2, while the same 1 million input and output is roughly 0.35 cents. Now let's compare the outputs of each model. 

## Data Extraction step: Image processing for 11b

First, we will configure our bedrock client, allowing us to leverage a llama 3.2 11b instruct model.

We will use this model to process w2 images from the [fake-w2-us-tax-form-dataset](https://huggingface.co/datasets/singhsays/fake-w2-us-tax-form-dataset)
from Huggingface.

In [4]:
bedrock_client = boto3.client("bedrock-runtime", region_name="us-west-2")
MODEL_ID = "us.meta.llama3-2-11b-instruct-v1:0"

In [28]:
def process_w2(image_bytes, image_format):
    image_media_type = image_format.lower()

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "text": """
Analyze the attached W2 form, extracting all fields and bounding boxes, and return the data as a JSON object. 
Focus on capturing each field as labeled on the form, and be especially precise with multi-state information. 
For each field, ensure the following:

1. **Employee Information**: Extract 'Employee Name,' 'Employee Address,' 'Social Security Number,' etc.
2. **Employer Information**: Include 'Employer Name,' 'Employer EIN,' 'Employer Address,' and 'Zip Code.'
3. **Earnings and Tax Information**: Extract 'Wages,' 'Social Security Wages,' 'Medicare Wages and Tips,' 'Federal Income Tax Withheld,' 'State Income Tax,' 'Local Wages / Tips,' 'Local Income Tax,' etc.
4. **Benefits and Other Deductions**: Include fields like 'Dependent Care Benefits' and 'Nonqualified Plans.'
5. **Multi-state Employment Information**: Identify all states listed on the W2, capturing information for each:
   - Ensure each state's data is complete and correct, including 'Local Wages / Tips,' and 'Local Income Tax.', 'Locality Name'
   - If multiple states are listed, each state should appear as a separate key in the JSON output (e.g., "NC", "UT").
   - Verify and cross-reference state abbreviations to ensure accuracy.

The JSON output should precisely reflect all information, especially multiple states, with each state’s information grouped under its corresponding abbreviation.
"""

                },
                {
                    "image": {
                        "format": image_media_type,
                        "source": {
                            "bytes": image_bytes
                        }
                    }
                }
            ]
        }
    ]
    return messages

In [29]:
def transform_to_expected_format(w2_data):
    # Transform Employee Information
    employee_info = w2_data.get("Employee Information", w2_data.get("employee", {}))
    employee = {
        "name": employee_info.get("Employee Name", employee_info.get("name", "")),
        "address": ", ".join([
            employee_info.get("Employee Address", employee_info.get("address", "")),
            employee_info.get("City", ""),
            employee_info.get("State", ""),
            employee_info.get("Zip Code", "")
        ]).strip(", "),
        "socialSecurityNumber": employee_info.get("Social Security Number", "[Redacted]"),
        "employeeIdNumber": "[Redacted]"
    }

    # Transform Employer Information
    employer_info = w2_data.get("Employer Information", w2_data.get("employer", {}))
    employer = {
        "name": employer_info.get("Employer Name", employer_info.get("name", "")),
        "ein": employer_info.get("Employer EIN", ""), 
        "address": ", ".join([
            employer_info.get("Employer Address", employer_info.get("address", "")),
            employer_info.get("City", ""),
            employer_info.get("State", ""),
            employer_info.get("Zip Code", "")
        ]).strip(", ")
    }

    # Transform Earnings and Tax Information
    earnings_info = w2_data.get("Earnings and Tax Information", w2_data.get("earnings", {}))
    earnings = {
        "wages": float(earnings_info.get("Wages", earnings_info.get("wages", "0"))),
        "socialSecurityWages": float(earnings_info.get("Social Security Wages", earnings_info.get("socialSecurityWages", "0"))),
        "medicareWagesAndTips": float(earnings_info.get("Medicare Wages and Tips", earnings_info.get("medicareWagesAndTips", "0"))),
        "federalIncomeTaxWithheld": float(earnings_info.get("Federal Income Tax Withheld", earnings_info.get("federalIncomeTaxWithheld", "0"))),
        "stateIncomeTax": float(earnings_info.get("State Income Tax", earnings_info.get("stateIncomeTax", "0"))),
        "localWagesTips": float(earnings_info.get("Local Wages / Tips", earnings_info.get("localWagesTips", "0"))),
        "localIncomeTax": float(earnings_info.get("Local Income Tax", earnings_info.get("localIncomeTax", "0")))
    }

    # Transform Benefits and Other Deductions
    benefits_info = w2_data.get("Benefits and Other Deductions", w2_data.get("benefits", {}))
    benefits = {
        "dependentCareBenefits": int(benefits_info.get("Dependent Care Benefits", benefits_info.get("dependentCareBenefits", "0"))),
        "nonqualifiedPlans": int(benefits_info.get("Nonqualified Plans", benefits_info.get("nonqualifiedPlans", "0")))
    }

    # Transform Multi-state Employment Information
    multi_state_info = w2_data.get("Multi-state Employment Information", w2_data.get("multiStateEmployment", {}))
    multiStateEmployment = {
        state: {
            "localWagesTips": float(info.get("Local Wages / Tips", info.get("localWagesTips", "0"))),
            "localIncomeTax": float(info.get("Local Income Tax", info.get("localIncomeTax", "0"))),
            "localityName": info.get("Locality Name", info.get("localityName", ""))
        } for state, info in multi_state_info.items()
    }

    # Construct the transformed data structure
    return {
        "employee": employee,
        "employer": employer,
        "earnings": earnings,
        "benefits": benefits,
        "multiStateEmployment": multiStateEmployment
    }


In [10]:
## 11b cell

from datasets import load_dataset
from io import BytesIO
import os
import json
import boto3

# Initialize S3 client
s3 = boto3.client('s3')
bucket_name = 'genai-accelerate-2024'  # Set your S3 bucket name
subdirectory = 'llama-3-2-vision-dataset'  # Set the subdirectory within the bucket

# Load the dataset
dataset = load_dataset("singhsays/fake-w2-us-tax-form-dataset", split="train", download_mode="force_redownload")
w2_dataset = []
print(f"Bucket name: {bucket_name}")
print(f"Subdirectory: {subdirectory}")

# Process the first 5 images
for i, item in enumerate(dataset):
    if i >= 5:
        break

    try:
        # Image processing steps
        image = item['image']
        image_format = image.format
        with BytesIO() as buffer:
            image.save(buffer, format=image_format)
            image_bytes = buffer.getvalue()

        # Define a unique file name for each image within the loop
        image_file_name = f"images/img_{i}.{image_format.lower()}"
        image_path = os.path.join("/home/ec2-user/SageMaker/vision-workshop/dataset", image_file_name)
        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        with open(image_path, "wb") as img_file:
            img_file.write(image_bytes)

        # Upload image to S3
        s3.put_object(Bucket=bucket_name, Key=f"{subdirectory}/{image_file_name}", Body=image_bytes)

        # Model inference
        response = bedrock_client.converse(
            modelId=MODEL_ID,
            messages=process_w2(image_bytes, image_format),
            inferenceConfig={"maxTokens": 2048, "temperature": 0.0, "topP": 0.1},
        )

        # Extract response content
        response_content = response['output']['message']['content'][0]['text']
        print(f"Raw response content for image {i}: {response_content}")

        # Attempt JSON extraction
        json_match = re.search(r'```json\n(.*?)\n```', response_content, re.DOTALL) or re.search(r'\{.*\}', response_content, re.DOTALL)

        if json_match:
            json_str = json_match.group(1) if json_match.groups() else json_match.group(0)
            json_str = json_str.strip()
            print(f"Extracted JSON for image {i}: {json_str}")

            try:
                w2_data = json.loads(json_str)
                transformed_data = transform_to_expected_format(w2_data)

                # Construct metadata entry
                prompt = f"Process W2 data for {transformed_data.get('employee', {}).get('name', 'Unknown')}"
                completion = json.dumps(transformed_data, indent=2)

                # Append to dataset with unique file_name for each entry
                w2_dataset.append({"file_name": image_file_name, "prompt": prompt, "completion": completion})

            except json.JSONDecodeError as e:
                print(f"Error parsing JSON for image {i}: {str(e)}")
        else:
            print(f"No JSON found in the response for image {i}.")

    except Exception as e:
        print(f"Error processing image {i}: {str(e)}")

# Final logging after dataset creation
print(f"Total JSON objects created: {len(w2_dataset)}")

# Save as JSONL file
output_directory = "/home/ec2-user/SageMaker/vision-workshop/dataset"
output_file_path = os.path.join(output_directory, "11b_metadata.jsonl")
os.makedirs(output_directory, exist_ok=True)

with open(output_file_path, "w") as f:
    for entry in w2_dataset:
        json.dump(entry, f)
        f.write('\n')

print("Processing complete. JSON dataset with image paths saved as 11b_metadata.jsonl")

Generating train split: 100%|██████████| 1800/1800 [00:01<00:00, 1049.49 examples/s]
Generating validation split: 100%|██████████| 100/100 [00:00<00:00, 650.17 examples/s]
Generating test split: 100%|██████████| 100/100 [00:00<00:00, 971.85 examples/s]


Bucket name: genai-accelerate-2024
Subdirectory: llama-3-2-vision-dataset
Raw response content for image 0: The attached W2 form contains the following information:

**Employee Information**

*   Employee Name: Michele Hebert
*   Employee Address: 9888 Zimmerman Roads Apt. 425, Moorestad, MO 77456-6485
*   Social Security Number: 412-88-2525

**Employer Information**

*   Employer Name: Bennett, Allen and Yang Inc
*   Employer EIN: 47-5592725
*   Employer Address: 40301 Cameron Village Suite 661, Aguirrebury, NH 36219-7671
*   Zip Code: 36219-7671

**Earnings and Tax Information**

*   Wages: $141,194.15
*   Social Security Wages: $169,708.96
*   Medicare Wages and Tips: $181,642.61
*   Federal Income Tax Withheld: $37,276.89
*   State Income Tax: $1,2982.74
*   Local Wages / Tips: $1,49192.1
*   Local Income Tax: $1,5871.27

**Benefits and Other Deductions**

*   Dependent Care Benefits: $220
*   Nonqualified Plans: $158

**Multi-state Employment Information**

*   NC:
    *   Local W

## Model Output Comparison

Taking a look at the data extraction, it did a great job of processing the the image data into formatted JSON. 

This information can be used to gain insights on image data and can automate processes for data extraction and business insights.

**Daniel Jenkins**

SSN is incorrect

**Ann Hill**

EIN and SSN is incorrect

**Katie Green**

EIN and SSN is incorrect


In [11]:
MODEL_ID = "us.meta.llama3-2-90b-instruct-v1:0"

In [34]:
w2_dataset = []
print(f"Bucket name: {bucket_name}")
print(f"Subdirectory: {subdirectory}")

# Process the first 5 images
for i, item in enumerate(dataset):
    if i >= 5:
        break

    try:
        # Image processing steps
        image = item['image']
        image_format = image.format
        with BytesIO() as buffer:
            image.save(buffer, format=image_format)
            image_bytes = buffer.getvalue()

        # Define a unique file name for each image within the loop
        image_file_name = f"images/img_{i}.{image_format.lower()}"
        image_path = os.path.join("/home/ec2-user/SageMaker/vision-workshop/dataset", image_file_name)
        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        with open(image_path, "wb") as img_file:
            img_file.write(image_bytes)

        # Upload image to S3
        s3.put_object(Bucket=bucket_name, Key=f"{subdirectory}/{image_file_name}", Body=image_bytes)

        # Model inference
        response = bedrock_client.converse(
            modelId=MODEL_ID,
            messages=process_w2(image_bytes, image_format),
            inferenceConfig={"maxTokens": 2048, "temperature": 0.0, "topP": 0.1},
        )

        # Extract response content
        response_content = response['output']['message']['content'][0]['text']
        print(f"Raw response content for image {i}: {response_content}")

        # Attempt JSON extraction
        json_match = re.search(r'```json\n(.*?)\n```', response_content, re.DOTALL) or re.search(r'\{.*\}', response_content, re.DOTALL)

        if json_match:
            json_str = json_match.group(1) if json_match.groups() else json_match.group(0)
            json_str = json_str.strip()
            print(f"Extracted JSON string for image {i}: {json_str}")

            try:
                w2_data = json.loads(json_str)
                transformed_data = transform_to_expected_format(w2_data)
                print(f"Transformed JSON for image {i}: {transformed_data}")

                # Construct metadata entry
                prompt = f"Process W2 data for {transformed_data.get('employee', {}).get('name', 'Unknown')}"
                completion = json.dumps(transformed_data, indent=2)

                # Append to dataset with unique file_name for each entry
                w2_dataset.append({"file_name": image_file_name, "prompt": prompt, "completion": completion})

            except json.JSONDecodeError as e:
                print(f"Error parsing JSON for image {i}: {str(e)}")
        else:
            print(f"No JSON found in the response for image {i}.")

    except Exception as e:
        print(f"Error processing image {i}: {str(e)}")

# Final logging after dataset creation
print(f"Total JSON objects created: {len(w2_dataset)}")
print(f"Dataset contents before saving: {w2_dataset}")

# Save as JSONL file
output_directory = "/home/ec2-user/SageMaker/vision-workshop/dataset"
output_file_path = os.path.join(output_directory, "metadata_test.jsonl")
try:
    with open(output_file_path, "w") as f:
        for entry in w2_dataset:
            f.write(json.dumps(entry) + "\n")
    print("Data successfully written to metadata_test.jsonl")
except IOError as e:
    print(f"An error occurred while writing the JSONL file: {e}")

SyntaxError: invalid syntax (802865206.py, line 46)

## One-shot prompting to create fine-tuning dataset

We can see with the results above both datasets look nearly identical, however 90B has better accuracy and less degredation in model responses when looking at different fields such as Employer EIN, and employee address.

In [24]:
def process_w2(image_bytes, image_format):
    image_media_type = image_format.lower()

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "text": """
                    Analyze the attached W2 form, extracting all fields and bounding boxes, and return the data as a JSON object. 
                    Focus on capturing each field as labeled on the form, and be especially precise with multi-state information. 
                    For each field, ensure the following:

                    1. **Employee Information**: Extract 'Employee Name,' 'Employee Address,' 'Social Security Number,' etc.
                    2. **Employer Information**: Include 'Employer Name,' 'Employer EIN,' 'Employer Address,' and 'Zip Code.'
                    3. **Earnings and Tax Information**: Extract 'Wages,' 'Social Security Wages,' 'Medicare Wages and Tips,' 'Federal Income Tax Withheld,' 'State Income Tax,' 'Local Wages / Tips,' 'Local Income Tax,' etc.
                    4. **Benefits and Other Deductions**: Include fields like 'Dependent Care Benefits' and 'Nonqualified Plans.'
                    5. **Multi-state Employment Information**: Identify all states listed on the W2, capturing information for each:
                       - Ensure each state's data is complete and correct, including 'Local Wages / Tips,' and 'Local Income Tax,' 'Locality Name.'
                       - If multiple states are listed, each state should appear as a separate key in the JSON output (e.g., "NC", "UT").
                       - Verify and cross-reference state abbreviations to ensure accuracy.
                    6. **Field Formatting and Validation**:
                       - **EIN**: Confirm that the Employer EIN follows the format "XX-XXXXXXX" with two digits, a hyphen, and seven digits. If not, set it to 'Manual Verification Needed' and log the error as 'Invalid EIN Format.'
                       - **SSN**: Ensure the Social Security Number matches the format "XXX-XX-XXXX." If not, flag with 'Manual Verification Needed' and log as 'Invalid SSN Format.'
                       - **Zip Code**: Confirm the Zip Code is in the format "XXXXX" or "XXXXX-XXXX." If it doesn't match, set to 'Manual Verification Needed' and log as 'Invalid Zip Code Format.'
                       - **Dollar Signs and Numeric Values**: Strip dollar signs and commas from amount fields, converting them into floating-point numbers. If conversion fails, set to 'Invalid Numeric Format' and log the field requiring verification.
                    7. **Error Logging**: Include an "error log" in the JSON output for fields flagged for manual review, listing each issue's field, reason, and the detected value for easy traceability.

                    Here's a sample JSON structure with an error log:
                    {
                        "Employee Information": {
                            "Employee Name": "Ann Hill",
                            "Employee Address": "39572 Jack Trail Apt 308, New Sarahside, MN 56848-7193",
                            "Social Security Number": "192-67-3262"
                        },
                        "Employer Information": {
                            "Employer Name": "Bryant Ltd Group",
                            "Employer EIN": "06-6105986",
                            "Employer Address": "82582 William Cape Suite 370, Scottside, ND 93090-3134",
                            "Zip Code": "93090-3134"
                        },
                        "Earnings and Tax Information": {
                            "Wages": 238111.55,
                            "State Income Tax": "6858.46",
                        },
                        "Multi-state Employment Information": {
                            "AL": {
                                "Local Wages / Tips": "287711.19",
                            }
                        },
                        "error_log": [
                            {"field": "Social Security Number", "reason": "Invalid SSN Format", "detected_value": "3272509"},
                            {"field": "Employer EIN", "reason": "Invalid EIN Format", "detected_value": "82582"},
                            {"field": "State Income Tax", "reason": "Invalid Numeric Format", "detected_value": "$14,9192.1"}
                        ]
                    }

                """

                },
                {
                    "image": {
                        "format": image_media_type,
                        "source": {
                            "bytes": image_bytes
                        }
                    }
                }
            ]
        }
    ]
    return messages

In [25]:
def validate_and_correct_fields(data):
    # Define regex patterns for EIN, SSN, Zip Code, and numeric formatting
    EIN_PATTERN = r'^\d{2}-\d{7}$'
    SSN_PATTERN = r'^\d{3}-\d{2}-\d{4}$'
    ZIP_PATTERN = r'^\d{5}(-\d{4})?$'
    NUMERIC_PATTERN = r'^\$?[\d,]+\.\d{2}$'

    # Address-related terms to detect address-like strings in EIN fields
    ADDRESS_KEYWORDS = ['Street', 'Avenue', 'Boulevard', 'Drive', 'Road', 'Lane', 'Apt', 'Suite', 
                        'Circle', 'Place', 'Square', 'Parkway', 'Court', 'Terrace', 'Way', 
                        'North', 'South', 'East', 'West', 'NW', 'NE', 'SW', 'SE']

    # Helper function to detect address-like content in EIN
    def contains_address_like_content(value):
        return any(keyword in value for keyword in ADDRESS_KEYWORDS)

    # Validate Employer EIN format and ensure it’s not an address
    if 'Employer Information' in data:
        ein = data['Employer Information'].get('Employer EIN', '')
        if not re.match(EIN_PATTERN, ein):
            if contains_address_like_content(ein) or re.search(r'\d{5}.*[A-Za-z]{2} \d{5}', ein):
                print(f"Warning: EIN appears to be misclassified as an address: ({ein}). Setting to 'Manual Verification Needed'.")
                data['Employer Information']['Employer EIN'] = "Manual Verification Needed"
            else:
                print(f"Warning: Invalid EIN format detected ({ein}). Setting EIN to 'Manual Verification Needed'.")
                data['Employer Information']['Employer EIN'] = "Manual Verification Needed"

    # Validate Social Security Number format in Employee Information
    if 'Employee Information' in data:
        ssn = data['Employee Information'].get('Social Security Number', '')
        if not re.match(SSN_PATTERN, ssn):
            print(f"Warning: Invalid SSN format detected ({ssn}). Setting SSN to 'Manual Verification Needed'.")
            data['Employee Information']['Social Security Number'] = "Manual Verification Needed"

    # Validate Zip Code format in Employer Information
    if 'Employer Information' in data:
        zip_code = data['Employer Information'].get('Zip Code', '')
        if not re.match(ZIP_PATTERN, zip_code):
            print(f"Warning: Invalid Zip Code format detected ({zip_code}). Setting Zip Code to 'Manual Verification Needed'.")
            data['Employer Information']['Zip Code'] = "Manual Verification Needed"

    # Clean and validate numeric fields
    numeric_fields = [
        'Wages', 'Social Security Wages', 'Medicare Wages and Tips', 
        'Federal Income Tax Withheld', 'State Income Tax', 
        'Local Wages / Tips', 'Local Income Tax'
    ]

    for field in numeric_fields:
        if field in data.get('Earnings and Tax Information', {}):
            num_value = data['Earnings and Tax Information'][field]
            if isinstance(num_value, str):
                # Clean up dollar signs and commas for conversion
                clean_value = num_value.replace("$", "").replace(",", "")
                try:
                    data['Earnings and Tax Information'][field] = float(clean_value)
                except ValueError:
                    print(f"Warning: Unable to parse {field} value '{num_value}'. Setting to 'Invalid Format'.")
                    data['Earnings and Tax Information'][field] = "Invalid Format"

    # Validate multi-state employment numeric data
    if 'Multi-state Employment Information' in data:
        for state, state_data in data['Multi-state Employment Information'].items():
            for field in numeric_fields:
                if field in state_data:
                    num_value = state_data[field]
                    if isinstance(num_value, str):
                        clean_value = num_value.replace("$", "").replace(",", "")
                        try:
                            state_data[field] = float(clean_value)
                        except ValueError:
                            print(f"Warning: Unable to parse {field} in {state}. Setting to 'Invalid Format'.")
                            state_data[field] = "Invalid Format"

    return data

In [None]:
w2_dataset = []
print(f"Bucket name: {bucket_name}")
print(f"Subdirectory: {subdirectory}")

# Process the first 100 images
for i, item in enumerate(dataset):
    if i >= 100:
        break

    try:
        # Image processing steps
        image = item['image']
        image_format = image.format
        with BytesIO() as buffer:
            image.save(buffer, format=image_format)
            image_bytes = buffer.getvalue()

        # Define a unique file name for each image within the loop
        image_file_name = f"images/img_{i}.{image_format.lower()}"
        image_path = os.path.join("/home/ec2-user/SageMaker/vision-workshop/dataset", image_file_name)
        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        with open(image_path, "wb") as img_file:
            img_file.write(image_bytes)

        # Upload image to S3
        s3.put_object(Bucket=bucket_name, Key=f"{subdirectory}/{image_file_name}", Body=image_bytes)

        # Model inference
        response = bedrock_client.converse(
            modelId=MODEL_ID,
            messages=process_w2(image_bytes, image_format),
            inferenceConfig={"maxTokens": 2048, "temperature": 0.0, "topP": 0.1},
        )

        # Extract response content
        response_content = response['output']['message']['content'][0]['text']
        print(f"Raw response content for image {i}: {response_content}")

        # Attempt JSON extraction
        json_match = re.search(r'```json\n(.*?)\n```', response_content, re.DOTALL) or re.search(r'\{.*\}', response_content, re.DOTALL)

        if json_match:
            json_str = json_match.group(1) if json_match.groups() else json_match.group(0)
            json_str = json_str.strip()
            print(f"Extracted JSON for image {i}: {json_str}")

            try:
                w2_data = json.loads(json_str)
                transformed_data = transform_to_expected_format(w2_data)
                print(f"Transformed JSON for image {i}: {transformed_data}")

                # Construct metadata entry
                prompt = f"Process W2 data for {transformed_data.get('employee', {}).get('name', 'Unknown')}"
                completion = json.dumps(transformed_data, indent=2)

                # Append to dataset with unique file_name for each entry
                w2_dataset.append({"file_name": image_file_name, "prompt": prompt, "completion": completion})

            except json.JSONDecodeError as e:
                print(f"Error parsing JSON for image {i}: {str(e)}")
        else:
            print(f"No JSON found in the response for image {i}.")

    except Exception as e:
        print(f"Error processing image {i}: {str(e)}")

# Final logging after dataset creation
print(f"Total JSON objects created: {len(w2_dataset)}")
print(f"Dataset contents before saving: {w2_dataset}")

# Save as JSONL file
output_directory = "/home/ec2-user/SageMaker/vision-workshop/dataset"
output_file_path = os.path.join(output_directory, "one_shot_metadata_test.jsonl")
try:
    with open(output_file_path, "w") as f:
        for entry in w2_dataset:
            f.write(json.dumps(entry) + "\n")
    print("Data successfully written to one_shot_metadata_test.jsonl")
except IOError as e:
    print(f"An error occurred while writing the JSONL file: {e}")

Bucket name: genai-accelerate-2024
Subdirectory: llama-3-2-vision-dataset
Raw response content for image 0: The attached W2 form contains the following information:

**Employee Information:**

*   Employee Name: Michele Hebert
*   Employee Address: 9888 Zimmerman Roads Apt. 425, Moorestad, MO 77456-6485
*   Social Security Number: 412-88-2525

**Employer Information:**

*   Employer Name: Bennett, Allen and Yang Inc
*   Employer EIN: 47-5592725
*   Employer Address: 40301 Cameron Village Suite 661, Aguirrebury, NH 36219-7671
*   Zip Code: 36219-7671

**Earnings and Tax Information:**

*   Wages: $141194.15
*   Social Security Wages: $169708.96
*   Medicare Wages and Tips: $181642.61
*   Federal Income Tax Withheld: $37276.89
*   State Income Tax: $12982.74
*   Local Wages / Tips: $149192.1
*   Local Income Tax: $15871.27

**Benefits and Other Deductions:**

*   Dependent Care Benefits: $220
*   Nonqualified Plans: $158

**Multi-state Employment Information:**

*   NC:
    *   Local Wag

## Conclusion

We now have our dataset ready to perform knowledge distillation for fine-tuning our llama 3.2 11b model. With this notebook, we have successfully leveraged llama 3.2 90b with a one-shot prompt to create an instruction-based dataset that we can use to fine-tune 11b with to perform distillation to enhance the accuracy of responses with 11b.

## Contributors
- AWS
- Meta