## Step 5: Image Inpainting

### Introduction

This notebook implements image inpainting capabilities using Amazon Bedrock's image generation APIs, allowing users to replace or restore selected regions of images based on surrounding context and provided prompts. The code includes essential components such as image saving and plotting functions, an interactive masking interface with position and size controls for defining areas to inpaint, and S3 integration for retrieving images. The core functionality revolves around the generate_image() function which handles the inpainting process by accepting a reference image, constructing appropriate request parameters including text prompts and masks, and calling Amazon Bedrock's API.

The inpainting workflow follows a structured approach where users first select a source image, then define the area to replace using the interactive masking tools which create a binary mask with black pixels indicating the region to inpaint. Users provide text instructions for what should appear in the masked area, after which the AI model analyzes the unmasked portions and the prompt to generate contextually appropriate content. This implementation supports common applications such as removing unwanted objects, fixing damaged portions of images, or creatively modifying specific areas, with results displayed in a grid layout allowing for easy comparison between original and inpainted versions.

## Table of Contents

1. [Setup and Dependencies](#Setup-and-Dependencies)
2. [Save Base64 Encoded Image to File](#Save-Base64-Encoded-Image-to-File)
3. [Image Plotter for Comparing Generated and Reference Images](#Image-Plotter-for-Comparing-Generated-and-Reference-Images)
4. [Generate Mask Image](#Generate-Mask-Image)
5. [Retrieve Text Content from S3 Bucket](#Retrieve-Text-Content-from-S3-Bucket)
6. [Generate images using inpainting capabilities](#Generate-images-using-inpainting-capabilities)

***

<div class="alert alert-block alert-info">
<b>Note:</b> Ensure that you are using the python kernel <b>conda_python3</b>
</div>

## Setup and Dependencies

First, we'll import the necessary libraries and modules. We'll also import functions from our previous notebooks to avoid code duplication.

In [None]:
import base64
import io
import json
from io import BytesIO
import nbimporter

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2
import ipywidgets as widgets
from IPython.display import display, clear_output

from _00_image_processing import resize_and_encode

## Save Base64 Encoded Image to File

This function converts a base64-encoded image string into an actual image file and saves it to the specified path.

base64_image (str): A string containing the base64-encoded image data.
output_file (str): The file path where the decoded image should be saved.

Decodes the base64 string into binary image data
Creates an in-memory file object from the binary data
Opens the data as an image using PIL
Saves the image to the specified output path

In [None]:
def save_image(base64_image, output_file):
    """
    Saves a base64-encoded image to a file.
    
    Args:
        base64_image (str): The base64-encoded image data.
        output_file (str): The path where the image will be saved.
        
    Returns:
        None
    
    Requires:
        - base64
        - PIL.Image
        - io.BytesIO
    """
    image_bytes = base64.b64decode(base64_image)
    image = Image.open(io.BytesIO(image_bytes))
    image.save(output_file)

## Image Plotter for Comparing Generated and Reference Images

This function visualizes generated images alongside an optional reference image in a grid layout. It accepts a list of PIL Image objects representing generated images and can display a reference image from either a local file path or an S3 bucket. The function automatically calculates an appropriate grid layout with a maximum of 3 images per row and creates a figure with consistent image sizes. Each image is displayed with a descriptive title, with the reference image labeled as "Reference Image" and generated images numbered sequentially. The function handles potential errors when loading the reference image and organizes the layout efficiently regardless of whether a reference image is provided.

In [None]:
def plot_images(generated_images, ref_image_path=None, s3_client=None, bucket_name=None, key_name=None):
    """
    Plot the reference image (if provided) and all generated images in a grid layout.
    Reference image can be a local path or from S3.
    
    Args:
        generated_images (list): List of PIL Image objects
        ref_image_path (str, optional): Path to local reference image
        s3_client (boto3.client, optional): S3 client for accessing S3 images
        bucket_name (str, optional): S3 bucket name containing reference image
        key_name (str, optional): S3 key for the reference image
    """
    
    # Determine if we should use S3 or local file
    use_s3 = s3_client is not None and bucket_name is not None and key_name is not None
    has_reference = ref_image_path is not None or use_s3
    
    # Calculate total number of images to display
    n_generated = len(generated_images)
    n_total = n_generated + (1 if has_reference else 0)
    
    # Calculate grid dimensions
    n_cols = min(3, n_total)  # Max 3 images per row
    n_rows = (n_total + n_cols - 1) // n_cols
    
    # Create figure
    plt.figure(figsize=(5*n_cols, 5*n_rows))
    
    # Plot reference image if provided
    current_idx = 1
    if has_reference:
        try:
            if use_s3:
                # Get image from S3
                response = s3_client.get_object(Bucket=bucket_name, Key=key_name)
                image_data = response['Body'].read()
                ref_image = Image.open(BytesIO(image_data))
            else:
                # Get image from local path
                ref_image = Image.open(ref_image_path)
                
            plt.subplot(n_rows, n_cols, current_idx)
            plt.imshow(ref_image)
            plt.axis('off')
            plt.title('Reference Image')
            current_idx += 1
        except Exception as e:
            print(f"Error loading reference image: {e}")
            has_reference = False
            current_idx = 1
    
    # Plot generated images
    for i, img in enumerate(generated_images):
        plt.subplot(n_rows, n_cols, current_idx + i)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Generated Image {i+1}')
    
    plt.tight_layout()
    plt.show()

## Generate Mask Image

This code represents an interactive image masking tool designed for use in Jupyter notebooks. It allows users to create a mask for image inpainting by selecting rectangular regions of an image. The tool loads images from an AWS S3 bucket and provides an intuitive interface with sliders to adjust the position and size of the masked area.


The interface includes sliders for X and Y position coordinates and for width and height dimensions, enabling precise control over the masked region. A red rectangle visually indicates the selected area on the displayed image. The tool offers two action buttons: "Save Mask" to store the created mask alongside a resized version of the original image, and "Clear Selection" to reset the selection parameters.


When saving, the code creates a directory structure based on the image's base filename and stores both the mask and the original image at a standardized 1280x720 resolution. The mask is represented as a binary image where the selected area is black (0) and the rest is white (255), making it suitable for various image inpainting algorithms that require mask input to determine which areas need to be filled in.

In [None]:
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import io
import os
import base64
from io import BytesIO
from PIL import Image
from typing import Tuple
import cv2

%matplotlib inline

class ImageMasking:
    def __init__(self, s3_client, bucket_name, file_key, base_file_name):
        # Store base_file_name as an instance variable
        self.base_file_name = base_file_name
        
        # Read image from S3 bucket
        response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
        image_content = response['Body'].read()
        
        # Convert the image content to numpy array using OpenCV
        nparr = np.frombuffer(image_content, np.uint8)
        self.original_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        self.original_image = cv2.cvtColor(self.original_image, cv2.COLOR_BGR2RGB)
        self.height, self.width, _ = self.original_image.shape
        
        # Create mask
        self.mask = np.full((self.height, self.width), 255, dtype=np.uint8)
        
        # Create output widget for dynamic display
        self.output = widgets.Output()
        
        # Create widgets and display initial image
        self.create_widgets()
        
    def create_widgets(self):
        # Create sliders for coordinates
        self.x_slider = widgets.IntSlider(
            value=2030,
            min=0,
            max=self.width-1,
            description='X position:',
            continuous_update=True
        )

        self.y_slider = widgets.IntSlider(
            value=85,
            min=0,
            max=self.height-1,
            description='Y position:',
            continuous_update=True
        )

        self.width_slider = widgets.IntSlider(
            value=750,
            min=1,
            max=self.width,
            description='Width:',
            continuous_update=True
        )

        self.height_slider = widgets.IntSlider(
            value=820,
            min=1,
            max=self.height,
            description='Height:',
            continuous_update=True
        )
        
        # Create buttons
        self.save_button = widgets.Button(
            description='Save Mask',
            button_style='success'
        )
        self.clear_button = widgets.Button(
            description='Clear Selection',
            button_style='danger'
        )
        
        # Add callbacks
        self.save_button.on_click(self.save_mask)
        self.clear_button.on_click(self.clear_selection)
        
        # Add observers for sliders
        self.x_slider.observe(self.on_value_change, names='value')
        self.y_slider.observe(self.on_value_change, names='value')
        self.width_slider.observe(self.on_value_change, names='value')
        self.height_slider.observe(self.on_value_change, names='value')
        
        # Display widgets and initial image
        display(widgets.VBox([
            self.x_slider,
            self.y_slider,
            self.width_slider,
            self.height_slider,
            widgets.HBox([self.save_button, self.clear_button]),
            self.output
        ]))
        
        # Show initial image
        self.update_display()
        
    def on_value_change(self, change):
        self.update_display()
        
    def update_display(self):
        with self.output:
            clear_output(wait=True)
            
            # Create a copy of the original image
            display_img = self.original_image.copy()
            
            # Get current coordinates
            x = self.x_slider.value
            y = self.y_slider.value
            w = self.width_slider.value
            h = self.height_slider.value
            
            # Draw rectangle on image
            cv2.rectangle(display_img, (x, y), (x+w, y+h), (0, 0, 255), 2)
            
            # Update mask
            self.mask.fill(255)
            self.mask[y:y+h, x:x+w] = 0
            
            # Display image with rectangle
            plt.figure(figsize=(10, 10))
            plt.imshow(display_img)
            plt.axis('on')
            plt.show()
        
    def clear_selection(self):
        # Reset sliders to initial values
        self.x_slider.value = self.width//4
        self.y_slider.value = self.height//4
        self.width_slider.value = self.width//4
        self.height_slider.value = self.height//4
        self.mask.fill(255)
        self.update_display()

    def save_mask(self):
        # Create directory if it doesn't exist
        save_dir = f'images/inpainting/{self.base_file_name}'
        os.makedirs(save_dir, exist_ok=True)
        
        # Create temporary image files to use with resize_and_encode
        temp_mask_path = os.path.join(save_dir, 'temp_mask.png')
        temp_image_path = os.path.join(save_dir, 'temp_original.png')
        
        # Save temporary files
        cv2.imwrite(temp_mask_path, self.mask)
        cv2.imwrite(temp_image_path, cv2.cvtColor(self.original_image, cv2.COLOR_RGB2BGR))
        
        # Use resize_and_encode for the mask
        try:
            # Resize and encode the mask
            mask_base64 = resize_and_encode(temp_mask_path, (1280, 720))
            
            # Decode the base64 string and save the mask
            mask_data = base64.b64decode(mask_base64)
            with open(os.path.join(save_dir, 'mask.png'), 'wb') as f:
                f.write(mask_data)
                
            # Resize and encode the original image
            image_base64 = resize_and_encode(temp_image_path, (1280, 720))
            
            # Decode the base64 string and save the image
            image_data = base64.b64decode(image_base64)
            with open(os.path.join(save_dir, 'original.png'), 'wb') as f:
                f.write(image_data)
                
            # Remove temporary files
            os.remove(temp_mask_path)
            os.remove(temp_image_path)
            
            with self.output:
                print(f"Mask saved to: {os.path.join(save_dir, 'mask.png')}")
                print(f"Resized original image saved to: {os.path.join(save_dir, 'original.png')}")
                
        except Exception as e:
            with self.output:
                print(f"Error saving images: {str(e)}")

## Retrieve Text Content from S3 Bucket

This function reads and returns the text content from a file stored in an Amazon S3 bucket. It takes three parameters: the name of the S3 bucket, the file key (path within the bucket), and an initialized S3 client object. The function attempts to retrieve the specified file using the S3 client's get_object method, decodes the content as UTF-8 text, and returns it as a string. If any errors occur during this process, such as file not found or permission issues, the function catches the exception, prints an error message, and returns None. The function requires the boto3 library and properly configured AWS credentials, and assumes that an S3 client object has already been created and passed as an argument.

In [None]:
def read_s3_text(bucket_name, file_key, s3_client):
    """
    Read text content from a file stored in an Amazon S3 bucket.
    
    Args:
        bucket_name (str): The name of the S3 bucket.
        file_key (str): The key (path) of the file within the bucket.
        
    Returns:
        str or None: The text content of the file if successful, None if an error occurs.
        
    Example:
        >>> content = read_s3_text('my-bucket', 'documents/text.txt')
        >>> if content:
        ...     print(f"File content length: {len(content)} characters")
        
    Requires:
        - boto3 (with properly configured AWS credentials)
        - An initialized s3_client object
    
    Note:
        This function assumes an s3_client object has been created with boto3.client('s3')
        before calling this function.
    """
    try:
        response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
        text_content = response['Body'].read().decode('utf-8')
        return text_content
    except Exception as e:
        print(f"Error reading file from S3: {e}")
        return None

## Generate images using inpainting capabilities

This function leverages Amazon Bedrock's image generation models to perform inpainting tasks - the process of replacing portions of an image with generated content. It allows users to modify existing images by providing a text prompt and either a mask prompt (text description of area to modify) or a mask image (binary image defining the area).

In [None]:
def generate_image(
    image_base64_s3_key,
    image_s3_key,
    bucket,
    s3_client,
    bedrock_runtime_client,
    image_generation_model_id,
    output_image_path,
    text, 
    negative_text, 
    seed,
    mask_prompt=None,
    mask_image_path=None,
    num_images=1,
    cfg_scale=10,
    quality="standard",
    save_all=False,
    plot_results=True
):
    """
    Generate images using Amazon Bedrock's image generation model with inpainting capabilities.
    
    Args:
        image_base64_s3_key (str): S3 key for the base64-encoded reference image
        image_s3_key (str): S3 key for the original reference image (for display)
        bucket (str): S3 bucket name
        output_image_path (str): Path to save the generated image(s)
        bedrock_runtime_client: Bedrock runtime client
        text (str): Text prompt for generation
        negative_text (str): Negative text prompt
        seed (int): Seed for generation (0-858993459)
        mask_prompt (str, optional): Text prompt for mask generation
        mask_image_path (str, optional): Path to mask image
        num_images (int, optional): Number of images to generate (1-5)
        cfg_scale (int, optional): How closely to follow the prompt (default: 10)
        quality (str, optional): Either "standard" or "premium"
        save_all (bool, optional): Whether to save all generated images
        plot_results (bool, optional): Whether to plot the results
    """

    # Read the base64-encoded reference image from S3
    reference_image_base64 = read_s3_text(bucket, image_base64_s3_key, s3_client)

    # Prepare inpainting parameters
    inpainting_params = {
        "text": text,
        "negativeText": negative_text,
        "image": reference_image_base64,
    }

    # Add either mask_prompt or mask_image based on which is provided
    if mask_prompt and mask_image_path:
        raise ValueError("Please provide either mask_prompt or mask_image_path, not both")
    elif mask_prompt:
        inpainting_params["maskPrompt"] = mask_prompt
    elif mask_image_path:
        # Read and encode the mask image
        with open(mask_image_path, "rb") as image_file:
            mask_image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
        inpainting_params["maskImage"] = mask_image_base64

    # Prepare request body
    body = json.dumps({
        "taskType": "INPAINTING",
        "inPaintingParams": inpainting_params,
        "imageGenerationConfig": {
            "numberOfImages": num_images,
            "cfgScale": cfg_scale,
            "seed": seed,
            "quality": quality,
        },
    })

    print("Generating image(s)...")

    try:
        response = bedrock_runtime_client.invoke_model(
            body=body,
            modelId=image_generation_model_id,
            accept="application/json",
            contentType="application/json",
        )

        response_body = json.loads(response.get("body").read())
        base64_images = response_body.get("images")

        # Save images
        if save_all:
            # Save all generated images with numbered suffixes
            for i, base64_img in enumerate(base64_images):
                output_path = output_image_path.rsplit('.', 1)
                numbered_path = f"{output_path[0]}_{i+1}.{output_path[1]}"
                save_image(base64_img, numbered_path)
                print(f"Image saved to {numbered_path}")
        else:
            # Save only the first image
            save_image(base64_images[0], output_image_path)
            print(f"Image saved to {output_image_path}")

        # Convert base64 images to PIL Images
        response_images = [
            Image.open(io.BytesIO(base64.b64decode(base64_image)))
            for base64_image in base64_images
        ]

        # Plot the images if requested
        if plot_results:
            plot_images(
                generated_images=response_images, 
                s3_client=s3_client,
                bucket_name=bucket, 
                key_name=image_s3_key
            )

        return response_images

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None

<div class="alert alert-success">
<b>🎉 Congratulations!</b> You have successfully completed the inpainting notebook!

Key accomplishments:
- ✅ Defined functions used for the next notebook

<b>Note:</b> For this section, no media is generated.
</div>
