In [1]:
import pandas as pd
from chromadb import Documents, EmbeddingFunction, Embeddings, PersistentClient, Collection
from langchain_huggingface import HuggingFaceEmbeddings


  from .autonotebook import tqdm as notebook_tqdm


# Load Datasets

In [2]:
test_dataset = pd.read_csv('./dataset/PURE_test.csv', usecols=['Requirement', 'Req/Not Req'])
test_dataset.rename(columns={'Requirement': 'text', 'Req/Not Req': 'label'}, inplace=True)
test_dataset.replace({'label': {'Req': True, 'Not Req': False}}, inplace=True)
test_dataset.head()

Unnamed: 0,text,label
0,System Initialization performs those functions...,True
1,"Whenever a power-on reset occurs, System Initi...",True
2,"As part of System Initialization , the Boot RO...",True
3,System Initialization shall [SRS014] initiate ...,True
4,System Initialization shall [SRS292] enable an...,True


In [3]:
validation_dataset = pd.read_csv('./dataset/PURE_valid.csv', usecols=['Requirement', 'Req/Not Req'])
validation_dataset.rename(columns={'Requirement': 'text', 'Req/Not Req': 'label'}, inplace=True)
validation_dataset.replace({'label': {'Req': True, 'Not Req': False}}, inplace=True)
validation_dataset.head()

Unnamed: 0,text,label
0,Any operation requiring the user to supply a f...,True
1,For any operation where the user is prompted t...,True
2,When collecting generated output files from HA...,True
3,"For example, given a transformation language p...",True
4,If X.tlp.parsed existed prior to executing the...,True


# Create the ChromaDB Client and Reader

In [4]:
class CustomEmbeddingFunction(EmbeddingFunction):
    def __init__(self, embedding_model: HuggingFaceEmbeddings):
        self.embedding_model = embedding_model

    def __call__(self, texts: Documents) -> Embeddings:
        return self.embedding_model.embed_documents(texts)

In [5]:
embedding_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-mpnet-base-v2"
)

In [6]:
chroma_client = PersistentClient(path="./chroma_data")

chroma_collection = chroma_client.get_or_create_collection(
    name="requirements_collection_mpnet_base_v2",
    embedding_function=CustomEmbeddingFunction(embedding_model=embedding_model),
)

In [7]:
# Test query
chroma_collection.query(
    query_texts=["Watchdog timers are used to detect and recover from malfunctions."],
    n_results=5,
)

{'ids': [['bef1b73c-0ac2-48cc-9e5b-bdf9823a434a',
   '4afb379b-8426-49de-a284-78c55994c3f6',
   '18f73123-9a39-4adb-a866-2ca60e2803b7',
   '05f41bf0-ed16-4306-9b32-64b3607d04f9',
   '611e6c04-4702-4d10-a349-844d248244b8']],
 'embeddings': None,
 'documents': [['The time critical functions include both control and supervision functions.',
   'Automatic procedures for detection of communication faults and for managing redun- dancy of system components shall be established. ',
   'System failures caused by user operation or network malfunctions have to be avoided. ',
   'Monitoring, troubleshooting, and controlling services, ports, and application programming interfaces.',
   'For support of service providers that have different needs for timers available for porting, two types of timers have been defined in the NPAC SMS. ']],
 'uris': None,
 'included': ['metadatas', 'documents', 'distances'],
 'data': None,
 'metadatas': [[{'is_req': True},
   {'is_req': True},
   {'is_req': True},
   {

## Function for Determining the Classification Scores

In [13]:
def get_range_vote_for_query(query: str, collection: Collection, k: int = 5) -> bool:
    """
    Implements weighted KNN to classify a query as requirement or not.
    Uses distance-weighted voting where closer neighbors have more influence.
    
    Args:
        query: The text to classify
        collection: ChromaDB collection containing labeled requirements
        k: Number of nearest neighbors to retrieve for voting
        
    Returns:
        bool: True if classified as requirement, False otherwise
    """
    # Query the vector database for k nearest neighbors
    query_result = collection.query(
        query_texts=[query],
        n_results=k,
    )
    
    # Extract distances and metadata
    distances = query_result["distances"][0]
    metadatas = query_result["metadatas"][0]
    
    # Weighted KNN: convert distances to similarity weights
    # Using 1/(1+distance) to convert distance to similarity weight
    similarities = [1 / (1 + dist) for dist in distances]
    
    # Accumulate weighted votes for each class
    req_score = 0.0
    not_req_score = 0.0
    
    for similarity, metadata in zip(similarities, metadatas):
        if metadata["is_req"]:
            req_score += similarity
        else:
            not_req_score += similarity
    
    # Return the class with higher weighted score
    return req_score > not_req_score


In [18]:
def get_confusion_matrix(
    true_values: list[str], true_labels: list[bool], vector_collection: Collection, k: int = 5
):
    """
    Generate confusion matrix using weighted KNN classification.
    Uses distance-weighted voting where closer neighbors have more influence.
    Uses batch querying for efficiency.
    
    Args:
        true_values: List of text samples to classify
        true_labels: List of ground truth labels
        vector_collection: ChromaDB collection to query
        k: Number of nearest neighbors for voting (default: 5)
        
    Returns:
        Tuple of (true_positive, false_positive, false_negative, true_negative)
    """
    # Batch query all texts at once
    batch_results = vector_collection.query(
        query_texts=true_values,
        n_results=k,
    )
    
    # Initialize counters
    true_positive = 0
    false_positive = 0
    true_negative = 0
    false_negative = 0
    
    # Process each query's results
    for i in range(len(true_values)):
        distances = batch_results["distances"][i]
        metadatas = batch_results["metadatas"][i]
        
        # Weighted KNN: convert distances to similarity weights
        # Using 1/(1+distance) to give more weight to closer neighbors
        similarities = [1 / (1 + dist)**2 for dist in distances]
        
        # Accumulate weighted votes for each class
        req_score = 0.0
        not_req_score = 0.0
        
        for similarity, metadata in zip(similarities, metadatas):
            if metadata["is_req"]:
                req_score += similarity
            else:
                not_req_score += similarity
        
        # Predict based on higher weighted score
        predicted_label = req_score > not_req_score
        actual_label = true_labels[i]
        
        # Manually calculate confusion matrix values
        if predicted_label == True and actual_label == True:
            true_positive += 1
        elif predicted_label == True and actual_label == False:
            false_positive += 1
        elif predicted_label == False and actual_label == False:
            true_negative += 1
        elif predicted_label == False and actual_label == True:
            false_negative += 1

    return true_positive, false_positive, false_negative, true_negative


In [19]:
def calculate_metrics(true_positive: int, true_negative: int, false_positive: int, false_negative: int) -> dict:
    """
    Calculate classification metrics from confusion matrix values.
    
    Args:
        true_positive: Number of true positives
        true_negative: Number of true negatives
        false_positive: Number of false positives
        false_negative: Number of false negatives
        
    Returns:
        Dictionary containing accuracy, precision, recall, and f1_score
    """
    # Calculate total samples
    total = true_positive + true_negative + false_positive + false_negative
    
    # Accuracy: (TP + TN) / Total
    accuracy = (true_positive + true_negative) / total if total > 0 else 0.0
    
    # Precision: TP / (TP + FP)
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0.0
    
    # Recall (Sensitivity): TP / (TP + FN)
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0.0
    
    # F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score
    }


# Test Dataset Performance

In [20]:
true_positive, false_positive, false_negative, true_negative = get_confusion_matrix(
    true_values=test_dataset["text"].tolist(),
    true_labels=test_dataset["label"].tolist(),
    vector_collection=chroma_collection,
    k=10,
)

print(
    f"TP: {true_positive}, TN: {true_negative}, FP: {false_positive}, FN: {false_negative}"
)

# Calculate and display metrics
metrics = calculate_metrics(
    true_positive, true_negative, false_positive, false_negative
)
print(f"\nMetrics:")
print(f"Accuracy:  {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall:    {metrics['recall']:.4f}")
print(f"F1 Score:  {metrics['f1_score']:.4f}")

TP: 561, TN: 0, FP: 0, FN: 497

Metrics:
Accuracy:  0.5302
Precision: 1.0000
Recall:    0.5302
F1 Score:  0.6930


# Validation Dataset Performance

In [21]:
true_positive, false_positive, false_negative, true_negative = get_confusion_matrix(
    true_values=validation_dataset["text"].tolist(),
    true_labels=validation_dataset["label"].tolist(),
    vector_collection=chroma_collection,
    k=10,
)

print(
    f"TP: {true_positive}, TN: {true_negative}, FP: {false_positive}, FN: {false_negative}"
)

# Calculate and display metrics
metrics = calculate_metrics(
    true_positive, true_negative, false_positive, false_negative
)
print(f"\nMetrics:")
print(f"Accuracy:  {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall:    {metrics['recall']:.4f}")
print(f"F1 Score:  {metrics['f1_score']:.4f}")

TP: 181, TN: 0, FP: 0, FN: 74

Metrics:
Accuracy:  0.7098
Precision: 1.0000
Recall:    0.7098
F1 Score:  0.8303
