### Notebook to test different LLM models for information extraction

For the basic information extraction of
1. Function
2. Style
3. Surroundings

We will be using a fuzzy words extraction of removing similar words <br>
Tentatively, we believe it is better to use a pre-set list of possible options, as we do not have a dataset, making it difficult for us to understand the possible variations and their explanation to use an LLM.<br>
It will be more efficient to simply extract my using word matching.

However, LLMs will be used to generate the schema of the elasticsearch database query. 

We will be exploring these models
1. OpenAI
2. Gemini
3. Llama2-7b

## Data Loading

In [1]:
# loading the pre-made mock_prompts dataset
import json

with open('../src/input/data/mock_prompts.json', 'r') as file:
    data = json.load(file)

print(len(data))
print(data[0])

100
{'prompt': 'Create a central precinct garden with lush green accents and a meadow style', 'function': 'Central Precinct Garden', 'colour': 'lush green', 'style': 'Meadow', 'surrounding': 'Walkway'}


In [2]:
data_list = {
    "function": ["Boundary",  "Playground", "Active Zone", "Central Precinct Garden", "Passive Zone", "Sitting Corner", "Pavilion", "Pergola", "Future Community Garden", "Butterfly Garden"],
    "style": ["Naturalistic", "Manicured",  "Meadow", "Ornamental", "Minimalist", "Formal", "Picturesque", "Rustic", "Plantation"],
    "surrounding": {
        "Boundary": "Road", 
        "Playground": "Walkway", 
        "Active Zone": "Walkway", 
        "Central Precinct Garden": "Walkway",
        "Passive Zone": "Walkway", 
        "Sitting Corner": "Walkway", 
        "Pavilion": "Walkway", 
        "Pergola": "Walkway", 
        "Future Community Garden": "Walkway", 
        "Butterfly Garden": "Walkway"
    }
}

## Exploration of the extraction of Function, Style & Surrounding

### Testing with Fuzzy Searching

In [3]:
from fuzzywuzzy import fuzz, process  # for fuzzy matching
from nltk.stem import WordNetLemmatizer
import time

def extract_keywords(data_list, input, scorer, threshold=60,):
    result = {}

    # Tokenise words
    input_words = input.split()

    # Get stems of input words
    lemmatizer = WordNetLemmatizer()
    base_words = [lemmatizer.lemmatize(word) for word in input_words]

    for key, targets in data_list.items():
        # Retrieve results for function and style
        if isinstance(targets, list):
            base_targets = [lemmatizer.lemmatize(word) for word in targets]
            match = process.extractOne(" ".join(base_words), base_targets, scorer=scorer, score_cutoff=threshold) #retrieve matching words
            result[key] = match[0] if match != None else None
        else:
            result[key] = targets.get(result['function'], None)
    
    return result




In [8]:
# Test all string matching algorithm

all_fuzz = [fuzz.partial_ratio, fuzz.token_set_ratio, fuzz.ratio, fuzz.token_sort_ratio, fuzz.QRatio, fuzz.UQRatio, fuzz.WRatio, fuzz.UWRatio, fuzz.partial_token_set_ratio, fuzz.partial_token_sort_ratio]

for scorer in all_fuzz:
    inaccuracy = 0
    for test_data in data:
        result = extract_keywords(data_list, test_data['prompt'], scorer)
        try:
            assert(result['function'] == test_data['function'])
            assert(result['style'] == test_data['style'])
            assert(result['surrounding'] == test_data['surrounding'])
        except:
            inaccuracy += 1
    
    print(scorer, 100-inaccuracy)


<function partial_ratio at 0x000001E29CA96160> 95
<function token_set_ratio at 0x000001E29CA965E0> 96
<function ratio at 0x000001E29CA45EE0> 2
<function token_sort_ratio at 0x000001E29CA963A0> 2
<function QRatio at 0x000001E29CA96700> 2
<function UQRatio at 0x000001E29CA96790> 2
<function WRatio at 0x000001E29CA96820> 95
<function UWRatio at 0x000001E29CA968B0> 95
<function partial_token_set_ratio at 0x000001E29CA96670> 73
<function partial_token_sort_ratio at 0x000001E29CA96430> 95


In [10]:
# Token Set ratio
wrong_values = []
start_time = time.time()

for test_data in data:
    result = extract_keywords(data_list, test_data['prompt'], fuzz.token_set_ratio)
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

total_time = time.time() - start_time
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])


Time taken:, 0.08 seconds.
Total Accuracy: 96%
------------
Design a natural playground with wildflowers and native plants
{'function': 'Playground', 'style': None, 'surrounding': 'Walkway'}
------------
Design a minimalistic sitting corner with soft green hues
{'function': 'Sitting Corner', 'style': None, 'surrounding': 'Walkway'}
------------
Create a natural sitting corner with a red accent
{'function': 'Sitting Corner', 'style': None, 'surrounding': 'Walkway'}
------------
Design a formal sitting corner with manicured hedges
{'function': 'Sitting Corner', 'style': 'Manicured', 'surrounding': 'Walkway'}


In [9]:
# WRatio
wrong_values = []
start_time = time.time()

for test_data in data:
    result = extract_keywords(data_list, test_data['prompt'], fuzz.WRatio)
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

total_time = time.time() - start_time
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])


Time taken:, 0.54 seconds.
Total Accuracy: 95%
------------
Design a natural playground with wildflowers and native plants
{'function': 'Playground', 'style': 'Plantation', 'surrounding': 'Walkway'}
------------
Design a formal sitting corner with manicured hedges
{'function': 'Sitting Corner', 'style': 'Manicured', 'surrounding': 'Walkway'}
------------
Design a formal garden boundary with a mix of red and white accents
{'function': 'Central Precinct Garden', 'style': 'Formal', 'surrounding': 'Walkway'}
------------
Plan a manicured landscape with soft pink and lilac tones
{'function': 'Active Zone', 'style': 'Manicured', 'surrounding': 'Walkway'}
------------
Design a rustic sitting corner with lavender plants and natural stones
{'function': 'Sitting Corner', 'style': 'Naturalistic', 'surrounding': 'Walkway'}


In [11]:
# UWRatio
wrong_values = []
start_time = time.time()

for test_data in data:
    result = extract_keywords(data_list, test_data['prompt'], fuzz.UWRatio)
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

total_time = time.time() - start_time
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])


Time taken:, 0.54 seconds.
Total Accuracy: 95%
------------
Design a natural playground with wildflowers and native plants
{'function': 'Playground', 'style': 'Plantation', 'surrounding': 'Walkway'}
------------
Design a formal sitting corner with manicured hedges
{'function': 'Sitting Corner', 'style': 'Manicured', 'surrounding': 'Walkway'}
------------
Design a formal garden boundary with a mix of red and white accents
{'function': 'Central Precinct Garden', 'style': 'Formal', 'surrounding': 'Walkway'}
------------
Plan a manicured landscape with soft pink and lilac tones
{'function': 'Active Zone', 'style': 'Manicured', 'surrounding': 'Walkway'}
------------
Design a rustic sitting corner with lavender plants and natural stones
{'function': 'Sitting Corner', 'style': 'Naturalistic', 'surrounding': 'Walkway'}


### Testing with Zero-shot Models

In [12]:
from transformers import pipeline
import torch

def zero_shot_extract_keywords(data_list, model, input, threshold=0.35):
    result = {}
    for key, targets in data_list.items():
        if isinstance(targets, list):
            output = model(input, targets, multi_label=False)    
            if output['scores'][0] >= threshold:
                result[key] = output['labels'][0]
            else:
                result[key] = None
        else:
            result[key] = targets.get(result['function'], None)

    return result

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", device=device)

wrong_values = []
start_time = time.time()

for test_data in data:
    result = zero_shot_extract_keywords(data_list, classifier, test_data['prompt'])
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

total_time = time.time() - start_time
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Time taken:, 29.04 seconds.
Total Accuracy: 91%
------------
Plan a rustic area
{'function': None, 'style': 'Naturalistic', 'surrounding': None}
------------
Design a formal sitting corner with manicured hedges
{'function': 'Sitting Corner', 'style': None, 'surrounding': 'Walkway'}
------------
Plan a central garden in a naturalistic style with golden accent
{'function': None, 'style': 'Naturalistic', 'surrounding': None}
------------
Design a central garden in naturalistic style with lavender accents
{'function': None, 'style': 'Naturalistic', 'surrounding': None}
------------
Create a naturalistic sitting area with red and yellow flowers
{'function': None, 'style': 'Naturalistic', 'surrounding': None}
------------
Plan a manicured sitting area
{'function': None, 'style': 'Manicured', 'surrounding': None}
------------
Design a manicured playground with tall grasses and native blooms
{'function': 'Playground', 'style': None, 'surrounding': 'Walkway'}
------------
Design a manicured pav

In [14]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)
wrong_values = []
start_time = time.time()

for test_data in data:
    result = zero_shot_extract_keywords(data_list, classifier, test_data['prompt'])
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

total_time = time.time() - start_time
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Time taken:, 28.99 seconds.
Total Accuracy: 94%
------------
Plan a central garden in a naturalistic style with golden accent
{'function': None, 'style': 'Naturalistic', 'surrounding': None}
------------
Design a central garden in naturalistic style with lavender accents
{'function': None, 'style': 'Naturalistic', 'surrounding': None}
------------
Create a naturalistic sitting area with red and yellow flowers
{'function': None, 'style': 'Naturalistic', 'surrounding': None}
------------
Plan a manicured sitting area
{'function': None, 'style': 'Manicured', 'surrounding': None}
------------
Create a meadow-style playground
{'function': 'Playground', 'style': 'Naturalistic', 'surrounding': 'Walkway'}
------------
Create a meadow-style area with blue and yellow wildflowers
{'function': None, 'style': 'Naturalistic', 'surrounding': None}


In [15]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

classifier = pipeline("zero-shot-classification",  model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", device=device)
wrong_values = []
start_time = time.time()

for test_data in data:
    result = zero_shot_extract_keywords(data_list, classifier, test_data['prompt'])
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

total_time = time.time() - start_time
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])

Time taken:, 55.71 seconds.
Total Accuracy: 73%
------------
Design a landscape that is rustic
{'function': 'Passive Zone', 'style': 'Rustic', 'surrounding': 'Walkway'}
------------
Create a meadow-style butterfly garden with soft yellow hues
{'function': 'Butterfly Garden', 'style': 'Ornamental', 'surrounding': 'Walkway'}
------------
Design a passive zone with tall ornamental grasses
{'function': 'Passive Zone', 'style': 'Naturalistic', 'surrounding': 'Walkway'}
------------
Plan a rustic area
{'function': 'Passive Zone', 'style': 'Rustic', 'surrounding': 'Walkway'}
------------
Create a manicured playground with soft purple blooms
{'function': 'Playground', 'style': 'Ornamental', 'surrounding': 'Walkway'}
------------
Create a manicured boundary with deep red flowers
{'function': 'Boundary', 'style': 'Ornamental', 'surrounding': 'Road'}
------------
Design a naturalistic landscape with blue flowers
{'function': 'Passive Zone', 'style': 'Naturalistic', 'surrounding': 'Walkway'}
-----

### Testing with LLMS

In [None]:
import os
import google.generativeai as genai
import typing_extensions as typing
from typing_extensions import TypedDict, Literal
import ast

os.environ["GEMINI_API_KEY"] = ""

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

class Response(typing.TypedDict):
    function: Literal["Boundary",  "Playground", "Active Zone", "Central Precinct Garden", "Passive Zone", "Sitting Corner", "Pavilion", "Pergola", "Future Community Garden", "Butterfly Garden", "None"]
    style: Literal["Naturalistic", "Manicured",  "Meadow", "Ornamental", "Minimalist", "Formal", "Picturesque", "Rustic", "Plantation", "None"]

system_instruction = """You are an information extraction model for a landscape architect.
Given user queries, you are to extract out the function (the purpose of the landscape area) and the style (planting style) from the input.
The extracted value must be in the possible function and styles, else assign None.

All possible function:
["Boundary",  "Playground", "Active Zone", "Central Precinct Garden", "Passive Zone", "Sitting Corner", "Pavilion", "Pergola", "Future Community Garden", "Butterfly Garden"]

All possible style:
["Naturalistic", "Manicured",  "Meadow", "Ornamental", "Minimalist", "Formal", "Picturesque", "Rustic", "Plantation"]
"""

model = genai.GenerativeModel(
  model_name="gemini-1.5-flash",
  system_instruction=system_instruction
)

In [17]:
def gemini_extract_keywords(model, response_json, data_list, input):
    result = {
        'function': None, 
        'style': None
    }

    response = model.generate_content(
        input,
        generation_config=genai.GenerationConfig(
            response_mime_type="application/json", response_schema=Response
        ))

    response_json = ast.literal_eval(response.text)

    for key, targets in response_json.items():
        if targets != 'None'and targets != None:
            # Valid response
            result[key] = targets
        else:
            result[key] = None

    result['surrounding'] = data_list['surrounding'].get(result['function'], None)
    
    return result

In [18]:
wrong_values = []
start_time = time.time()

for i, test_data in enumerate(data, start=1):  
    result = gemini_extract_keywords(model, Response, data_list, test_data['prompt'])
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

    # Pausing for 1 min since API only support 15 calls every minute
    if i % 15 == 0:
        time.sleep(90)
    
    if i == 1:
        iteration_time = time.time() - start_time
        print(f"Time taken for 1 iteration:, {iteration_time:.2f} seconds.")

total_time = iteration_time*100
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])

Time taken for 1 iteration:, 1.92 seconds.
Time taken:, 191.59 seconds.
Total Accuracy: 99%
------------
Design a plantation-style pergola area with large palms
{'function': 'Pavilion', 'style': 'Plantation', 'surrounding': 'Walkway'}


In [None]:
from openai import OpenAI

os.environ["OPENAI_API_KEY"] = ""

client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

system_instruction = """You are an information extraction model for a landscape architect.
Given user queries, you are to extract out the function (the purpose of the landscape area) and the style (planting style) from the input.
The extracted value must be in the possible function and styles, else assign None.

All possible function:
["Boundary",  "Playground", "Active Zone", "Central Precinct Garden", "Passive Zone", "Sitting Corner", "Pavilion", "Pergola", "Future Community Garden", "Butterfly Garden"]

All possible style:
["Naturalistic", "Manicured",  "Meadow", "Ornamental", "Minimalist", "Formal", "Picturesque", "Rustic", "Plantation"]
"""

response_format = {
    "type": "json_schema",
    "json_schema": {
        "name": "response",
        "schema": {
            "type": "object",
            "properties": {
                "function": {
                    "type": "string",
                    "enum": [
                        "Boundary", 
                        "Playground", 
                        "Active Zone", 
                        "Central Precinct Garden", 
                        "Passive Zone", 
                        "Sitting Corner", 
                        "Pavilion", 
                        "Pergola", 
                        "Future Community Garden", 
                        "Butterfly Garden", 
                        "None"
                    ],
                },
                "style": {
                    "type": "string",
                    "enum": [
                        "Naturalistic", 
                        "Manicured", 
                        "Meadow", 
                        "Ornamental", 
                        "Minimalist", 
                        "Formal", 
                        "Picturesque", 
                        "Rustic", 
                        "Plantation", 
                        "None"
                    ],
                }
            },
            "required": ["function", "style"],
            "additionalProperties": False
        },
        "strict": True
    }
}

In [20]:
def openai_extract_keywords(system_instruction, response_format, data_list, input):
    result = {
        'function': None, 
        'style': None
    }

    completion = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": system_instruction},
            {"role": "user", "content": input}
        ],
        response_format=response_format,
    )

    response_json = ast.literal_eval(completion.choices[0].message.content)

    for key, targets in response_json.items():
        if targets != 'None'and targets != None:
            # Valid response
            result[key] = targets
        else:
            result[key] = None

    result['surrounding'] = data_list['surrounding'].get(result['function'], None)
    
    return result

In [21]:
wrong_values = []
start_time = time.time()

for test_data in data:
    result = openai_extract_keywords(system_instruction, response_format, data_list, test_data['prompt'])
    try:
        assert(result['function'] == test_data['function'])
        assert(result['style'] == test_data['style'])
        assert(result['surrounding'] == test_data['surrounding'])
    except:
        wrong_values.append([test_data['prompt'], result])

total_time = time.time() - start_time
print(f"Time taken:, {total_time:.2f} seconds.")
print(f"Total Accuracy: {len(data) - len(wrong_values)}%")

for error_data in wrong_values:
    print(("------------"))
    print(error_data[0])
    print(error_data[1])

Time taken:, 64.45 seconds.
Total Accuracy: 99%
------------
Design a formal sitting corner with manicured hedges
{'function': 'Sitting Corner', 'style': 'Manicured', 'surrounding': 'Walkway'}


## Exploration of generating ElasticSearch Query from prompt

### Base Variables

In [18]:
# Elastic Search Schema
dataset_schema = {
    "mappings": {
        "properties": {
            "Scientific Name" : {"type": "str"}, # Not important
            "Common Name": {"type": "str"}, # Not important
            "Species ID": {"type": "int"}, # Not important
            "Link": {"type": "str"}, # Not important
            "Plant Type": {"type": "keyword"}, # Tree / Shrub / Palm / Herbaceous Plant
            "Light Preference": {"type": "keyword"}, # Light Preference Input
            "Water Preference": {"type": "keyword"}, # Water Preference Input
            "Drought Tolerant": {"type": "bool"}, # Drought Tolerant Input
            "Native to SG": {"type": "bool"}, # Not important in querying, but important when picking best options
            "Fruit Bearing": {"type": "bool"}, # Useful Planting (Function)
            "Fragrant Plant": {"type": "bool"}, # Useful Planting (Function)
            "Maximum Height (m)": {"type": "float"}, # Tall / Short? Prioritise Height when picking
            "Flower Colour": {"type": "str"}, # Colour from prompt, especially if flower/bloom mentioned
            "Hazard": {"type": "str"}, # (Function)
            "Attracted Animals": {"type": "str"}, # (Function)
            "Native habitat": {"type": "str"}, # Not important in querying, but important when picking best options
            "Mature Leaf Colour": {"type": "str"}, # Colour from prompt, if general (not specific to flower)
            "Young Flush Leaf Colour": {"type": "str"}, # Colour from prompt, if general (not specific to flower)
            "Leaf Area Index": {"type": "str"}, # Not important
            "Growth Rate": {"type": "keyword"}, # Not important in querying, but important when picking best options
            "Trunk Colour": {"type": "str"}, # Colour from prompt, if general (not specific to flower)
            "Trunk Texture": {"type": "str"}, # Not important in querying, but important when picking best options (adding variation)
            "Leaf Texture": {"type": "keyword"}, # Not important in querying, but important when picking best options (adding variation)            
        }
    }
}

In [19]:
# Expected Input to Backend
import random

def generate_api_call():
    api_call = {
        "prompt": random.choice(data)['prompt'],
        "maximum_plant_count": random.randint(3, 8),
        "light_preference": random.choice(["Full Shade", "Semi Shade", "Full Sun"]),
        "water_preference": random.choice(["Lots of Water", "Moderate Water", "Little Water"]),
        "drought_tolerant": random.choice([True, False]),
        "fauna_attracted": random.sample(["Butterfly", "Bird", "Caterpillar Moth", "Bat", "Bee"], k=random.randint(0,2)),
        "ratio_native": round(random.uniform(0,1), 1)
    }
    return api_call

generate_api_call()

{'prompt': 'Plan a manicured sitting area',
 'maximum_plant_count': 8,
 'light_preference': 'Full Sun',
 'water_preference': 'Lots of Water',
 'drought_tolerant': True,
 'fauna_attracted': [],
 'ratio_native': 0.3}

## Using Function Description

### Base Variables

In [92]:
function_design_considerations = {
    "Boundary": "Provide two layers of trees within the buffer area where space permits, to offer more shade and greenery while buffering the roads", 
    "Playground": "Surrounding landscape can include colourful and fragrant plants to stimulate the senses. Minimial Hazards.", 
    "Active Zone": "Colourful and educational planting can be provided around playgrounds to enhance play experience. May include biodiversity-attracting or useful plants. Minimial Hazards.", 
    "Central Precinct Garden": "The main recreational area for residents.  Within this space, there are both active and passive zones",
    "Passive Zone": "Useful planting such as medicinal plants, spices, timber trees or fruit trees can be provided away from flats and activity spaces. Choose perennial plant species for durability", 
    "Sitting Corner": "Maintain sight lines between seating areas and facilities to facilitate supervision", 
    "Pavilion": "A green roof can be provided to soften the structure", 
    "Pergola": "Pergolas can be designed to allow creepers to grow over structure for shade and greenery", 
    "Future Community Garden": "Initial planting can be turf or hardy, edible plants", 
    "Butterfly Garden": "Green sanctuaries with butterfly attracting plants to support a butterfly ecosystem"
}

In [None]:
system_instruction = f"""You are a database querying model to retrieve the required plants for a landscape design. You will be given user requirements and the function of the landscape area.

Using the user requirements and function, return the value for each database key. 
If a key is invalid to the querying, let the value for that key be None.
For colours, unless explicitedly mentioned should be for flowers, apply the colour to all colour fields.

---------------------
Function definition:
{function_design_considerations}

---------------------
Database keys and possible values:
- Plant Type: list[str] (Tree, Shrub Palm, Herbaceous Plant, Creeper, Climber)
- Fruit Bearing: str (True, False, None)
- Fragrant Plant: str (True, False, None)
- Height: str (Tall, Short, None)
- Flower Colour: list[str]
- Attracted Animals: list[str] (Bird, Butterfly, Bee, Caterpillar Moth, Bat)
- Avoid Animals: list[str] (Bird, Butterfly, Bee, Caterpillar Moth, Bat)
- Mature Leaf Colour: list[str]
- Young Flush Leaf Colour: list[str]
- Trunk Colour: list[str]
- Trunk Texture: list[str]
- Leaf Texture: list[str] (Fine, Medium, Coarse)
"""

### Gemini Setup

In [None]:
import os
import google.generativeai as genai
import typing_extensions as typing
from typing_extensions import TypedDict, Literal
import ast

os.environ["GEMINI_API_KEY"] = ""

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

class Result(typing.TypedDict):
    Plant_Type: list[Literal["Palm", "Herbaceous Plant", "Grass", "Epiphyte", "Creeper", "Climber"]]
    Fruit_Bearing: Literal["True", "False", "None"]
    Fragrant_Plant: Literal["True", "False", "None"]
    Height: Literal["Tall", "Short", "None"]
    Flower_Colour: list[str]
    Attracted_Animals: list[Literal["Bird", "Butterfly", "Bee", "Caterpillar Moth", "Bat"]]
    Avoid_Animals: list[Literal["Bird", "Butterfly", "Bee", "Caterpillar Moth", "Bat"]]
    Mature_Leaf_Colour: list[str]
    Young_Flush_Leaf_Colour: list[str]
    Trunk_Colour: list[str]
    Trunk_Texture: list[str]
    Leaf_Texture: list[Literal["Fine", "Medium", "Coarse"]]


gemini = genai.GenerativeModel(
  model_name="gemini-1.5-flash",
  system_instruction=system_instruction
)

In [None]:
gemini_json = {"Gemini": []}

# Downloading Gemini Output
for test_data in data:
    try:
        parsed_data = extract_keywords(data_list, test_data['prompt'], fuzz.token_set_ratio)
        user_prompt = f"Prompt: {test_data}\nFunction: {parsed_data['function']}"
        # OPENAI
        response = gemini.generate_content(
            user_prompt,
            generation_config=genai.GenerationConfig(
                response_mime_type="application/json", response_schema=Result
        ))

        gemini_response_json = json.loads(response.text)
        gemini_response_json["Prompt"] = test_data['prompt']
        gemini_response_json["Function"] = parsed_data["function"]

        gemini_json["Gemini"].append(gemini_response_json)
    
    except:
        time.sleep(60)
        parsed_data = extract_keywords(data_list, test_data['prompt'], fuzz.token_set_ratio)
        user_prompt = f"Prompt: {test_data}\nFunction: {parsed_data['function']}"
        # OPENAI
        response = gemini.generate_content(
            user_prompt,
            generation_config=genai.GenerationConfig(
                response_mime_type="application/json", response_schema=Result
        ))

        gemini_response_json = json.loads(response.text)
        gemini_response_json["Prompt"] = test_data['prompt']
        gemini_response_json["Function"] = parsed_data["function"]

        gemini_json["Gemini"].append(gemini_response_json)

    # if i % 5 == 0:
    #     time.sleep(90)

with open('../src/input/data/Gemini_query.json', 'w') as f:
    json.dump(gemini_json, f, indent=4)

### GPT Setup

In [None]:
from openai import OpenAI

os.environ["OPENAI_API_KEY"] = ""

client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

response_format = {
    "type": "json_schema",
    "json_schema": {
        "name": "result",
        "schema": {
            "type": "object",
            "properties": {
                "Plant Type": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Palm", 
                            "Herbaceous Plant",
                            "Grass",
                            "Epiphyte",
                            "Creeper",
                            "Climber"
                        ]
                    }
                },
                "Fruit Bearing": {
                    "type": "string",
                    "enum": [
                        "True",
                        "False",
                        "None"
                    ]
                },
                "Fragrant Plant": {
                    "type": "string",
                    "enum": [
                        "True",
                        "False",
                        "None"
                    ]                
                },
                "Height" : {
                    "type": "string",
                    "enum": [
                        "Tall",
                        "Short",
                        "None"
                    ]
                },
                "Flower Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Attracted Animals": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Bird",
                            "Butterfly", 
                            "Bee", 
                            "Caterpillar Moth",
                            "Bat"
                        ]
                    }
                },
                "Avoid Animals": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Bird",
                            "Butterfly", 
                            "Bee", 
                            "Caterpillar",
                            "Bat"
                        ]
                    }
                },
                "Mature Leaf Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Young Flush Leaf Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Trunk Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                }, 
                "Trunk Texture": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Leaf Texture": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Fine",
                            "Medium",
                            "Coarse"
                        ]
                    }
                }
            },
            "required": ["Plant Type", "Fruit Bearing", "Fragrant Plant", "Height", "Flower Colour", "Attracted Animals", "Avoid Animals", "Mature Leaf Colour", "Young Flush Leaf Colour", "Trunk Colour", "Trunk Texture", "Leaf Texture"],
            "additionalProperties": False
        },
        "strict": True
    }
}

In [None]:
# Downloading OpenAI Output
openai_json = {"GPT": []}

for test_data in data:
    parsed_data = extract_keywords(data_list, test_data['prompt'], fuzz.token_set_ratio)
    user_prompt = f"Prompt: {test_data}\nFunction: {parsed_data['function']}"
    # OPENAI
    completion = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": system_instruction},
            {"role": "user", "content": user_prompt}
        ],
        response_format=response_format,
    )
    gpt_response_json = json.loads(completion.choices[0].message.content)
    gpt_response_json["Prompt"] = test_data['prompt']
    gpt_response_json["Function"] = parsed_data["function"]

    openai_json["GPT"].append(gpt_response_json)

with open('../src/input/data/GPT_query.json', 'w') as f:
    json.dump(openai_json, f, indent=4)

### Comparing Results

In [82]:
for i in range(5):
    test_data = generate_api_call()
    parsed_data = extract_keywords(data_list, test_data['prompt'], fuzz.token_set_ratio)

    user_prompt = f"Prompt: {test_data['prompt']}\nFunction: {parsed_data['function']}"
    
    # GEMINI
    response = gemini.generate_content(
        user_prompt,
        generation_config=genai.GenerationConfig(
            response_mime_type="application/json", response_schema=Result
    ))

    # OPENAI
    completion = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": system_instruction},
            {"role": "user", "content": user_prompt}
        ],
        response_format=response_format,
    )
    # response_json = ast.literal_eval(completion.choices[0].message.content)
    print(user_prompt)
    gpt_response_json = json.loads(completion.choices[0].message.content)
    gemini_response_json = json.loads(response.text)
    
    for key, values in gpt_response_json.items():
        print(f"{key}: GPT - {values}")
        print(f"{key}: GEMINI - {gemini_response_json[key.replace(' ', '_')]}")

    print("-------------")

Prompt: Design a manicured passive zone
Function: Passive Zone
Plant Type: GPT - ['Palm', 'Herbaceous Plants']
Plant Type: GEMINI - ['Herbaceous Plants', 'Climber', 'Creeper']
Fruit Bearing: GPT - None
Fruit Bearing: GEMINI - None
Fragrant Plant: GPT - None
Fragrant Plant: GEMINI - True
Height: GPT - None
Height: GEMINI - Tall
Flower Colour: GPT - []
Flower Colour: GEMINI - ['White']
Attracted Animals: GPT - []
Attracted Animals: GEMINI - ['Bird', 'Butterfly', 'Bee', 'Caterpillar']
Avoid Animals: GPT - []
Avoid Animals: GEMINI - ['Bat']
Mature Leaf Colour: GPT - []
Mature Leaf Colour: GEMINI - ['Green']
Young Flush Leaf Colour: GPT - []
Young Flush Leaf Colour: GEMINI - ['Green']
Trunk Colour: GPT - []
Trunk Colour: GEMINI - ['Brown']
Trunk Texture: GPT - []
Trunk Texture: GEMINI - ['Rough']
Leaf Texture: GPT - []
Leaf Texture: GEMINI - ['Medium']
-------------
Prompt: Design a landscape with picturesque style
Function: None
Plant Type: GPT - []
Plant Type: GEMINI - ['Palm', 'Herbaceou

## Using Hardcoded Function Details

### Base Variables

In [4]:
function_design_requirements = {
    "Boundary": {
        "Height": "Tall",
    },
    "Playground": {
        "Fragrant Plant": True,
        "Hazard": "N/A",
        "Avoid Animals": ["Bee"],
    },
    "Active Zone": {
        "Plant Type": ["Herbaceous Plant"],
        "Fragrant Plant": True,
        "Fruit Bearing": True,
        "Hazard": "N/A",
        "Avoid Animals": ["Bee"],
        "Attracted Animals": ["Bird", "Butterfly", "Caterpillar Moth"]
    },
    "Central Precinct Garden": {}, # Not sure what to put for this ngl
    "Passive Zone": {
        "Plant Type": ["Herbaceous Plant"],
    },
    "Sitting Corner": {
        "Height": "Short",
        "Hazard": "N/A",
        "Avoid Animals": ["Bee"]
    },
    "Pavilion": {
        "Plant Type": ["Creeper", "Climber"],
    },
    "Pergola": {
        "Plant Type": ["Creeper", "Climber"],
    },
    "Future Community Garden": {
        "Plant Type": ["Herbaceous Plant"],
        "Fruit Bearing": True,
        "Avoid Animals": ["Bee"],
    },
    "Butterfly Garden": {
        "Attracted Animals": ["Butterfly", "Caterpillar Moth"],
        "Avoid Animals": ["Bee"]
    }
}

### GPT Setup

In [5]:
system_instruction = f"""You are a database querying model to retrieve the required plants for a landscape design.

Using the user requirements, return the value for each database key. 
If a key is invalid to the querying, let the value for that key be None.
For colours, unless explicitedly mentioned should be for flowers, apply the colour to all colour fields.

---------------------
Database keys and possible values:
- Plant Type: list[str] (Palm, Herbaceous Plants)
- Fruit Bearing: str (True, False, None)
- Fragrant Plant: str (True, False, None)
- Height: str (Tall, Short, None)
- Flower Colour: list[str]
- Attracted Animals: list[str] (Bird, Butterfly, Bee, Caterpillar Moth, Bat)
- Avoid Animals: list[str] (Bird, Butterfly, Bee, Caterpillar Moth, Bat)
- Mature Leaf Colour: list[str]
- Young Flush Leaf Colour: list[str]
- Trunk Colour: list[str]
- Trunk Texture: list[str]
- Leaf Texture: list[str] (Fine, Medium, Coarse)
"""

In [None]:
from openai import OpenAI
import os 
os.environ["OPENAI_API_KEY"] = ""

client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

response_format = {
    "type": "json_schema",
    "json_schema": {
        "name": "result",
        "schema": {
            "type": "object",
            "properties": {
                "Plant Type": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Palm", 
                            "Herbaceous Plant",
                            "Grass",
                            "Epiphyte",
                            "Creeper",
                            "Climber"
                        ]
                    }
                },
                "Fruit Bearing": {
                    "type": "string",
                    "enum": [
                        "True",
                        "False",
                        "None"
                    ]
                },
                "Fragrant Plant": {
                    "type": "string",
                    "enum": [
                        "True",
                        "False",
                        "None"
                    ]                
                },
                "Height" : {
                    "type": "string",
                    "enum": [
                        "Tall",
                        "Short",
                        "None"
                    ]
                },
                "Flower Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Attracted Animals": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Bird",
                            "Butterfly", 
                            "Bee", 
                            "Caterpillar Moth",
                            "Bat"
                        ]
                    }
                },
                "Avoid Animals": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Bird",
                            "Butterfly", 
                            "Bee", 
                            "Caterpillar Moth",
                            "Bat"
                        ]
                    }
                },
                "Mature Leaf Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Young Flush Leaf Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Trunk Colour": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                }, 
                "Trunk Texture": {
                    "type": "array",
                    "items": {
                        "type": "string"
                    }
                },
                "Leaf Texture": {
                    "type": "array",
                    "items": {
                        "type": "string",
                        "enum": [
                            "Fine",
                            "Medium",
                            "Coarse"
                        ]
                    }
                }
            },
            "required": ["Plant Type", "Fruit Bearing", "Fragrant Plant", "Height", "Flower Colour", "Attracted Animals", "Avoid Animals", "Mature Leaf Colour", "Young Flush Leaf Colour", "Trunk Colour", "Trunk Texture", "Leaf Texture"],
            "additionalProperties": False
        },
        "strict": True
    }
}

### Query Generation

In [8]:
elastic_search_requirements = {
    "Plant Type": "should",
    "Light Preference" : "must", 
    "Water Preference": "must", 
    "Drought Tolerant": "must", 
    "Attracted Animals": "should", 
    "Hazard": "must", 
    "Avoid Animals": "must_not",
    "Fruit Bearing": "should",
    "Fragrant Plant": "should", 
    "Flower Colour": "should", 
    "Mature Leaf Colour": "should", 
    "Young Flush Leaf Colour": "should", 
    "Trunk Colour": "should", 
    "Trunk Texture": "should", 
    "Leaf Texture": "should"
}

In [9]:
def generate_query(data_list, user_call):

    # Base Query from the other option values
    query = {
        "bool": {
            "must": [],
            "must_not": [],
            "should": []
        }
    }

    # Light Preference
    query["bool"][elastic_search_requirements["Light Preference"]].append({
        "terms": {"Light Preference.keyword": list(set(["Full Shade", user_call['light_preference']]))}
    })

    # Water Preference
    query["bool"][elastic_search_requirements["Water Preference"]].append({
        "term": {"Water Preference.keyword": user_call['water_preference']}
    })

    # Drought Tolerant
    query["bool"][elastic_search_requirements["Drought Tolerant"]].append({
        "term": {"Drought Tolerant": user_call['drought_tolerant']}
    })

    # Extracting Function, Style & Surrounding
    function_style_surrounding_extraction = extract_keywords(data_list, user_call['prompt'], fuzz.token_set_ratio)
    function_requirements = function_design_requirements.get(function_style_surrounding_extraction['function'], {})

    # Extraction of Colours and Requirements
    completion = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": system_instruction},
            {"role": "user", "content": user_call['prompt']}
        ],
        response_format=response_format,
    )
    gpt_response_json = json.loads(completion.choices[0].message.content)

    # Plant Types
    function_plant_requirements = function_requirements.get("Plant Type", [])
    all_plant_types = list(set(["Tree", "Shrub"] + function_plant_requirements + gpt_response_json["Plant Type"]))
    
    query["bool"][elastic_search_requirements["Plant Type"]].append({
        "terms": {"Plant Type.keywords": all_plant_types}
    })


    # Hazard
    hazard_requirements = function_requirements.get("Hazard", None)
    if hazard_requirements != None:
        query["bool"][elastic_search_requirements["Hazard"]].append({
            "term": {"Hazard": hazard_requirements}
        })


    # Other variables
    for key, items in gpt_response_json.items():
        
        # Ignore non querying terms
        if key == "Plant Type" or key == "Height":
            pass

        # Boolean Terms, Prioritise the gpt response > preset function requirements
        elif key == "Fruit Bearing" or key == "Fragrant Plant":
            gpt_value = True if items == "True" else False if items == "False" else None
            function_value = function_requirements.get(key, None)

            # only add query if none are None
            if gpt_value != None and function_value != None:
                query["bool"][elastic_search_requirements[key]].append({
                    "term": {key: gpt_value if gpt_value != None else function_value}
                })

        # Attracted Animals, need include option data
        elif key == "Attracted Animals":
            function_value = function_requirements.get(key, [])
            all_animal_values = list(set(items + function_value + user_call['fauna_attracted']))
            if len(all_animal_values) > 0:
                string_value = " ".join(all_animal_values)
                query["bool"][elastic_search_requirements[key]].append({
                    "match": {key: string_value}
                })

        # Avoid Animals is under Attraced Animals
        elif key == "Avoid Animals":
            function_value = function_requirements.get(key, [])
            all_values = list(set(items + function_value))
            if len(all_values) > 0:
                string_value = " ".join(all_values)
                query["bool"][elastic_search_requirements[key]].append({
                    "match": {"Attracted Animals": string_value}
                })

        else:
            # List of Str terms, convert to a long string
            function_value = function_requirements.get(key, [])
            all_values = list(set(items + function_value))
            if len(all_values) > 0:
                string_value = " ".join(all_values)
                query["bool"][elastic_search_requirements[key]].append({
                    "match": {key: string_value}
                })

    other_requirements = {
        "Maximum Plant Count": user_call['maximum_plant_count'],
        "Ratio Native": user_call['ratio_native'],
        "Plant Type": all_plant_types,
        "Attracted Animals": all_animal_values,
        "Light Preference": user_call['light_preference']
    }

    if gpt_response_json['Height'] != "None" and gpt_response_json['Height'] != None:
        other_requirements['Height'] = gpt_response_json['Height']

    return query, other_requirements

In [279]:
for i in range(3):
    user_call = generate_api_call()
    query, requirements = generate_query(data_list, user_call)
    print(user_call['prompt'])
    print(query)
    print(requirements)

Plan a rustic Active Zone with vibrant green foliage
{'bool': {'must': [{'terms': {'Light Preference.keyword': ['Full Shade', 'Semi Shade']}}, {'term': {'Water Preference.keyword': 'Little Water'}}, {'term': {'Drought Tolerant': False}}, {'term': {'Hazard': 'N/A'}}], 'must_not': [{'match': {'Attracted Animals': 'Bee'}}], 'should': [{'terms': {'Plant Type.keywords': ['Tree', 'Shrub', 'Herbaceous Plant']}}, {'match': {'Attracted Animals': 'Caterpillar Moth Butterfly Bird'}}, {'match': {'Mature Leaf Colour': 'green'}}, {'match': {'Young Flush Leaf Colour': 'green'}}]}}
{'Maximum Plant Count': 4, 'Ratio Native': 0.6, 'Plant Type': ['Tree', 'Shrub', 'Herbaceous Plant'], 'Attracted Animals': ['Caterpillar Moth', 'Butterfly', 'Bird'], 'Light Preference': 'Semi Shade'}
Design a rustic boundary area
{'bool': {'must': [{'terms': {'Light Preference.keyword': ['Full Shade']}}, {'term': {'Water Preference.keyword': 'Moderate Water'}}, {'term': {'Drought Tolerant': True}}], 'must_not': [], 'should':

## Querying, Ranking and retrieving results from ElasticSearch
Other than elastic search hits, we will also need to consider
1. Native Ratio
2. Maximum Plant Count
3. Height
4. Native Habitat
5. Growth Rate
6. Variation in Leaf Texture & Trunk Texture

### Imports

In [None]:
# Import and variables
import collections
import os
from typing import Union, List, Dict
import pandas as pd

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan, streaming_bulk

# Map common python types to ES Types
TYPE_MAP = {
    "int": "integer",
    "float": "float",
    "double": "double",
    "str": "text",
    "bool": "boolean",
    "datetime": "date",
    "list[int]": "integer",
    "list[str]": "text",
    "list[float]": "float",
    "list[double]": "double",
    "torch.tensor": "dense_vector",
    "numpy.ndarray": "dense_vector",
    "keyword": "keyword"
}

MAX_BULK_SIZE = 100

os.environ['ELASTIC_USERNAME'] = 'elastic'
os.environ['ELASTIC_PASSWORD'] = '' # To be filled
os.environ['ELASTIC_PORT'] = '9200'
os.environ['ELASTIC_HOST'] = 'localhost'

In [5]:
class ESManager():
    """
    Class to manage ElaticSearch
    """
    def __init__(self):
        self.url = f"http://{os.getenv('ELASTIC_HOST')}:{os.getenv('ELASTIC_PORT')}"
        self.username =  os.getenv('ELASTIC_USERNAME')
        self.password = os.getenv('ELASTIC_PASSWORD')

        self.client = Elasticsearch(self.url,
                                    verify_certs=False,
                                    basic_auth=(self.username, self.password), request_timeout=30, max_retries=10, retry_on_timeout=True)

        print(self.client.info())

        self.consolidated_actions = []

    def _check_data_type(self, var, var_type):
        try:
            assert type(var) == var_type
        except:
            return False
        return True

    def _check_valid_values(self, map_dict: dict) -> int:
        """
        Traverse mapping dictionary to ensure that all types are valid types within TYPE_MAP

        Args:
            map_dict (dict): Mapping to be checked

        Returns:
            int: 0 if there is invalid types, 1 otherwise

        """
        ret_val = 1
        for k, v in map_dict.items():
            if isinstance(v, dict):
                ret_val = self._check_valid_values(v)
            else:
                if not v in TYPE_MAP:
                    print(f"'{v}' type for '{k}' NOT FOUND")
                    return 0

        return ret_val * 1

    def _traverse_map(self, map_dict: Dict) -> Dict:
        """
        Traverse mapping dictionary to convert data type into framework specific type

        Args:
            map_dict (dict): Mapping to be used to create ES index

        Returns:
            dict: updated mapping dictionary

        """
        dictionary = {"properties": dict()}
        for k, v in map_dict.items():
            if isinstance(v, dict):
                dictionary['properties'][k] = self._traverse_map(v)
            else:
                dictionary['properties'][k] = {"type": TYPE_MAP[v]}
        return dictionary

    def _flush(self):
        errors = []
        list_of_es_ids = []
        for ok, item in streaming_bulk(self.client, self.consolidated_actions):
            if not ok:
                errors.append(item)
            else:
                list_of_es_ids.append(item['index']['_id'])
        if len(errors) != 0:
            print("List of faulty documents:", errors)
        self.consolidated_actions = []  # Reset List
        return list_of_es_ids

    def _flatten(self, d, parent_key='', sep='.'):
        """
        Flatten nested dictionary keys to dotted parameters because Elasticsearch. 
        """
        items = []
        for k, v in d.items():
            new_key = parent_key + sep + k if parent_key else k
            if isinstance(v, collections.MutableMapping):
                items.extend(self._flatten(v, new_key, sep=sep).items())
            else:
                items.append((new_key, v))
        return dict(items)

    def create_collection(self, collection_name: str, schema: Dict, custom_schema: bool = False) -> Dict:
        """
        Create the index on ElasticSearch

        Args:
            collection_name (str): Index name of ES
            schema (dict): Mapping to be used to create ES index
            custom_schema (bool): If set to True, user may input schema that in accordance to ElasticSearch Mapping's format. The schema will not be parsed. 

        Returns:
            dict: response of error, or 200 if no errors caught

        """
        if not self._check_data_type(schema, dict):
            return {"response": "Type of 'schema' is not dict"}
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(custom_schema, bool):
            return {"response": "Type of 'custom_schema' is not bool"}
        if custom_schema:
            try:
                self.client.indices.create(
                    index=collection_name, mappings=schema)
            except Exception as e:
                return {"response": f"{e}"}
            return {"response": "200"}
        else:
            mapping_validity = self._check_valid_values(schema)
            if not mapping_validity:
                return {"response": "KeyError: data type not found in TYPE_MAP"}
            updated_mapping = self._traverse_map(schema)
            try:
                self.client.indices.create(
                    index=collection_name, mappings=updated_mapping)
            except Exception as e:
                return {"response": f"{e}"}
            return {"response": "200"}

    def delete_collection(self, collection_name: str) -> dict:
        """
        Create the index on ElasticSearch

        Args:
            collection_name (str): Index name of ES
            schema (dict): Mapping to be used to create ES index

        Returns:
            dict: response of error, or 200 if no errors caught

        """
        try:
            self.client.indices.delete(index=collection_name)
        except Exception as e:
            return {"response": f"{e}"}
        return {"response": "200"}

    def create_document(self, collection_name: str, documents: Union[list, dict], id_field: str = None) -> dict:
        """
        Upload document(s) in the specified index within ElasticSearch

        Args:
            collection_name (str): Index name of ES
            documents (dict, list): A dict of document objects to be ingested. A list of dict is accepted as well. 
            id_field (str, Optional): Specify the key amongst the document object to be the id field. If not specified, id will be generated by ES. 

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the ids of ingested document if no errors caught

        """
        if not self._check_data_type(documents, list):
            if not self._check_data_type(documents, dict):
                return {"response": "Type of 'documents' is not dict or a list"}
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not id_field is None:
            if not self._check_data_type(id_field, str):
                return {"response": "Type of 'id_field' is not str"}

        # If single document, wrap it in a list so it can be an iterable as it would be when a list of document is submitted
        if type(documents) == dict:
            documents = [documents]

        # If id_field is specified, verify that all documents possess the id_field.
        if id_field != None:
            for doc in documents:
                if not id_field in doc.keys():
                    print(
                        "Fix document, or set 'id_field' to None. No documents uploaded.")
                    return {"response": "Fix document, or set 'id_field' to None. No documents uploaded.",
                            "error_doc": doc}
                try:
                    doc[id_field] = str(doc[id_field])
                except Exception as e:
                    return {"response": "id cannot be casted to String type. No documents uploaded.",
                            "error_doc": doc}
        all_id = []
        for doc in documents:
            doc_copy = dict(doc)
            action_dict = {}
            action_dict['_op_type'] = 'index'
            action_dict['_index'] = collection_name
            if id_field != None:
                action_dict['_id'] = doc_copy[id_field]
                doc_copy.pop(id_field)
            action_dict['_source'] = doc_copy
            self.consolidated_actions.append(action_dict)
            if len(self.consolidated_actions) == MAX_BULK_SIZE:
                all_id = all_id+self._flush()

        all_id = all_id+self._flush()

        return {"response": "200", "ids": all_id}

    def delete_document(self, collection_name: str, doc_id: str) -> dict:
        """
        Delete document from index based on the specified document id. 

        Args:
            collection_name (str): Index name of ES
            doc_id (str): id of doc to be deleted

        Returns:
            dict: response of error along with the faulty document, or code 200 along with elastic API response

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(doc_id, str):
            return {"response": "Type of 'doc_id' is not str"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query={
                                           "match": {"_id": doc_id}})
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"Document '{doc_id}' not found!"}

        try:
            resp = self.client.delete(index=collection_name, id=doc_id)
        except Exception as e:
            return {"response": f"{e.__class__.__name__}. Document Deletion failed"}

        return {"response": "200", "api_resp": resp}

    def update_document(self, collection_name: str, doc_id: str, document: dict) -> dict:
        """
        Delete document from index based on the specified document id. 

        Args:
            collection_name (str): Index name of ES
            doc_id (str): id of doc to be updated
            document (dict): key and values of fields to be updated.

        Returns:
            dict: response of error along with the faulty document, or code 200 along with elastic API response

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(doc_id, str):
            return {"response": "Type of 'doc_id' is not str"}
        if not self._check_data_type(document, dict):
            return {"response": "Type of 'document' is not dict"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query={
                                           "match": {"_id": doc_id}})
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"Document '{doc_id}' not found, create document first"}

        try:
            for key in document.keys():

                q = {
                    "script": {
                        "source": f"ctx._source.{key}=params.infer",
                        "params": {
                            "infer": document[key]
                        },
                        "lang": "painless"
                    },
                    "query": {
                        "match": {
                            "_id": doc_id
                        }
                    }
                }
                resp = self.client.update_by_query(
                    body=q, index=collection_name)
        except Exception as e:
            return {"response": f"{e.__class__.__name__}. Document Update failed"}

        return {"response": "200", "api_resp": resp}

    def read_document(self, collection_name: str, doc_id: str) -> dict:
        """
        Read document from index based on the specified document id. 

        Args:
            collection_name (str): Index name of ES
            doc_id (str): id of doc to be read

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the retrieved document

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(doc_id, str):
            return {"response": "Type of 'doc_id' is not str"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query={
                                           "match": {"_id": doc_id}})
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"Document '{doc_id}' not found!"}

        doc_body = search_result['hits']['hits']

        return {"response": "200", "api_resp": doc_body}

    def query_collection(self, collection_name: str, field_value_dict: dict) -> dict:
        """
        Read document from index based on the specific key-value dictionary query. 

        Args:
            collection_name (str): Index name of ES
            field_value_dict (dict): A dictionary with the field to be queried as the key, and the value to be queried as the value of the dictionary. 
                                    example: {"field1":"query1", "field2", "query2"}

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the list of retrieved document

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(field_value_dict, dict):
            return {"response": "Type of 'field_value_dict' is not dict"}

        # Check for document's existence
        reorg_dict = {"bool":{
            "should":[]
            }
        }
        for field in field_value_dict:
            reorg_dict['bool']['should'].append({"match":{field:field_value_dict[field]}})

        search_result = self.client.search(index=collection_name, query=reorg_dict)
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"No documents found."}

        docs = search_result['hits']['hits']

        return {"response": "200", "api_resp": docs}

    def custom_query(self, collection_name: str, query: dict, size:int=10) -> dict:
        """
        Read document from index based on custom ES query syntax. 

        Args:
            collection_name (str): Index name of ES
            query (dict): Custom query for ES users who are familiar with the query format
            size (int): Number of results to return per query

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the list of retrieved document

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(query, dict):
            return {"response": "Type of 'field_value_dict' is not dict"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query=query, size=size)
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"No documents found."}

        docs = search_result['hits']['hits']

        return {"response": "200", "api_resp": docs}

    def get_all_documents(self, collection_name: str) -> dict:
        """
        Generator method to retrieve all documents within the index

        Args:
            collection_name (str): Index name of ES

        Returns:
            Generator Object: Iterable object containing all documents within index specified. 
        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        docs_response = scan(self.client, index=collection_name, query={
                             "query": {"match_all": {}}})
        for item in docs_response:
            yield item

In [None]:
esManager = ESManager()

### Database Querying

In [25]:
def retrieve_results(esManager, query, other_requirements):
    result = esManager.custom_query('flora', query, 9999)
    if result['response'] == '200' and len(result['api_resp']) >= other_requirements['Maximum Plant Count']:
        print(f"Retrieved {len(result['api_resp'])} results.")
        for data in result['api_resp']:
            print(data['_source'])
            print(data['_score'])
        return result
    else:
        print(f"Not enough data was retrieved from the dataset. Required {other_requirements['Maximum Plant Count']} but retrieved {len(result['api_resp']) if result['response'] == '200' else 0}")
        return None

In [27]:
user_call = generate_api_call()
query, requirements = generate_query(data_list, user_call)
print(user_call['prompt'])
print(query)
print(requirements)
result = retrieve_results(esManager, query, requirements)

Plan a picturesque landscape with pink, blue, and white blossoms
{'bool': {'must': [{'terms': {'Light Preference.keyword': ['Semi Shade', 'Full Shade']}}, {'term': {'Water Preference.keyword': 'Lots of Water'}}, {'term': {'Drought Tolerant': True}}], 'must_not': [], 'should': [{'terms': {'Plant Type.keywords': ['Shrub', 'Tree']}}, {'match': {'Flower Colour': 'pink blue white'}}]}}
{'Maximum Plant Count': 7, 'Ratio Native': 0.4, 'Plant Type': ['Shrub', 'Tree'], 'Attracted Animals': [], 'Light Preference': 'Semi Shade'}
Not enough data was retrieved from the dataset. Required 7 but retrieved 0


In [29]:
# Expected Input to Backend
import random
api_call = {
    "prompt": "Create a minimalist area with soft beige and light green tones",
    "maximum_plant_count": random.randint(3, 6),
    "light_preference": "Full Sun",
    "water_preference": "Moderate Water",
    "drought_tolerant": False,
    "fauna_attracted": [],
    "ratio_native": round(random.uniform(0,1), 1)
}

In [30]:
query, requirements = generate_query(data_list, api_call)
print(api_call['prompt'])
print(query)
print(requirements)
result = retrieve_results(esManager, query, requirements)

Create a minimalist area with soft beige and light green tones
{'bool': {'must': [{'terms': {'Light Preference.keyword': ['Full Sun', 'Full Shade']}}, {'term': {'Water Preference.keyword': 'Moderate Water'}}, {'term': {'Drought Tolerant': False}}], 'must_not': [], 'should': [{'terms': {'Plant Type.keywords': ['Shrub', 'Tree']}}, {'match': {'Flower Colour': 'Light Green Beige'}}, {'match': {'Mature Leaf Colour': 'Light Green Beige'}}, {'match': {'Young Flush Leaf Colour': 'Light Green Beige'}}, {'match': {'Trunk Colour': 'Light Green Beige'}}]}}
{'Maximum Plant Count': 5, 'Ratio Native': 0.7, 'Plant Type': ['Shrub', 'Tree'], 'Attracted Animals': [], 'Light Preference': 'Full Sun'}
Retrieved 21 results.
{'Scientific Name': "Radermachera 'Kunming'", 'Common Name': 'Dwarf Tree Jasmine', 'Species ID': 2381, 'Link': 'https://www.nparks.gov.sg/florafaunaweb/flora/2/3/2381', 'Plant Type': ['Shrub'], 'Light Preference': ['Full Sun', 'Semi Shade'], 'Water Preference': ['Moderate Water'], 'Drough

In [None]:
import pandas as pd
import numpy as np
from collections import Counter

def rerank_result(result, other_requirements):
    # Function to simply rerank the scores of all the results
    """
    Reranking ElasticSearch results with the following:
    1. Height, if Height = Tall, rerank from tall to short, with +0.2 incremental score in each plant type
    2. Plant Type: If provided and not Shrub & Tree, +1 score wise
    3. Attracted Animals: If provided, +2 score wise
    4. Native Habitat: Calculate all unique instances of Native Habitat, add score of counts / total (we want same habitats to be boosted)
    5. Leaf Texture: Calculate all unique instance of Leaf Texture, add 1/ counts (we want different textures)
    """

    # Convert response into a pandas dataframe for easy access
    processed_data = [
        {**{'_score': item['_score']}, **item['_source']}
        for item in result['api_resp']
    ]

    data = pd.DataFrame(processed_data)

    # Height
    height_value = other_requirements.get('Height', None)
    if height_value != None:
        sorted_data = data.sort_values(by="Maximum Height (m)", ascending= (False if height_value == 'Tall' else True))
        for plant_type in [["Shrub"], ["Tree", "Palm"]]:
            # Filter rows where Plant Type contains the target type, then apply cumulative scoring
            mask = sorted_data["Plant Type"].apply(lambda x: any(plant_t in x for plant_t in plant_type))
            num_items = mask.sum()
            incremental_values = np.arange(num_items - 1, -1, -1) * 0.2
            # Count from the bottom +0.25 * num of value from bottom for specific class
            sorted_data.loc[mask, "_score"] += incremental_values
            data = sorted_data
        
    # Plant Type
    prompt_plant_type = other_requirements['Plant Type'].copy()
    prompt_plant_type.remove('Tree')
    prompt_plant_type.remove('Shrub')
    # Unique plant type exists
    if len(prompt_plant_type) > 0:
        mask = data["Plant Type"].apply(lambda x: any(plant_t in x for plant_t in prompt_plant_type))
        data.loc[mask, "_score"] += 1

    # Attracted Animals
    prompt_attracted_animals = other_requirements['Attracted Animals']
    if len(prompt_attracted_animals) > 0:
        mask = data['Attracted Animals'].str.contains('|'.join(prompt_attracted_animals), case=False)
        data.loc[mask, "_score"] += 2

    # Native Habitat
    # Counting all habitat count
    habitat_counter = Counter()
    for habitat in data['Native habitat']:
        if "(" in habitat:
            # Get overaching and sub habitats
            overarching_habitat, sub_habitat_list = extract_overarching_subhabitat(habitat)
            # Counter
            habitat_counter[overarching_habitat] += 1
            for sub_value in sub_habitat_list:
                habitat_counter[f'{overarching_habitat} ({sub_value.strip()})'] += 1
        # Only overaching, no sub habitats
        else:
            habitat_counter[habitat.strip()] += 1

    # Applying score of habitat counts
    # If just overarching terrestial, ignore
    # Else we + the best score which is the maxcount / total count
    for index, row in data.iterrows():
        habitat = row['Native habitat']
        if "(" in habitat:
            # Get overaching and sub habitats
            overarching_habitat, sub_habitat_list = extract_overarching_subhabitat(habitat)
            # Get counter score
            sub_counts = [habitat_counter.get(f"{overarching_habitat} ({sub.strip()})", 0) for sub in sub_habitat_list]
            max_sub_count = max(sub_counts)
            score_ratio = max_sub_count / habitat_counter.get(overarching_habitat)
            data.at[index, '_score'] += score_ratio
        # Only overaching, no sub habitats, no score addition

    # Leaf Texture
    flattened_leaf_textures = data['Leaf Texture'].explode() # Remove list
    leaf_texture_counts = flattened_leaf_textures.value_counts()
    total_count = len(data)
    leaf_texture_ratios =  (1 - (leaf_texture_counts / total_count))/2
    data['_score'] += data['Leaf Texture'].apply(lambda x: sum(leaf_texture_ratios.get(texture, 0) for texture in x))

    return data.sort_values(by='_score', ascending=False) # Descending Score 

    
def extract_overarching_subhabitat(habitat):
    # Get overaching and sub habitats
    overarching_habitat = habitat[:habitat.index("(")].strip()
    sub_habitat = habitat[habitat.index("(")+1:habitat.index(")")]
    # Split sub habitats
    if "," in sub_habitat:
        sub_habitat_list = sub_habitat.split(",")
    else:
        sub_habitat_list = [sub_habitat]
    
    return overarching_habitat, sub_habitat_list

In [None]:
reranked_result = rerank_result(result, requirements)

Counter({'Terrestrial': 21, 'Terrestrial (Secondary Rainforest)': 10, 'Terrestrial (Primary Rainforest)': 7, 'Terrestrial (Riverine)': 5, 'Terrestrial (Freshwater Swamp Forest)': 4, 'Terrestrial (Monsoon Forest)': 3, 'Terrestrial (Grassland / Savannah/ Scrubland)': 2, 'Terrestrial (Coastal Forest)': 2})


### Selecting Data

In [1]:
import math

def select_data(reranked_df, other_requirements):
    num_native_species = math.ceil(other_requirements['Ratio Native'] * other_requirements['Maximum Plant Count'])
    light_requirements = other_requirements['Light Preference']
    # Full Sun + Full Shade
    if light_requirements != 'Full Shade':
        # 1-5 plants: 1 shade loving, 6 will be 2 shade loving (2,2,2)
        num_shade_shrubs = math.ceil(other_requirements['Maximum Plant Count'] / 5)
        # Remainder divided 2 or total num of trees frm db
        num_tree = min(math.floor((other_requirements['Maximum Plant Count']-num_shade_shrubs)/2), reranked_df['Plant Type'].apply(lambda x: 'Tree' in x or 'Palm' in x).sum())
        # Remainder
        num_shrub = other_requirements['Maximum Plant Count'] - num_shade_shrubs - num_tree

    # Light preference is already full Shade
    else:
        num_shade_shrubs = 0
        # Divided by 2 or total num of trees in db
        num_tree = min(math.floor((other_requirements['Maximum Plant Count'])/2), reranked_df['Plant Type'].apply(lambda x: 'Tree' in x or 'Palm' in x).sum())
        # Remainder
        num_shrub = other_requirements['Maximum Plant Count'] - num_tree

    # Priority list
    # 1. native species, 2. number of trees, shade, shrubs
    native_df = reranked_df[reranked_df['Native to SG'] == True]
    non_native_df = reranked_df[reranked_df['Native to SG'] == False]

    # Retrieving Native By Plant Type
    native_tree = native_df[native_df['Plant Type'].apply(lambda x: any(pt in ['Tree', 'Palm'] for pt in x))]
    native_shrub = native_df[native_df['Plant Type'].apply(lambda x: 'Shrub' in x)]
    native_shade_shrub = native_df[(native_df['Plant Type'].apply(lambda x: 'Shrub' in x)) &
                                    (native_df['Light Preference'].apply(lambda x: 'Full Shade' in x))]

    # Retrieving Non-Native By Plant Type
    non_native_tree = non_native_df[non_native_df['Plant Type'].apply(lambda x: any(pt in ['Tree', 'Palm'] for pt in x))]
    non_native_shrub = non_native_df[non_native_df['Plant Type'].apply(lambda x: 'Shrub' in x)]
    non_native_shade_shrub = non_native_df[(non_native_df['Plant Type'].apply(lambda x: 'Shrub' in x)) &
                                           (non_native_df['Light Preference'].apply(lambda x: 'Full Shade' in x))]

    # Not enough native plants, all native_df will go into result
    if len(native_df) <= num_native_species:
        # Calculate required trees, shrubs and shade_shrub
        # There is a bug for when the required are negative, the overall may be larger than the actual results...
        tree_required = max(num_tree - len(native_tree), 0)
        shrub_required = max(num_shrub - len(native_shrub), 0)
        shade_shrub_required = max(num_shade_shrubs - len(native_shade_shrub), 0)

        result = native_df
        # Sort based on the length of each dataframe
        non_native_plants = [[non_native_tree, tree_required], [non_native_shrub, shrub_required], [non_native_shade_shrub, shade_shrub_required]]
        sorted_non_native_plants = sorted(non_native_plants, key=lambda x: len(x[0]))

        diff = 0
        for index, [non_native_df, required_count] in enumerate(sorted_non_native_plants):
            if index == 0:
                best_k_data = non_native_df.head(required_count)
                result = pd.concat([result, best_k_data], ignore_index=True)
                if len(best_k_data) < required_count:
                    diff = len(best_k_data) - required_count
            
            elif index == 1:
                # Get 1/2 of the difference from prev
                new_required_count = required_count + diff//2
                best_k_data = non_native_df.head(new_required_count)
                result = pd.concat([result, best_k_data], ignore_index=True)

            else:
                # Get all the remaining
                new_required_count = other_requirements['Maximum Plant Count'] - len(result)
                best_k_data = non_native_df.head(new_required_count)
                result = pd.concat([result, best_k_data], ignore_index=True)

        return result
        
    if len(native_df) > num_native_species:
        # Extract num_trees, num_shade, num_shade_shrub data from native and non-native
        best_native_tree = native_tree.head(num_tree)
        best_non_native_tree = non_native_tree.head(num_tree)
        best_native_shrub = native_shrub.head(num_shrub)
        best_non_native_shrub = non_native_shrub.head(num_shrub)
        best_native_shade_shrub = native_shade_shrub.head(num_shade_shrubs)
        best_non_native_shade_shrub = non_native_shade_shrub.head(num_shade_shrubs)

        total_dataset = pd.concat([best_native_tree, best_non_native_tree, best_native_shrub, best_non_native_shrub, best_native_shade_shrub, best_non_native_shade_shrub], ignore_index=True)
        sorted_total_dataset = total_dataset.sort_values(by='_score', ascending=False)

        result = pd.DataFrame(columns=sorted_total_dataset.columns)
        unused_data = pd.DataFrame(columns=sorted_total_dataset.columns)
        num_native = 0
        num_non_native = 0
        current_trees = 0
        current_shrubs = 0
        current_shade_shrubs = 0

        count_dictionary = {
            "Tree": [current_trees, num_tree],
            "Shrub": [current_shrubs, num_shrub],
            "Shrub Full Shade": [current_shade_shrubs, num_shade_shrubs],
        }

        for _, row in sorted_total_dataset.iterrows():
            plant_type_key = ["Tree", "Palm", "Shrub"]
            plant_key = [plant for plant in row['Plant Type'] if plant in plant_type_key][0]
            full_shade = True if ('Full Shade' in row['Light Preference'] and other_requirements['Light Preference'] != 'Full Shade') else False

            if row['Native to SG'] and num_native < num_native_species:
                if plant_key == "Shrub" and full_shade and count_dictionary["Shrub Full Shade"][0] < count_dictionary["Shrub Full Shade"][1]:
                    current_shade_shrubs += 1
                    num_native += 1
                    result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                
                elif plant_key == "Shrub" and count_dictionary["Shrub"][0] < count_dictionary["Shrub"][1]:
                    current_shrubs += 1
                    num_native += 1
                    result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                
                elif plant_key == "Palm" or plant_key == "Tree" and count_dictionary["Tree"][0] < count_dictionary["Tree"][1]:
                    current_trees += 1
                    num_native += 1
                    result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                
            elif not row['Native to SG'] and num_non_native < (other_requirements['Maximum Plant Count'] - num_native_species):
                if plant_key == "Shrub" and full_shade and count_dictionary["Shrub Full Shade"][0] < count_dictionary["Shrub Full Shade"][1]:
                    current_shade_shrubs += 1
                    num_non_native += 1
                    result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                
                elif plant_key == "Shrub" and count_dictionary["Shrub"][0] < count_dictionary["Shrub"][1]:
                    current_shrubs += 1
                    num_non_native += 1
                    result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                
                elif plant_key == "Palm" or plant_key == "Tree" and count_dictionary["Tree"][0] < count_dictionary["Tree"][1]:
                    current_trees += 1
                    num_non_native += 1
                    result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)

            else:
                unused_data =  pd.concat([unused_data, pd.DataFrame([row])], ignore_index=True)

            if len(result) == other_requirements['Maximum Plant Count']:
                break
        
        if len(result) < other_requirements['Maximum Plant Count']:
            difference = other_requirements['Maximum Plant Count'] - len(result) 
            additional_data = unused_data.head(difference)
            result = pd.concat([result, additional_data], ignore_index=True)
        
        return result



In [81]:
print(requirements)
selected_result = select_data(reranked_result, requirements)

{'Maximum Plant Count': 8, 'Ratio Native': 0.7, 'Plant Type': ['Shrub', 'Tree'], 'Attracted Animals': [], 'Light Preference': 'Full Sun'}


In [82]:
requirements

{'Maximum Plant Count': 8,
 'Ratio Native': 0.7,
 'Plant Type': ['Shrub', 'Tree'],
 'Attracted Animals': [],
 'Light Preference': 'Full Sun'}

In [79]:
selected_result

Unnamed: 0,_score,Scientific Name,Common Name,Species ID,Link,Plant Type,Light Preference,Water Preference,Drought Tolerant,Native to SG,...,Hazard,Attracted Animals,Native habitat,Mature Leaf Colour,Young Flush Leaf Colour,Leaf Area Index,Growth Rate,Trunk Texture,Trunk Colour,Leaf Texture
0,4.667897,Sterculia macrophylla Vent.,Broad-leaved Sterculia,3138,https://www.nparks.gov.sg/florafaunaweb/flora/...,[Tree],[Full Sun],"[Lots of Water, Moderate Water]",False,True,...,,Bird-Attracting,"Terrestrial (Primary Rainforest, Secondary Rai...",Green,Red,3.0 (Tree - Intermediate Canopy),[Moderate],smooth,light grey,[N/A]
1,2.266846,Cratoxylum cochinchinense (Lour.) Blume,Derum Selunchor,2829,https://www.nparks.gov.sg/florafaunaweb/flora/...,"[Shrub, Tree]","[Full Sun, Semi Shade]",[Moderate Water],False,True,...,,"Butterfly Host Plant (Leaves, Associated with:...","Terrestrial (Secondary Rainforest, Primary Rai...",Green,Red,3.0 (Tree - Intermediate Canopy),[Moderate],"Peeling / Flaking / Papery, Smooth",Reddish-brown,[Medium]
2,2.243037,Aglaonema simplex (Blume) Blume,Malayan Sword,3740,https://www.nparks.gov.sg/florafaunaweb/flora/...,"[Herbaceous Plant, Shrub]","[Full Shade, Semi Shade]","[Lots of Water, Moderate Water]",False,True,...,,-,"Terrestrial (Primary Rainforest, Secondary Rai...",Green,,3.5 (Shrub & Groundcover - Monocot),[Slow],,,[Coarse]
3,2.10018,Horsfieldia irya (Gaertn.) Warb.,Pianggu,2964,https://www.nparks.gov.sg/florafaunaweb/flora/...,[Tree],[Full Sun],"[Lots of Water, Moderate Water]",False,True,...,,Bird-Attracting (Fruits),"Terrestrial (Primary Rainforest, Coastal Fores...",Green,,3.0 (Tree - Intermediate Canopy),[Moderate],"Fissured, Cracked",red,[N/A]
4,2.10018,Parkia speciosa Hassk.,Petai,3052,https://www.nparks.gov.sg/florafaunaweb/flora/...,[Tree],[Full Sun],[Moderate Water],False,True,...,,"Bird-Attracting, Butterfly Host Plant, Bat Food","Terrestrial (Primary Rainforest, Secondary Rai...",Green,,2.5 (Tree - Open Canopy),[Moderate],smooth,reddish-brown,[N/A]
5,1.623989,Oncosperma tigillarium,Nibung,2659,https://www.nparks.gov.sg/florafaunaweb/flora/...,[Palm],[Full Sun],[Moderate Water],False,True,...,Spines/Thorns - Trunk,-,Terrestrial,Green,,4.0 (Palm - Cluster),[Moderate],-,black,[N/A]
6,2.220807,Acrotrema costatum Jack,Yellow Jungle Star,3739,https://www.nparks.gov.sg/florafaunaweb/flora/...,"[Creeper, Herbaceous Plant, Shrub]","[Full Shade, Semi Shade]",[Moderate Water],False,False,...,,-,"Terrestrial (Secondary Rainforest, Monsoon For...","Green, Silver / Grey",,4.5 (Shrub & Groundcover - Dicot),[Moderate],,,[Coarse]
7,5.548068,Radermachera 'Kunming',Dwarf Tree Jasmine,2381,https://www.nparks.gov.sg/florafaunaweb/flora/...,[Shrub],"[Full Sun, Semi Shade]",[Moderate Water],False,False,...,,Butterfly-Attracting,Terrestrial,Green,"Green, Green - Light Green",4.5 (Shrub & Groundcover - Dicot),[Moderate],,,[Medium]


## Overall Code

In [None]:
# Import and variables
import collections
import os
from typing import Union, List, Dict
import pandas as pd

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan, streaming_bulk

# Map common python types to ES Types
TYPE_MAP = {
    "int": "integer",
    "float": "float",
    "double": "double",
    "str": "text",
    "bool": "boolean",
    "datetime": "date",
    "list[int]": "integer",
    "list[str]": "text",
    "list[float]": "float",
    "list[double]": "double",
    "torch.tensor": "dense_vector",
    "numpy.ndarray": "dense_vector",
    "keyword": "keyword"
}

MAX_BULK_SIZE = 100

os.environ['ELASTIC_USERNAME'] = 'elastic'
os.environ['ELASTIC_PASSWORD'] = '' # To be filled
os.environ['ELASTIC_PORT'] = '9200'
os.environ['ELASTIC_HOST'] = 'localhost'

class ESManager():
    """
    Class to manage ElaticSearch
    """
    def __init__(self):
        self.url = f"http://{os.getenv('ELASTIC_HOST')}:{os.getenv('ELASTIC_PORT')}"
        self.username =  os.getenv('ELASTIC_USERNAME')
        self.password = os.getenv('ELASTIC_PASSWORD')

        self.client = Elasticsearch(self.url,
                                    verify_certs=False,
                                    basic_auth=(self.username, self.password), request_timeout=30, max_retries=10, retry_on_timeout=True)

        print(self.client.info())

        self.consolidated_actions = []

    def _check_data_type(self, var, var_type):
        try:
            assert type(var) == var_type
        except:
            return False
        return True

    def _check_valid_values(self, map_dict: dict) -> int:
        """
        Traverse mapping dictionary to ensure that all types are valid types within TYPE_MAP

        Args:
            map_dict (dict): Mapping to be checked

        Returns:
            int: 0 if there is invalid types, 1 otherwise

        """
        ret_val = 1
        for k, v in map_dict.items():
            if isinstance(v, dict):
                ret_val = self._check_valid_values(v)
            else:
                if not v in TYPE_MAP:
                    print(f"'{v}' type for '{k}' NOT FOUND")
                    return 0

        return ret_val * 1

    def _traverse_map(self, map_dict: Dict) -> Dict:
        """
        Traverse mapping dictionary to convert data type into framework specific type

        Args:
            map_dict (dict): Mapping to be used to create ES index

        Returns:
            dict: updated mapping dictionary

        """
        dictionary = {"properties": dict()}
        for k, v in map_dict.items():
            if isinstance(v, dict):
                dictionary['properties'][k] = self._traverse_map(v)
            else:
                dictionary['properties'][k] = {"type": TYPE_MAP[v]}
        return dictionary

    def _flush(self):
        errors = []
        list_of_es_ids = []
        for ok, item in streaming_bulk(self.client, self.consolidated_actions):
            if not ok:
                errors.append(item)
            else:
                list_of_es_ids.append(item['index']['_id'])
        if len(errors) != 0:
            print("List of faulty documents:", errors)
        self.consolidated_actions = []  # Reset List
        return list_of_es_ids

    def _flatten(self, d, parent_key='', sep='.'):
        """
        Flatten nested dictionary keys to dotted parameters because Elasticsearch. 
        """
        items = []
        for k, v in d.items():
            new_key = parent_key + sep + k if parent_key else k
            if isinstance(v, collections.MutableMapping):
                items.extend(self._flatten(v, new_key, sep=sep).items())
            else:
                items.append((new_key, v))
        return dict(items)

    def create_collection(self, collection_name: str, schema: Dict, custom_schema: bool = False) -> Dict:
        """
        Create the index on ElasticSearch

        Args:
            collection_name (str): Index name of ES
            schema (dict): Mapping to be used to create ES index
            custom_schema (bool): If set to True, user may input schema that in accordance to ElasticSearch Mapping's format. The schema will not be parsed. 

        Returns:
            dict: response of error, or 200 if no errors caught

        """
        if not self._check_data_type(schema, dict):
            return {"response": "Type of 'schema' is not dict"}
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(custom_schema, bool):
            return {"response": "Type of 'custom_schema' is not bool"}
        if custom_schema:
            try:
                self.client.indices.create(
                    index=collection_name, mappings=schema)
            except Exception as e:
                return {"response": f"{e}"}
            return {"response": "200"}
        else:
            mapping_validity = self._check_valid_values(schema)
            if not mapping_validity:
                return {"response": "KeyError: data type not found in TYPE_MAP"}
            updated_mapping = self._traverse_map(schema)
            try:
                self.client.indices.create(
                    index=collection_name, mappings=updated_mapping)
            except Exception as e:
                return {"response": f"{e}"}
            return {"response": "200"}

    def delete_collection(self, collection_name: str) -> dict:
        """
        Create the index on ElasticSearch

        Args:
            collection_name (str): Index name of ES
            schema (dict): Mapping to be used to create ES index

        Returns:
            dict: response of error, or 200 if no errors caught

        """
        try:
            self.client.indices.delete(index=collection_name)
        except Exception as e:
            return {"response": f"{e}"}
        return {"response": "200"}

    def create_document(self, collection_name: str, documents: Union[list, dict], id_field: str = None) -> dict:
        """
        Upload document(s) in the specified index within ElasticSearch

        Args:
            collection_name (str): Index name of ES
            documents (dict, list): A dict of document objects to be ingested. A list of dict is accepted as well. 
            id_field (str, Optional): Specify the key amongst the document object to be the id field. If not specified, id will be generated by ES. 

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the ids of ingested document if no errors caught

        """
        if not self._check_data_type(documents, list):
            if not self._check_data_type(documents, dict):
                return {"response": "Type of 'documents' is not dict or a list"}
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not id_field is None:
            if not self._check_data_type(id_field, str):
                return {"response": "Type of 'id_field' is not str"}

        # If single document, wrap it in a list so it can be an iterable as it would be when a list of document is submitted
        if type(documents) == dict:
            documents = [documents]

        # If id_field is specified, verify that all documents possess the id_field.
        if id_field != None:
            for doc in documents:
                if not id_field in doc.keys():
                    print(
                        "Fix document, or set 'id_field' to None. No documents uploaded.")
                    return {"response": "Fix document, or set 'id_field' to None. No documents uploaded.",
                            "error_doc": doc}
                try:
                    doc[id_field] = str(doc[id_field])
                except Exception as e:
                    return {"response": "id cannot be casted to String type. No documents uploaded.",
                            "error_doc": doc}
        all_id = []
        for doc in documents:
            doc_copy = dict(doc)
            action_dict = {}
            action_dict['_op_type'] = 'index'
            action_dict['_index'] = collection_name
            if id_field != None:
                action_dict['_id'] = doc_copy[id_field]
                doc_copy.pop(id_field)
            action_dict['_source'] = doc_copy
            self.consolidated_actions.append(action_dict)
            if len(self.consolidated_actions) == MAX_BULK_SIZE:
                all_id = all_id+self._flush()

        all_id = all_id+self._flush()

        return {"response": "200", "ids": all_id}

    def delete_document(self, collection_name: str, doc_id: str) -> dict:
        """
        Delete document from index based on the specified document id. 

        Args:
            collection_name (str): Index name of ES
            doc_id (str): id of doc to be deleted

        Returns:
            dict: response of error along with the faulty document, or code 200 along with elastic API response

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(doc_id, str):
            return {"response": "Type of 'doc_id' is not str"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query={
                                           "match": {"_id": doc_id}})
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"Document '{doc_id}' not found!"}

        try:
            resp = self.client.delete(index=collection_name, id=doc_id)
        except Exception as e:
            return {"response": f"{e.__class__.__name__}. Document Deletion failed"}

        return {"response": "200", "api_resp": resp}

    def update_document(self, collection_name: str, doc_id: str, document: dict) -> dict:
        """
        Delete document from index based on the specified document id. 

        Args:
            collection_name (str): Index name of ES
            doc_id (str): id of doc to be updated
            document (dict): key and values of fields to be updated.

        Returns:
            dict: response of error along with the faulty document, or code 200 along with elastic API response

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(doc_id, str):
            return {"response": "Type of 'doc_id' is not str"}
        if not self._check_data_type(document, dict):
            return {"response": "Type of 'document' is not dict"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query={
                                           "match": {"_id": doc_id}})
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"Document '{doc_id}' not found, create document first"}

        try:
            for key in document.keys():

                q = {
                    "script": {
                        "source": f"ctx._source.{key}=params.infer",
                        "params": {
                            "infer": document[key]
                        },
                        "lang": "painless"
                    },
                    "query": {
                        "match": {
                            "_id": doc_id
                        }
                    }
                }
                resp = self.client.update_by_query(
                    body=q, index=collection_name)
        except Exception as e:
            return {"response": f"{e.__class__.__name__}. Document Update failed"}

        return {"response": "200", "api_resp": resp}

    def read_document(self, collection_name: str, doc_id: str) -> dict:
        """
        Read document from index based on the specified document id. 

        Args:
            collection_name (str): Index name of ES
            doc_id (str): id of doc to be read

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the retrieved document

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(doc_id, str):
            return {"response": "Type of 'doc_id' is not str"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query={
                                           "match": {"_id": doc_id}})
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"Document '{doc_id}' not found!"}

        doc_body = search_result['hits']['hits']

        return {"response": "200", "api_resp": doc_body}

    def query_collection(self, collection_name: str, field_value_dict: dict) -> dict:
        """
        Read document from index based on the specific key-value dictionary query. 

        Args:
            collection_name (str): Index name of ES
            field_value_dict (dict): A dictionary with the field to be queried as the key, and the value to be queried as the value of the dictionary. 
                                    example: {"field1":"query1", "field2", "query2"}

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the list of retrieved document

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(field_value_dict, dict):
            return {"response": "Type of 'field_value_dict' is not dict"}

        # Check for document's existence
        reorg_dict = {"bool":{
            "should":[]
            }
        }
        for field in field_value_dict:
            reorg_dict['bool']['should'].append({"match":{field:field_value_dict[field]}})

        search_result = self.client.search(index=collection_name, query=reorg_dict)
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"No documents found."}

        docs = search_result['hits']['hits']

        return {"response": "200", "api_resp": docs}

    def custom_query(self, collection_name: str, query: dict, size:int=10) -> dict:
        """
        Read document from index based on custom ES query syntax. 

        Args:
            collection_name (str): Index name of ES
            query (dict): Custom query for ES users who are familiar with the query format
            size (int): Number of results to return per query

        Returns:
            dict: response of error along with the faulty document, or code 200 along with the list of retrieved document

        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        if not self._check_data_type(query, dict):
            return {"response": "Type of 'field_value_dict' is not dict"}

        # Check for document's existence
        search_result = self.client.search(index=collection_name, query=query, size=size)
        result_count = search_result['hits']['total']['value']

        if result_count == 0:
            return {"response": f"No documents found."}

        docs = search_result['hits']['hits']

        return {"response": "200", "api_resp": docs}

    def get_all_documents(self, collection_name: str) -> dict:
        """
        Generator method to retrieve all documents within the index

        Args:
            collection_name (str): Index name of ES

        Returns:
            Generator Object: Iterable object containing all documents within index specified. 
        """
        if not self._check_data_type(collection_name, str):
            return {"response": "Type of 'collection_name' is not str"}
        docs_response = scan(self.client, index=collection_name, query={
                             "query": {"match_all": {}}})
        for item in docs_response:
            yield item

In [2]:
from fuzzywuzzy import fuzz, process  # for fuzzy matching
from nltk.stem import WordNetLemmatizer
import os
import json
from openai import OpenAI

class ESPlantQueryGenerator():
    """
    Class to Generate ElasticSearch Query from input prompt and options
    """
    def __init__(self, 
                  gpt_extract:bool=False, 
                  function_requirements:str='../src/input/function_requirements.json',
                  es_query_requirements:str='../src/input/es_query_requirements.json',
                  keyword_schema:str='../src/input/schema/keyword_schema.json', 
                  query_schema:str='../src/input/schema/query_schema.json'):
        
        """
        Args:
            gpt_extract (bool, Optional): option to use GPT to extract keywords, if not string matching will automatically be used. Default to False.
            function_requirements (str, Optional): Filepath to JSON with all the key:value pairs for different functions
            es_query_requirements (str, Optional): Filepath to JSON containing all database key and their importance level (should be adhered to / must be adhered to)
            keyword_schema (str, Optional): Filepath to GPT's JSON schema for keyword extraction
            query_schema (str, Optional): Filepath to GPT's JSON schema for query generation
        """
        
        # GPT setup
        self.gpt_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
        self.keyword_schema_path =  keyword_schema
        self.query_schema_path = query_schema

        # Keyword Extraction Setup
        self.data_list = {
            "function": ["Boundary",  "Playground", "Active Zone", "Central Precinct Garden", "Passive Zone", "Sitting Corner", "Pavilion", "Pergola", "Future Community Garden", "Butterfly Garden"],
            "style": ["Naturalistic", "Manicured",  "Meadow", "Ornamental", "Minimalist", "Formal", "Picturesque", "Rustic", "Plantation"],
            "surrounding": {
                "Boundary": "Road", 
                "Playground": "Walkway", 
                "Active Zone": "Walkway", 
                "Central Precinct Garden": "Walkway",
                "Passive Zone": "Walkway", 
                "Sitting Corner": "Walkway", 
                "Pavilion": "Walkway", 
                "Pergola": "Walkway", 
                "Future Community Garden": "Walkway", 
                "Butterfly Garden": "Walkway"
            }
        }
        self.extract_keyword = self._extract_gpt_keyword if gpt_extract else self._extract_string_matching_keyword

        # Query Creation Setup
        self.function_design_requirements = self._load_JSON(function_requirements)
        self.es_query_requirements = self._load_JSON(es_query_requirements)   

    def _load_JSON(self, filepath:str):
        """
        Function to load a JSON object

        Args:
            filepath (str): filepath to json
        """
        with open(filepath, 'r') as file:
            data = json.load(file)
        return data


    def _extract_string_matching_keyword(self, prompt:str):
        """
        Function to extract function, style and surrounding from prompt using string matching
        These keywords are already pre-defined in self.data_list

        Args:
            prompt (str): user prompt from api call

        Returns:
            result (dictionary): dictionary in the format of {function: x, style:x, surrounding:x}, x will be from self.data_list or None.
        """
        result = {}
        # Tokenise words
        input_words = prompt.split()

        # Get stems of input words
        lemmatizer = WordNetLemmatizer()
        base_words = [lemmatizer.lemmatize(word) for word in input_words]

        for key, targets in self.data_list.items():
            # Retrieve results for function and style
            if isinstance(targets, list):
                base_targets = [lemmatizer.lemmatize(word) for word in targets]
                match = process.extractOne(" ".join(base_words), base_targets, scorer=fuzz.token_set_ratio, score_cutoff=60) #retrieve best matching word
                result[key] = match[0] if match != None else None
            # Extract surrounding data
            else:
                result[key] = targets.get(result['function'], None)
        
        return result
    

    def _extract_gpt_keyword(self, prompt:str):
        """
        Function to extract function, style and surrounding from prompt using gpt
        These keywords are already pre-defined in self.data_list

        Args:
            prompt (str): user prompt from api call

        Returns:
            result (dictionary): dictionary in the format of {function: x, style:x, surrounding:x}, x will be from self.data_list or None.
        """
        # GPT keyword extraction variables
        system_instruction = f"""You are an information extraction model for a landscape architect.
        Given user queries, you are to extract out the function (the purpose of the landscape area) and the style (planting style) from the input.
        The extracted value must be in the possible function and styles, else assign None.

        All possible function:
        {self.data_list['function']}
        
        All possible style:
        {self.data_list['style']}
        """
        # Load gpt return schema
        response_format = self._load_JSON(self.keyword_schema_path)

        # Base result
        result = {
            'function': None, 
            'style': None
        }
        # Query GPT to retrieve function and style
        completion = self.gpt_client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": prompt}
            ],
            response_format=response_format,
        )
        # Convert string response to JSON
        response_json = json.loads(completion.choices[0].message.content)

        for key, targets in response_json.items():
            if targets != 'None'and targets != None:
                # Valid response, update result
                result[key] = targets

        result['surrounding'] = self.data_list['surrounding'].get(result['function'], None)
        
        return result 
    

    def _extract_query_values(self, prompt:str):
        """
        Function to extract elastic search query values from prompt using gpt
        Mainly the requirement of any specific plant type, attracted animals and colour

        Args:
            prompt (str): user prompt from api call

        Returns:
            gpt_response (dictionary): dictionary of GPT response, in the format in query_schema
        """
        
        # Query generation variables
        system_instruction = f"""You are a database querying model to retrieve the required plants for a landscape design.
        
        Using the user requirements, return the value for each database key. 
        If a key is invalid to the querying, let the value for that key be None.
        For colours, unless explicitedly mentioned should be for flowers, apply the colour to all colour fields.
        
        ---------------------
        Database keys and possible values:
        - Plant Type: list[str] (Palm, Herbaceous Plants)
        - Fruit Bearing: str (True, False, None)
        - Fragrant Plant: str (True, False, None)
        - Maximum Height (m): int 
        - Height: str (Tall, Short, None)
        - Flower Colour: list[str]
        - Attracted Animals: list[str] (Bird, Butterfly, Bee, Caterpillar Moth, Bat)
        - Avoid Animals: list[str] (Bird, Butterfly, Bee, Caterpillar Moth, Bat)
        - Mature Leaf Colour: list[str]
        - Young Flush Leaf Colour: list[str]
        - Trunk Colour: list[str]
        - Trunk Texture: list[str]
        - Leaf Texture: list[str] (Fine, Medium, Coarse)
        """
        # Load query return schema
        response_format = self._load_JSON(self.query_schema_path)
        # Extraction of Colours and Requirements
        completion = self.gpt_client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": prompt}
            ],
            response_format=response_format,
        )
        return json.loads(completion.choices[0].message.content)


    def generate_query(self, user_call:dict):
        """
        Function to generate elasticsearch query from user call

        Args:
            user_call (dict): api call input in the format of 
        {
            "prompt": str, 
            "maximum_plant_count": int(3-8),
            "light_preference":"Full Shade"/ "Semi Shade"/ "Full Sun",
            "water_preference": "Lots of Water"/ "Moderate Water"/ "Little Water"/ "Occassional Misting",
            "drought_tolerant": True / False,
            "fauna_attracted": ["Butterfly", "Bird", "Caterpillar Moth", "Bat", "Bee"],
            "ratio_native": float(0-1)
        } 

        Returns:
            function_style_surrounding_extraction (dictionary): extracted function, style and surrounding
            query (dictionary): custom query for elastic search
            rerank_requirements (dictionary): key requirements for the reranking algorithm       
        """
        # Base Query from the other option values
        query = {
            "bool": {
                "must": [],
                "must_not": [],
                "should": []
            }
        }
        
        # Option Data ----------------------------------------------------------------------------------------
        # Light Preference, add Full Shade as well
        query["bool"][self.es_query_requirements["Light Preference"]].append({
            "terms": {"Light Preference.keyword": list(set(["Full Shade", user_call['light_preference']]))}
        })

        # Water Preference
        query["bool"][self.es_query_requirements["Water Preference"]].append({
            "term": {"Water Preference.keyword": user_call['water_preference']}
        })

        # Drought Tolerant
        query["bool"][self.es_query_requirements["Drought Tolerant"]].append({
            "term": {"Drought Tolerant": user_call['drought_tolerant']}
        })

        # Prompt Data ----------------------------------------------------------------------------------------
        # Extracting Function, Style & Surrounding from prompt and retrieve requirements
        function_style_surrounding_extraction = self.extract_keyword(user_call['prompt'])
        function_requirements = self.function_design_requirements.get(function_style_surrounding_extraction['function'], {})
        
        # Extract colour and other requirements from GPT
        gpt_response_json = self._extract_query_values(user_call['prompt'])

        # Plant Types, add shrub and tree
        function_plant_requirements = function_requirements.get("Plant Type", [])
        all_plant_types = list(set(["Tree", "Shrub"] + function_plant_requirements + gpt_response_json["Plant Type"]))
        query["bool"][self.es_query_requirements["Plant Type"]].append({
            "terms": {"Plant Type.keywords": all_plant_types}
        })

        # Hazard, if no hazard dont need to add to query, else Hazard must be N/A (no hazard)
        hazard_requirements = function_requirements.get("Hazard", None)
        if hazard_requirements != None:
            query["bool"][self.es_query_requirements["Hazard"]].append({
                "term": {"Hazard": hazard_requirements}
            })

        # Other variables
        for key, items in gpt_response_json.items():
            # Ignore non querying terms
            if key == "Plant Type" or key == "Height":
                pass

            # Maximum height, Prioritise the gpt response > preset function requirements
            elif key == "Maximum Height (m)":
                function_value = function_requirements.get(key, 0)
                # If there is a height constraint
                if function_value != 0 or items != 0:
                    height_constraints = items if items != 0 else function_value
                    query["bool"][self.es_query_requirements[key]].append({
                        "range": {
                            "Maximum Height (m)": {
                                "lt": height_constraints
                            }
                        }
                    })

            # Boolean Terms, Prioritise the gpt response > preset function requirements
            elif key == "Fruit Bearing" or key == "Fragrant Plant":
                gpt_value = True if items == "True" else False if items == "False" else None
                function_value = function_requirements.get(key, None)

                # only add query if both are not None
                if gpt_value != None and function_value != None:
                    query["bool"][self.es_query_requirements[key]].append({
                        "term": {key: gpt_value if gpt_value != None else function_value}
                    })

            # Attracted Animals, need include option data
            elif key == "Attracted Animals":
                function_value = function_requirements.get(key, [])
                all_animal_values = list(set(items + function_value + user_call['fauna_attracted']))
                if len(all_animal_values) > 0:
                    string_value = " ".join(all_animal_values)
                    query["bool"][self.es_query_requirements[key]].append({
                        "match": {key: string_value}
                    })

            # Avoid Animals is under Attraced Animals
            elif key == "Avoid Animals":
                function_value = function_requirements.get(key, [])
                all_values = list(set(items + function_value))
                if len(all_values) > 0:
                    string_value = " ".join(all_values)
                    query["bool"][self.es_query_requirements[key]].append({
                        "match": {"Attracted Animals": string_value}
                    })

            else:
                # List of Str terms, convert to a long string for a match query
                function_value = function_requirements.get(key, [])
                all_values = list(set(items + function_value))
                if len(all_values) > 0:
                    string_value = " ".join(all_values)
                    query["bool"][self.es_query_requirements[key]].append({
                        "match": {key: string_value}
                    })

        # Reranking Requirements, just for easy access
        rerank_requirements = {
            "Maximum Plant Count": user_call['maximum_plant_count'],
            "Ratio Native": user_call['ratio_native'],
            "Plant Type": all_plant_types,
            "Attracted Animals": all_animal_values,
            "Light Preference": user_call['light_preference']
        }
        # Add Height if not none
        if gpt_response_json['Height'] != "None" and gpt_response_json['Height'] != None:
            rerank_requirements['Height'] = gpt_response_json['Height']

        return function_style_surrounding_extraction, query, rerank_requirements



In [3]:
import pandas as pd
import numpy as np
from collections import Counter
import math

class PlantSelectionModel():
    """
    Class that contains entire pipeline to generate plant palette from user requirements
    """
    def __init__(self,
                 es_manager:ESManager,
                 plant_query_generator:ESPlantQueryGenerator,
                 collection_name:str='flora'
                 ):
        
        """
        Args:
            es_manager (ESManager): esManager instance
            plant_query_generator (ESPlantQueryGenerator): ESPlantQueryGenerator instance
            collection_name (str, Optional): 
            gpt_extract (bool, Optional): option to use GPT to extract keywords, if not string matching will automatically be used. Default to False.
            function_requirements (str, Optional): Filepath to JSON with all the key:value pairs for different functions
            es_query_requirements (str, Optional): Filepath to JSON containing all database key and their importance level (should be adhered to / must be adhered to)
            keyword_schema (str, Optional): Filepath to GPT's JSON schema for keyword extraction
            query_schema (str, Optional): Filepath to GPT's JSON schema for query generation
        """
        self.query_generator = plant_query_generator
        self.es_manager = es_manager
        self.collection_name = collection_name


    def retrieve_results(self, query:dict, light_preference:str, maximum_plant_count:int):
        """
        Function to query elasticSearch database to retrieve filtered plants
        Checks that the amount of plants returned is sufficient, or if user query was invalid

        Args:
            query (dict): Custom elasticsearch query
            light_preference (str): user light preference
            maximum_plant_count (int): maximum plant count for user plant palette

        Returns:
            results (list): list of results from elasticSearch, else empty list for no results
        """
        result = self.es_manager.custom_query(self.collection_name, query, 9999)
        if result['response'] == '200':
            print(f"Retrieved {len(result['api_resp'])} results.")
            # Enough data to meet maximum plant count
            if len(result['api_resp']) >= maximum_plant_count:
                return result['api_resp']
            
            # Not enough data, need to ensure at least 3 plants
            elif len(result['api_resp']) < 3:
                return []
            
            else:
                data = pd.DataFrame([item['_source'] for item in result])
                tree_count = data[data['Plant Type'].isin(['Tree', 'Palm'])].shape[0] # Get tree count
                shade_count = data[(data['Plant Type'] == 'Shrub')].shape[0] # Get shrub count
                shade_shrub_count = data[(data['Plant Type'] == 'Shrub') & (data['Light Preference'] == 'Full Shade')].shape[0] # Get shade loving count
                
                # Require 1 full shade and 1 light preference plant
                if light_preference != 'Full Shade':
                    non_shade_shrub_count = shade_count - shade_shrub_count
                else:
                    non_shade_shrub_count = 1 #Just to pass the checker 

                # Minimum 1 tree/palm, 1 shrub that is shade loving, 1 shrub for whatever light preference
                if tree_count >= 1 and shade_shrub_count >= 1 and non_shade_shrub_count >=1 :
                    return result['api_resp']
                else:
                    return []
        
        # No documents found
        else:
            print(result)
            return []        
        

    def rerank_results(self, results:list, rerank_requirements:dict):
        """
        Function to rerank the scores of all results
        Reranking ElasticSearch results with the following:
        1. Height, if Height = Tall, rerank from tall to short, with +0.2 incremental score in each plant type
        2. Plant Type: If provided and not Shrub & Tree, +1 score wise
        3. Attracted Animals: If provided, +2 score wise
        4. Native Habitat: Calculate all unique instances of Native Habitat, add score of counts / total (we want same habitats to be boosted)
        5. Leaf Texture: Calculate all unique instance of Leaf Texture, add (1-ratio)/2 (we want different textures)

        Args:
            results (list): results from database
            rerank_requirements (dict): reranked requirements to aid with the reranking scoring 
                {        
                "Maximum Plant Count": int,
                "Ratio Native": float,
                "Plant Type": list[str],
                "Attracted Animals": list[str],
                "Light Preference": str,
                "Height": str
                }
        
        Returns:
            reranked_results (pd.Dataframe)
        """
        # Convert response into a pandas dataframe for easy access
        processed_data = [
            {**{'_score': item['_score']}, **item['_source']}
            for item in results
        ]
        data = pd.DataFrame(processed_data)

        # Height
        height_value = rerank_requirements.get('Height', None)
        if height_value != None:
            sorted_data = data.sort_values(by="Maximum Height (m)", ascending= (False if height_value == 'Tall' else True))
            for plant_type in [["Shrub"], ["Tree", "Palm"]]:
                # Filter rows where Plant Type contains the target type, then apply cumulative scoring
                mask = sorted_data["Plant Type"].apply(lambda x: any(plant_t in x for plant_t in plant_type))
                num_items = mask.sum()
                incremental_values = np.arange(num_items - 1, -1, -1) * 0.2
                # Count from the bottom +0.2 * num of value from bottom for specific class
                sorted_data.loc[mask, "_score"] += incremental_values
                data = sorted_data

        # Plant Type
        prompt_plant_type = rerank_requirements['Plant Type'].copy()
        prompt_plant_type.remove('Tree')
        prompt_plant_type.remove('Shrub')
        # Unique plant type exists
        if len(prompt_plant_type) > 0:
            mask = data["Plant Type"].apply(lambda x: any(plant_t in x for plant_t in prompt_plant_type))
            data.loc[mask, "_score"] += 1

        # Attracted Animals
        prompt_attracted_animals = rerank_requirements['Attracted Animals']
        if len(prompt_attracted_animals) > 0:
            mask = data['Attracted Animals'].str.contains('|'.join(prompt_attracted_animals), case=False) # Finding any matching string
            data.loc[mask, "_score"] += 2

        # Native Habitat
        # Counting all habitat count
        habitat_counter = Counter()
        for habitat in data['Native habitat']:
            if "(" in habitat:
                # Get overaching and sub habitats
                overarching_habitat, sub_habitat_list = self._extract_overarching_subhabitat(habitat)
                # Counter
                habitat_counter[overarching_habitat] += 1
                for sub_value in sub_habitat_list:
                    habitat_counter[f'{overarching_habitat} ({sub_value.strip()})'] += 1
            # Only overaching, no sub habitats
            else:
                habitat_counter[habitat.strip()] += 1

        # Applying score of habitat counts
        # If just overarching terrestial, ignore
        # Else we + the best score which is the maxcount / total count
        for index, row in data.iterrows():
            habitat = row['Native habitat']
            if "(" in habitat:
                # Get overaching and sub habitats
                overarching_habitat, sub_habitat_list = self._extract_overarching_subhabitat(habitat)
                # Get counter score
                sub_counts = [habitat_counter.get(f"{overarching_habitat} ({sub.strip()})", 0) for sub in sub_habitat_list]
                max_sub_count = max(sub_counts)
                score_ratio = max_sub_count / habitat_counter.get(overarching_habitat)
                data.at[index, '_score'] += score_ratio
            # Only overaching, no sub habitats, no score addition

        # Leaf Texture
        flattened_leaf_textures = data['Leaf Texture'].explode() # Remove list
        leaf_texture_counts = flattened_leaf_textures.value_counts()
        total_count = len(data)
        leaf_texture_ratios =  (1 - (leaf_texture_counts / total_count))/2
        data['_score'] += data['Leaf Texture'].apply(lambda x: sum(leaf_texture_ratios.get(texture, 0) for texture in x))

        return data.sort_values(by='_score', ascending=False) # Descending Score 
    
    
    def _extract_overarching_subhabitat(self,habitat:str):
        """
        Extract all subhabitat and overaching habitat from string

        Args:
            habitat (str): Native Habitat of Data, in the format of overarchingHabitat (subHabitat)
        
        Returns:
            overarching_habitat (str): Overarching Habitat
            sub_habitat_list (list): List of all possible subHabitat, separated by ,
        """
        # Get overaching and sub habitats
        overarching_habitat = habitat[:habitat.index("(")].strip()
        sub_habitat = habitat[habitat.index("(")+1:habitat.index(")")]
        # Split sub habitats
        if "," in sub_habitat:
            sub_habitat_list = sub_habitat.split(",")
        else:
            sub_habitat_list = [sub_habitat]
        
        return overarching_habitat, sub_habitat_list
    

    def select_palette(self, reranked_result:pd.DataFrame, light_requirements:str, native_ratio:float, maximum_plant_count:int):
        """
        Function to select plant palette from reranked results, following the native ratio and maximum plant count

        Args:
            reranked_result (pd.DataFrame): reranked results
            light_preference (str): user light preference
            native_ratio (float): user native preference ratio
            maximum_plant_count (int): maximum plant count for user plant palette         

        Returns:
            result (pd.DataFrame): _description_
        """
        # Calculate required values by native & plant type
        num_native_species = math.ceil(native_ratio * maximum_plant_count)
        # 2 different light requirements 
        if light_requirements != 'Full Shade':
            # 1-5 plants: 1 shade loving, 6 will be 2 shade loving (2,2,2)
            num_shade_shrubs = math.ceil(maximum_plant_count / 5)
            # Remainder divided 2 or total num of trees frm db
            num_tree = min(math.floor((maximum_plant_count-num_shade_shrubs)/2), reranked_result['Plant Type'].apply(lambda x: 'Tree' in x or 'Palm' in x).sum())
            # Remainder
            num_non_shade_shrubs = maximum_plant_count - num_shade_shrubs - num_tree

        # Light preference is already full Shade, num_non_shade_shrub = 0
        else:
            num_non_shade_shrubs = 0
            # Divided by 2 or total num of trees in db
            num_tree = min(math.floor((maximum_plant_count)/2), reranked_result['Plant Type'].apply(lambda x: 'Tree' in x or 'Palm' in x).sum())
            # Remainder
            num_shade_shrubs = maximum_plant_count - num_tree

        # Priority list
        # 1. native species, 2. number of trees, shade, shrubs
        native_df = reranked_result[reranked_result['Native to SG'] == True]
        non_native_df = reranked_result[reranked_result['Native to SG'] == False]

        # Retrieving Native By Plant Type
        native_tree = native_df[native_df['Plant Type'].apply(lambda x: any(pt in ['Tree', 'Palm'] for pt in x))]
        native_non_shade_shrub = native_df[native_df['Plant Type'].apply(lambda x: 'Shrub' in x) & 
                                           (non_native_df['Light Preference'].apply(lambda x: 'Full Shade' not in x))]
        native_shade_shrub = native_df[(native_df['Plant Type'].apply(lambda x: 'Shrub' in x)) &
                                        (native_df['Light Preference'].apply(lambda x: 'Full Shade' in x))]

        # Retrieving Non-Native By Plant Type
        non_native_tree = non_native_df[non_native_df['Plant Type'].apply(lambda x: any(pt in ['Tree', 'Palm'] for pt in x))]
        non_native_non_shade_shrub = non_native_df[non_native_df['Plant Type'].apply(lambda x: 'Shrub' in x) &
                                                   (non_native_df['Light Preference'].apply(lambda x: 'Full Shade' not in x))]
        non_native_shade_shrub = non_native_df[(non_native_df['Plant Type'].apply(lambda x: 'Shrub' in x)) &
                                                (non_native_df['Light Preference'].apply(lambda x: 'Full Shade' in x))]
        

        # Not enough native plants, all native_df will go into result
        if len(native_df) <= num_native_species:
            # Calculate required trees, shrubs and shade_shrub
            # There is a bug for when the required are negative, the overall may be larger than the actual results...
            tree_required = max(num_tree - len(native_tree), 0)
            non_shade_shrub_required = max(num_non_shade_shrubs - len(native_non_shade_shrub), 0)
            shade_shrub_required = max(num_shade_shrubs - len(native_shade_shrub), 0)

            result = native_df
            # Sort based on the length of each dataframe
            non_native_plants = [[non_native_tree, tree_required], [non_native_non_shade_shrub, non_shade_shrub_required], [non_native_shade_shrub, shade_shrub_required]]
            sorted_non_native_plants = sorted(non_native_plants, key=lambda x: len(x[0]))

            diff = 0
            for index, [non_native_plant_df, required_count] in enumerate(sorted_non_native_plants):
                if index == 0:
                    best_k_data = non_native_plant_df.head(required_count)
                    result = pd.concat([result, best_k_data], ignore_index=True)
                    if len(best_k_data) < required_count:
                        diff = len(best_k_data) - required_count
                
                elif index == 1:
                    # Get 1/2 of the difference from prev
                    new_required_count = required_count + diff//2
                    best_k_data = non_native_plant_df.head(new_required_count)
                    result = pd.concat([result, best_k_data], ignore_index=True)

                else:
                    # Get all the remaining
                    new_required_count = maximum_plant_count - len(result)
                    best_k_data = non_native_plant_df.head(new_required_count)
                    result = pd.concat([result, best_k_data], ignore_index=True)

            return result
        
        # More than enough native plants, prioritise scoring when retrieving data
        if len(native_df) > num_native_species:
            # Extract best num_trees, num_non_shade_shrub, num_shade_shrub data from native and non-native
            best_native_tree = native_tree.head(num_tree)
            best_non_native_tree = non_native_tree.head(num_tree)
            best_native_non_shade_shrub = native_non_shade_shrub.head(num_non_shade_shrubs)
            best_non_native_non_shade_shrub = non_native_non_shade_shrub.head(num_non_shade_shrubs)
            best_native_shade_shrub = native_shade_shrub.head(num_shade_shrubs)
            best_non_native_shade_shrub = non_native_shade_shrub.head(num_shade_shrubs)

            # Merge and rank by scoring
            total_dataset = pd.concat([best_native_tree, best_non_native_tree, best_native_non_shade_shrub, best_non_native_non_shade_shrub, best_native_shade_shrub, best_non_native_shade_shrub], ignore_index=True)
            sorted_total_dataset = total_dataset.sort_values(by='_score', ascending=False)

            result = pd.DataFrame(columns=sorted_total_dataset.columns)
            unused_data = pd.DataFrame(columns=sorted_total_dataset.columns)
            # Base variable to keep track of all the data when we start getting the results
            num_native = 0
            num_non_native = 0
            current_trees = 0
            current_non_shade_shrubs = 0
            current_shade_shrubs = 0

            # Before iteration, we check to ensure that is enough data, if there is only total 2 shade shrubs, we must add both regardless of native or not, assuming there is enough other data
            if len(best_native_tree) + len(best_non_native_tree) <= num_tree:
                result = pd.concat([result, best_native_tree, best_non_native_tree], ignore_index=True)
                num_native += len(best_native_tree)
                num_non_native += len(best_non_native_tree)
                current_trees +=  len(best_native_tree) + len(best_non_native_tree) 
            # Non Shade Shrubs
            if len(best_native_non_shade_shrub) + len(best_non_native_non_shade_shrub) <= num_non_shade_shrubs:
                result = pd.concat([result, best_native_non_shade_shrub, best_non_native_non_shade_shrub], ignore_index=True)
                num_native += len(best_native_non_shade_shrub)
                num_non_native += len(best_non_native_non_shade_shrub)
                current_non_shade_shrubs +=  len(best_native_non_shade_shrub) + len(best_non_native_non_shade_shrub)
            # Shade Shrubs
            if len(best_native_shade_shrub) + len(best_non_native_shade_shrub) <= num_shade_shrubs:
                result = pd.concat([result, best_native_shade_shrub, best_non_native_shade_shrub], ignore_index=True)
                num_native += len(best_native_shade_shrub)
                num_non_native += len(best_non_native_shade_shrub)
                current_shade_shrubs += len(best_native_shade_shrub) + len(best_non_native_shade_shrub)

            for _, row in sorted_total_dataset.iterrows():
                # Determine which plant type and light preference
                plant_type_key = ["Tree", "Palm", "Shrub"]
                plant_key = [plant for plant in row['Plant Type'] if plant in plant_type_key][0]
                full_shade = True if ('Full Shade' in row['Light Preference']) else False
                # Adding SG native data when the counts lower than required
                if row['Native to SG'] and num_native < num_native_species:
                    if plant_key == "Shrub" and full_shade and current_shade_shrubs < num_shade_shrubs:
                        current_shade_shrubs += 1
                        num_native += 1
                        result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                    
                    elif plant_key == "Shrub" and current_non_shade_shrubs < num_non_shade_shrubs:
                        current_non_shade_shrubs += 1
                        num_native += 1
                        result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                    
                    elif plant_key == "Palm" or plant_key == "Tree" and current_trees < num_tree:
                        current_trees += 1
                        num_native += 1
                        result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                    
                    else:
                        unused_data =  pd.concat([unused_data, pd.DataFrame([row])], ignore_index=True)

                # Adding non native data when the counts lower than required
                elif not row['Native to SG'] and num_non_native < (maximum_plant_count - num_native_species):
                    if plant_key == "Shrub" and full_shade and current_shade_shrubs < num_shade_shrubs:
                        current_shade_shrubs += 1
                        num_non_native += 1
                        result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                    
                    elif plant_key == "Shrub" and current_non_shade_shrubs < num_non_shade_shrubs:
                        current_non_shade_shrubs += 1
                        num_non_native += 1
                        result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)
                    
                    elif plant_key == "Palm" or plant_key == "Tree" and current_trees < num_tree:
                        current_trees += 1
                        num_non_native += 1
                        result = pd.concat([result, pd.DataFrame([row])], ignore_index=True)

                    else:
                        unused_data =  pd.concat([unused_data, pd.DataFrame([row])], ignore_index=True)

                # If not adding, add to unused data
                else:
                    unused_data =  pd.concat([unused_data, pd.DataFrame([row])], ignore_index=True)

                # Once hit the size, break 
                if len(result) == maximum_plant_count:
                    break

            # In the scenario there wasn't enough specific plant type data and the total count is lacking after iterating all the data
            # Add the best unused data to make up the difference
            if len(result) < maximum_plant_count:
                native_diff = num_native_species - num_native
                non_native_diff = maximum_plant_count - num_native_species - num_non_native
                # Select additional native plants
                native_additional = unused_data.loc[unused_data['Native to SG'] == True].head(native_diff)
                # Select additional non-native plants
                non_native_additional = unused_data.loc[unused_data['Native to SG'] == False].head(non_native_diff)
                # Not enough native data, get more non_native data
                if len(native_additional) < native_diff:
                    non_native_additional = unused_data.loc[unused_data['Native to SG'] == False].head(non_native_diff + (native_diff - len(native_additional)))
                # Not enough non native data, get more native data
                elif len(non_native_additional) < non_native_diff:
                    native_additional = unused_data.loc[unused_data['Native to SG'] == True].head(native_diff + (non_native_diff - len(non_native_additional)))

                result = pd.concat([result, native_additional, non_native_additional], ignore_index=True)
            
            return result


    def generate_plant_palette(self, user_call:dict):
        """
        Function to retrieve plant palette and all possible plants from user_call

        Args:
            user_call (dict): api call input in the format of 
            {
                "prompt": str, 
                "maximum_plant_count": int(3-8),
                "light_preference":"Full Shade"/ "Semi Shade"/ "Full Sun",
                "water_preference": "Lots of Water"/ "Moderate Water"/ "Little Water"/ "Occassional Misting",
                "drought_tolerant": True / False,
                "fauna_attracted": ["Butterfly", "Bird", "Caterpillar Moth", "Bat", "Bee"],
                "ratio_native": float(0-1)
            } 

        Return:
            result (dict): API call result
            {
            "style": "Naturalistic" / "Manicured" /  "Meadow" / "Ornamental" / "Minimalist" / "Formal" / "Picturesque" / "Rustic" / "Plantation"
            "surrounding": Road / Walkway,
            "plant_palette": [Species_ID],
            "all_plants": [
                {plant data from ElasticSearch}, ...
                ] 
            }
        """
        # Generate query from user_call
        extracted_function_style_surrounding, es_query, rerank_requirements = self.query_generator.generate_query(user_call)
        # Retrieve data from elasticSearch
        results = self.retrieve_results(es_query, rerank_requirements['Light Preference'], rerank_requirements['Maximum Plant Count'])
        if len(results) == 0:
            return {
                "plant_palette": [],
                "all_plants": []
            } 
        # Rerank results
        reranked_results = self.rerank_results(results, rerank_requirements)
        # Selected results
        selected_plant_palette = self.select_palette(reranked_results, rerank_requirements['Light Preference'], rerank_requirements['Ratio Native'], rerank_requirements['Maximum Plant Count'])
        # Retrieve all species ID from selected_plant_palette
        selected_ids = selected_plant_palette['Species ID'].to_list()[:rerank_requirements['Maximum Plant Count']]
        return {
            "style": extracted_function_style_surrounding['style'],
            "surrounding": extracted_function_style_surrounding['surrounding'],
            "plant_palette": selected_ids,
            "all_plants": [data['_source'] for data in results]
        }

In [8]:
api_call = {
    "prompt": "Create a manicured butterfly garden with soft beige and light green tones",
    "maximum_plant_count": 6,
    "light_preference": "Full Sun",
    "water_preference": "Moderate Water",
    "drought_tolerant": False,
    "fauna_attracted": [],
    "ratio_native": 0.5
}

In [9]:
es_manager = ESManager()
query_generator = ESPlantQueryGenerator()
plant_selection_model = PlantSelectionModel(es_manager, query_generator)
plant_selection_model.generate_plant_palette(api_call)

{'name': 'f9de9f21a6c3', 'cluster_name': 'docker-cluster', 'cluster_uuid': 'lQ9maNGqThyqboUYSwPNJQ', 'version': {'number': '8.10.1', 'build_flavor': 'default', 'build_type': 'docker', 'build_hash': 'a94744f97522b2b7ee8b5dc13be7ee11082b8d6b', 'build_date': '2023-09-14T20:16:27.027355296Z', 'build_snapshot': False, 'lucene_version': '9.7.0', 'minimum_wire_compatibility_version': '7.17.0', 'minimum_index_compatibility_version': '7.0.0'}, 'tagline': 'You Know, for Search'}
Retrieved 16 results.


  native_non_shade_shrub = native_df[native_df['Plant Type'].apply(lambda x: 'Shrub' in x) &
  result = pd.concat([result, best_native_non_shade_shrub, best_non_native_non_shade_shrub], ignore_index=True)
  unused_data =  pd.concat([unused_data, pd.DataFrame([row])], ignore_index=True)


{'style': 'Manicured',
 'surrounding': 'Walkway',
 'plant_palette': [3427, 2010, 3740, 3739, 2607, 3138],
 'all_plants': [{'Scientific Name': 'Excoecaria cochinchinensis',
   'Common Name': 'Chinese Croton',
   'Species ID': 2010,
   'Link': 'https://www.nparks.gov.sg/florafaunaweb/flora/2/0/2010',
   'Plant Type': ['Shrub'],
   'Light Preference': ['Full Sun', 'Semi Shade'],
   'Water Preference': ['Moderate Water'],
   'Drought Tolerant': False,
   'Native to SG': False,
   'Fruit Bearing': False,
   'Fragrant Plant': False,
   'Maximum Height (m)': 1.0,
   'Flower Colour': 'Green',
   'Hazard': 'Toxic Upon Ingestion, Irritant - Sap',
   'Attracted Animals': '-',
   'Native habitat': 'Terrestrial',
   'Mature Leaf Colour': 'Green, Purple',
   'Young Flush Leaf Colour': 'Red',
   'Leaf Area Index': '4.5 (Shrub & Groundcover - Dicot)',
   'Growth Rate': ['-'],
   'Trunk Texture': 'N/A',
   'Trunk Colour': 'N/A',
   'Leaf Texture': ['Medium'],
   'Canopy Radius': 'N/A'},
  {'Scientific 