# Pipeline for locally accessing Inpainting API by stability AI.

## 0. Install Required Libraries

In [None]:
!pip install -q pillow requests ipython

## 1. Import Required Libraries

In [1]:
import os
from PIL import Image
import requests
import time

import getpass
STABILITY_KEY = getpass.getpass('Enter your API Key: ')
# STABILITY_KEY = os.environ['STABILITY_KEY']

## 2. Define Helper Functions

In [3]:
def send_generation_request(host, params):
    headers = {
        "Accept": "image/*",
        "Authorization": f"Bearer {STABILITY_KEY}"
    }

    # Encode parameters
    files = {}
    image = params.pop("image", None)
    mask = params.pop("mask", None)
    if image:
        files["image"] = open(image, 'rb')
    if mask:
        files["mask"] = open(mask, 'rb')
    if not files:
        files["none"] = ''

    # Send request
    print(f"Sending REST request to {host}...")
    response = requests.post(
        host,
        headers=headers,
        files=files,
        data=params
    )
    if not response.ok:
        raise Exception(f"HTTP {response.status_code}: {response.text}")

    return response

def send_async_generation_request(host, params):
    headers = {
        "Accept": "application/json",
        "Authorization": f"Bearer {STABILITY_KEY}"
    }

    # Encode parameters
    files = {}
    if "image" in params:
        image = params.pop("image")
        files = {"image": open(image, 'rb')}

    # Send request
    print(f"Sending REST request to {host}...")
    response = requests.post(
        host,
        headers=headers,
        files=files,
        data=params
    )
    if not response.ok:
        raise Exception(f"HTTP {response.status_code}: {response.text}")

    # Process async response
    response_dict = response.json()
    generation_id = response_dict.get("id")
    if not generation_id:
        raise Exception("Expected 'id' in response")

    # Loop until result or timeout
    timeout = int(os.getenv("WORKER_TIMEOUT", 500))
    start = time.time()
    status_code = 202
    while status_code == 202:
        response = requests.get(
            f"{host}/result/{generation_id}",
            headers={
                **headers,
                "Accept": "image/*"
            },
        )

        if not response.ok:
            raise Exception(f"HTTP {response.status_code}: {response.text}")
        status_code = response.status_code
        time.sleep(10)
        if time.time() - start > timeout:
            raise Exception(f"Timeout after {timeout} seconds")

    return response


## 3. Perform Batch Inpainting

In [14]:
def process_images(input_folder, mask_folder, output_folder, prompt, negative_prompt="", seed=0, output_format="png"):
    # Maximum allowed dimensions
    MAX_PIXELS = 9437184  # 3072 x 3072 pixels
    
    # Ensure output folder exists
    os.makedirs(output_folder, exist_ok=True)
    
    # Iterate through all files in the input folder
    for image_file in os.listdir(input_folder):
        # Construct the paths for the input image and the corresponding mask
        image_path = os.path.join(input_folder, image_file)
        mask_path = os.path.join(mask_folder, os.path.splitext(image_file)[0] + "_mask.png")  # Assuming mask has same name with "_mask" suffix
        
        # # Skip old files (getting index part)
        # try:
        #     index = int(image_file.split('_')[-1].split('.')[0])
        #     if index not in [16, 18, 23, 25, 26, 27, 29]:
        #         print(f"Skipping image {image_file}.")
        #         continue
        # except:
        #     print("The file name doesn't have index at the correct position.")
        #     pass
        
        # Skip if mask doesn't exist
        if not os.path.exists(mask_path):
            print(f"Mask not found for {image_file}, skipping...")
            continue
        
        # Open the image to check dimensions
        with Image.open(image_path) as img:
            width, height = img.size
            if width * height > MAX_PIXELS:
                scale_factor = (MAX_PIXELS / (width * height)) ** 0.5
                new_width = int(width * scale_factor)
                new_height = int(height * scale_factor)
                img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
                
                # Save the resized image temporarily
                resized_image_path = os.path.join(output_folder, f"resized_{image_file}")
                img.save(resized_image_path)
                print(f"Resized {image_file} to {new_width}x{new_height}")
            else:
                resized_image_path = image_path  # No resizing needed

        # Parameters for the API
        params = {
            "image": resized_image_path,
            "mask": mask_path,
            "negative_prompt": negative_prompt,
            "seed":seed,
            "mode": "mask",
            "output_format": output_format,
            "prompt": prompt,
            "denoising_strength": 0.6
        }
        
        # Call the API
        response = send_generation_request("https://api.stability.ai/v2beta/stable-image/edit/inpaint", params)
        
        # Decode response
        output_image = response.content
        finish_reason = response.headers.get("finish-reason")
        seed = response.headers.get("seed")
        
        # Check for NSFW classification
        if finish_reason == 'CONTENT_FILTERED':
            print(f"Generation for {image_file} failed due to NSFW classification.")
            continue
        
        # Save the result image
        filename_list = image_file.split('_')
        filename_list.insert(1, 'White') # add race
        edited_filename = '_'.join(filename_list)
        edited_path = os.path.join(output_folder, edited_filename)
        
        with open(edited_path, "wb") as f:
            f.write(output_image)
        
        print(f"Saved image to {edited_path}")


In [None]:
import random
# input_folder = "../images/Myanmar_Festivals/original_images"
# mask_folder = "../images/Myanmar_Festivals/masks"
# output_folder = "../images/Myanmar_Festivals/synthesized_images/White"
input_folder = "/Users/junseongkim/Desktop/Data/Korean_Clothes/original_images"
mask_folder = "/Users/junseongkim/Desktop/Data/Korean_Clothes/masks"
output_folder = "/Users/junseongkim/Desktop/Data/Korean_Clothes/synthesized_images/White"
prompt = "White Person wearing the clothes"
seed = random.randint(0, 1000000)
negative_prompt = "blurry, grey, monochrome, low quality, low detail, deformed, washed out"
output_format = "png"

process_images(input_folder, mask_folder, output_folder, prompt, negative_prompt, seed, output_format)

Sending REST request to https://api.stability.ai/v2beta/stable-image/edit/inpaint...
Saved image to /Users/junseongkim/Desktop/Data/Korean_Clothes/synthesized_images/White/Korea_White_Clothes_1.jpg
Sending REST request to https://api.stability.ai/v2beta/stable-image/edit/inpaint...
Saved image to /Users/junseongkim/Desktop/Data/Korean_Clothes/synthesized_images/White/Korea_White_Clothes_33.jpg
Sending REST request to https://api.stability.ai/v2beta/stable-image/edit/inpaint...


KeyboardInterrupt: 

## Individual Inpainting

In [None]:
input_folder = "/Volumes/JUN\'S\ DRIVE/Data/Korean_Clothes/original_images"
mask_folder = "/Volumes/JUN\'S\ DRIVE/Data/Korean_Clothes/masks"
output_folder = "/Volumes/JUN\'S\ DRIVE/Data/Korean_Clothes/synthesized_images"
prompt = "British young ladies dancing"
negative_prompt = ""
output_format = "png"

image_file = "Myanmar_festival_26.png"
MAX_PIXELS = 9437184  # 3072 x 3072 pixels
# seed = 50
# Construct the paths for the input image and the corresponding mask
image_path = os.path.join(input_folder, image_file)
mask_path = os.path.join(mask_folder, os.path.splitext(image_file)[0] + "_mask.png")  # Assuming mask has same name with "_mask" suffix

# # Skip if mask doesn't exist
# if not os.path.exists(mask_path):
#     print(f"Mask not found for {image_file}, skipping...")
#     continue

# Open the image to check dimensions
with Image.open(image_path) as img:
    width, height = img.size
    if width * height > MAX_PIXELS:
        scale_factor = (MAX_PIXELS / (width * height)) ** 0.5
        new_width = int(width * scale_factor)
        new_height = int(height * scale_factor)
        img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)

        # Save the resized image temporarily
        resized_image_path = os.path.join(output_folder, f"resized_{image_file}")
        img.save(resized_image_path)
        print(f"Resized {image_file} to {new_width}x{new_height}")
    else:
        resized_image_path = image_path  # No resizing needed

# Parameters for the API
params = {
    "image": resized_image_path,
    "mask": mask_path,
    "grow_mask":20,
    "steps":60,
    "negative_prompt": negative_prompt,
    "mode": "mask",
    "output_format": output_format,
    "prompt": prompt
}

# Call the API
response = send_generation_request("https://api.stability.ai/v2beta/stable-image/edit/inpaint", params)

# Decode response
output_image = response.content
finish_reason = response.headers.get("finish-reason")
seed = response.headers.get("seed")

# # Check for NSFW classification
# if finish_reason == 'CONTENT_FILTERED':
#     print(f"Generation for {image_file} failed due to NSFW classification.")
#     continue

# Save the result image
filename, _ = os.path.splitext(image_file)
edited_filename = f"{filename}_revised.{output_format}"
edited_path = os.path.join(output_folder, edited_filename)

with open(edited_path, "wb") as f:
    f.write(output_image)

print(f"Saved image to {edited_path}")

Sending REST request to https://api.stability.ai/v2beta/stable-image/edit/inpaint...
Saved image to ../images/Myanmar_Festivals/synthesized_images/White/Myanmar_festival_26_revised.png
