In [None]:
import os
import keras
import keras_nlp
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import glob
from tqdm import tqdm

# Set the Keras backend to JAX and configure float precision
os.environ["KERAS_BACKEND"] = "jax"
keras.config.set_floatx("bfloat16")

In [None]:
# Load pre-trained models - Paligemma
paligemma = keras_nlp.models.PaliGemmaCausalLM.from_preset("pali_gemma_3b_mix_224")

In [None]:
# Load pre-trained models - Gemma 2
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")

In [None]:
# Define input and output directories for images and captions
inputDir = '/kaggle/input/test-rendered-imgs/rendered_imgs'
outputDir = '/kaggle/working/2D_captions_prompt'

def crop_and_resize(image, target_size):
    """
    Crops and resizes an image to a specified target size.

    Args:
        image (PIL.Image): The input image.
        target_size (tuple): Desired output size (height, width) for the image.

    Returns:
        PIL.Image: The cropped and resized image.
    """
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right = left + source_size
    bottom = top + source_size

    # Crop and resize the image
    return image.crop((left, top, right, bottom)).resize(target_size)

def read_image(filepath, target_size=(224, 224)):
    """
    Reads an image from a local file path and resizes it.

    Args:
        filepath (str): Path to the image file.
        target_size (tuple): Desired output size (height, width) for the image.

    Returns:
        numpy.ndarray: The image data as a NumPy array with shape (height, width, channels).
    """
    try:
        image = Image.open(filepath)
        image = crop_and_resize(image, target_size)
        image = np.array(image)
        if image.shape[2] == 4:
            image = image[:, :, :3]
        return image
    except Exception as e:
        print(f"Error reading {filepath}: {str(e)}")
        return None

def clean_output_folder(output_folder):
    """
    Cleans the output folder by removing all files.

    Args:
        output_folder (str): Path to the folder to be cleaned.
    """
    files = glob.glob(f"{output_folder}/*")
    for file in files:
        os.remove(file)
        print(f"Deleted {file}")

def generate_captions(view_number, target_size=(224, 224)):
    """
    Generates captions for images in a specific view folder.

    Args:
        view_number (int): The view number of the images.
        target_size (tuple): Desired output size (height, width) for the images.
    """
    outfilename = f'{outputDir}/2D_captions_view{view_number}.txt'
    infolder = f'{inputDir}/rendered_imgs_view{view_number}/*.png'

    all_files = glob.glob(infolder)
    all_imgs = [x for x in all_files if x.endswith(".png")]
    print(f"Number of .png images: {len(all_imgs)}")

    for filename in tqdm(all_imgs):
        current_image = read_image(filename, target_size)
        if current_image is None:
            continue

        caption_prompt = 'caption en.\n'
        output = paligemma.generate(
            inputs={
                "images": current_image,
                "prompts": caption_prompt,
            }, max_length = 300
        )
        output = output.removeprefix(caption_prompt).strip()
        print(output)

        image_name = os.path.basename(filename).split('.')[0]
        outdirectory = outputDir
        os.makedirs(outdirectory, exist_ok=True)

        with open(outfilename, 'a+') as f:
            f.write(f"{image_name} : {output}\n")

def read_captions_files(file_paths):
    """
    Reads captions from multiple text files and groups them by the base image name.

    Args:
        file_paths (list): List of paths to the caption files.

    Returns:
        dict: A dictionary where keys are base image names and values are lists of captions.
    """
    captions = {}
    for file_path in file_paths:
        with open(file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                image_name, caption = line.split(": ", 1)
                base_image_name = image_name.rsplit('_', 1)[0]
                if base_image_name not in captions:
                    captions[base_image_name] = []
                captions[base_image_name].append(caption.strip())
    return captions

def clean_fused_caption(caption, prompt_text, caption_list):
    """
    Clean up the generated caption by removing unwanted tags, the original prompt, 
    and the image captions from the fused caption.
    
    Args:
        caption (str): The fused caption returned by the model.
        prompt_text (str): The original prompt used in the caption generation.
        caption_list (list): The list of original image captions.

    Returns:
        str: The cleaned fused caption.
    """
    # Remove the original prompt and image captions
    caption = caption.replace(prompt_text, "")

    # Remove each of the individual captions from the final caption
    for original_caption in caption_list:
        caption = caption.replace(original_caption, "")

    # Remove unwanted tags
    caption = caption.strip()
    caption = caption.replace("<start_of_turn>user", "").replace("<end_of_turn>", "")
    caption = caption.replace("<start_of_turn>model", "").replace("<end_of_turn>", "")

    # Return the cleaned caption
    return caption.strip()

def fuse_captions(captions):
    """
    Fuses multiple captions into a single concise caption for each image.

    Args:
        captions (dict): A dictionary of image names and their corresponding list of captions.

    Returns:
        dict: A dictionary of image names and their fused captions.
    """
    start_of_turn_user = "<start_of_turn>user\n"
    start_of_turn_model = "<start_of_turn>model\n"
    end_of_turn = "<end_of_turn>\n"

    fused_captions = {}

    for image_name, caption_list in tqdm(captions.items()):
        # Create the text to be passed to the model
        text = "; ".join(caption_list)
        prompt = start_of_turn_user + \
                 f"Given a set of descriptions about the same 3D object, distill these descriptions into one coherent caption. Pay attention to all necessary details. Do not provide any explanation. The descriptions are as follows: '{text}'." + \
                 end_of_turn + start_of_turn_model

        # Generate the fused caption
        fused_caption = gemma_lm.generate(prompt, max_length=300).strip()

        # Clean the fused caption by removing the prompt and original captions
        cleaned_caption = clean_fused_caption(fused_caption, prompt, caption_list)

        # Store the cleaned caption
        fused_captions[image_name] = cleaned_caption

    return fused_captions

In [None]:
# Clean the output folder before generating new captions
clean_output_folder(outputDir)

# Generate captions for images in all views
for i in range(8):
    generate_captions(i)

In [None]:
# Read captions from the generated files
file_paths = [f"{outputDir}/2D_captions_view{i}.txt" for i in range(8)]
captions = read_captions_files(file_paths)

# Fuse the captions
fused_captions = fuse_captions(captions)

In [None]:
# Print the final fused captions
for image_name, fused_caption in fused_captions.items():
    print(f"Object: {image_name}\nFused Caption: {fused_caption}\n" + "-"*40 + "\n")

In [None]:
# Print and save the final fused captions
with open('/kaggle/working/3D_captions_prompt\.txt', 'a+') as f:
    for object_name, fused_caption in fused_captions.items():
        print(f"Object: {object_name}\nFused Caption: {fused_caption}\n" + "-"*40 + "\n")
        f.write(f"{object_name} : {fused_caption}\n")