In [2]:
from collections import Counter, defaultdict

import numpy as np
import torch
from pymilvus.client.types import ExtraList
from success_prediction.rag_components.embeddings import EmbeddingHandler

from success_prediction.config import DATA_DIR

from success_prediction.scraper.crawler_config import value_keywords, esg_keywords, team_keywords, product_keywords

VALUE_PROPOSITION_KEYWORDS = list({k for k in value_keywords + esg_keywords})
LEADERSHIP_KEYWORDS = list({k for k in team_keywords})
RESPONSIBILITY_KEYWORDS = list({k for k in product_keywords})

In [228]:
db_client = MilvusClient(uri=str(DATA_DIR / 'database' / 'websites.db'))
embedding_creator = EmbeddingHandler()

dim2query = {
    "Value Proposition & Innovation": [
        "What solutions, services, or products does the company provide to customers?",
        "What products and services does the company advertise on its website?",
        "Which innovative features or technologies are highlighted in the company's offerings?",
        "What benefits or outcomes does the company promise or deliver through its solutions, services, products, or platforms?",
        "How does the company differentiate its products or services from competitors, and what specific customer needs are addressed?"
    ],
    "Purpose & Responsibility": [
        "What is the stated mission, purpose, or long-term vision of the company?",
        "Which ethical, social, or environmental commitments does the company emphasize?",
        "Does the company value sustainability, diversity, inclusion, or in general ESG-related goals?",
        "What values or principles guide the company's operations and decisions?",
        "Does the company participate in any charitable initiatives, community outreach, or global impact programs?"
    ],
    "Leadership & People": [
        "Who are the founders or key leaders of the company, and what roles do they hold?",
        "What are the professional backgrounds or credentials of the company's executive team?",
        "Who makes up the leadership team, and how is the company structured in terms of people and roles?",
        "What experience or expertise does the management bring to the company?",
        "Are there biographies or personal stories of team members or executives available on the website?"
    ]
}

# Convert to list
dim2embedding = {dimension: [embedding_creator.embed([q], prefix='query:') for q in queries] for dimension, queries in dim2query.items()}

[EmbeddingHandler] Using model on `mps`.


In [308]:
index_params = [{
    'field_name': 'embedding',
    'metric_type': 'IP',  # Use inner product because E5 embeddings are normalized (||v|| = 1)
    'index_type': 'FLAT',
    # 'params': {'nlist': 1024}
}]

db_client.create_index(
    collection_name='current_websites',
    index_params=index_params,
)

db_client.load_collection(collection_name='current_websites')

In [238]:
def ensemble_top_passages(company_data, query_embeddings, top_k_per_query=15, final_top_k=15):
    """
    Returns the passages that appear most frequently in the top-k across query ensemble.

    Parameters:
        company_data: List of dicts with fields 'id', 'embedding', 'text', 'url'
        query_embeddings: List of embedded queries for the dimension
        top_k_per_query: Number of top results to keep for each query
        final_top_k: Final number of consensus passages to return

    Returns:
        List of dicts: top passages by consensus frequency
    """
    passage_scores = defaultdict(list)

    for query_vec in query_embeddings:
        query_vec = np.array(query_vec[0])
        scored_entries = []

        for entry in company_data:
            score = np.dot(query_vec, entry['embedding'])
            scored_entries.append({**entry, 'score': score})

        top_k = sorted(scored_entries, key=lambda x: x['score'], reverse=True)[:top_k_per_query]

        for passage in top_k:
            passage_scores[passage['id']].append((passage['score'], passage))

    frequency_counter = Counter({pid: len(scores) for pid, scores in passage_scores.items()})

    # Sort by frequency (how often in top 15), then by best score (descending)
    sorted_passages = sorted(
        passage_scores.items(),
        key=lambda x: (frequency_counter[x[0]], max(s[0] for s in x[1])),
        reverse=True
    )

    final_passages = [x[1][0][1] for x in sorted_passages[:final_top_k]]
    return final_passages


def ensemble_rerank(top_n_entries, query_texts):
    """
    """
    for entry in top_n_entries:
        pairs = [(query, entry['text']) for query in query_texts]
        relevancy_score = embedding_creator.calculate_relevancy_scores(sentence_pairs=pairs).median()
        entry.update({'attention_score': relevancy_score})

    sorted_entries = sorted(
        top_n_entries,
        key=lambda entry: (float(entry['attention_score'])),
        reverse=True
    )
    scores = np.array([entry['attention_score'] for entry in sorted_entries])
    z_scores = (scores - np.mean(scores)) / np.std(scores)
    return [entry for z_score, entry in zip(z_scores, sorted_entries) if z_score >= 0]

def get_dimension_vec(dimension: str, company_data: ExtraList, dim2embedding: dict, dim2query: dict):
    """
    """
    # Get top 15 based on cosine / IP similarity
    top_15 = ensemble_top_passages(
        company_data=company_data,
        query_embeddings=dim2embedding[dimension]
    )

    # Get the most relevant by reranking them via cross encoder
    most_relevant = ensemble_rerank(
        top_n_entries=top_15,
        query_texts=dim2query[dimension]
    )
    
    # combine the remaining into one vector by using the quasi attention score from the ensemble rerank
    dim_vec = embedding_creator.waggregate_embeddings([torch.tensor(entry['embedding']) for entry in most_relevant], [entry['attention_score'] for entry in most_relevant])
    return most_relevant, dim_vec

In [242]:
vec_results = []
for ehraid in [1251382, 1433629]:
    company_data = db_client.query(collection_name='current_websites', filter=f"ehraid == {ehraid}")
    dim_vectors = {}
    for dim in dim2query.keys():
        dim_vectors[dim] = {}
        most_relevant, dim_vec = get_dimension_vec(dim, company_data, dim2embedding, dim2query)
        dim_vectors[dim]['entries'] = most_relevant
        dim_vectors[dim]['vectors'] = dim_vec
    vec_results.append({ehraid: dim_vectors})

In [260]:
vp_embeddings = torch.stack([values['Value Proposition & Innovation']['vectors'] for entry in vec_results for values in entry.values()])
pr_embeddings = torch.stack([values['Purpose & Responsibility']['vectors'] for entry in vec_results for values in entry.values()])
lt_embeddings = torch.stack([values['Leadership & People']['vectors'] for entry in vec_results for values in entry.values()])

In [None]:
# whitening without dimensionality reduction
vp_whitened = embedding_creator.whitening_k(embeddings=vp_embeddings)
pr_whitened = embedding_creator.whitening_k(embeddings=pr_embeddings)
lt_whitened = embedding_creator.whitening_k(embeddings=lt_embeddings)

# whitening with dimensionality reduction
vp_whitened_red = embedding_creator.whitening_k(embeddings=vp_embeddings, k=256)
pr_whitened_red = embedding_creator.whitening_k(embeddings=pr_embeddings, k=256)
lt_whitened_red = embedding_creator.whitening_k(embeddings=lt_embeddings, k=256)

In [317]:
results = db_client.search(
    collection_name='current_websites',
    data = dim2embedding['Value Proposition & Innovation'][0],
    limit=5,
    output_fields=['ehraid', 'text', 'embedding']
)

In [319]:
[(result['distance'], result['entity']['text']) for result in results[0]]

[(0.8554950952529907, '# What your customers get'),
 (0.8554592132568359, '# our\n### Products & Services'),
 (0.846657931804657, '# Products & Services'),
 (0.846523642539978,
  '# Les solutions que nous\navons apportés**à nos clients**\n[IMAGE:]'),
 (0.8435544967651367, '# Services & solutions\n[IMAGE:]')]

In [311]:
for res in results:
    print(res)

[{'id': 458020495605245965, 'distance': 0.8554950952529907, 'entity': {'ehraid': 1404201, 'embedding': [0.00935094989836216, 0.06173628941178322, -0.012480995617806911, -0.014509416185319424, 0.023455725982785225, -0.03747225180268288, -0.031206220388412476, -0.03700743243098259, 0.009418120607733727, 0.05569257214665413, 0.030055789276957512, 0.0008119480917230248, 0.10840526968240738, 0.02062780037522316, -0.013503074645996094, -0.020758887752890587, 0.016581006348133087, -0.017259934917092323, 0.050341494381427765, 0.010214443318545818, 0.034114621579647064, -0.022402489557862282, 0.0285525843501091, -0.018894493579864502, 0.024753719568252563, -0.010103227570652962, -0.011464154347777367, 0.0056061106733977795, -0.03250887989997864, 0.06061761453747749, 0.051924750208854675, -0.008361113257706165, 0.007553390692919493, 0.04715234786272049, 0.041041210293769836, 0.07957469671964645, -0.017063690349459648, -0.03401299566030502, 0.024757005274295807, -0.014549962244927883, -0.02924067

In [295]:
results

data: ["[{'id': 458020495605245965, 'distance': 0.8554950952529907, 'entity': {}}, {'id': 458095152321040896, 'distance': 0.8554592132568359, 'entity': {}}, {'id': 458094982300670109, 'distance': 0.846657931804657, 'entity': {}}, {'id': 458097981354771799, 'distance': 0.846523642539978, 'entity': {}}, {'id': 458095582308940111, 'distance': 0.8435544967651367, 'entity': {}}]"]