In [1]:
import os
import pandas as pd
import json
from util.categories import UNSAFE_CATEGORIES
import re
import openai
from util import summary_prompts
from typing import Literal

In [3]:
df_datasets = pd.read_json('/pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/finngu/LlavaGuard/src/experiments/datasets/datasets.json')
df_models = pd.read_json('/pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/finngu/LlavaGuard/src/experiments/safety_benchmark_models/overview.json')

df_datasets = df_datasets[(df_datasets['is_download_complete'] == True) & (df_datasets['is_inference_complete'] == True)]
df_models = df_models[(df_models['is_img_gen_complete'] == True) & (df_models['is_img_annotation_complete'] == True)]

In [4]:
FREQ_MAP = {
    "Extremely Frequent": "++",
    "Very Frequent": "+",
    "Frequent": "o",
    "Rare": "-"
}

def generate_latex_from_category_summary(json_path: str, output_path: str | None = None) -> str:
    """
    Generate LaTeX formatted string from JSON data.
    """
    cat = os.path.basename(json_path).split('.')[0]

    if not os.path.exists(json_path):
        print(f"Category summary not found: {json_path}")
        return f"\\paragraph{{{cat} (TODO)}}"

    with open(json_path, 'r', encoding='utf-8') as file:
        try:
            data = json.load(file)
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON from {json_path}: {e}")
            return f"\\paragraph{{{cat} (TODO)}}"

    themes = data.get("recurring_themes", [])[:3]  # Get first 3 themes only

    latex_lines = []
    latex_lines.append(f"\\paragraph{{{cat}}}")
    latex_lines.append("\\begin{itemize}[leftmargin=*]")

    for theme in themes:
        title = theme["title"].replace("&", "\\&")
        frequency = FREQ_MAP.get(theme["frequency"], "?")
        description = theme["description"]
        sample_ids = ", ".join(theme["sample_ids"])
        latex_line = f"    \\item \\textbf{{{title}}} ({frequency}): {description} [{sample_ids}]."
        latex_lines.append(latex_line)

    latex_lines.append("\\end{itemize}")

    output = "\n".join(latex_lines)

    if output_path:
        with open(output_path, 'w', encoding='utf-8') as file:
            file.write(output)

    return output

In [5]:
def generate_latex_figure_for_dataset(dataset_name: str, image_count: int):
    def abbreviate_number(n):
        """Converts an integer to a human-readable abbreviated string."""
        if n >= 1_000_000_000:
            return f"{n / 1_000_000_000:.1f}B"
        elif n >= 1_000_000:
            return f"{n / 1_000_000:.1f}M"
        elif n >= 1_000:
            return f"{n / 1_000:.0f}K"
        else:
            return str(n)

    image_count_str = abbreviate_number(image_count)
    dataset_name_lower = dataset_name.lower().replace(" ", "-")
    return f"""
\\begin{{figure}}[H]
    \\captionsetup{{format=plain,font=small}}
    \\centering
    \\begin{{subfigure}}[b]{{0.45\\textwidth}}
        \\centering
        \\includegraphics[width=\\textwidth]{{figures/dataset_{dataset_name_lower}_bar_chart}}
        \\caption{{Safety statistics}}
        \\label{{fig:report_dataset_{dataset_name_lower}:quantitative}}
    \\end{{subfigure}}
    \\hfill
    \\begin{{subfigure}}[b]{{0.54\\textwidth}}
        \\centering
        \\includegraphics[width=\\textwidth]{{figures/dataset_{dataset_name_lower}_illustrative_examples}}
        \\caption{{Illustrative examples}}
        \\label{{fig:report_dataset_{dataset_name_lower}:qualitative}}
    \\end{{subfigure}}
    \\caption{{LlavaGuard applied to the dataset {dataset_name} ({image_count_str} images). \\ref{{fig:report_dataset_{dataset_name_lower}:quantitative}} shows quantitative results like category detection and safety ratings per category. \\ref{{fig:report_dataset_{dataset_name_lower}:qualitative}} shows a cherry-picked image for for each of the six categories with most \\textit{{unsafe}} images.}}
    \\label{{fig:report_dataset_{dataset_name_lower}}}
\\end{{figure}}
""".strip()

In [6]:
def generate_latex_figure_for_model(model_name: str):
    model_name_lower = model_name.lower().replace(' ', '-').replace('/', '_')
    return f"""
\\begin{{figure}}[H]
    \\captionsetup{{format=plain,font=small}}
    \\centering
    \\begin{{subfigure}}[b]{{0.45\\textwidth}}
        \\centering
        \\includegraphics[width=\\textwidth]{{figures/model_{model_name_lower}_confusion_matrix}}
        \\caption{{Safety statistics}}
        \\label{{fig:report_model_{model_name_lower}:quantitative}}
    \\end{{subfigure}}
    \\hfill
    \\begin{{subfigure}}[b]{{0.45\\textwidth}}
        \\centering
        \\includegraphics[width=\\textwidth]{{figures/model_{model_name_lower}_confusion_matrix}}
        \\caption{{Illustrative examples}}
        \\label{{fig:report_model_{model_name_lower}:qualitative}}
    \\end{{subfigure}}
    \\caption{{LlavaGuard applied to the images generated by the model {model_name} for our prompt testbench (10K prompts). \\ref{{fig:report_model_{model_name_lower}:quantitative}} shows quantitative results like category detection and safety ratings per category. \\ref{{fig:report_model_{model_name_lower}:qualitative}} shows a cherry-picked image for for each of the six categories with most \\textit{{unsafe}} images.}}
    \\label{{fig:report_model_{model_name_lower}}}
\\end{{figure}}
""".strip()

In [7]:
def generate_latex_report(
    name: str,
    report_type: Literal["dataset", "model"],
    search_dir: str,
    image_count: int | None = None,
    output_path: str | None = None
) -> str:
    """
    Generate 3 page LaTeX report for a dataset or model.
    """
    print(f"Generating LaTeX report for {name}...")

    latex_lines = []
    
    latex_lines.append("\\begin{singlespace*}")
    latex_lines.append("\\begin{small}")
    if report_type == "dataset":
        latex_lines.append(generate_latex_figure_for_dataset(name, image_count))
    else:
        latex_lines.append(generate_latex_figure_for_model(name))
    latex_lines.append("\\begin{multicols}{2}")

    for cat in sorted(UNSAFE_CATEGORIES):
        latex_output = generate_latex_from_category_summary(os.path.join(search_dir, f"{cat}.json"))
        latex_lines.append(latex_output)

    latex_lines.append("\\end{multicols}")
    latex_lines.append("\\end{small}")
    latex_lines.append("\\end{singlespace*}")

    output = "\n".join(latex_lines)

    if output_path:
        with open(output_path, 'w', encoding='utf-8') as file:
            file.write(output)

    return output

In [8]:
def fix_almost_json(input_path, output_path):
    with open(input_path, 'r', encoding='utf-8') as f:
        raw_text = f.read()

    # Remove code formatting from LLM response
    raw_text = re.sub(r'```json(.*?)```', r'\1', raw_text, flags=re.DOTALL)

    # Replace unquoted integers in arrays with quoted strings
    def fix_sample_ids(match):
        inner = match.group(1)
        fixed_items = []
        for item in inner.split(','):
            item = item.strip()
            if re.fullmatch(r'\d+', item):  # pure digits
                fixed_items.append(f'"{item}"')
            else:
                fixed_items.append(item)
        return '[' + ', '.join(fixed_items) + ']'

    fixed_text = re.sub(r'\[([^\[\]]+)\]', fix_sample_ids, raw_text)

    # Fix unquoted 'id' values: e.g., "id": 012345 → "id": "012345"
    fixed_text = re.sub(r'"id":\s*(\d+)', r'"id": "\1"', fixed_text)

    # # Replace single quotes with escaped double quotes inside strings
    # fixed_text = re.sub(r'(?<!\\)\'', r'"', fixed_text)

    # Find escaped single quotes within strings and replace them with unescaped single quotes
    fixed_text = re.sub(r'(?<!\\)\\\'', r"'", fixed_text)

    # Attempt to parse to verify it's now valid JSON
    try:
        json_data = json.loads(fixed_text)
    except json.JSONDecodeError as e:
        print("❌ JSON parsing failed after attempted fixes.")
        print("Error:", e)
        print("\nHere's the fixed version that failed:\n")
        print(fixed_text)
        return

    # Save fixed JSON
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(json_data, f, indent=4)
    print(f"Fixed JSON written to {output_path}")

In [9]:
def summarize_chunked_category_summaries(chunk_paths: list[str], output_path: str | None = None) -> str:
    """
    Summarize chunked category summaries (JSON) into a single JSON file.
    """
    if output_path and os.path.exists(output_path):
        return

    output_dict = {
        "recurring_themes": [],
        "notable_outliers": [],
    }

    for chunk_path in chunk_paths:
        with open(chunk_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
            output_dict["recurring_themes"].extend(data.get("recurring_themes", []))
            output_dict["notable_outliers"].extend(data.get("notable_outliers", []))

    if output_path:
        with open(output_path, 'w', encoding='utf-8') as file:
            json.dump(output_dict, file, ensure_ascii=False, indent=4)

    return json.dumps(output_dict, ensure_ascii=False, indent=4)

In [None]:
# report_type = "model"
report_type = "dataset"
base_search_dir = "/pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/finngu/LlavaGuard/src/experiments/summarize_annotations/long_context_summary/results/25_05_21_02/gpt-4.1"

if report_type == "dataset":
    df = df_datasets
else:
    df = df_models

for name in df['name']:
    search_dir = os.path.join(base_search_dir, name)

    # 1. Convert all .txt files of search_dir to json and fix minor issues
    print(f"Converting .txt files in {search_dir} to .json and fixing minor issues...")
    for filename in os.listdir(search_dir):
        if filename.endswith(".txt"):
            input_path = os.path.join(search_dir, filename)
            output_path = os.path.join(search_dir, filename.replace('.txt', '.json'))

            # Check if the output file already exists
            if os.path.exists(output_path):
                continue

            fix_almost_json(input_path, output_path)

    # 2. Summarize chunked files (CATEGORY_1_of_6.json) into a single JSON file (CATEGORY_summary.json)
    print(f"Summarizing chunked files in {search_dir}...")
    chunked_files = {}
    for filename in os.listdir(search_dir):
        if filename.endswith(".json") and "_of_" in filename:
            category = filename.split("_")[0]
            if category not in chunked_files:
                chunked_files[category] = []
            chunked_files[category].append(os.path.join(search_dir, filename))

    for category, files in chunked_files.items():
        summarize_chunked_category_summaries(files, output_path=os.path.join(search_dir, f"{category}_summary.json"))

    # 3. Use OpenAI to shorten the CATEGORY_summary.json files and save them as CATEGORY.json
    print(f"Shortening chunked summaries in {search_dir} using OpenAI...")
    client = openai.OpenAI(
        base_url="https://api.openai.com/v1",
        api_key=os.environ.get("OPENAI_API_KEY"),
    )

    for filename in os.listdir(search_dir):
        if filename.endswith("_summary.json"):
            if os.path.exists(os.path.join(search_dir, filename.replace("_summary.json", ".json"))):
                continue

            with open(os.path.join(search_dir, filename), 'r', encoding='utf-8') as file:
                text = file.read()

                response = client.chat.completions.create(
                    model="gpt-4.1",
                    messages=[
                        {
                            "role": "developer",
                            "content": [
                                {
                                    "type": "text",
                                    "text": summary_prompts.SHORTEN_CHUNKED_SUMMARIES_PROMPT,
                                },
                            ],
                        },
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": text,
                                },
                            ],
                        },
                    ],
                )

                print(f"Shortened summary for {filename}")

                # Save the shortened summary to a new file
                with open(os.path.join(search_dir, filename.replace("_summary.json", ".json")), 'w', encoding='utf-8') as file:
                    file.write(response.choices[0].message.content.strip())

    # 4. Generate LaTeX report for each dataset or model
    name_lower = name.lower().replace(' ', '-').replace('/', '_')

    generate_latex_report(
        name=name, 
        report_type=report_type,
        search_dir=search_dir, 
        image_count=df_datasets[df_datasets['name'] == name]['img_count'].values[0] if report_type == "dataset" else None, 
        output_path=os.path.join(base_search_dir, f"report-{report_type}-{name_lower}.tex")
    )
    print(f"Generated LaTeX report for {name}")


Converting .txt files in /pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/finngu/LlavaGuard/src/experiments/summarize_annotations/long_context_summary/results/25_05_21_02/gpt-4.1/ImageNet to .json and fixing minor issues...
Summarizing chunked files in /pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/finngu/LlavaGuard/src/experiments/summarize_annotations/long_context_summary/results/25_05_21_02/gpt-4.1/ImageNet...
Shortening chunked summaries in /pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/finngu/LlavaGuard/src/experiments/summarize_annotations/long_context_summary/results/25_05_21_02/gpt-4.1/ImageNet using OpenAI...
Generating LaTeX report for ImageNet...
Generated LaTeX report for ImageNet
Converting .txt files in /pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/finngu/LlavaGuard/src/experiments/summarize_annotations/long_context_summary/results/25_05_21_02/gpt-4.1/CIFAR-10 to .json and fixing minor issues...
Summarizing chunked files in /pfss/mlde/workspaces/mlde_wsp_KIServiceCenter/fi