In [1]:
%cd ..

/mnt/c/Users/ankit/Desktop/Portfolio/kaggle/drawing-with-llms


In [None]:
import ast
import os
import random
import re
import time
from typing import List

import pandas as pd
from camel.agents import ChatAgent
from camel.datagen.cot_datagen import CoTDataGenerator, logger
from camel.models.model_factory import ModelFactory, ModelPlatformType, ModelType
from datasets import load_dataset
from dotenv import load_dotenv
from pydantic import BaseModel
from tqdm import tqdm

from utils.constraints import SVGConstraints
from utils.process_response import svg_to_png
from utils.verifier import VQAEvaluator

load_dotenv(".env")

True

In [4]:
model = ModelFactory.create(
    model_platform=ModelPlatformType.OPENAI_COMPATIBLE_MODEL,
    model_type=ModelType.GEMINI_2_5_FLASH_PREVIEW,
    url="https://generativelanguage.googleapis.com/v1beta/openai/",
    api_key=os.getenv("GEMINI_API_KEY"),
)

In [42]:
PROMPT = """ 
# Role
You are a meticulous and highly skilled AI SVG Architect. Your primary function is to translate rich textual descriptions of sceneries into precise, well-structured SVG code.

# Objective
Generate SVG code that accurately and comprehensively depicts the scenery described in the input text. The generated SVG will undergo rigorous Visual Question Answering (VQA) evaluation. All VQA questions will be answerable *solely* from the information present in the original text description. Therefore, absolute fidelity to *all* details in the description is paramount.

# Strategy
Generating complex SVG directly is prone to errors. Employ a Chain of Thought (CoT) process to first decompose the described scenery into constituent entities, map these entities to fundamental SVG primitives (rectangles, circles, paths, etc.), plan their attributes and layout, and only then construct the final SVG code. This structured approach is critical for accuracy and for adhering to the constraints.

# Constraints
1.  Allowed SVG Elements: `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`.
2.  Allowed SVG Attributes: `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`. *No other elements or attributes are permitted.*
3.  Token Limit: Your entire response (CoT + SVG code) must not exceed 2000 tokens. Conciseness in both your reasoning and the generated SVG is key.
4.  Completeness: The SVG must visually represent *every* distinct object, property, and spatial relationship mentioned in the description.

Chain of Thought (CoT) Steps:

Phase 1: Scene Deconstruction & Entity Analysis
    1.1. Full Text Parsing: Read and internalize the entire description.
    1.2. Entity Identification: List every distinct visual entity mentioned (e.g., sun, specific tree type, house, car model if specified, clouds).
    1.3. Attribute Extraction per Entity: For each entity, meticulously detail its properties:
        *   Component Primitives: Identify how the entity can be constructed from one or more allowed SVG primitives (e.g., a "house" might be a `rect` for the body and a `polygon` or `path` for the roof). Be explicit about this decomposition.
        *   Visual Properties: Color (fill, stroke), stroke width, opacity. Specify exact color values (e.g., "blue", "#FF0000", "rgb(0,0,255)"). If a color is implied (e.g., "a grassy field"), infer a common color (e.g., green).
        *   Size & Scale: Note any described dimensions or relative sizes (e.g., "a tall tree," "a small window"). If not specified, use reasonable default proportions relative to other objects or the canvas.
        *   Position & Orientation: Note absolute (e.g., "in the top-left corner") or relative positioning (e.g., "the sun is above the mountains," "a car is parked next to the house"). Also note any rotation or skew if described and how it might be achieved with `transform`.
        *   Relationships:** Document how entities relate to each other spatially (e.g., overlapping, adjacent, contained within).

Phase 2: SVG Mapping & Layout Planning
    2.1. Canvas Definition: Choose viewBox="0 0 368 368". This defines your coordinate space. All subsequent coordinates will be relative to this.
    2.2. Primitive Mapping & Attribute Specification: For each component primitive identified in 1.3:
        *   Select the precise SVG elements
        *   Translate the visual properties from 1.3 into specific SVG attribute values (e.g., `fill="blue"`, `r="10"`, `d="..."`).
        *   Calculate and assign coordinates (`cx`, `cy`, `x`, `y`, `points`, path commands) and dimensions (`r`, `width`, `height`) based on the `viewBox` and the entity's position/size from 1.3. Be explicit about these calculations if they are not trivial.
        *   Consider the z-ordering (drawing order): elements drawn later appear on top. Plan the sequence of elements accordingly (e.g., background elements first, foreground elements last).
    2.3. Grouping Strategy: Determine if `<g>` elements are beneficial for grouping components of complex entities or for applying shared transformations or styles. Plan any `transform` attributes for these groups or individual elements.
    2.4. Gradient Definitions: If the description implies gradients (e.g., "sky fading from blue to orange"), define `linearGradient` or `radialGradient` elements within `<defs>` with appropriate `stop` colors and offsets. Assign them unique `id`s. These `id`s will be referenced in `fill` attributes (e.g., `fill="url(#myGradient)"`).

Phase 3: Code Generation 
    3.1. SVG Construction: Systematically write the SVG code.

# Output Format
### COT
[Chain of Thought as a multi line single paragraph]
### SVG
[SVG Code]

### Input
Description: {description}

### Output
"""

In [None]:
class VQACoT(CoTDataGenerator):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.vqa_model = VQAEvaluator()
        self.dataset = load_dataset(
            "csv",
            data_files="data/descriptions_with_vqa.csv",
        )["train"]  # type: ignore
        self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
        self.svg_constraints = SVGConstraints()

    def extract_response(self, text: str):
        text = text.split("### COT")[1]
        cot, svg = text.split("### SVG")

        if "```svg" in svg:
            svg = svg.split("```svg")[1]
            svg = svg.split("```")[0]

        return cot.strip(), svg.strip()

    def extract_image(self, text: str):
        matches = re.findall(r"<svg.*?</svg>", text, re.DOTALL | re.IGNORECASE)
        if matches:
            svg = matches[-1]
        else:
            svg = self.default_svg

        try:
            self.svg_constraints.validate_svg(svg)
        except Exception:
            svg = self.default_svg

        image = svg_to_png(svg)

        return image

    def get_answer(self, question: str, context: str = ""):
        self.generator_agent.reset()
        response = self.generator_agent.step(PROMPT.format(question))  # description
        cot, svg = self.extract_response(response.msgs[0].content)
        answer = cot + "\n" + svg
        logger.info("AI thought process:\n%s", answer)
        return answer

    def _get_score(self, question: str, answer: str):
        """
        Calculate vqa score of the svg (answer) for the given question (description)
        """
        sample = self.dataset[self.dataset["description"].index(question)]  # type: ignore
        limit = 4
        question_list = ast.literal_eval(sample["question"])[:limit]
        choices_list = ast.literal_eval(sample["choices"])[:limit]
        answer_list = ast.literal_eval(sample["answer"])[:limit]

        image = self.extract_image(answer)

        score = self.vqa_model.score(
            question_list, choices_list, answer_list, [image], 1
        )  # batch size 1
        return score

    def verify_answer(self, question: str, answer: str):
        score = self._get_score(question, answer)

        if score > 0.8:  # Threshold for acceptance
            is_correct = True
        else:
            is_correct = False

        logger.info("Answer verification result: %s", is_correct)
        return is_correct

    def evaluate_partial_solution(self, question: str, partial_solution: str = ""):
        score = self._get_score(question, partial_solution)

        return score


In [None]:
# TODO : Create a sophisticated SVG Creation workflow