Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import json
import os
import pickle
import numpy as np
from functools import lru_cache

app = Flask(__name__)

Expand All @@ -57,6 +59,41 @@
app.register_blueprint(cfg.SWAGGER_BLUEPRINT, url_prefix = cfg.SWAGGER_URL)
FRONT_LOG_FILE = 'front_log.json'

@lru_cache(maxsize=1)
def get_values_embedding_function():
"""
Getting the embedding function for the /values endpoint.
Cached to avoid reloading the model multiple times.

Returns:
Embedding function callable
"""
model_id, model_path = save_model.save_model()
return recommendation_handler.get_embedding_func(inference='local', model_id=model_path)

@lru_cache(maxsize=1)
def get_values_centroids():
"""
Getting the positive and negative value centroids for the /values endpoint.
Cached to avoid reloading the data multiple times.

Returns:
Dictionary with 'positive' and 'negative' centroid embeddings
"""
prompt_json = recommendation_handler.populate_json()
positive_category_centroid = {}
negative_category_centroid = {}

for category in prompt_json['positive_values']:
positive_category_centroid[category['label']] = np.array(category['centroid'])

for category in prompt_json['negative_values']:
negative_category_centroid[category['label']] = np.array(category['centroid'])

return {
'positive': positive_category_centroid,
'negative': negative_category_centroid
}

@app.route("/")
def index():
Expand Down Expand Up @@ -109,7 +146,7 @@ def get_thresholds():
@cross_origin()
def recommend_local():
model_id, _ = save_model.save_model()
prompt_json, _ = recommendation_handler.populate_json()
prompt_json = recommendation_handler.populate_json()
args = request.args
print("args list = ", args)
prompt = args.get("prompt")
Expand Down Expand Up @@ -166,6 +203,42 @@ def demo_inference():
return response
except:
return "Model Inference failed.", 500

@app.route("/values", methods=['GET'])
@cross_origin()
def get_values():
"""
Getting positive and negative values for a given prompt using cached embedding function and centroids for performance.
"""
args = request.args
prompt = args.get("prompt")

# validating input
if not prompt:
return jsonify({"error": "Missing required parameter: prompt"}), 400

if not isinstance(prompt, str):
return jsonify({"error": "Parameter 'prompt' must be a string"}), 400

if len(prompt.strip()) == 0:
return jsonify({"error": "Parameter 'prompt' cannot be empty"}), 400

try:
embedding_fn = get_values_embedding_function()
centroids = get_values_centroids()

values = recommendation_handler.get_values(
prompt,
centroids['positive'],
centroids['negative'],
embedding_fn
)

return jsonify(values)

except Exception as e:
logger.error(f'Error in /values endpoint: {str(e)}')
return jsonify({"error": "Internal server error processing prompt"}), 500

if __name__=='__main__':
debug_mode = os.getenv('FLASK_DEBUG', 'False').lower() in ['true', '1', 't']
Expand Down
79 changes: 79 additions & 0 deletions control/recommendation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,82 @@ def get_thresholds(
thresholds['remove_higher_threshold'] = round(remove_similarities_df.describe([.9]).loc['90%', 'similarity'], 1)

return thresholds

def get_values(
prompt,
positive_embeddings,
negative_embeddings,
embedding_fn = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Within the get_values function, please add a check for when the embedding_fn is None and create an embedding function. You can see get_thresholds and recommend_prompt to see how it is done.

I know is it not absolutely necessary currently. But its better to have this safeguard in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had asked for this confirmation earlier here: conversation

@santanavagner clarified we might not need a threshold yet as we are taking maximum similarity value every time. Also, I believe its good to have both in the root logic. In subsequent applications (multi-agent conversations lets say), we can have thresholds but this guarantees we get both values every time.

Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mystic-Slice
I have addressed the concerns, let me know if anything else needs to be changed.

( Just wanted to point out one thing. I was thinking of having a provision to clean up the cache - there are 2 options - manual API (additional) to clean it or having a TTL based implementation which would have required custom decorators. So, wanted your opinion on it before implementing.)

cc: @santanavagner, @cassiasamp

Copy link
Collaborator

@Mystic-Slice Mystic-Slice Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now, this is enough.
The cache works fine. The endpoint works quick with it. I see no need for cleanup as of now because the embeddings do not change mid-deployment. But maybe in the future.
We can add cleanup if the need arises. I'm just try to avoid unnecessarily complicating the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly the point.

thank you.

):
"""
Compute positive and negative value associations for each sentence in the input prompt.

Args:
prompt: Input prompt text.
positive_embeddings: Dictionary mapping positive value labels to centroid embeddings.
negative_embeddings: Dictionary mapping negative value labels to centroid embeddings.
embedding_fn: Function to generate embeddings from text.

Returns:
Dictionary containing sentences with their associated positive and negative values and similarity scores.
"""

if embedding_fn is None:
# using all-MiniLM-L6-v2 locally by default
embedding_fn = get_embedding_func('local', model_id='sentence-transformers/all-MiniLM-L6-v2')

sentences = split_into_sentences(prompt)

# bifurcating and filtering out empty sentences
sentences = [s for s in sentences if s.strip()]

values = {}
values["prompts"] = []

# returning if no valid sentences
if not sentences:
return values

# generating all sentence embeddings in a single call by batching all
sentence_embeddings = embedding_fn(sentences)
sentence_embeddings = np.array(sentence_embeddings)

# ensuring embeddings have correct shape - expanding embeddings of all sentences
if len(sentence_embeddings.shape) == 1:
sentence_embeddings = np.expand_dims(sentence_embeddings, axis=0)

# processing each sentence with its corresponding embedding
for idx, sentence in enumerate(sentences):

sentence_embedding = sentence_embeddings[idx]

max_similarity_positive = -1
positive_label = None
for label, centroid in positive_embeddings.items():
similarity = cosine_similarity(
np.expand_dims(sentence_embedding, axis=0),
np.array([centroid])
)[0, 0]
if similarity > max_similarity_positive:
max_similarity_positive = similarity
positive_label = label

max_similarity_negative = -1
negative_label = None
for label, centroid in negative_embeddings.items():
similarity = cosine_similarity(
np.expand_dims(sentence_embedding, axis=0),
np.array([centroid])
)[0, 0]
if similarity > max_similarity_negative:
max_similarity_negative = similarity
negative_label = label

values["prompts"].append({
"sentence": sentence,
"positive_value": {"label": positive_label, "similarity": float(max_similarity_positive)},
"negative_value": {"label": negative_label, "similarity": float(max_similarity_negative)}
})

return values

Loading