In [1]:
%cd ..

/mnt/c/Users/ankit/Desktop/Portfolio/kaggle/drawing-with-llms


In [63]:
import ast
import io
import os
import random
import re
import statistics
import string
import time
from dataclasses import dataclass, field
from typing import List

import cairosvg
import pandas as pd
import torch
from camel.agents import ChatAgent
from camel.datagen.cot_datagen import CoTDataGenerator, logger
from camel.models.model_factory import ModelFactory, ModelPlatformType, ModelType
from datasets import load_dataset
from defusedxml import ElementTree as etree
from dotenv import load_dotenv
from more_itertools import chunked
from PIL import Image
from pydantic import BaseModel
from tqdm import tqdm
from transformers import (
    AutoProcessor,
    BitsAndBytesConfig,
    PaliGemmaForConditionalGeneration,
)

load_dotenv(".env")

True

In [4]:
model = ModelFactory.create(
    model_platform=ModelPlatformType.OPENAI_COMPATIBLE_MODEL,
    model_type=ModelType.GEMINI_2_5_FLASH_PREVIEW,
    url="https://generativelanguage.googleapis.com/v1beta/openai/",
    api_key=os.getenv("GEMINI_API_KEY"),
)

In [64]:
@dataclass(frozen=True)
class SVGConstraints:
    """Defines constraints for validating SVG documents.

    Attributes
    ----------
    max_svg_size : int, default=10000
        Maximum allowed size of an SVG file in bytes.
    allowed_elements : dict[str, set[str]]
        Mapping of the allowed elements to the allowed attributes of each element.
    """

    max_svg_size: int = 10000
    allowed_elements: dict[str, set[str]] = field(
        default_factory=lambda: {
            "common": {
                "id",
                "clip-path",
                "clip-rule",
                "color",
                "color-interpolation",
                "color-interpolation-filters",
                "color-rendering",
                "display",
                "fill",
                "fill-opacity",
                "fill-rule",
                "filter",
                "flood-color",
                "flood-opacity",
                "lighting-color",
                "marker-end",
                "marker-mid",
                "marker-start",
                "mask",
                "opacity",
                "paint-order",
                "stop-color",
                "stop-opacity",
                "stroke",
                "stroke-dasharray",
                "stroke-dashoffset",
                "stroke-linecap",
                "stroke-linejoin",
                "stroke-miterlimit",
                "stroke-opacity",
                "stroke-width",
                "transform",
            },
            "svg": {
                "width",
                "height",
                "viewBox",
                "preserveAspectRatio",
            },
            "g": {"viewBox"},
            "defs": set(),
            "symbol": {"viewBox", "x", "y", "width", "height"},
            "use": {"x", "y", "width", "height", "href"},
            "marker": {
                "viewBox",
                "preserveAspectRatio",
                "refX",
                "refY",
                "markerUnits",
                "markerWidth",
                "markerHeight",
                "orient",
            },
            "pattern": {
                "viewBox",
                "preserveAspectRatio",
                "x",
                "y",
                "width",
                "height",
                "patternUnits",
                "patternContentUnits",
                "patternTransform",
                "href",
            },
            "linearGradient": {
                "x1",
                "x2",
                "y1",
                "y2",
                "gradientUnits",
                "gradientTransform",
                "spreadMethod",
                "href",
            },
            "radialGradient": {
                "cx",
                "cy",
                "r",
                "fx",
                "fy",
                "fr",
                "gradientUnits",
                "gradientTransform",
                "spreadMethod",
                "href",
            },
            "stop": {"offset"},
            "filter": {
                "x",
                "y",
                "width",
                "height",
                "filterUnits",
                "primitiveUnits",
            },
            "feBlend": {"result", "in", "in2", "mode"},
            "feFlood": {"result"},
            "feOffset": {"result", "in", "dx", "dy"},
            "path": {"d"},
            "rect": {"x", "y", "width", "height", "rx", "ry"},
            "circle": {"cx", "cy", "r"},
            "ellipse": {"cx", "cy", "rx", "ry"},
            "line": {"x1", "y1", "x2", "y2"},
            "polyline": {"points"},
            "polygon": {"points"},
        }
    )

    def validate_svg(self, svg_code: str) -> None:
        """Validates an SVG string against a set of predefined constraints.

        Parameters
        ----------
        svg_code : str
            The SVG string to validate.

        Raises
        ------
        ValueError
            If the SVG violates any of the defined constraints.
        """
        # Check file size
        if len(svg_code.encode("utf-8")) > self.max_svg_size:
            raise ValueError("SVG exceeds allowed size")

        # Parse XML
        tree = etree.fromstring(
            svg_code.encode("utf-8"),
            forbid_dtd=True,
            forbid_entities=True,
            forbid_external=True,
        )

        elements = set(self.allowed_elements.keys())

        # Check elements and attributes
        for element in tree.iter():
            # Check for disallowed elements
            tag_name = element.tag.split("}")[-1]
            if tag_name not in elements:
                raise ValueError(f"Disallowed element: {tag_name}")

            # Check attributes
            for attr, attr_value in element.attrib.items():
                # Check for disallowed attributes
                attr_name = attr.split("}")[-1]
                if (
                    attr_name not in self.allowed_elements[tag_name]
                    and attr_name not in self.allowed_elements["common"]
                ):
                    raise ValueError(f"Disallowed attribute: {attr_name}")

                # Check for embedded data
                if "data:" in attr_value.lower():
                    raise ValueError("Embedded data not allowed")
                if ";base64" in attr_value:
                    raise ValueError("Base64 encoded content not allowed")

                # Check that href attributes are internal references
                if attr_name == "href":
                    if not attr_value.startswith("#"):
                        raise ValueError(
                            f'Invalid href attribute in <{tag_name}>. Only internal references (starting with "#") are allowed. Found: "{attr_value}"'
                        )

In [None]:
class VQAEvaluator:
    def __init__(self):
        self.quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16,
        )
        self.letters = string.ascii_uppercase
        # self.model_path = kagglehub.model_download(
        #     "google/paligemma-2/transformers/paligemma2-3b-mix-448"
        # )
        self.model_path = "google/paligemma2-3b-pt-448"
        self.processor = AutoProcessor.from_pretrained(self.model_path)
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(
            self.model_path,
            low_cpu_mem_usage=True,
            quantization_config=self.quantization_config,
        ).to("cuda:1")  # type: ignore

    def score(self, questions, choices, answers, images, n=4):
        # TODO change this !
        scores = []
        batches = (chunked(qs, n) for qs in [questions, choices, answers])
        for question_batch, choice_batch, answer_batch in zip(*batches, strict=True):
            scores.extend(
                self.score_batch(
                    images,  # list of images
                    question_batch,
                    choice_batch,
                    answer_batch,
                )
            )
        return statistics.mean(scores)

    def score_batch(
        self,
        images: list[Image.Image],
        questions: list[str],
        choices_list: list[list[str]],
        answers: list[str],
    ) -> list[float]:
        prompts = [
            self.format_prompt(question, choices)
            for question, choices in zip(questions, choices_list, strict=True)
        ]
        batched_choice_probabilities = self.get_choice_probability(
            images,  # send list of images
            prompts,
            choices_list,
        )

        scores = []
        for i, _ in enumerate(questions):
            choice_probabilities = batched_choice_probabilities[i]
            answer = answers[i]
            answer_probability = 0.0
            for choice, prob in choice_probabilities.items():
                if choice == answer:
                    answer_probability = prob
                    break
            scores.append(answer_probability)

        return scores

    def format_prompt(self, question: str, choices: list[str]) -> str:
        prompt = f"<image>answer en Question: {question}\nChoices:\n"
        for i, choice in enumerate(choices):
            prompt += f"{self.letters[i]}. {choice}\n"
        return prompt

    def mask_choices(self, logits, choices_list):
        batch_size = logits.shape[0]
        masked_logits = torch.full_like(logits, float("-inf"))

        for batch_idx in range(batch_size):
            choices = choices_list[batch_idx]
            for i in range(len(choices)):
                letter_token = self.letters[i]

                first_token = self.processor.tokenizer.encode(
                    letter_token, add_special_tokens=False
                )[0]
                first_token_with_space = self.processor.tokenizer.encode(
                    " " + letter_token, add_special_tokens=False
                )[0]

                if isinstance(first_token, int):
                    masked_logits[batch_idx, first_token] = logits[
                        batch_idx, first_token
                    ]
                if isinstance(first_token_with_space, int):
                    masked_logits[batch_idx, first_token_with_space] = logits[
                        batch_idx, first_token_with_space
                    ]

        return masked_logits

    def get_choice_probability(self, images, prompts, choices_list) -> list[dict]:
        inputs = self.processor(
            # images=[image] * len(prompts),
            images=images,  # batch different images
            text=prompts,
            return_tensors="pt",
            padding="longest",
        ).to("cuda:1")

        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]  # Logits for the last (predicted) token
            masked_logits = self.mask_choices(logits, choices_list)
            probabilities = torch.softmax(masked_logits, dim=-1)

        batched_choice_probabilities = []
        for batch_idx in range(len(prompts)):
            choice_probabilities = {}
            choices = choices_list[batch_idx]
            for i, choice in enumerate(choices):
                letter_token = self.letters[i]
                first_token = self.processor.tokenizer.encode(
                    letter_token, add_special_tokens=False
                )[0]
                first_token_with_space = self.processor.tokenizer.encode(
                    " " + letter_token, add_special_tokens=False
                )[0]

                prob = 0.0
                if isinstance(first_token, int):
                    prob += probabilities[batch_idx, first_token].item()
                if isinstance(first_token_with_space, int):
                    prob += probabilities[batch_idx, first_token_with_space].item()
                choice_probabilities[choice] = prob

            # Renormalize probabilities for each question
            total_prob = sum(choice_probabilities.values())
            if total_prob > 0:
                renormalized_probabilities = {
                    choice: prob / total_prob
                    for choice, prob in choice_probabilities.items()
                }
            else:
                renormalized_probabilities = (
                    choice_probabilities  # Avoid division by zero if total_prob is 0
                )
            batched_choice_probabilities.append(renormalized_probabilities)

        return batched_choice_probabilities

    def ocr(self, images, free_chars=4):
        inputs = (
            self.processor(
                text=["<image>ocr\n"] * len(images),
                images=images,
                return_tensors="pt",
            )
            .to(torch.float16)
            .to(self.model.device)
        )
        input_len = inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            outputs = self.model.generate(**inputs, max_new_tokens=32, do_sample=False)
            out_list = self.processor.batch_decode(
                outputs[:, input_len:], skip_special_tokens=True
            )

        scores = [1.0 if len(decoded) < free_chars else -1.0 for decoded in out_list]

        return scores


def svg_to_png(svg_code: str, size: tuple = (384, 384)):
    if "viewBox" not in svg_code:
        svg_code = svg_code.replace("<svg", f'<svg viewBox="0 0 {size[0]} {size[1]}"')

    png_data = cairosvg.svg2png(bytestring=svg_code.encode("utf-8"))
    return Image.open(io.BytesIO(png_data)).convert("RGB").resize(size)  # type: ignore

In [42]:
PROMPT = """ 
# Role
You are a meticulous and highly skilled AI SVG Architect. Your primary function is to translate rich textual descriptions of sceneries into precise, well-structured SVG code.

# Objective
Generate SVG code that accurately and comprehensively depicts the scenery described in the input text. The generated SVG will undergo rigorous Visual Question Answering (VQA) evaluation. All VQA questions will be answerable *solely* from the information present in the original text description. Therefore, absolute fidelity to *all* details in the description is paramount.

# Strategy
Generating complex SVG directly is prone to errors. Employ a Chain of Thought (CoT) process to first decompose the described scenery into constituent entities, map these entities to fundamental SVG primitives (rectangles, circles, paths, etc.), plan their attributes and layout, and only then construct the final SVG code. This structured approach is critical for accuracy and for adhering to the constraints.

# Constraints
1.  Allowed SVG Elements: `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`.
2.  Allowed SVG Attributes: `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`. *No other elements or attributes are permitted.*
3.  Token Limit: Your entire response (CoT + SVG code) must not exceed 2000 tokens. Conciseness in both your reasoning and the generated SVG is key.
4.  Completeness: The SVG must visually represent *every* distinct object, property, and spatial relationship mentioned in the description.

Chain of Thought (CoT) Steps:

Phase 1: Scene Deconstruction & Entity Analysis
    1.1. Full Text Parsing: Read and internalize the entire description.
    1.2. Entity Identification: List every distinct visual entity mentioned (e.g., sun, specific tree type, house, car model if specified, clouds).
    1.3. Attribute Extraction per Entity: For each entity, meticulously detail its properties:
        *   Component Primitives: Identify how the entity can be constructed from one or more allowed SVG primitives (e.g., a "house" might be a `rect` for the body and a `polygon` or `path` for the roof). Be explicit about this decomposition.
        *   Visual Properties: Color (fill, stroke), stroke width, opacity. Specify exact color values (e.g., "blue", "#FF0000", "rgb(0,0,255)"). If a color is implied (e.g., "a grassy field"), infer a common color (e.g., green).
        *   Size & Scale: Note any described dimensions or relative sizes (e.g., "a tall tree," "a small window"). If not specified, use reasonable default proportions relative to other objects or the canvas.
        *   Position & Orientation: Note absolute (e.g., "in the top-left corner") or relative positioning (e.g., "the sun is above the mountains," "a car is parked next to the house"). Also note any rotation or skew if described and how it might be achieved with `transform`.
        *   Relationships:** Document how entities relate to each other spatially (e.g., overlapping, adjacent, contained within).

Phase 2: SVG Mapping & Layout Planning
    2.1. Canvas Definition: Choose viewBox="0 0 368 368". This defines your coordinate space. All subsequent coordinates will be relative to this.
    2.2. Primitive Mapping & Attribute Specification: For each component primitive identified in 1.3:
        *   Select the precise SVG elements
        *   Translate the visual properties from 1.3 into specific SVG attribute values (e.g., `fill="blue"`, `r="10"`, `d="..."`).
        *   Calculate and assign coordinates (`cx`, `cy`, `x`, `y`, `points`, path commands) and dimensions (`r`, `width`, `height`) based on the `viewBox` and the entity's position/size from 1.3. Be explicit about these calculations if they are not trivial.
        *   Consider the z-ordering (drawing order): elements drawn later appear on top. Plan the sequence of elements accordingly (e.g., background elements first, foreground elements last).
    2.3. Grouping Strategy: Determine if `<g>` elements are beneficial for grouping components of complex entities or for applying shared transformations or styles. Plan any `transform` attributes for these groups or individual elements.
    2.4. Gradient Definitions: If the description implies gradients (e.g., "sky fading from blue to orange"), define `linearGradient` or `radialGradient` elements within `<defs>` with appropriate `stop` colors and offsets. Assign them unique `id`s. These `id`s will be referenced in `fill` attributes (e.g., `fill="url(#myGradient)"`).

Phase 3: Code Generation 
    3.1. SVG Construction: Systematically write the SVG code.

# Output Format
### COT
[Chain of Thought as a multi line single paragraph]
### SVG
[SVG Code]

### Input
Description: {description}

### Output
"""

In [None]:
class VQACoT(CoTDataGenerator):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.vqa_model = VQAEvaluator()
        self.dataset = load_dataset(
            "csv",
            data_files="data/descriptions_with_vqa.csv",
        )["train"]  # type: ignore
        self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
        self.svg_constraints = SVGConstraints()

    def extract_response(self, text: str):
        text = text.split("### COT")[1]
        cot, svg = text.split("### SVG")

        if "```svg" in svg:
            svg = svg.split("```svg")[1]
            svg = svg.split("```")[0]

        return cot.strip(), svg.strip()

    def extract_image(self, text: str):
        matches = re.findall(r"<svg.*?</svg>", text, re.DOTALL | re.IGNORECASE)
        if matches:
            svg = matches[-1]
        else:
            svg = self.default_svg

        try:
            self.svg_constraints.validate_svg(svg)
        except Exception:
            svg = self.default_svg

        image = svg_to_png(svg)

        return image

    def get_answer(self, question: str, context: str = ""):
        self.generator_agent.reset()
        response = self.generator_agent.step(PROMPT.format(question))  # description
        cot, svg = self.extract_response(response.msgs[0].content)
        answer = cot + "\n" + svg
        logger.info("AI thought process:\n%s", answer)
        return answer

    def verify_answer(self, question: str, answer: str):
        sample = self.dataset[self.dataset["description"].index(question)]  # type: ignore
        limit = 4
        question_list = ast.literal_eval(sample["question"])[:limit]
        choices_list = ast.literal_eval(sample["choices"])[:limit]
        answer_list = ast.literal_eval(sample["answer"])[:limit]

        image = self.extract_image(answer)

        score = self.vqa_model.score(
            question_list, choices_list, answer_list, [image], 1
        )  # batch size 1

        if score > 0.8:  # Threshold for acceptance
            is_correct = True
        else:
            is_correct = False

        logger.info("Answer verification result: %s", is_correct)
        return is_correct

    def evaluate_partial_solution(self, question: str, partial_solution: str = ""):
        sample = self.dataset[self.dataset["description"].index(question)]  # type: ignore
        limit = 4
        question_list = ast.literal_eval(sample["question"])[:limit]
        choices_list = ast.literal_eval(sample["choices"])[:limit]
        answer_list = ast.literal_eval(sample["answer"])[:limit]

        image = self.extract_image(partial_solution)

        score = self.vqa_model.score(
            question_list, choices_list, answer_list, [image], 1
        )  # batch size 1
        return score


ChatAgentResponse(msgs=[BaseMessage(role_name='assistant', role_type=<RoleType.ASSISTANT: 'assistant'>, meta_dict={}, content='Worlds drawn bright,\nHeroes take flight.\nStories unfold, brave, bold.\nEmotions shown,\nMagic known.\nAnime dreams,\nVibrant gleams.', video_bytes=None, image_list=None, image_detail='auto', video_detail='low', parsed=None)], terminated=False, info={'id': 'n584aNSMMoe1kdUPr4CR-Ag', 'usage': {'prompt_tokens': 10, 'completion_tokens': 36, 'total_tokens': 638}, 'termination_reasons': ['stop'], 'num_tokens': 15, 'tool_calls': [], 'external_tool_call_requests': None})

In [None]:
# TODO : Run with GPU