# Finetune and Deploy BLIP-2 with Amazon SageMaker for Visual Question Answering

In this notebook, we are going to fine-tune BLIP-2 for visual question answering. This will be used for a fashion product description generation use case.

In our example, we are going to leverage Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) and [PEFT](https://github.com/huggingface/peft) for the finetuning.

## 1. Setup development environment

Select the `Data Science 3.0` Image with `ml.t3.medium` instance type

Please make sure the IAM Role being used has the following permissions:
  - S3 Bucket access
  - SageMaker access

You can use the following IAM policy:
 - `arn:aws:iam::aws:policy/AmazonS3FullAccess`
 - `arn:aws:iam::aws:policy/AmazonSageMakerFullAcces`


In [None]:
! pip install -q --upgrade "scikit-image" "sagemaker>=2.190.0"

Here we set up the default session and bucket to use. If you want to use a different bucket, you can replace the bucket with the preferred bucket name.

In [None]:
import sagemaker
import boto3

sagemaker_session = sagemaker.Session()
sagemaker_session_bucket = None

if sagemaker_session_bucket is None and sagemaker_session is not None:
    # set to default bucket if a bucket name is not given, sagemaker will automatically create this bucket if it not exists
    sagemaker_session_bucket = sagemaker_session.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sagemaker_session = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sagemaker_session.default_bucket()}")
print(f"sagemaker session region: {sagemaker_session.boto_region_name}")

## 2. Load and prepare data

For the following demo we will be using tha Kaggle [Fashion Product Images Dataset](https://www.kaggle.com/datasets/paramaggarwal/fashion-product-images-dataset). We will use the shirt category `Tshirts` and `Shirts` as an example to finetune the model. The main files in the data sets are
1. `styles.csv`: all products and some of their key categories. We use this file to filter on the products that we're interested in.
2. `images.csv`: the link to all the images.
3. `images/product_id.jpg`: image of the product of id `product_id`.
4. `styles/product_id.json`: complete metadata of the product of id `product_id`.


#### You have to follow the following steps to load the data:
1. Sign in to Kaggle and download the dataset
2. Unzip the dataset and move the dataset into the folder `data`

For the following preprocessing code to work, the following structure is expected:
- `data/styles.csv`
- `data/images.csv`
- `data/styles/`
- `data/images/`

In the following processing, we perform two steps to create the train and test dataset.
1. Extract all the attributes from the product JSON file to create one CSV file containing information of all products and their attributes of interest. This can be used for data exploration.
2. Format the train and test dataset by 

### Combine all JSON attribute files into one dataset and filter the ids based on image availability.

In [40]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import json

In [None]:
image_summary = pd.read_csv('../data/images.csv')
attr_summary = pd.read_csv('../data/styles.csv', on_bad_lines='skip')

In [None]:
# We can see that some images don't have a link so we should remove them from the dataset
undefined_images = image_summary[image_summary.link == 'undefined']
undefined_images

In [None]:
# remove these records from the attribute summary
undefined_images_id = undefined_images.filename.apply(lambda x: x.replace('.jpg', '')).values
attr_summary['id'] = attr_summary['id'].astype(str)
attr_summary = attr_summary[~attr_summary.id.isin(undefined_images_id)]

In [None]:
# select only the target shirt for finetuning
tops = attr_summary[attr_summary['articleType'].isin(['Shirts', 'Tshirts'])]
tops.shape

In [None]:
dataset = pd.DataFrame()

for id in tqdm(tops['id']):
    with open(f"../data/styles/{id}.json", "r") as f:
        fl = f.read()
        jsn = json.loads(fl)

    attr = jsn['data']['articleAttributes']
    descr = jsn['data']['productDescriptors']
    
    attr.update(descr)
    attr['id'] = id
    attr['baseColour'] = tops
    row_df = pd.json_normalize(attr, sep='_')
    
    dataset = pd.concat([dataset, row_df], ignore_index=True)
    
    
# also add color to the dataset
dataset['id'] = dataset['id'].astype(str)
dataset = dataset.merge(tops[['id', 'baseColour']], on='id')

dataset.to_csv("../data/dataset.csv", index=False)

In [71]:
dataset.columns

Index(['Unnamed: 0', 'Fit', 'Pattern', 'Body or Garment Size', 'Sleeve Length',
       'Fabric', 'Collar', 'id', 'description_descriptorType',
       'description_value', 'Occasion', 'Fabric 2', 'Fabric 3', 'Neck',
       'materials_care_desc_descriptorType', 'materials_care_desc_value',
       'size_fit_desc_descriptorType', 'size_fit_desc_value',
       'style_note_descriptorType', 'style_note_value', 'Sleeve Styling',
       'Business Unit', 'Multipack Set', 'Main Trend', 'Print or Pattern Type',
       'Number of Components', 'Sport Team', 'Segment', 'Sport', 'Technology',
       'Players', 'Character', 'Wash Care', 'Number of Pockets',
       'Surface Styling', 'Length', 'Brick', 'Family', 'Class',
       'Pattern Coverage', 'Processing Time', 'Brand Fit Name', 'Plating',
       'Hemline', 'Placket', 'Placket Length', 'Transparency', 'Pocket Type',
       'Weave Pattern', 'Cuff', 'Stitch', 'Brand', 'Colour Shade Name'],
      dtype='object')

In [72]:
# formulate questions and answers for finetuning
cols = ["id", "Fabric", "Fit", "Collar", "Pattern", "Sleeve Length", "Sleeve Styling", "Neck", "baseColour"]
dataset = dataset[cols]

product_attributes = [
    {
        "Attribute": "Fabric",
        "Prompt": "What is the fabric of the shirt in this picture?",
    },
    {
        "Attribute": "Fit",
        "Prompt": "What is the Fit of the shirt in this picture?",
    },
    {
        "Attribute": "Collar",
        "Prompt": "What is the collar of the shirt in this picture?",
    },
    {
        "Attribute": "Pattern",
        "Prompt": "What is the pattern of the shirt in this picture?",
    },
    {
        "Attribute": "Neck",
        "Prompt": "What is the neck type of the shirt in this picture?",
    },
    {
        "Attribute": "Sleeve Length",
        "Prompt": "What is the sleeve length of the shirt in this picture?",
    },
    {
        "Attribute": "Sleeve Styling",
        "Prompt": "What is the sleeve styling of the shirt in this picture?",
    },
    {
        "Attribute": "baseColour",
        "Prompt": "What is the colour of the shirt in this picture?",
    },
]

In [74]:
vqa_data = []

for index, row in dataset.iterrows():
    for attribute in product_attributes:
        if row[attribute['Attribute']] is not np.nan:
            item_data = {}
            item_data['id'] = row['id']
            item_data['Question'] = attribute['Prompt'] 
            item_data['Answer'] = f"{attribute['Attribute']}: {row[attribute['Attribute']]}"
            vqa_data.append(item_data)

vqa = pd.DataFrame.from_records(vqa_data)
print(vqa.shape)
vqa.sample(1)

In [76]:
train = vqa.groupby("id").sample(frac=0.8,random_state=200)
test = vqa.drop(train.index)

In [None]:
# we dont want to finetune on colour, however we want to extract colour during testing/inference
print(train.shape[0])
train = train[train['Question'] != "What is the colour of the shirt in this picture?"]
print(train.shape[0])

In [78]:
train.sample()

Unnamed: 0,id,Question,Answer
44385,38642,What is the fabric of the clothing in this pic...,Fabric: Cotton


### Upload data to S3 for the finetuning job

In [79]:
# upload train and test sets
train.to_csv(f"s3://{sagemaker_session_bucket}/data/vqa_train.csv", index=False)
test.to_csv(f"s3://{sagemaker_session_bucket}/data/vqa_test.csv", index=False)

# upload images to S3
s3_client = boto3.client('s3')

for id in tqdm(tops['id']):
    s3_client.upload_file(f'/data/images/{id}.jpg', sagemaker_session_bucket, f'/data/images/{id}.jpg')

## 3. Fine-Tune BLIP-2 with LoRA on Amazon SageMaker

In [85]:
## Specify inputs to Training Jobs
inputs = f"s3://{sagemaker_session_bucket}/data/"
image_s3_uri = f"s3://{sagemaker_session_bucket}/data/images"
output_path = f"s3://{sagemaker_session_bucket}/training"

In [86]:
from sagemaker.inputs import TrainingInput

input_file = TrainingInput(s3_data=inputs, input_mode="File")
images_input = TrainingInput(s3_data=image_s3_uri, input_mode="FastFile", content_type="application/jpeg")

In [117]:
from sagemaker.huggingface import HuggingFace

hyperparameters = {
    'epochs': 10,
    'file-name': "vqa_train.csv",
}

estimator = HuggingFace(
    entry_point="entrypoint_vqa_finetuning.py",
    source_dir="../src",
    role=role,
    instance_count=1,
    instance_type="ml.g5.2xlarge", 
    transformers_version='4.26',
    pytorch_version='1.13',
    py_version='py39',
    hyperparameters = hyperparameters,
    base_job_name="VQA",
    sagemaker_session=sagemaker_session,
    output_path=f"{output_path}/models",
    code_location=f"{output_path}/code",
    volume_size=60,
    metric_definitions=[
        {'Name': 'batch_loss', 'Regex': 'Loss: ([0-9\\.]+)'},
        {'Name': 'epoch_loss', 'Regex': 'Epoch Loss: ([0-9\\.]+)'}
    ],
)

estimator.fit({"images": images_input, "input_file": input_file})

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: VQA-2024-02-07-13-03-54-497


2024-02-07 13:03:55 Starting - Starting the training job...
2024-02-07 13:03:58 Pending - Training job waiting for capacity......
2024-02-07 13:05:07 Pending - Preparing the instances for training......
2024-02-07 13:06:14 Downloading - Downloading input data...................................................
2024-02-07 13:14:52 Training - Training image download completed. Training in progress.bash: cannot set terminal process group (-1): Inappropriate ioctl for device
bash: no job control in this shell
2024-02-07 13:14:54,024 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training
2024-02-07 13:14:54,044 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)
2024-02-07 13:14:54,053 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.
2024-02-07 13:14:54,057 sagemaker_pytorch_container.training INFO     Invoking user training script.
2024-02-07 13:14:54,258 sagemaker-training-toolk

## 4. Deploy Fine-tuned BLIP-2 on Amazon SageMaker

In [95]:
from sagemaker.huggingface import HuggingFaceModel

# create Hugging Face Model Class
model = HuggingFaceModel(
   model_data=estimator.model_data,
   role=role, 
   transformers_version="4.28", 
   pytorch_version="2.0", 
   py_version="py310",
   model_server_workers=1,
   sagemaker_session=sagemaker_session
)

In [96]:
endpoint_name = "endpoint-finetuned-blip2"

model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    endpoint_name=endpoint_name
)

INFO:sagemaker:Creating model with name: huggingface-pytorch-inference-2024-02-07-12-45-48-758
INFO:sagemaker:Creating endpoint-config with name endpoint-finetuned-blip2-final
INFO:sagemaker:Creating endpoint with name endpoint-finetuned-blip2-final


----------!

<sagemaker.huggingface.model.HuggingFacePredictor at 0x168a36e50>

## 5. Run inference on the model

In [None]:
# sample one test image
sample_image_id = test.sample(1)['id'].values[0]
test_image = f"../data/{sample_image_id}.jpg"
test_image

In [101]:
import base64
import re 

smr_client = boto3.client("sagemaker-runtime")

def encode_image(img_file):
    with open(img_file, "rb") as image_file:
        img_str = base64.b64encode(image_file.read())
        base64_string = img_str.decode("latin1")
    return base64_string

def run_inference(endpoint_name, inputs):
    response = smr_client.invoke_endpoint(
        EndpointName=endpoint_name, Body=json.dumps(inputs),  ContentType="application/json"
    )
    return response["Body"].read().decode("utf-8")
    
base64_string = encode_image(test_image)


attributes = []

for product_attribute in product_attributes:
    inputs = {
        "prompt": f"Question: {product_attribute['Prompt']} Answer: ",
        "image": base64_string
    }
    response = run_inference(endpoint_name, inputs)
    attributes.append(re.sub("[\"']", "", response))
    print(response)


'Sleeve Length: Long Sleeves'

## 6. Generate product description with Amazon Bedrock

In [38]:
prompt = f"""
You are an expert in writing product descriptions for shirts. Use the data below to create product description for a website. 
The product description should contain all given attributes.
Provide some inspirational sentences, on e.g. how the fabric moves. Think about what a potential customer wants to know about the shirts. 
"""

print(prompt)


You are an expert in writing product descriptions for shirts. Use the data below to create product description for a website. 
The product description should contain all given attributes.
Provide some inspirational sentences, on e.g. how the fabric moves. Think about what a potential customer wants to know about the shirts. 



In [27]:
attributes_content =  {"role": "user", "content": f"Here are the facts you need to create the product descriptions: <product_attributes>{', '.join(attributes)}</product_attributes>"}

In [34]:
bedrock = boto3.client(    
    service_name='bedrock-runtime',
    region_name='us-west-2',
)

model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

body = json.dumps({
    "system":prompt,
    "messages": [attributes_content],
    "max_tokens": 400,
    "temperature": 0.1,
    "anthropic_version": "bedrock-2023-05-31",
})

accept = 'application/json'
contentType = 'application/json'

response = bedrock.invoke_model(
     body=body,
     modelId=model_id,
     accept=accept,
     contentType=contentType
 )

response_body = json.loads(response.get('body').read())

print(response_body['content'][0]['text'])

Classic Striped Shirt Relax into comfortable casual style with this classic collared striped shirt. With a regular fit that is neither too slim nor too loose, this versatile top layers perfectly under sweaters or jackets.


## 7. Delete resources

Delete the SageMaker endpoint and the endpoint configuration

In [None]:
client = boto3.client('sagemaker')

response = client.delete_endpoint(
    EndpointName=endpoint_name
)

response = client.delete_endpoint_config(
    EndpointConfigName=endpoint_name
)