# BK Classification Inference Notebook

This notebook loads the best-performing 2-stage BART model and provides an interface for classifying bibliographic records with BK codes.

**Model**: Two-Stage BART (25.7% subset accuracy, 0.498 MCC)  
**Checkpoint**: `bart_classifier_bart-large_bs64_e15_sALL_2stage_bart_20250811_000132`

In [2]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import json
import os
from transformers import AutoTokenizer, BartModel
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [3]:
class BartWithClassifier(nn.Module):
    """BART classifier for multi-label BK classification"""
    
    def __init__(self, num_labels=1884, model_name="facebook/bart-large", dropout=0.1):
        super(BartWithClassifier, self).__init__()
        
        self.num_labels = num_labels
        self.bart = BartModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bart.config.hidden_size, num_labels)
        
    def forward(self, input_ids, attention_mask=None):
        outputs = self.bart(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        cls_output = last_hidden_state[:, 0, :]  # Take [CLS] token representation
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

print("Model architecture defined.")

Model architecture defined.


In [4]:
# Model paths
MODEL_DIR = "results/bart_classifier_bart-large_bs64_e15_sALL_2stage_bart_20250811_000132"
MODEL_PATH = f"{MODEL_DIR}/checkpoints_stage2/best_model_15.pt"
LABEL_MAP_PATH = "data/label_map.json"

# Load label mapping
print("Loading label mapping...")
with open(LABEL_MAP_PATH, 'r') as f:
    label_map = json.load(f)

# Create reverse mapping (index -> label)
idx_to_label = {v: k for k, v in label_map.items()}
num_labels = len(label_map)

print(f"Loaded {num_labels} BK labels")
print(f"Sample labels: {list(label_map.keys())[:10]}")

Loading label mapping...
Loaded 1884 BK labels
Sample labels: ['01.00', '01.20', '01.22', '01.29', '01.30', '01.40', '02.00', '02.01', '02.02', '02.10']


In [None]:
# Initialize model
print("Initializing model...")
model = BartWithClassifier(num_labels=num_labels, model_name="facebook/bart-large")

# Load trained weights
print("Loading trained weights...")
checkpoint = torch.load(MODEL_PATH, map_location=device)

# Print checkpoint keys to debug
print("Checkpoint keys:", list(checkpoint.keys()))

# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
elif 'model_state' in checkpoint:
    model.load_state_dict(checkpoint['model_state'])
    print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
elif 'state_dict' in checkpoint:
    model.load_state_dict(checkpoint['state_dict'])
    print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
else:
    # If it's just the raw state dict
    try:
        model.load_state_dict(checkpoint)
        print("Loaded raw state dict")
    except:
        print("ERROR: Could not determine checkpoint format")
        print("Available keys:", list(checkpoint.keys()))
        raise

model.to(device)
model.eval()
print("Model loaded and ready for inference!")

Initializing model...
Loading trained weights...
Checkpoint keys: ['epoch', 'model_state', 'optimizer_state', 'best_metric', 'monitor_metric']
Loaded model from epoch 15
Best metric: 0.21449027735603707
Monitor metric: f1_macro
Model loaded and ready for inference!


In [6]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
print("Tokenizer loaded.")

Loading tokenizer...
Tokenizer loaded.


In [10]:
def preprocess_text(title="", summary="", keywords="", loc_keywords="", rvk="", author=""):
    """
    Preprocess bibliographic fields into the format expected by the model.
    
    Args:
        title: Book title
        summary: Book summary/abstract
        keywords: Subject keywords
        loc_keywords: Library of Congress keywords
        rvk: RVK classification codes
        author: Author information (optional)
    
    Returns:
        Formatted text string for model input
    """
    # Combine fields in the same format as training
    input_text = f"""Title: {title or ''}
Summary: {summary or ''}
Keywords: {keywords or ''}
LOC_Keywords: {loc_keywords or ''}
RVK: {rvk or ''}"""

    
    return input_text.strip()

# Test preprocessing
test_text = preprocess_text(
    title="Machine Learning in Practice",
    summary="A comprehensive guide to machine learning applications",
    keywords="machine learning, artificial intelligence"
)
print("Sample preprocessed text:")
print(test_text)

Sample preprocessed text:
Title: Machine Learning in Practice
Summary: A comprehensive guide to machine learning applications
Keywords: machine learning, artificial intelligence
LOC_Keywords: 
RVK:


In [11]:
def predict_bk_codes(text: str, 
                     threshold: float = 0.5, 
                     top_k: int = 10,
                     max_length: int = 768) -> Dict:
    """
    Predict BK classification codes for input text.
    
    Args:
        text: Preprocessed input text
        threshold: Probability threshold for positive predictions
        top_k: Return top-k predictions regardless of threshold
        max_length: Maximum input sequence length
    
    Returns:
        Dictionary containing predictions and metadata
    """
    # Tokenize input
    inputs = tokenizer(
        text,
        truncation=True,
        padding=True,
        max_length=max_length,
        return_tensors='pt'
    )
    
    # Move to device
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    # Make prediction
    with torch.no_grad():
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.sigmoid(logits).cpu().numpy()[0]  # Get probabilities
    
    # Get predictions above threshold
    threshold_predictions = []
    for idx, prob in enumerate(probabilities):
        if prob >= threshold:
            threshold_predictions.append({
                'label': idx_to_label[idx],
                'probability': float(prob),
                'confidence': 'High' if prob > 0.8 else 'Medium' if prob > 0.6 else 'Low'
            })
    
    # Sort by probability
    threshold_predictions.sort(key=lambda x: x['probability'], reverse=True)
    
    # Get top-k predictions (regardless of threshold)
    top_indices = np.argsort(probabilities)[-top_k:][::-1]
    top_k_predictions = []
    for idx in top_indices:
        top_k_predictions.append({
            'label': idx_to_label[idx],
            'probability': float(probabilities[idx]),
            'confidence': 'High' if probabilities[idx] > 0.8 else 'Medium' if probabilities[idx] > 0.6 else 'Low'
        })
    
    return {
        'threshold_predictions': threshold_predictions,
        'top_k_predictions': top_k_predictions,
        'num_above_threshold': len(threshold_predictions),
        'max_probability': float(np.max(probabilities)),
        'threshold_used': threshold,
        'input_length': len(input_ids[0])
    }

print("Prediction functions ready.")

Prediction functions ready.


In [14]:
def display_predictions(predictions: Dict, show_top_k: int = 5):
    """
    Display prediction results in a formatted way.
    """
    print(f"\n{'='*60}")
    print("BK CLASSIFICATION RESULTS")
    print(f"{'='*60}")
    
    print(f"Input length: {predictions['input_length']} tokens")
    print(f"Max probability: {predictions['max_probability']:.4f}")
    print(f"Threshold used: {predictions['threshold_used']}")
    print(f"Predictions above threshold: {predictions['num_above_threshold']}")
    
    # Show threshold-based predictions
    if predictions['threshold_predictions']:
        print(f"\n📊 PREDICTIONS ABOVE THRESHOLD ({predictions['threshold_used']})")
        print("-" * 50)
        for i, pred in enumerate(predictions['threshold_predictions'][:show_top_k], 1):
            print(f"{i:2d}. {pred['label']:15s} | {pred['probability']:.4f} | {pred['confidence']}")
    else:
        print(f"\n⚠️  No predictions above threshold {predictions['threshold_used']}")
    
    # Show top-k predictions
    print(f"\n🔝 TOP {show_top_k} PREDICTIONS (regardless of threshold)")
    print("-" * 50)
    for i, pred in enumerate(predictions['top_k_predictions'][:show_top_k], 1):
        marker = "✓" if pred['probability'] >= predictions['threshold_used'] else " "
        print(f"{marker} {i:2d}. {pred['label']:15s} | {pred['probability']:.4f} | {pred['confidence']}")
    
    print(f"\n{'='*60}")

def get_bk_category_info(bk_code: str) -> str:
    """
    Get general category information for a BK code.
    """
    # Extract main category (before dot)
    main_cat = bk_code.split('.')[0]
    
    # General BK category mapping (simplified)
    bk_categories = {
        '00': 'Computer Science, Knowledge & General',
        '02': 'Librarianship, Information Science',
        '05': 'Communication, Mass Media',
        '10': 'Philosophy',
        '11': 'Theology',
        '15': 'Psychology',
        '17': 'Ethics',
        '18': 'Ancient Philosophy',
        '20': 'Education',
        '24': 'Education',
        '30': 'Sociology',
        '31': 'Politics',
        '33': 'Economics',
        '34': 'Law',
        '35': 'Public Administration',
        '38': 'Ethnology, Cultural Anthropology',
        '39': 'Folklore',
        '43': 'German Language & Literature',
        '50': 'Mathematics',
        '53': 'Physics',
        '54': 'Chemistry',
        '57': 'Biology',
        '58': 'Botany',
        '59': 'Zoology',
        '61': 'Medicine',
        '69': 'Architecture, Construction',
        '70': 'Agriculture',
        '76': 'Technology',
        '83': 'Economics',
        '85': 'Education',
        '86': 'Law',
        '89': 'Political Science'
    }
    
    return bk_categories.get(main_cat, f"Category {main_cat}")

print("Display functions ready.")

Display functions ready.


In [17]:
# Example 1: Computer Science
text1 = preprocess_text(
    title="Deep Learning with PyTorch",
    summary="This book provides a comprehensive introduction to deep learning using PyTorch framework. It covers neural networks, convolutional networks, and natural language processing applications.",
    keywords="deep learning, neural networks, PyTorch, machine learning, artificial intelligence",
    loc_keywords="Computer algorithms, Machine learning",
    author="Smith, John"
)

print("INPUT TEXT:")
print(text1)

# Make predictions with different thresholds
predictions1 = predict_bk_codes(text1, threshold=0.3, top_k=10)
display_predictions(predictions1)

# Show category information for top predictions
print("\n📚 CATEGORY INFORMATION:")
for pred in predictions1['top_k_predictions'][:3]:
    category = get_bk_category_info(pred['label'])
    print(f"  {pred['label']} → {category}")

INPUT TEXT:
Title: Deep Learning with PyTorch
Summary: This book provides a comprehensive introduction to deep learning using PyTorch framework. It covers neural networks, convolutional networks, and natural language processing applications.
Keywords: deep learning, neural networks, PyTorch, machine learning, artificial intelligence
LOC_Keywords: Computer algorithms, Machine learning
RVK:

BK CLASSIFICATION RESULTS
Input length: 78 tokens
Max probability: 0.9676
Threshold used: 0.3
Predictions above threshold: 1

📊 PREDICTIONS ABOVE THRESHOLD (0.3)
--------------------------------------------------
 1. 54.72           | 0.9676 | High

🔝 TOP 5 PREDICTIONS (regardless of threshold)
--------------------------------------------------
✓  1. 54.72           | 0.9676 | High
   2. 54.53           | 0.2489 | Low
   3. 31.73           | 0.0846 | Low
   4. 54.62           | 0.0403 | Low
   5. 54.80           | 0.0137 | Low


📚 CATEGORY INFORMATION:
  54.72 → Chemistry
  54.53 → Chemistry
  31.73 

In [19]:
# Example 2: Medicine
text2 = preprocess_text(
    title="Clinical Cardiology: A Modern Approach",
    summary="A comprehensive textbook covering contemporary approaches to cardiovascular disease diagnosis and treatment. Includes latest research on heart failure, arrhythmias, and interventional cardiology.",
    keywords="cardiology, heart disease, clinical medicine, cardiovascular",
    loc_keywords="Cardiology, Heart diseases"
)

print("INPUT TEXT:")
print(text2)

predictions2 = predict_bk_codes(text2, threshold=0.25, top_k=10)
display_predictions(predictions2)

# Show category information
print("\n📚 CATEGORY INFORMATION:")
for pred in predictions2['top_k_predictions'][:3]:
    category = get_bk_category_info(pred['label'])
    print(f"  {pred['label']} → {category}")

INPUT TEXT:
Title: Clinical Cardiology: A Modern Approach
Summary: A comprehensive textbook covering contemporary approaches to cardiovascular disease diagnosis and treatment. Includes latest research on heart failure, arrhythmias, and interventional cardiology.
Keywords: cardiology, heart disease, clinical medicine, cardiovascular
LOC_Keywords: Cardiology, Heart diseases
RVK:

BK CLASSIFICATION RESULTS
Input length: 74 tokens
Max probability: 0.9862
Threshold used: 0.25
Predictions above threshold: 1

📊 PREDICTIONS ABOVE THRESHOLD (0.25)
--------------------------------------------------
 1. 44.85           | 0.9862 | High

🔝 TOP 5 PREDICTIONS (regardless of threshold)
--------------------------------------------------
✓  1. 44.85           | 0.9862 | High
   2. 44.84           | 0.0369 | Low
   3. 44.38           | 0.0336 | Low
   4. 44.87           | 0.0180 | Low
   5. 44.37           | 0.0176 | Low


📚 CATEGORY INFORMATION:
  44.85 → Category 44
  44.84 → Category 44
  44.38 → Cate

In [20]:
# Interactive prediction - modify these fields
YOUR_TITLE = "Quantum Computing Fundamentals"
YOUR_SUMMARY = "An introduction to quantum computing principles, quantum algorithms, and quantum information theory."
YOUR_KEYWORDS = "quantum computing, quantum algorithms, quantum information"
YOUR_LOC_KEYWORDS = "Quantum theory, Computer science"
YOUR_RVK = ""
YOUR_AUTHOR = ""

# Preprocessing
your_text = preprocess_text(
    title=YOUR_TITLE,
    summary=YOUR_SUMMARY,
    keywords=YOUR_KEYWORDS,
    loc_keywords=YOUR_LOC_KEYWORDS,
    rvk=YOUR_RVK,
    author=YOUR_AUTHOR
)

print("YOUR INPUT TEXT:")
print(your_text)

# Make prediction
your_predictions = predict_bk_codes(your_text, threshold=0.25, top_k=15)
display_predictions(your_predictions, show_top_k=10)

# Category information
print("\n📚 CATEGORY INFORMATION:")
for pred in your_predictions['top_k_predictions'][:5]:
    category = get_bk_category_info(pred['label'])
    print(f"  {pred['label']} → {category}")

YOUR INPUT TEXT:
Title: Quantum Computing Fundamentals
Summary: An introduction to quantum computing principles, quantum algorithms, and quantum information theory.
Keywords: quantum computing, quantum algorithms, quantum information
LOC_Keywords: Quantum theory, Computer science
RVK:

BK CLASSIFICATION RESULTS
Input length: 54 tokens
Max probability: 0.9412
Threshold used: 0.25
Predictions above threshold: 2

📊 PREDICTIONS ABOVE THRESHOLD (0.25)
--------------------------------------------------
 1. 54.10           | 0.9412 | High
 2. 33.23           | 0.5561 | Low

🔝 TOP 10 PREDICTIONS (regardless of threshold)
--------------------------------------------------
✓  1. 54.10           | 0.9412 | High
✓  2. 33.23           | 0.5561 | Low
   3. 53.71           | 0.0558 | Low
   4. 54.51           | 0.0379 | Low
   5. 54.01           | 0.0277 | Low
   6. 33.06           | 0.0242 | Low
   7. 54.70           | 0.0180 | Low
   8. 54.38           | 0.0172 | Low
   9. 54.25           | 0.0120 