# Transaction Tagging Inference Demo

This notebook demonstrates how to:
1. Load the trained model and golden record index
2. Run batch predictions on a CSV file
3. Display predicted labels and retrieved similar transactions
4. Identify potential mislabeled golden records

## Prerequisites

Ensure you have:
- ‚úÖ Trained model: `experiments/.../fusion_encoder_best.pth`
- ‚úÖ Training artifacts: `training_artifacts/training_artifacts.pkl`
- ‚úÖ FAISS index: `golden_records.faiss`

## 1. Setup & Imports

In [None]:
import pandas as pd
import numpy as np
from src.inference_pipeline import TransactionInferencePipeline, print_prediction_result
import warnings
warnings.filterwarnings('ignore')

print("‚úÖ Imports successful")

## 2. Initialize Pipeline

Load the trained model and golden record index.

In [None]:
# Configuration
ARTIFACTS_PATH = "training_artifacts/training_artifacts.pkl"
MODEL_PATH = "experiments/tagger_proj256_final256_freeze-gradual_bs2048_lr5.66e-05/fusion_encoder_best.pth"
INDEX_PATH = "golden_records.faiss"
TOP_K = 5  # Number of similar transactions to retrieve

# Initialize pipeline
pipeline = TransactionInferencePipeline(
    artifacts_path=ARTIFACTS_PATH,
    model_path=MODEL_PATH,
    index_path=INDEX_PATH
)

print(f"\n‚úÖ Pipeline initialized successfully!")
print(f"   Index size: {pipeline.index.ntotal} golden records")
print(f"   Categories: {len(pipeline.label_mapping)}")

## 3. Load Your Test CSV

**Upload your CSV file** or specify the path below.

Required columns:
- `tran_partclr` (transaction description)
- `tran_mode` (transaction mode)
- `dr_cr_indctor` (debit/credit indicator)
- `sal_flag` (salary flag)
- `tran_amt_in_ac` (transaction amount)

In [None]:
# Load your CSV file
CSV_PATH = "data/sample_txn.csv"  # ‚Üê Change this to your test file

df = pd.read_csv(CSV_PATH)

print(f"‚úÖ Loaded {len(df)} transactions from {CSV_PATH}")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nFirst 3 rows:")
df.head(3)

## 4. Run Batch Prediction

Predict labels for all transactions in the CSV.

In [None]:
# Convert DataFrame to list of dictionaries
transactions = df.to_dict('records')

# Run batch prediction
print(f"Running predictions on {len(transactions)} transactions...\n")

results = pipeline.predict_batch(transactions, top_k=TOP_K)

print(f"‚úÖ Batch prediction complete!")

## 5. View Results Summary

Quick overview of all predictions.

In [None]:
# Create summary DataFrame
summary_data = []

for i, (txn, result) in enumerate(zip(transactions, results)):
    summary_data.append({
        'Index': i,
        'Description': txn['tran_partclr'][:50] + '...' if len(txn['tran_partclr']) > 50 else txn['tran_partclr'],
        'Amount': txn['tran_amt_in_ac'],
        'Predicted_Category': result['predicted_category'],
        'Confidence': f"{result['confidence']:.1%}",
        'Votes': f"{int(result['confidence'] * TOP_K)}/{TOP_K}"
    })

summary_df = pd.DataFrame(summary_data)

print("\n" + "="*80)
print("PREDICTION SUMMARY")
print("="*80)
summary_df

## 6. Detailed View: Individual Predictions

Examine each prediction with retrieved similar transactions.

In [None]:
# Choose which transaction to inspect (change this number)
TRANSACTION_INDEX = 0  # ‚Üê Change this to view different transactions

txn = transactions[TRANSACTION_INDEX]
result = results[TRANSACTION_INDEX]

# Print detailed result
print_prediction_result(result, txn, TOP_K)

## 7. Interactive Viewer

Loop through all transactions and display results.

In [None]:
# View all transactions with details
NUM_TO_DISPLAY = 10  # ‚Üê Change to see more/fewer

for i in range(min(NUM_TO_DISPLAY, len(transactions))):
    txn = transactions[i]
    result = results[i]
    
    print(f"\n{'#'*80}")
    print(f"TRANSACTION {i+1}/{len(transactions)}")
    print(f"{'#'*80}")
    
    print(f"\nQuery:")
    print(f"  Description: {txn['tran_partclr']}")
    print(f"  Amount: ${txn['tran_amt_in_ac']:.2f}")
    print(f"  Mode: {txn['tran_mode']} | DR/CR: {txn['dr_cr_indctor']}")
    
    print(f"\nPrediction:")
    print(f"  Category: {result['predicted_category']}")
    print(f"  Confidence: {result['confidence']:.2%} ({int(result['confidence']*TOP_K)}/{TOP_K} votes)")
    
    print(f"\nVote Distribution:")
    for category, count in sorted(result['vote_distribution'].items(), key=lambda x: x[1], reverse=True):
        print(f"  {category}: {count} vote(s)")
    
    print(f"\nRetrieved Similar Transactions:")
    for j, similar in enumerate(result['similar_transactions'], 1):
        similar_txn = similar['transaction']
        print(f"\n  {j}. Golden Record Index: {similar['index']}")
        print(f"     Label: {similar['label']}")
        print(f"     Description: {similar_txn['description']}")
        print(f"     Amount: ${similar_txn['amount']:.2f}")
        print(f"     Similarity Distance: {similar['similarity_distance']:.4f}")
    
    print(f"\n{'-'*80}")

## 8. Identify Low Confidence Predictions

Find transactions that may need manual review.

In [None]:
# Set confidence threshold
CONFIDENCE_THRESHOLD = 0.6  # 60%

low_confidence = []

for i, (txn, result) in enumerate(zip(transactions, results)):
    if result['confidence'] < CONFIDENCE_THRESHOLD:
        low_confidence.append({
            'Index': i,
            'Description': txn['tran_partclr'],
            'Predicted': result['predicted_category'],
            'Confidence': result['confidence'],
            'Votes': result['vote_distribution']
        })

print(f"\n‚ö†Ô∏è Found {len(low_confidence)} low-confidence predictions (< {CONFIDENCE_THRESHOLD:.0%})\n")

if low_confidence:
    print("="*80)
    print("LOW CONFIDENCE PREDICTIONS - REVIEW RECOMMENDED")
    print("="*80)
    
    for item in low_confidence[:10]:  # Show first 10
        print(f"\nIndex {item['Index']}:")
        print(f"  Description: {item['Description'][:60]}...")
        print(f"  Predicted: {item['Predicted']} (Confidence: {item['Confidence']:.1%})")
        print(f"  Vote breakdown: {item['Votes']}")
else:
    print("‚úÖ All predictions have high confidence!")

## 9. Check for Potential Mislabeled Golden Records

If you have ground truth labels in your CSV, compare them with retrieved golden records.

In [None]:
# If your CSV has a 'category' column with ground truth
if 'category' in df.columns:
    suspicious_golden_records = []
    
    for i, (txn, result) in enumerate(zip(transactions, results)):
        expected_label = txn['category']
        predicted_label = result['predicted_category']
        
        # If prediction is wrong
        if expected_label != predicted_label:
            # Check if any retrieved golden records have wrong labels
            for similar in result['similar_transactions']:
                if similar['label'] != expected_label:
                    suspicious_golden_records.append({
                        'golden_index': similar['index'],
                        'golden_description': similar['transaction']['description'],
                        'golden_label': similar['label'],
                        'query_description': txn['tran_partclr'],
                        'expected_label': expected_label,
                        'distance': similar['similarity_distance']
                    })
    
    print(f"\nüîç Found {len(suspicious_golden_records)} potentially mislabeled golden records\n")
    
    if suspicious_golden_records:
        # Remove duplicates by golden_index
        unique_suspicious = {item['golden_index']: item for item in suspicious_golden_records}
        
        print("="*80)
        print("POTENTIALLY MISLABELED GOLDEN RECORDS")
        print("="*80)
        
        for golden_idx, item in list(unique_suspicious.items())[:10]:
            print(f"\nGolden Record Index: {item['golden_index']}")
            print(f"  Description: {item['golden_description']}")
            print(f"  Current Label: {item['golden_label']}")
            print(f"  Similar to: {item['query_description'][:50]}...")
            print(f"  Expected Label: {item['expected_label']}")
            print(f"  Distance: {item['distance']:.4f}")
            print(f"  ‚ö†Ô∏è Consider reviewing this label!")
else:
    print("‚ÑπÔ∏è No ground truth 'category' column found in CSV")

## 10. Export Results to CSV

Save predictions for further analysis.

In [None]:
# Prepare export data
export_data = []

for i, (txn, result) in enumerate(zip(transactions, results)):
    # Base record
    record = {
        'transaction_index': i,
        'description': txn['tran_partclr'],
        'amount': txn['tran_amt_in_ac'],
        'mode': txn['tran_mode'],
        'dr_cr': txn['dr_cr_indctor'],
        'predicted_category': result['predicted_category'],
        'confidence': result['confidence'],
        'vote_distribution': str(result['vote_distribution']),
    }
    
    # Add ground truth if available
    if 'category' in txn:
        record['ground_truth'] = txn['category']
        record['correct'] = txn['category'] == result['predicted_category']
    
    # Add top-3 similar transactions
    for j, similar in enumerate(result['similar_transactions'][:3], 1):
        record[f'similar_{j}_index'] = similar['index']
        record[f'similar_{j}_description'] = similar['transaction']['description']
        record[f'similar_{j}_label'] = similar['label']
        record[f'similar_{j}_distance'] = similar['similarity_distance']
    
    export_data.append(record)

# Create DataFrame and export
export_df = pd.DataFrame(export_data)
output_path = 'inference_results.csv'
export_df.to_csv(output_path, index=False)

print(f"‚úÖ Results exported to: {output_path}")
print(f"   Total predictions: {len(export_df)}")
print(f"\nPreview:")
export_df.head()

## 11. Statistics & Analysis

In [None]:
print("="*80)
print("INFERENCE STATISTICS")
print("="*80)

# Confidence distribution
confidences = [r['confidence'] for r in results]
print(f"\nConfidence Scores:")
print(f"  Mean: {np.mean(confidences):.2%}")
print(f"  Median: {np.median(confidences):.2%}")
print(f"  Min: {np.min(confidences):.2%}")
print(f"  Max: {np.max(confidences):.2%}")

# Category distribution
predicted_categories = [r['predicted_category'] for r in results]
category_counts = pd.Series(predicted_categories).value_counts()
print(f"\nPredicted Category Distribution:")
for category, count in category_counts.items():
    print(f"  {category}: {count} ({count/len(results):.1%})")

# Accuracy (if ground truth available)
if 'category' in df.columns:
    correct = sum(1 for txn, result in zip(transactions, results) 
                  if txn['category'] == result['predicted_category'])
    accuracy = correct / len(results)
    print(f"\nAccuracy: {accuracy:.2%} ({correct}/{len(results)})")

print("\n" + "="*80)

## 12. Visualizations (Optional)

Uncomment if you want plots.

In [None]:
# import matplotlib.pyplot as plt
# import seaborn as sns

# # Confidence distribution
# plt.figure(figsize=(10, 5))
# plt.hist([r['confidence'] for r in results], bins=20, edgecolor='black')
# plt.xlabel('Confidence')
# plt.ylabel('Count')
# plt.title('Confidence Distribution')
# plt.axvline(x=0.6, color='r', linestyle='--', label='Threshold (60%)')
# plt.legend()
# plt.show()

# # Category distribution
# plt.figure(figsize=(12, 6))
# category_counts.plot(kind='bar')
# plt.xlabel('Category')
# plt.ylabel('Count')
# plt.title('Predicted Category Distribution')
# plt.xticks(rotation=45, ha='right')
# plt.tight_layout()
# plt.show()

## Summary

‚úÖ **What you learned:**
1. How to load the inference pipeline
2. How to run batch predictions on a CSV file
3. How to view predicted labels and similar transactions
4. How to identify low-confidence predictions
5. How to spot potentially mislabeled golden records
6. How to export results for further analysis

**Next Steps:**
- Review low-confidence predictions manually
- Investigate potentially mislabeled golden records
- Add corrected labels to your dataset
- Rebuild golden record index with `run_inference.py`