In [15]:
import pandas as pd
from chromadb import Documents, EmbeddingFunction, Embeddings, PersistentClient, Collection
from langchain_huggingface import HuggingFaceEmbeddings
from sklearn.metrics import confusion_matrix

# Load Datasets

In [16]:
test_dataset = pd.read_csv('./dataset/PURE_test.csv', usecols=['Requirement', 'Req/Not Req'])
test_dataset.rename(columns={'Requirement': 'text', 'Req/Not Req': 'label'}, inplace=True)
test_dataset.replace({'label': {'Req': True, 'Not Req': False}}, inplace=True)
test_dataset.head()

Unnamed: 0,text,label
0,System Initialization performs those functions...,True
1,"Whenever a power-on reset occurs, System Initi...",True
2,"As part of System Initialization , the Boot RO...",True
3,System Initialization shall [SRS014] initiate ...,True
4,System Initialization shall [SRS292] enable an...,True


In [17]:
validation_dataset = pd.read_csv('./dataset/PURE_valid.csv', usecols=['Requirement', 'Req/Not Req'])
validation_dataset.rename(columns={'Requirement': 'text', 'Req/Not Req': 'label'}, inplace=True)
validation_dataset.replace({'label': {'Req': True, 'Not Req': False}}, inplace=True)
validation_dataset.head()

Unnamed: 0,text,label
0,Any operation requiring the user to supply a f...,True
1,For any operation where the user is prompted t...,True
2,When collecting generated output files from HA...,True
3,"For example, given a transformation language p...",True
4,If X.tlp.parsed existed prior to executing the...,True


# Create the ChromaDB Client and Reader

In [18]:
class CustomEmbeddingFunction(EmbeddingFunction):
    def __init__(self, embedding_model: HuggingFaceEmbeddings):
        self.embedding_model = embedding_model

    def __call__(self, texts: Documents) -> Embeddings:
        return self.embedding_model.embed_documents(texts)

In [19]:
embedding_model = HuggingFaceEmbeddings(model_name="Qwen/Qwen3-Embedding-0.6B")

In [20]:
chroma_client = PersistentClient(path="./chroma_data")

chroma_collection = chroma_client.get_or_create_collection(
    name="requirements_collection_qwen3",
    embedding_function=CustomEmbeddingFunction(embedding_model=embedding_model),
)

In [29]:
# Test query
chroma_collection.query(
    query_texts=["Watchdog timers are used to detect and recover from malfunctions."],
    n_results=5,
)

{'ids': [['53e30d48-1277-4edd-bd5b-0bd4fb641a60',
   '9c7167b4-d011-459e-bba8-43ea45a38476',
   'fb71c5e9-e37a-4ec2-8e7b-18445374acb3',
   'b1b71423-c739-479d-ab7b-ca955ad817c8',
   'bf6acb83-c3ce-44bb-969e-14ebd94b04e5']],
 'embeddings': None,
 'documents': [['The time critical functions include both control and supervision functions.',
   'The two types are long and short timers.',
   'This for instance is the case for condition monitoring of components such as gearbox bearings.',
   'Monitoring, troubleshooting, and controlling server performance.',
   'Monitoring, troubleshooting, and controlling services, ports, and application programming interfaces.']],
 'uris': None,
 'included': ['metadatas', 'documents', 'distances'],
 'data': None,
 'metadatas': [[{'is_req': True},
   {'is_req': False},
   {'is_req': True},
   {'is_req': False},
   {'is_req': False}]],
 'distances': [[0.7377492189407349,
   0.780716061592102,
   0.8417319059371948,
   0.8508152961730957,
   0.852357387542724

# Debug: Check Collection Status

In [37]:
# Check if the collection has any data
collection_count = chroma_collection.count()
print(f"Total documents in collection: {collection_count}")

# If collection has data, check a sample
if collection_count > 0:
    sample = chroma_collection.peek(limit=3)
    print(f"\nSample metadata: {sample['metadatas']}")
    print(f"Sample documents: {sample['documents']}")
else:
    print("\n⚠️ WARNING: Collection is EMPTY! You need to run data ingestion first.")

Total documents in collection: 5306

Sample metadata: [{'is_req': True}, {'is_req': True}, {'is_req': True}]
Sample documents: ['The solution should provide detailed context-sensitive help material for all the possible actions and scenarios on all user interfaces in the application.', 'The help should be accessible to the users both in the offline and online mode.', 'The solution should provide an interface for the user to log any defects or enhancement requests on the application and track thereafter.']


In [38]:
# Test a single query to see what we actually get back
if collection_count > 0:
    test_query = test_dataset["text"].iloc[0]
    test_label = test_dataset["label"].iloc[0]
    
    print(f"Test query: {test_query}")
    print(f"True label: {test_label}")
    
    result = chroma_collection.query(
        query_texts=[test_query],
        n_results=5,
    )
    
    print(f"\nQuery results:")
    print(f"Distances: {result['distances'][0]}")
    print(f"Metadatas: {result['metadatas'][0]}")
    print(f"Documents: {result['documents'][0][:2]}")  # First 2 docs
    
    # Test the voting function
    predicted = get_range_vote_for_query(test_query, chroma_collection, k=5)
    print(f"\nPredicted label: {predicted}")
    print(f"Correct: {predicted == test_label}")

Test query: System Initialization performs those functions necessary to transform the hardware consisting of the FCP processors, network elements, and on-board I/O devices into a real time system executing tasks with fault tolerant message exchanges.
True label: True

Query results:
Distances: [0.8391311764717102, 0.9021110534667969, 0.9062187075614929, 0.9077016115188599, 0.9278020858764648]
Metadatas: [{'is_req': False}, {'is_req': True}, {'is_req': False}, {'is_req': True}, {'is_req': False}]
Documents: ['Identifies what is to be done by the system, what inputs should be transformed to what outputs, and what specific operations are required. ', 'Automatic procedures for detection of communication faults and for managing redun- dancy of system components shall be established. ']

Predicted label: False
Correct: False


## Function for Determining the Classification Scores

In [30]:
def get_range_vote_for_query(query: str, collection: Collection, k: int = 5) -> bool:
    """
    Implements Range/Score Voting to classify a query as requirement or not.
    Uses similarity-weighted voting from k nearest neighbors to minimize Bayesian Regret.
    
    Args:
        query: The text to classify
        collection: ChromaDB collection containing labeled requirements
        k: Number of nearest neighbors to retrieve for voting
        
    Returns:
        bool: True if classified as requirement, False otherwise
    """
    # Query the vector database for k nearest neighbors
    query_result = collection.query(
        query_texts=[query],
        n_results=k,
    )
    
    # Extract distances and metadata
    distances = query_result["distances"][0]
    metadatas = query_result["metadatas"][0]
    
    # Convert distances to similarity scores (lower distance = higher similarity)
    # Using 1/(1+distance) to convert distance to similarity weight
    similarities = [1 / (1 + dist) for dist in distances]
    
    # Accumulate weighted votes for each class
    req_score = 0.0
    not_req_score = 0.0
    
    for similarity, metadata in zip(similarities, metadatas):
        if metadata["is_req"]:
            req_score += similarity
        else:
            not_req_score += similarity
    
    # Return the class with higher weighted score
    return req_score > not_req_score


In [31]:
def get_confusion_matrix(
    true_values: list[str], true_labels: list[bool], vector_collection: Collection, k: int = 5
):
    """
    Generate confusion matrix using Range Voting classification.
    
    Args:
        true_values: List of text samples to classify
        true_labels: List of ground truth labels
        vector_collection: ChromaDB collection to query
        k: Number of nearest neighbors for voting (default: 5)
        
    Returns:
        Confusion matrix as a 2D array
    """
    y_true = []
    y_pred = []

    for i in range(len(true_values)):
        predicted_label = get_range_vote_for_query(
            query=true_values[i],
            collection=vector_collection,
            k=k
        )
        y_true.append(true_labels[i])
        y_pred.append(predicted_label)

    cm = confusion_matrix(y_true, y_pred, labels=[True, False])
    return cm


In [39]:
def calculate_metrics(true_positive: int, true_negative: int, false_positive: int, false_negative: int) -> dict:
    """
    Calculate classification metrics from confusion matrix values.
    
    Args:
        true_positive: Number of true positives
        true_negative: Number of true negatives
        false_positive: Number of false positives
        false_negative: Number of false negatives
        
    Returns:
        Dictionary containing accuracy, precision, recall, and f1_score
    """
    # Calculate total samples
    total = true_positive + true_negative + false_positive + false_negative
    
    # Accuracy: (TP + TN) / Total
    accuracy = (true_positive + true_negative) / total if total > 0 else 0.0
    
    # Precision: TP / (TP + FP)
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0.0
    
    # Recall (Sensitivity): TP / (TP + FN)
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0.0
    
    # F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score
    }


# Test Dataset Performance

In [None]:
true_negative, false_positive, false_negative, true_positive = (
    get_confusion_matrix(
        true_values=test_dataset["text"].tolist(),
        true_labels=test_dataset["label"].tolist(),
        vector_collection=chroma_collection,
        k=1,
    )
    .ravel()
    .tolist()
)

print(
    f"TP: {true_positive}, TN: {true_negative}, FP: {false_positive}, FN: {false_negative}"
)

# Calculate and display metrics
metrics = calculate_metrics(
    true_positive, true_negative, false_positive, false_negative
)
print(f"\nMetrics:")
print(f"Accuracy:  {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall:    {metrics['recall']:.4f}")
print(f"F1 Score:  {metrics['f1_score']:.4f}")

TP: 0, TN: 0, FP: 0, FN: 0

Metrics:
Accuracy:  0.0000
Precision: 0.0000
Recall:    0.0000
F1 Score:  0.0000
