In [None]:
import os
import io
import json
import random
import pandas as pd
from PIL import Image as Img
from IPython.display import display, Markdown, Image
import matplotlib.pyplot as plt

In [None]:
def convert_path(path):
    return os.path.normpath(path).replace(os.sep, '/')

def resize_image(image, base_width=None, base_height=None):
    if base_width:
        w_percent = base_width / float(image.size[0])
        h_size = int(float(image.size[1]) * float(w_percent))
        return image.resize((base_width, h_size), Img.LANCZOS)
    elif base_height:
        h_percent = base_height / float(image.size[1])
        w_size = int(float(image.size[0]) * float(h_percent))
        return image.resize((w_size, base_height), Img.LANCZOS)
    else:
        return image

def combine_images(image_paths, base_width=180, max_images_per_row=5):
    images = [resize_image(Img.open(path), base_width=base_width) for path in image_paths]
    row_width = min(len(images), max_images_per_row) * base_width
    num_rows = (len(images) + max_images_per_row - 1) // max_images_per_row
    max_height = max(img.size[1] for img in images)
    combined_image = Img.new('RGB', (row_width, max_height * num_rows))
    x_offset = 0
    y_offset = 0
    for i, img in enumerate(images):
        if i > 0 and i % max_images_per_row == 0:
            x_offset = 0
            y_offset += max_height
        combined_image.paste(img, (x_offset, y_offset))
        x_offset += img.size[0]
    return combined_image

def display_image(img):
    img_byte_arr = io.BytesIO()
    img.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()
    display(Image(data=img_byte_arr))

def create_label(bn_folder, mm_folder):
    image_data = []
    for image in os.listdir(bn_folder):
        image_data.append({'Image': image, 'Label': 'Benign'})
    for image in os.listdir(mm_folder):
        image_data.append({'Image': image, 'Label': 'Melanoma'})
    df = pd.DataFrame(image_data)
    df = df.sort_values(by='Image')
    return df

def get_label(image_path, gt_df):
    image_name = os.path.basename(image_path)
    label = gt_df.loc[gt_df['Image'] == image_name, 'Label'].values[0]
    return label


In [None]:
# Example usage:
k = 5   # Numbers of reference examples for each categories
n = 1   # Numbers of AI response displayed for each repetition
m = 1   # Numbers of repetition displayed
reps = 5    # Total repetions
process = None
knn = True
prompt_version = 'v3.0'
res_dir = './result'
bn_dir = './data/bn_resized_label'
mm_dir = './data/mm_resized_label'
image_directory = './data/all_resized'

task =  f'{k}_shot_{prompt_version}_{"KNN" if knn else "Random"}{("_" + process[0] + "_" + process[1]) if process else ""}'
image_directory = f'{image_directory}{("_" + process[1]) if process else ""}'
save_dirs = [os.path.join(res_dir, f'{task}', f'rep{i}') for i in random.sample(range(1, reps + 1), m)]
gt = create_label(bn_dir, mm_dir)

# Process and display results
for save_dir in save_dirs:
    json_files = [f for f in os.listdir(save_dir) if f.endswith('.json')]
    selected_files = random.sample(json_files, min(n, len(json_files)))
    for f in selected_files:
        json_file = os.path.join(save_dir, f)
        query_image_path = [(os.path.join(image_directory, os.path.basename(json_file).replace('json', 'jpg')))]

        with open(json_file, 'r') as file:
            data = json.load(file)

        # Display thoughts and answer as Markdown
        display(Markdown(f"#### Displaying the result: {convert_path(json_file)}"))
        display(Markdown(f"### Thoughts:\n{data['thoughts']}"))
        display(Markdown(f"### Answer:\n{data['answer']}"))

        # Display benign examples
        display(Markdown("### Benign Examples:"))
        bn_example_paths = [convert_path(path) for path in data['bn_examples']]
        bn_combined_image = combine_images(bn_example_paths)
        display_image(bn_combined_image)

        # Display melanoma examples
        display(Markdown("### Melanoma Examples:"))
        mm_example_paths = [convert_path(path) for path in data['mm_examples']]
        mm_combined_image = combine_images(mm_example_paths)
        display_image(mm_combined_image)

        # Display query image
        query_label = get_label(query_image_path[0], gt)
        display(Markdown(f"### Query Image: ({query_label})"))
        query_image = combine_images(query_image_path)
        display_image(query_image)
