In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import numpy as np
from PIL import Image
import textwrap

def create_comparison_grid(images_dict, output_path, row_titles, num_methods):
    """
    Creates a grid of image comparisons with labels and arrows, loading and placing images directly.
    Images within each row will have no spacing between them.
    
    Args:
        images_dict: Dictionary containing paths to images for each row
        output_path: Path to save the final composite image
        row_titles: List of tuples containing (source_text, target_text) for each row
        num_methods: Number of methods/columns in the grid
    """
    # Calculate figure dimensions based on first row of images
    first_row_images = list(images_dict.values())[0]
    first_image = Image.open(first_row_images[0])
    img_width, img_height = first_image.size
    
    # Set up the figure with appropriate spacing
    # Scale figure size based on image dimensions and number of images
    fig_width = num_methods * (img_width / 145)  # Removed +1 to tighten spacing
    fig_height = len(row_titles) * (img_height / 110)
    fig = plt.figure(figsize=(fig_width, fig_height))
    
    height_ratios = [1] * len(row_titles)
    gs = gridspec.GridSpec(len(row_titles), 1, height_ratios=height_ratios, hspace=0.08, wspace=0.0)
    
    # For each row
    for row_idx, ((row_name, image_paths), (source_text, target_text)) in enumerate(zip(images_dict.items(), row_titles)):
        # Create subplot for this row
        ax = fig.add_subplot(gs[row_idx])
        
       # Wrap text if too long (adjust width as needed)
        wrapped_source = textwrap.fill(source_text, width=45)
        wrapped_target = textwrap.fill(target_text, width=45)
        
        if row_idx == 0:
            
            # Position text without creating the bar effect
            plt.text(0.02, 0.95, wrapped_source, transform=ax.transAxes, fontsize=20, 
                    verticalalignment='top', horizontalalignment='left',family='serif',style='italic',fontweight="bold")
            plt.text(0.48, 0.95, "→", transform=ax.transAxes, fontsize=20, 
                    horizontalalignment='center', verticalalignment='top',family='serif',style='italic',fontweight="bold")
            plt.text(0.5, 0.95, wrapped_target, transform=ax.transAxes, fontsize=20, 
                    verticalalignment='top', horizontalalignment='left',family='serif',style='italic',fontweight="bold")
        else:
            plt.text(0.02, 1.05, wrapped_source, transform=ax.transAxes, fontsize=20, 
                    verticalalignment='top', horizontalalignment='left',family='serif',style='italic',fontweight="bold")
            plt.text(0.48, 1.05, "→", transform=ax.transAxes, fontsize=20, 
                    horizontalalignment='center', verticalalignment='top',family='serif',style='italic',fontweight="bold")
            plt.text(0.5, 1.05, wrapped_target, transform=ax.transAxes, fontsize=20, 
                    verticalalignment='top', horizontalalignment='left',family='serif',style='italic',fontweight="bold")


        # # Position text without creating the bar effect
        # plt.text(0.02, 1.02, wrapped_source, transform=ax.transAxes, fontsize=20, 
        #         verticalalignment='top', horizontalalignment='left',family='serif',style='italic',fontweight="bold")
        # plt.text(0.45, 1.02, "→", transform=ax.transAxes, fontsize=20, 
        #         horizontalalignment='center', verticalalignment='top',family='serif',style='italic',fontweight="bold")
        # plt.text(0.5, 1.02, wrapped_target, transform=ax.transAxes, fontsize=20, 
        #         verticalalignment='top', horizontalalignment='left',family='serif',style='italic',fontweight="bold")


        # Place images with no spacing between them
        for i, img_path in enumerate(image_paths):
            if os.path.exists(img_path):
                # Calculate image position - removed spacing between images
                # if i==2 or i==4:
                #     spacing_reduction = 0.021
                # elif i==3:
                #     spacing_reduction = 0.018
                # else:
                #         spacing_reduction = 0.01
                spacing_reduction = 0.044
                img_pos = [i/num_methods - spacing_reduction * i, 0, 1/num_methods, 0.85]
                
                # Load and display image
                img = Image.open(img_path)
                img_array = np.array(img)
                
                # Create new axis for image within the subplot
                img_ax = plt.axes([ax.get_position().x0 + img_pos[0],
                                 ax.get_position().y0 + img_pos[1],
                                 ax.get_position().width * img_pos[2],
                                 ax.get_position().height * img_pos[3]])
                
                img_ax.imshow(img_array)
                img_ax.axis('off')
                
                # Add model labels below images only for the last row
                if row_idx == len(row_titles)-1:
                    labels = ['(a)\nSource Image', '(b)\nDDIM+P2P', '(c)\nDDIM+MasaCtrl', '(d)\nDDIM+Pix2Pix-Zero', '(e)\nDirectInversion+P2P', '(f)\nOurs+P2P']
                            #  '(c)\nSD2.1 Reconstruction', '(d)\nSD1.4 Edited Image',
                            #  '(e)\nSD2.1 Edited Image', '(f)\nOurs + P2P']
                    plt.text(0.5, -0.1, labels[i], ha='center', va='top',
                            transform=img_ax.transAxes, fontsize=15,fontweight="bold")
        
        ax.set_xticks([])
        ax.set_yticks([])
        # Set spines visibility
        for spine in ax.spines.values():
            spine.set_visible(False)
    
    plt.savefig(output_path, bbox_inches='tight',pad_inches=0.3,format='png')
    plt.close()

# Example usage FULL COMPARISION GRID:
if __name__ == "__main__":
    import os

    # Define the base directory paths
    base_annotation_path = "/home/abhi2358/code/project/PnPInversion"
    base_output_path = "/home/abhi2358/code/project/PnPInversion/output_images"

    # List of methods
    methods = [
        "1_ddim+p2p",
        "1_ddim+masactrl",
        "1_ddim+pix2pix-zero",
        "1_directinversion+p2p",
        "1_ddim+fgps+p2p"
    ]

    # List of subtypes
    subtypes = [
        "0_random_140",
        "1_change_object_80",
        # "2_add_object_80",
        "3_delete_object_80",
        "4_change_attribute_content_40",
        "7_change_attribute_material_40",
        # "8_change_background_80",
        # "9_change_style_80"
    ]
    
    specific_images = ["000000000018.jpg","122000000009.jpg",
                    #    "223000000002.jpg",
                      "324000000005.jpg", 
                      "412000000004.jpg", "712000000000.jpg"
                    #   "813000000004.jpg","922000000000.jpg"
                      ]

    # Read prompts from CSV
    import csv
    csv_file_path = "/home/abhi2358/code/project/PnPInversion/prompts.csv"
    row_titles = []

    base_path = []
    
    with open(csv_file_path, 'r', encoding='utf-8') as csvfile:
        csv_reader = csv.reader(csvfile)
        next(csv_reader)  # Skip header
        for row in csv_reader:
            row_titles.append((row[1], row[2]))  # original_prompt, editing_prompt
            base_path.append(row[0])
    
    # Generate rows and populate images_dict
    images_dict = {}
    for index, (subtype, image_path) in enumerate(zip(subtypes, specific_images)):
        # Path to the source image
        # if index ==2 or index>=6:
        #     index = index+1
        source_image_path = os.path.join(base_annotation_path, base_path[index])
        
        #find the "/" index and reove the first two from base_path
        immm = base_path[index][ base_path[index].find("/")+1:]
        immm = immm[immm.find("/")+1:]
        immm = immm[immm.find("/")+1:]

        
        # Generate paths for each method
        method_paths = []
        for method in methods:
            method_image_path = os.path.join(base_output_path, method, 
                                           "annotation_images", subtype, immm)
            method_paths.append(method_image_path)
        
        # Combine source image path and method paths
        #add base_annotation image path to base_path

        images_dict[f"row{index + 1}"] = [source_image_path] + method_paths

    
    
    # Create the grid with images
    create_comparison_grid(images_dict=images_dict, 
                         output_path="final_comparison_grid.png",
                         row_titles=row_titles,
                         num_methods=6)



In [None]:


# Example usage SD2:
if __name__ == "__main__":
    import os

    # Define the base directory paths
    base_annotation_path = "/home/abhi2358/code/project/PnPInversion"
    base_output_path_tgt_sd = "/home/abhi2358/code/project/PnPInversion/output_images_tgt1"
    base_output_path_src_sd = "/home/abhi2358/code/project/PnPInversion/output_images_src1"
    
    base_output_path_tgt_final = "/home/abhi2358/code/project/PnPInversion/output_images_tgt2"
    base_output_path_src_final = "/home/abhi2358/code/project/PnPInversion/output_images_src2"
    

    # List of methods
    methods = [
        "1_ddim+p2p",
        "1_ddim+masactrl",
        "1_ddim+pix2pix-zero",
        "1_directinversion+p2p",
    ]

    # List of subtypes
    subtypes = [
        "0_random_140",
        "5_change_attribute_pose_40",
        "9_change_style_80",
        "0_random_140"
    ]
    
    specific_images = ["000000000026.jpg","522000000003.jpg","911000000009.jpg","000000000015.jpg"]

    # Read prompts from CSV
    import csv
    csv_file_path = "/home/abhi2358/code/project/PnPInversion/prompts2.csv"
    row_titles = []

    base_path = []
    
    with open(csv_file_path, 'r', encoding='utf-8') as csvfile:
        csv_reader = csv.reader(csvfile)
        next(csv_reader)  # Skip header
        for row in csv_reader:
            row_titles.append((row[1], row[2]))  # original_prompt, editing_prompt
            base_path.append(row[0])
    
    # Generate rows and populate images_dict
    images_dict = {}
    for index, (subtype, image_path) in enumerate(zip(subtypes, specific_images)):
        # Path to the source image
        source_image_path = os.path.join(base_annotation_path, base_path[index])
        
        #find the "/" index and reove the first two from base_path
        immm = base_path[index][ base_path[index].find("/")+1:]
        immm = immm[immm.find("/")+1:]
        immm = immm[immm.find("/")+1:]

        
        # Generate paths for each method
        method_paths = []
        # for method in methods:
        method_image_path = os.path.join(base_output_path_src_final, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)
        # for method in methods:
        method_image_path = os.path.join(base_output_path_src_sd, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)
        
        # for method in methods:
        method_image_path = os.path.join(base_output_path_tgt_final, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)
        # for method in methods:
        method_image_path = os.path.join(base_output_path_tgt_sd, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)

        # Combine source image path and method paths
        #add base_annotation image path to base_path

        images_dict[f"row{index + 1}"] = [source_image_path] + method_paths

    
    
    # Create the grid with images
    create_comparison_grid(images_dict=images_dict, 
                         output_path="sd.pdf",
                         row_titles=row_titles,
                         num_methods=5)

In [None]:

# Example usage DDIM vs ours:
if __name__ == "__main__":
    import os

    # Define the base directory paths
    base_annotation_path = "/home/abhi2358/code/project/PnPInversion"
    base_output_path_tgt = "/home/abhi2358/code/project/PnPInversion/output_images_tgt"
    base_output_path_src = "/home/abhi2358/code/project/PnPInversion/output_images_src"
    

    # List of methods
    methods = [
        "1_ddim+p2p",
        "1_ddim+fgps+p2p"
    ]

    # List of subtypes
    subtypes = [
        "0_random_140",
        "0_random_140",
        "0_random_140",
        "0_random_140"
    ]
    
    specific_images = ["000000000002.jpg","000000000005.jpg","000000000016.jpg","000000000050.jpg"]

    # Read prompts from CSV
    import csv
    csv_file_path = "/home/abhi2358/code/project/PnPInversion/prompts1.csv"
    row_titles = []

    base_path = []
    
    with open(csv_file_path, 'r', encoding='utf-8') as csvfile:
        csv_reader = csv.reader(csvfile)
        next(csv_reader)  # Skip header
        for row in csv_reader:
            row_titles.append((row[1], row[2]))  # original_prompt, editing_prompt
            base_path.append(row[0])
    
    # Generate rows and populate images_dict
    images_dict = {}
    for index, (subtype, image_path) in enumerate(zip(subtypes, specific_images)):
        # Path to the source image
        source_image_path = os.path.join(base_annotation_path, base_path[index])
        
        #find the "/" index and reove the first two from base_path
        immm = base_path[index][ base_path[index].find("/")+1:]
        immm = immm[immm.find("/")+1:]
        immm = immm[immm.find("/")+1:]

        
        # Generate paths for each method
        method_paths = []
        for method in methods:
            method_image_path = os.path.join(base_output_path_src, method, 
                                           "annotation_images", subtype, immm)
            method_paths.append(method_image_path)
        for method in methods:
            method_image_path = os.path.join(base_output_path_tgt, method, 
                                           "annotation_images", subtype, immm)
            method_paths.append(method_image_path)
        
        # Combine source image path and method paths
        #add base_annotation image path to base_path

        images_dict[f"row{index + 1}"] = [source_image_path] + method_paths

    
    
    # Create the grid with images
    create_comparison_grid(images_dict=images_dict, 
                         output_path="cmp.pdf",
                         row_titles=row_titles,
                         num_methods=5)

In [180]:
row_titles

[('a cat sitting on a wooden chair', 'a dog sitting on a wooden chair'),
 ('an orange cat sitting on top of a fence',
  'an black cat sitting on top of a fence'),
 ('a plate with steak on it', 'a plate with salmon on it'),
 ('a bird standing on clods', 'a bird standing on eggs')]

In [None]:
# Example usage FULL COMPARISION GRID:
if __name__ == "__main__":
    import os

    # Define the base directory paths
    base_annotation_path = "/home/abhi2358/code/project/PnPInversion"
    base_output_path = "/home/abhi2358/code/project/PnPInversion/output_images"

    # List of methods
    methods = [
        "1_ddim+p2p",
        "1_ddim+masactrl",
        "1_ddim+pix2pix-zero",
        "1_directinversion+p2p",
        "1_ddim+fgps+p2p"
    ]

    # List of subtypes
    subtypes = [
        "0_random_140",
        "1_change_object_80",
        "2_add_object_80",
        "3_delete_object_80",
        "4_change_attribute_content_40",
        "7_change_attribute_material_40",
        "8_change_background_80",
        "9_change_style_80"
    ]
    
    specific_images = ["000000000018.jpg","122000000009.jpg","223000000002.jpg",
                      "324000000005.jpg", "412000000004.jpg", "712000000000.jpg",
                      "813000000004.jpg","922000000000.jpg"]

    # Read prompts from CSV
    import csv
    csv_file_path = "/home/abhi2358/code/project/PnPInversion/prompts.csv"
    row_titles = []

    base_path = []
    
    with open(csv_file_path, 'r', encoding='utf-8') as csvfile:
        csv_reader = csv.reader(csvfile)
        next(csv_reader)  # Skip header
        for row in csv_reader:
            row_titles.append((row[1], row[2]))  # original_prompt, editing_prompt
            base_path.append(row[0])
    
    # Generate rows and populate images_dict
    images_dict = {}
    for index, (subtype, image_path) in enumerate(zip(subtypes, specific_images)):
        # Path to the source image
        source_image_path = os.path.join(base_annotation_path, base_path[index])
        
        #find the "/" index and reove the first two from base_path
        immm = base_path[index][ base_path[index].find("/")+1:]
        immm = immm[immm.find("/")+1:]
        immm = immm[immm.find("/")+1:]

        
        # Generate paths for each method
        method_paths = []
        for method in methods:
            method_image_path = os.path.join(base_output_path, method, 
                                           "annotation_images", subtype, immm)
            method_paths.append(method_image_path)
        
        # Combine source image path and method paths
        #add base_annotation image path to base_path

        images_dict[f"row{index + 1}"] = [source_image_path] + method_paths

    
    
    # Create the grid with images
    create_comparison_grid(images_dict=images_dict, 
                         output_path="final_comparison_grid.pdf",
                         row_titles=row_titles,
                         num_methods=6)

In [None]:

# Example usage SD2:
if __name__ == "__main__":
    import os

    # Define the base directory paths
    base_annotation_path = "/home/abhi2358/code/project/PnPInversion"
    base_output_path_tgt_sd = "/home/abhi2358/code/project/PnPInversion/output_images_tgt1"
    base_output_path_src_sd = "/home/abhi2358/code/project/PnPInversion/output_images_src1"
    
    base_output_path_tgt_final = "/home/abhi2358/code/project/PnPInversion/output_images_tgt2"
    base_output_path_src_final = "/home/abhi2358/code/project/PnPInversion/output_images_src2"
    

    # List of methods
    methods = [
        "1_ddim+p2p",
        "1_ddim+masactrl",
        "1_ddim+pix2pix-zero",
        "1_directinversion+p2p",
    ]

    # List of subtypes
    subtypes = [
        "0_random_140",
        "5_change_attribute_pose_40",
        "9_change_style_80",
        "0_random_140"
    ]
    
    specific_images = ["000000000026.jpg","522000000003.jpg","911000000009.jpg","000000000015.jpg"]

    # Read prompts from CSV
    import csv
    csv_file_path = "/home/abhi2358/code/project/PnPInversion/prompts2.csv"
    row_titles = []

    base_path = []
    
    with open(csv_file_path, 'r', encoding='utf-8') as csvfile:
        csv_reader = csv.reader(csvfile)
        next(csv_reader)  # Skip header
        for row in csv_reader:
            row_titles.append((row[1], row[2]))  # original_prompt, editing_prompt
            base_path.append(row[0])
    
    # Generate rows and populate images_dict
    images_dict = {}
    for index, (subtype, image_path) in enumerate(zip(subtypes, specific_images)):
        # Path to the source image
        source_image_path = os.path.join(base_annotation_path, base_path[index])
        
        #find the "/" index and reove the first two from base_path
        immm = base_path[index][ base_path[index].find("/")+1:]
        immm = immm[immm.find("/")+1:]
        immm = immm[immm.find("/")+1:]

        
        # Generate paths for each method
        method_paths = []
        # for method in methods:
        method_image_path = os.path.join(base_output_path_src_final, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)
        # for method in methods:
        method_image_path = os.path.join(base_output_path_src_sd, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)
        
        # for method in methods:
        method_image_path = os.path.join(base_output_path_tgt_final, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)
        # for method in methods:
        method_image_path = os.path.join(base_output_path_tgt_sd, methods[index], 
                                        "annotation_images", subtype, immm)
        method_paths.append(method_image_path)

        # Combine source image path and method paths
        #add base_annotation image path to base_path

        images_dict[f"row{index + 1}"] = [source_image_path] + method_paths

    
    
    # Create the grid with images
    create_comparison_grid(images_dict=images_dict, 
                         output_path="sd.png",
                         row_titles=row_titles,
                         num_methods=5)

In [224]:
images_dict

{'row1': ['/home/abhi2358/code/project/PnPInversion/data/annotation_images/0_random_140/000000000015.jpg',
  '/home/abhi2358/code/project/PnPInversion/output_images_src2/1_ddim+p2p/annotation_images/0_random_140/000000000015.jpg',
  '/home/abhi2358/code/project/PnPInversion/output_images_src1/1_ddim+p2p/annotation_images/0_random_140/000000000015.jpg',
  '/home/abhi2358/code/project/PnPInversion/output_images_tgt2/1_ddim+p2p/annotation_images/0_random_140/000000000015.jpg',
  '/home/abhi2358/code/project/PnPInversion/output_images_tgt1/1_ddim+p2p/annotation_images/0_random_140/000000000015.jpg'],
 'row2': ['/home/abhi2358/code/project/PnPInversion/data/annotation_images/0_random_140/000000000026.jpg',
  '/home/abhi2358/code/project/PnPInversion/output_images_src2/1_ddim+masactrl/annotation_images/5_change_attribute_pose_40/000000000026.jpg',
  '/home/abhi2358/code/project/PnPInversion/output_images_src1/1_ddim+masactrl/annotation_images/5_change_attribute_pose_40/000000000026.jpg',
  '