In [1]:
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
import os
from dotenv import load_dotenv
import pandas as pd
import json

def create_skills_matrix(skill_list, distance_threshold=1.0):
    """
    Create a DataFrame showing which entities possess which skills based on ChromaDB similarity search.
    
    Args:
        skill_list (list): List of skills to query
        distance_threshold (float): Maximum distance to consider a skill match (default: 1.0)
        
    Returns:
        pandas.DataFrame: Matrix of entities and their skills
    """
    # Load environment variables
    load_dotenv()

    # Initialize ChromaDB client with persistence
    client = chromadb.PersistentClient(path="../entity_skills_db")

    # Initialize the OpenAI embedding function
    embedding_function = OpenAIEmbeddingFunction(
        api_key=os.getenv("OPENAI_API_KEY"),
        model_name="text-embedding-3-large"
    )

    try:
        # Get existing collection
        collection = client.get_collection(
            name="entity_skills",
            embedding_function=embedding_function
        )
        
        # Dictionary to store all results
        all_entity_skills = {}
        
        # Query each skill
        for skill in skill_list:
            results = collection.query(
                query_texts=[skill],
                n_results=1000,  # Large number to get all potential matches
                include=["documents", "metadatas", "distances"]
            )
            
            # Process results for this skill
            for entity_metadata, distance in zip(
                results['metadatas'][0],
                results['distances'][0]
            ):
                entity_id = entity_metadata.get('entity_name')
                
                # Initialize entity in dictionary if not present
                if entity_id not in all_entity_skills:
                    all_entity_skills[entity_id] = {skill: False for skill in skill_list}
                
                # Mark skill as True if distance is below threshold
                if distance < distance_threshold:
                    all_entity_skills[entity_id][skill] = True
        
        # Convert to DataFrame
        df = pd.DataFrame.from_dict(all_entity_skills, orient='index')
        
        # Reset index and rename it to entity_id
        df.index.name = 'entity_id'
        df.reset_index(inplace=True)
        
        return df

    except Exception as e:
        print(f"Error accessing collection: {str(e)}")
        raise

# Example usage
if __name__ == "__main__":
    # Example list of skills to query
    skills_to_query = [
        "Python Programming",
        "Data Analysis",
        "AWS",
        "Machine Learning"
    ]
    
    # Create the skills matrix
    skills_df = create_skills_matrix(skills_to_query)
    
    # Display the first few rows
    print("\nSkills Matrix:")
    # print(skills_df.head())
    
    # Display some summary statistics
    print("\nSkill Distribution:")
    for skill in skills_to_query:
        count = skills_df[skill].sum()
        total = len(skills_df)
        percentage = (count / total) * 100
        print(f"{skill}: {count} entities ({percentage:.1f}%)")

Number of requested results 1000 is greater than number of elements in index 9, updating n_results = 9
Number of requested results 1000 is greater than number of elements in index 9, updating n_results = 9
Number of requested results 1000 is greater than number of elements in index 9, updating n_results = 9
Number of requested results 1000 is greater than number of elements in index 9, updating n_results = 9



Skills Matrix:
      entity_id  Python Programming  Data Analysis    AWS  Machine Learning
0    Alice Chen                True          False   True             False
1  Bob Martinez                True          False   True             False
2    Carol Wong                True          False  False              True

Skill Distribution:
Python Programming: 3 entities (100.0%)
Data Analysis: 0 entities (0.0%)
AWS: 2 entities (66.7%)
Machine Learning: 1 entities (33.3%)
