In [1]:
import pandas as pd
from chromadb import (
    Documents,
    EmbeddingFunction,
    Embeddings,
    PersistentClient,
    Collection,
)
from langchain_huggingface import HuggingFaceEmbeddings

  from .autonotebook import tqdm as notebook_tqdm


# Load Datasets

In [2]:
test_dataset = pd.read_csv('./dataset/PURE_test.csv', usecols=['Requirement', 'Req/Not Req'])
test_dataset.rename(columns={'Requirement': 'text', 'Req/Not Req': 'label'}, inplace=True)
test_dataset.replace({'label': {'Req': True, 'Not Req': False}}, inplace=True)
test_dataset.head()

Unnamed: 0,text,label
0,System Initialization performs those functions...,True
1,"Whenever a power-on reset occurs, System Initi...",True
2,"As part of System Initialization , the Boot RO...",True
3,System Initialization shall [SRS014] initiate ...,True
4,System Initialization shall [SRS292] enable an...,True


In [3]:
validation_dataset = pd.read_csv('./dataset/PURE_valid.csv', usecols=['Requirement', 'Req/Not Req'])
validation_dataset.rename(columns={'Requirement': 'text', 'Req/Not Req': 'label'}, inplace=True)
validation_dataset.replace({'label': {'Req': True, 'Not Req': False}}, inplace=True)
validation_dataset.head()

Unnamed: 0,text,label
0,Any operation requiring the user to supply a f...,True
1,For any operation where the user is prompted t...,True
2,When collecting generated output files from HA...,True
3,"For example, given a transformation language p...",True
4,If X.tlp.parsed existed prior to executing the...,True


# Create the ChromaDB Client and Reader

In [4]:
class CustomEmbeddingFunction(EmbeddingFunction):
    def __init__(self, embedding_model: HuggingFaceEmbeddings):
        self.embedding_model = embedding_model

    def __call__(self, texts: Documents) -> Embeddings:
        return self.embedding_model.embed_documents(texts)

In [5]:
embedding_model = HuggingFaceEmbeddings(model_name="Qwen/Qwen3-Embedding-4B")

Loading checkpoint shards: 100%|██████████| 2/2 [01:08<00:00, 34.47s/it]


In [6]:
chroma_client = PersistentClient(path="./chroma_data")

chroma_collection = chroma_client.get_or_create_collection(
    name="requirements_collection_qwen3_4b",
    embedding_function=CustomEmbeddingFunction(embedding_model=embedding_model),
)

In [7]:
# Test query
chroma_collection.query(
    query_texts=["Watchdog timers are used to detect and recover from malfunctions."],
    n_results=5,
)

{'ids': [['91df9cd8-1883-49e0-84e0-add4fc6ac591',
   '5fffa911-8c90-46eb-af00-a83b5383c066',
   'f27942ee-dbeb-4f08-bafe-cf7a36417d47',
   'ca1b8ac4-fa54-48bb-86af-e50420320c34',
   '2efd59e6-58a6-451b-bdd5-fa77fc8304f8']],
 'embeddings': None,
 'documents': [['The following table illustrates the outcome of timers to be used based on the possible combinations:',
   'Monitoring, troubleshooting, and controlling server performance.',
   'The old service provider who supports the short timers will have to recognize that the long timers are being used instead of the expected short timers.',
   'Monitoring, troubleshooting, and controlling services, ports, and application programming interfaces.',
   'The two types are long and short timers.']],
 'uris': None,
 'included': ['metadatas', 'documents', 'distances'],
 'data': None,
 'metadatas': [[{'is_req': False},
   {'is_req': False},
   {'is_req': False},
   {'is_req': False},
   {'is_req': False}]],
 'distances': [[0.7440860271453857,
   0

## Function for Determining the Classification Scores

In [8]:
def get_range_vote_for_query(query: str, collection: Collection, k: int = 5) -> bool:
    """
    Implements a hybrid voting system to classify a query as requirement or not.
    First uses majority voting, then falls back to weighted Range Voting on ties.
    
    Args:
        query: The text to classify
        collection: ChromaDB collection containing labeled requirements
        k: Number of nearest neighbors to retrieve for voting
        
    Returns:
        bool: True if classified as requirement, False otherwise
    """
    # Query the vector database for k nearest neighbors
    query_result = collection.query(
        query_texts=[query],
        n_results=k,
    )
    
    # Extract distances and metadata
    distances = query_result["distances"][0]
    metadatas = query_result["metadatas"][0]
    
    # Count votes for majority voting
    req_votes = sum(1 for metadata in metadatas if metadata["is_req"])
    not_req_votes = len(metadatas) - req_votes
    
    # If there's a clear majority, return it
    if req_votes > not_req_votes:
        return True
    elif not_req_votes > req_votes:
        return False
    
    # Tie case: use weighted Range Voting based on distances
    # Convert distances to similarity scores (lower distance = higher similarity)
    # Using 1/(1+distance) to convert distance to similarity weight
    similarities = [1 / (1 + dist) for dist in distances]
    
    # Accumulate weighted votes for each class
    req_score = 0.0
    not_req_score = 0.0
    
    for similarity, metadata in zip(similarities, metadatas):
        if metadata["is_req"]:
            req_score += similarity
        else:
            not_req_score += similarity
    
    # Return the class with higher weighted score
    return req_score > not_req_score


In [9]:
def get_confusion_matrix(
    true_values: list[str], true_labels: list[bool], vector_collection: Collection, k: int = 5
):
    """
    Generate confusion matrix using hybrid voting classification.
    Uses majority voting first, then weighted Range Voting for ties.
    Uses batch querying for efficiency.
    
    Args:
        true_values: List of text samples to classify
        true_labels: List of ground truth labels
        vector_collection: ChromaDB collection to query
        k: Number of nearest neighbors for voting (default: 5)
        
    Returns:
        Tuple of (true_positive, false_positive, false_negative, true_negative)
    """
    # Batch query all texts at once
    batch_results = vector_collection.query(
        query_texts=true_values,
        n_results=k,
    )
    
    # Initialize counters
    true_positive = 0
    false_positive = 0
    true_negative = 0
    false_negative = 0
    
    # Process each query's results
    for i in range(len(true_values)):
        distances = batch_results["distances"][i]
        metadatas = batch_results["metadatas"][i]
        
        # Count votes for majority voting
        req_votes = sum(1 for metadata in metadatas if metadata["is_req"])
        not_req_votes = len(metadatas) - req_votes
        
        # Determine prediction
        if req_votes > not_req_votes:
            # Clear majority for requirement
            predicted_label = True
        elif not_req_votes > req_votes:
            # Clear majority for not requirement
            predicted_label = False
        else:
            # Tie case: use weighted Range Voting
            similarities = [1 / (1 + dist) for dist in distances]
            
            req_score = 0.0
            not_req_score = 0.0
            
            for similarity, metadata in zip(similarities, metadatas):
                if metadata["is_req"]:
                    req_score += similarity
                else:
                    not_req_score += similarity
            
            predicted_label = req_score > not_req_score
        
        actual_label = true_labels[i]
        
        # Manually calculate confusion matrix values
        if predicted_label == True and actual_label == True:
            true_positive += 1
        elif predicted_label == True and actual_label == False:
            false_positive += 1
        elif predicted_label == False and actual_label == False:
            true_negative += 1
        elif predicted_label == False and actual_label == True:
            false_negative += 1

    return true_positive, false_positive, false_negative, true_negative


In [10]:
def calculate_metrics(true_positive: int, true_negative: int, false_positive: int, false_negative: int) -> dict:
    """
    Calculate classification metrics from confusion matrix values.
    
    Args:
        true_positive: Number of true positives
        true_negative: Number of true negatives
        false_positive: Number of false positives
        false_negative: Number of false negatives
        
    Returns:
        Dictionary containing accuracy, precision, recall, and f1_score
    """
    # Calculate total samples
    total = true_positive + true_negative + false_positive + false_negative
    
    # Accuracy: (TP + TN) / Total
    accuracy = (true_positive + true_negative) / total if total > 0 else 0.0
    
    # Precision: TP / (TP + FP)
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0.0
    
    # Recall (Sensitivity): TP / (TP + FN)
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0.0
    
    # F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score
    }


# Test Dataset Performance

In [11]:
true_positive, false_positive, false_negative, true_negative = get_confusion_matrix(
    true_values=test_dataset["text"].tolist(),
    true_labels=test_dataset["label"].tolist(),
    vector_collection=chroma_collection,
    k=10,
)

print(
    f"TP: {true_positive}, TN: {true_negative}, FP: {false_positive}, FN: {false_negative}"
)

# Calculate and display metrics
metrics = calculate_metrics(
    true_positive, true_negative, false_positive, false_negative
)
print(f"\nMetrics:")
print(f"Accuracy:  {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall:    {metrics['recall']:.4f}")
print(f"F1 Score:  {metrics['f1_score']:.4f}")

TP: 697, TN: 0, FP: 0, FN: 361

Metrics:
Accuracy:  0.6588
Precision: 1.0000
Recall:    0.6588
F1 Score:  0.7943


# Validation Dataset Performance

In [None]:
true_positive, false_positive, false_negative, true_negative = get_confusion_matrix(
    true_values=validation_dataset["text"].tolist(),
    true_labels=validation_dataset["label"].tolist(),
    vector_collection=chroma_collection,
    k=10,
)

print(
    f"TP: {true_positive}, TN: {true_negative}, FP: {false_positive}, FN: {false_negative}"
)

# Calculate and display metrics
metrics = calculate_metrics(
    true_positive, true_negative, false_positive, false_negative
)
print(f"\nMetrics:")
print(f"Accuracy:  {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall:    {metrics['recall']:.4f}")
print(f"F1 Score:  {metrics['f1_score']:.4f}")