# Imports & Set Up

In [1]:
from difflib import get_close_matches
import time
from tqdm import tqdm
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from neo4j import GraphDatabase
from helpers.enhanced_evaluation_metrics import EnhancedEvaluationMetrics
from sentence_transformers.util import dot_score
import numpy as np
import torch

from sklearn.metrics.pairwise import cosine_similarity

uri = "bolt://localhost:7687"
username = "neo4j"
password = "password1234"

driver = GraphDatabase.driver(uri, auth=(username, password))

# Load Test df

In [2]:
test_df = pd.read_excel("Test cross-references.xlsx")
test_df

Unnamed: 0,PN1,Manufacturer1,PN2,Manufacturer2,Category
0,MX25L6433FM2I-08G,Macronix,W25Q64JVSSIQ,Winbond Electronics,Memory
1,MT25QL512ABB1EW9-0SIT,Micron Technology Inc.,W25Q512JVEIQ,Winbond Electronics,Memory
2,MT25QU256ABA1EW7-0SIT,Micron Technology Inc.,W25Q256JWPIQ,Winbond Electronics,Memory
3,MT25QL128ABA1EW7-0SIT,Micron Technology Inc.,W25Q128JVPIQ,Winbond Electronics,Memory
4,MX25L51245GMI-08G,Macronix,W25Q512JVFIM,Winbond Electronics,Memory
5,MT25QU128ABA1EW7-0SIT,Micron Technology Inc,W25Q128JWPIQ,Winbond Electronics,Memory
6,MT25QL512ABB8E12-0SIT,Micron Technology Inc,W25Q512JVBIQ,Winbond Electronics,Memory
7,IS25LP064A-JKLE,"ISSI, Integrated Silicon Solution Inc",MX25L6433FZNI-08G,Macronix,Memory
8,W25Q256JWEIQ,Winbond Electronics,MX25U25645GZ4I00,Macronix,Memory
9,SST26VF016B-104I/SN,Microchip Technology,W25Q16JVSSIQ TR,Winbond Electronics,Memory


In [3]:
def find_similar_manufacturers(driver, manufacturer_name, category_name, n=3):
    """
    Function to find closest manufacturer name in database
    """
    with driver.session() as session:
        result = session.run(
            "MATCH (p:Product)-[:BELONGS_TO]->(c:Category {name: $category_name}) "
            "RETURN DISTINCT p.manufacturer AS manufacturer",
            category_name=category_name
        ).data()
        
        manufacturers = [record["manufacturer"] for record in result]
        
        matches = get_close_matches(manufacturer_name, manufacturers, n=n, cutoff=0.2)
        # If no matches found, return the original as first choice
        if not matches:
            return [manufacturer_name]
        return matches

def check_part_numbers_exist(driver, part_numbers, manufacturers):
    """Check if part numbers exist in database
    """
    missing_entries = []
    
    with driver.session() as session:
        for pn, manufacturer in zip(part_numbers, manufacturers):
            result = session.run(
                "MATCH (p:Product) WHERE p.name = $pn RETURN count(p) AS count",
                pn=pn
            ).single()
            
            if result and result["count"] == 0:
                missing_entries.append((pn, manufacturer))
    return missing_entries


def evaluate_cross_referencing_accuracy(driver, test_df):
    """
    Enhanced evaluation of cross-referencing accuracy with multiple metrics
    """
    # Ensure all required columns exist
    required_cols = ['PN1', 'Manufacturer1', 'PN2', 'Manufacturer2', 'Category']
    if not all(col in test_df.columns for col in required_cols):
        raise ValueError(f"Test dataframe must contain columns: {required_cols}")
    
    # Check for missing part numbers
    print("Checking if all part numbers exist in database...")
    missing_pn1 = check_part_numbers_exist(driver, test_df['PN1'].tolist(), test_df['Manufacturer1'].tolist())
    missing_pn2 = check_part_numbers_exist(driver, test_df['PN2'].tolist(), test_df['Manufacturer2'].tolist())
    
    if missing_pn1:
        print(f"Warning: Found {len(missing_pn1)} missing entries from first set")
    if missing_pn2:
        print(f"Warning: Found {len(missing_pn2)} missing entries from second set")
    
    # Build a set of missing part numbers for filtering
    missing_pns = set([pn for pn, _ in missing_pn1 + missing_pn2])
    
    # Filter out rows with missing part numbers
    valid_df = test_df[~(test_df['PN1'].isin(missing_pns) | test_df['PN2'].isin(missing_pns))].copy()
    print(f"Testing on {len(valid_df)} valid pairs out of {len(test_df)} total pairs")
    
    evaluator = EnhancedEvaluationMetrics()
    
    results = {
        'total_pairs': len(valid_df) * 2,
        'hit_rate_1': 0,
        'hit_rate_3': 0,
        'hit_rate_5': 0,
        'hit_rate_10': 0,
        'avg_position': 0,
        'median_position': 0,
        'details': [],
        'by_category': {}
    }
    
    categories = valid_df['Category'].unique()
    
    for category in categories:
        results['by_category'][category] = {
            'total_pairs': 0,
            'hit_rate_1': 0,
            'hit_rate_3': 0,
            'hit_rate_5': 0,
            'hit_rate_10': 0,
            'positions': []
        }
    
    positions = []
    
    # Process each test pair in both directions
    for idx, row in tqdm(valid_df.iterrows(), total=len(valid_df), desc="Evaluating forward direction"):
        category_name = row['Category']
        
        # Forward direction: PN1 -> PN2
        evaluate_single_direction(
            driver, row['PN1'], row['Manufacturer1'], row['PN2'], row['Manufacturer2'], 
            category_name, results, positions, "forward"
        )
    
    # Process each test pair in reverse direction
    for idx, row in tqdm(valid_df.iterrows(), total=len(valid_df), desc="Evaluating reverse direction"):
        category_name = row['Category']
        
        # Reverse direction: PN2 -> PN1
        evaluate_single_direction(
            driver, row['PN2'], row['Manufacturer2'], row['PN1'], row['Manufacturer1'], 
            category_name, results, positions, "reverse"
        )
    
    # Calculate overall accuracy metrics
    if positions:
        results['avg_position'] = sum(positions) / len(positions)
        results['median_position'] = sorted(positions)[len(positions) // 2]
    
    # Calculate overall accuracy percentages
    total = results['total_pairs']
    if total > 0:
        results['hit_rate_1'] = results['hit_rate_1'] / total * 100 
        results['hit_rate_3'] = results['hit_rate_3'] / total * 100
        results['hit_rate_5'] = results['hit_rate_5'] / total * 100
        results['hit_rate_10'] = results['hit_rate_10'] / total * 100
    
    # Calculate per-category accuracy metrics
    for category in categories:
        category_results = results['by_category'][category]
        category_total = category_results['total_pairs']
        
        if category_total > 0:
            category_results['hit_rate_1'] = category_results['hit_rate_1'] / category_total * 100
            category_results['hit_rate_3'] = category_results['hit_rate_3'] / category_total * 100
            category_results['hit_rate_5'] = category_results['hit_rate_5'] / category_total * 100 
            category_results['hit_rate_10'] = category_results['hit_rate_10'] / category_total * 100 
            
            if category_results['positions']:
                category_results['avg_position'] = sum(category_results['positions']) / len(category_results['positions'])
                category_results['median_position'] = sorted(category_results['positions'])[len(category_results['positions']) // 2]
    
    # Calculate enhanced metrics using the evaluator
    enhanced_metrics = evaluator.calculate_metrics(results['details'])
    results.update({
        'mrr': enhanced_metrics['mrr'],
        'rbp': enhanced_metrics['rbp'],
        'vrbp': enhanced_metrics['vrbp'],
        'nfrr': enhanced_metrics['nfrr']
    })
    
    
    # Calculate per-category enhanced metrics
    for category in categories:
        category_details = [detail for detail in results['details'] if detail['category'] == category]
        category_metrics = evaluator.calculate_metrics(category_details)
        results['by_category'][category].update({
            'mrr': category_metrics['mrr'],
            'rbp': category_metrics['rbp'],
            'vrbp': enhanced_metrics['vrbp'],
            'nfrr': category_metrics['nfrr']
        })
    
    return results

def evaluate_single_direction(driver, source_pn, source_manufacturer, target_pn, target_manufacturer, 
                             category_name, results, positions, direction):
    """
    helper function to evaluate a single direction
    """
    # Get corrected source manufacturer
    source_manufacturers = find_similar_manufacturers(driver, source_manufacturer, category_name, n=1)
    corrected_source_manufacturer = source_manufacturers[0]
    
    # Get top 3 similar target manufacturers
    target_manufacturers = find_similar_manufacturers(driver, target_manufacturer, category_name, n=2)
    
    # Initialize variables for best result across all manufacturers
    best_position = -1
    best_found_pns = []
    best_manufacturer = None
    max_candidates = 0
    
    # Try each potential target manufacturer
    for curr_target_manufacturer in target_manufacturers:
        cross_refs = find_cross_referencing_products(
            driver,
            product_name=source_pn,
            manufacturer_name=corrected_source_manufacturer,
            category_name=category_name,
            exclude_target_manufacturer=False,
            include_manufacturer=curr_target_manufacturer
        )
        
        max_candidates = max(max_candidates, len(cross_refs))
        
        found_pns = [cr[0] for cr in cross_refs]
        
        # Check if expected part number is in results and at what position
        if target_pn in found_pns:
            position = found_pns.index(target_pn) + 1
            if best_position == -1 or position < best_position:
                best_position = position
                best_found_pns = found_pns
                best_manufacturer = curr_target_manufacturer
    
    # If we found a match in any manufacturer, record the position
    if best_position != -1:
        positions.append(best_position)
        results['by_category'][category_name]['positions'].append(best_position)
        
        if best_position == 1:
            results['hit_rate_1'] += 1
            results['by_category'][category_name]['hit_rate_1'] += 1
        if best_position <= 3:
            results['hit_rate_3'] += 1
            results['by_category'][category_name]['hit_rate_3'] += 1 
        if best_position <= 5:
            results['hit_rate_5'] += 1 
            results['by_category'][category_name]['hit_rate_5'] += 1
        if best_position <= 10:
            results['hit_rate_10'] += 1
            results['by_category'][category_name]['hit_rate_10'] += 1 
    
    results['by_category'][category_name]['total_pairs'] += 1
    
    max_rank = max(10, max_candidates)
    
    results['details'].append({
        'source_pn': source_pn,
        'source_manufacturer': source_manufacturer,
        'corrected_source_manufacturer': corrected_source_manufacturer,
        'expected_pn': target_pn,
        'expected_manufacturer': target_manufacturer,
        'matched_manufacturer': best_manufacturer,
        'tried_manufacturers': target_manufacturers,
        'found_at_position': best_position,
        'top_10_results': best_found_pns[:10],
        'category': category_name,
        'direction': direction,
        'max_rank': max_rank
    })

    
def get_package_embedding(driver, product_name, manufacturer_name):
    with driver.session() as session:
        result = session.run(
            "MATCH (p:Product {name: $product_name})-[:PACKAGED_IN]->(pkg:Package) "
            "WHERE p.manufacturer = $manufacturer_name "
            "RETURN pkg.embedding AS embedding",
            product_name=product_name,
            manufacturer_name=manufacturer_name
        )
        record = result.single()
        return record["embedding"] if record else None
    
def get_feature_vector_prioritized(driver, product_name, manufacturer_name):
    with driver.session() as session:
        result = session.run(
            "MATCH (p:Product {name: $product_name, manufacturer: $manufacturer_name}) "
            "RETURN p.feature_vector_prioritized_1 AS feature_vector_prioritized",
            product_name=product_name,
            manufacturer_name=manufacturer_name
        )
        record = result.single()
        return record["feature_vector_prioritized"] if record else None


def get_feature_vector(driver, product_name, manufacturer_name):
    with driver.session() as session:
        result = session.run(
            "MATCH (p:Product {name: $product_name}) "
            "WHERE p.manufacturer = $manufacturer_name "
            "RETURN p.feature_vector_1 AS feature_vector", # Change to feature_vector_2 for just values or to feature_vector_3 for standardized
            product_name=product_name,
            manufacturer_name=manufacturer_name
        )
        record = result.single()
        return record["feature_vector"] if record else None


def find_cross_referencing_products(driver, product_name, manufacturer_name, category_name, exclude_target_manufacturer=True, include_manufacturer=None):
    # Get target product's embeddings
    package_embedding = get_package_embedding(driver, product_name, manufacturer_name)
    feature_vector = get_feature_vector(driver, product_name, manufacturer_name)
    # feature_vector_prioritized = get_feature_vector_prioritized(driver, product_name, manufacturer_name)
    if not package_embedding or not feature_vector:
        print(f"Missing embeddings for {product_name}")
        return []

    query = """
    MATCH (p:Product {name: $product_name})-[:BELONGS_TO]->(c:Category {name: $category_name})
    WHERE p.manufacturer = $manufacturer_name
    MATCH (similar:Product)-[:BELONGS_TO]->(c)
    WHERE similar <> p
      AND ( $include_manufacturer IS NOT NULL AND similar.manufacturer = $include_manufacturer
            OR $include_manufacturer IS NULL AND ( NOT $exclude_target_manufacturer OR similar.manufacturer <> $manufacturer_name ) )
    MATCH (similar)-[:PACKAGED_IN]->(s_pkg:Package)
    RETURN similar.name AS name, similar.manufacturer AS manufacturer, s_pkg.embedding AS pkg_embedding, similar.feature_vector_1 AS feature_vector,
    """
    
    # , similar.feature_vector_prioritized_1 AS feature_vector_prioritized
    

    with driver.session() as session:
        result = session.run(
            query,
            product_name=product_name,
            manufacturer_name=manufacturer_name,
            category_name=category_name,
            include_manufacturer=include_manufacturer,
            exclude_target_manufacturer=exclude_target_manufacturer
        ).data()
        

    # Compute similarities
    similarities = []
    for record in result:
        similar_name = record["name"]
        similar_manufacturer = record["manufacturer"]
        similar_pkg_embedding = record["pkg_embedding"]
        similar_feature_vector = record["feature_vector"]
        # similar_feature_vector_prioritized = record["feature_vector_prioritized"]

        # Compute cosine similarities
        pkg_similarity = cosine_similarity([package_embedding], [similar_pkg_embedding])[0][0]
        attr_similarity = cosine_similarity([feature_vector], [similar_feature_vector])[0][0]
        # attr_similarity = dot_score(torch.tensor([feature_vector]), torch.tensor([similar_feature_vector]))[0][0].item()
        # prior_similarity = cosine_similarity([feature_vector_prioritized], [similar_feature_vector_prioritized])[0][0]

        # Combine scores
        total_score = 0.5 * pkg_similarity + 0.5 * attr_similarity
        # total_score = 0.3 * pkg_similarity + 0.3 * prior_similarity + 0.4 * attr_similarity
        similarities.append((similar_name, similar_manufacturer, total_score))

    # Sort by score descending and return top N
    similarities.sort(key=lambda x: x[2], reverse=True)
    return similarities

def generate_accuracy_report(results):
    """
    Generate a comprehensive report of cross-reference accuracy with advanced metrics:
    """
    print("\n" + "="*60)
    print("CROSS-REFERENCE EVALUATION RESULTS")
    print("="*60)
    print(f"Total test pairs evaluated: {results['total_pairs']}")
    
    print("\nPRIMARY HIT RATE METRICS:")
    print(f"  Hit Rate@1: {results.get('hit_rate_1', 0):.2f}%")
    print(f"  Hit Rate@3: {results.get('hit_rate_3', 0):.2f}%")
    print(f"  Hit Rate@5: {results.get('hit_rate_5', 0):.2f}%")
    print(f"  Hit Rate@10: {results.get('hit_rate_10', 0):.2f}%")
    
    # Add the advanced metrics to the report
    print("\nADVANCED RANKING METRICS:")
    
    
    print(f"  Mean Reciprocal Rank (MRR): {results.get('mrr', 0):.4f}")

    print(f" Rank-Biased Precision (RBP): {results.get('rbp', 0):.4f}")
    
    print(f" Variable Rank-Biased Precision (VRBP): {results.get('vrbp', 0):.4f}")
    
    print(f"  'Normalized First Relevant Rank (NFRR): {results.get('nfrr', 0):.4f}")
    
    if results.get('avg_position'):
        print(f"\nAverage position of correct match: {results['avg_position']:.2f}")
        print(f"Median position of correct match: {results['median_position']}")
    
    # Print category-specific results
    print("\n" + "="*60)
    print("RESULTS BY CATEGORY")
    print("="*60)
    
    for category, cat_results in results['by_category'].items():
        print(f"\nCategory: {category}")
        print(f"  Total pairs: {cat_results['total_pairs']}")
        print(f"  Hit Rate@1: {cat_results.get('hit_rate_1', 0):.2f}%")
        print(f"  Hit Rate@3: {cat_results.get('hit_rate_3', 0):.2f}%")
        print(f"  Hit Rate@5: {cat_results.get('hit_rate_5', 0):.2f}%")
        print(f"  Hit Rate@10: {cat_results.get('hit_rate_10', 0):.2f}%")
        
        # Add advanced metrics for each category
        print(f"  MRR: {cat_results.get('mrr', 0):.4f}")
        print(f"  RBP: {cat_results.get('rbp', 0):.4f}")
        print(f"  VRBP: {cat_results.get('vrbp', 0):.4f}")
        print(f"  NFRR: {cat_results.get('nfrr', 0):.4f}")
        
        if cat_results.get('avg_position'):
            print(f"  Average position: {cat_results['avg_position']:.2f}")
            print(f"  Median position: {cat_results['median_position']}")
    
    # Manufacturer correction insights
    source_manufacturer_corrections = sum(1 for d in results['details'] 
                                         if d['source_manufacturer'] != d['corrected_source_manufacturer'])
    target_manufacturer_matches = sum(1 for d in results['details'] 
                                     if d['matched_manufacturer'] is not None and 
                                     d['matched_manufacturer'] != d['expected_manufacturer'])
    
    print("\n" + "="*60)
    print("MANUFACTURER CORRECTION INSIGHTS")
    print("="*60)
    print(f"  Source manufacturer corrections: {source_manufacturer_corrections} ({source_manufacturer_corrections/results['total_pairs']*100:.2f}%)")
    print(f"  Target manufacturer alternative matches: {target_manufacturer_matches} ({target_manufacturer_matches/results['total_pairs']*100:.2f}%)")
    
    print("\n" + "="*60)
    print("FAILURE ANALYSIS")
    print("="*60)
    not_found = sum(1 for detail in results['details'] if detail['found_at_position'] == -1)
    print(f"  Matches not found at all: {not_found} ({not_found/results['total_pairs']*100:.2f}%)")
    
    # List problematic cases
    problem_cases = [d for d in results['details'] if d['found_at_position'] == -1 or d['found_at_position'] > 5]
    if problem_cases:
        print("\nTop 3 problematic cases:")
        for i, case in enumerate(problem_cases[:3]):
            print(f"\n  Case {i+1}:")
            print(f"    Category: {case['category']}")
            print(f"    Source: {case['source_pn']} (Original: {case['source_manufacturer']}, Corrected: {case['corrected_source_manufacturer']})")
            print(f"    Expected: {case['expected_pn']} (Original: {case['expected_manufacturer']})")
            print(f"    Tried target manufacturers: {case['tried_manufacturers']}")
            print(f"    Best matched manufacturer: {case['matched_manufacturer']}")
            print(f"    Found at position: {case['found_at_position']}")
            print(f"    Top 10 results: {case['top_10_results']}")
    
    # List successful cases
    corrected_successes = [d for d in results['details'] 
                         if (d['source_manufacturer'] != d['corrected_source_manufacturer'] or 
                             (d['matched_manufacturer'] is not None and d['matched_manufacturer'] != d['expected_manufacturer'])) 
                         and d['found_at_position'] <= 5]
    
    if corrected_successes:
        print("\nTop 3 successful cases with manufacturer corrections:")
        for i, case in enumerate(corrected_successes[:3]):
            print(f"\n  Case {i+1}:")
            print(f"    Category: {case['category']}")
            print(f"    Source: {case['source_pn']} (Original: {case['source_manufacturer']}, Corrected: {case['corrected_source_manufacturer']})")
            print(f"    Expected: {case['expected_pn']} (Original: {case['expected_manufacturer']})")
            print(f"    Matched manufacturer: {case['matched_manufacturer']}")
            print(f"    Found at position: {case['found_at_position']}")
    return


print("Starting evaluation...")
start_time = time.time()
results = evaluate_cross_referencing_accuracy(driver, test_df)
end_time = time.time()

# Generate report
generate_accuracy_report(results)
print(f"\nEvaluation completed in {end_time - start_time:.2f} seconds")

# Save detailed results
details_df = pd.DataFrame(results['details'])
details_df.to_csv("cr_results_attribute_names_big_vector_mpnet.csv", index=False)
print("Detailed results saved to 'cross_reference_evaluation_details.csv'")

driver.close()

Starting evaluation...
Checking if all part numbers exist in database...
Testing on 0 valid pairs out of 50 total pairs


Evaluating forward direction: 0it [00:00, ?it/s]
Evaluating reverse direction: 0it [00:00, ?it/s]


CROSS-REFERENCE EVALUATION RESULTS
Total test pairs evaluated: 0

PRIMARY HIT RATE METRICS:
  Hit Rate@1: 0.00%
  Hit Rate@3: 0.00%
  Hit Rate@5: 0.00%
  Hit Rate@10: 0.00%

ADVANCED RANKING METRICS:
  Mean Reciprocal Rank (MRR): 0.0000
 Rank-Biased Precision (RBP): 0.0000
 Variable Rank-Biased Precision (VRBP): 0.0000
  'Normalized First Relevant Rank (NFRR): 0.0000

RESULTS BY CATEGORY

MANUFACTURER CORRECTION INSIGHTS





ZeroDivisionError: division by zero