# Model Inference and Visualization Notebook

This notebook loads saved models and runs inference on evaluation data, then visualizes the results.

In [1]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import json
from TextUMC.model_dbscan import TextUMC, ModelConfig, Evidence, Claim, cluster_and_evaluate, visualize_clusters
import matplotlib.pyplot as plt
import os
from datetime import datetime

## Load Model

Load a saved model from the outputs directory.

In [2]:
# Initialize model with default config
model = TextUMC()

# List available model checkpoints
output_dirs = sorted([d for d in os.listdir('../outputs') if os.path.isdir(os.path.join('../outputs', d))])
print("Available output directories:")
for d in output_dirs:
    print(f"- {d}")

# Load the latest model by default
latest_dir = output_dirs[-1] if output_dirs else None
if latest_dir:
    model_path = os.path.join('../outputs', latest_dir, 'textumc_dbscan_model.pt')
    if os.path.exists(model_path):
        model.load_model(model_path)
        print(f"\nLoaded model from {model_path}")
    else:
        print(f"\nNo model_final.pt found in {latest_dir}")

Available output directories:
- 20250114_232554
- 20250115_001858
- 20250115_160930
- 20250115_161613
- 20250115_162201
- 20250115_162315
- 20250115_162539
- 20250115_163156
- 20250115_163326
- 20250115_165133
- 20250115_170200
- 20250115_171836
- 20250115_172108
- 20250115_172948
- 20250115_173045
- 20250115_175745
- 20250115_175805
- 20250115_180112
- 20250115_180146
- 20250115_180230
- 20250115_204304
- 20250115_204409
- 20250115_204728
- 20250115_204921
- 20250115_205108
- 20250115_205818
- 20250115_205908
- 20250115_210556
- 20250115_210850
- 20250115_215300
- 20250116_004304
- 20250116_004831
- 20250116_004930
- 20250116_005208
- 20250116_051048
- 20250116_102540
- 20250116_102608
- 20250116_154842
- 20250116_155523
- 20250116_155600
- 20250116_155614
- 20250116_155628
- 20250116_155715
- 20250116_155725
- 20250120_160124
- 20250120_160132
- 20250120_161320
- 20250120_161432
- 20250120_161550
- 20250120_161659
- 20250120_161901
- 20250120_162006
- 20250120_162740
- 20250120_16322

## Load Evaluation Data

Load the evaluation data to run inference on.

In [3]:
def load_evaluation_data(eval_path='../dataset/test.json'):
    """Load and prepare evaluation data"""
    with open(eval_path, 'r') as f:
        data = json.load(f)
    
    claims = []
    for item in data:
        # Create Evidence objects
        evidences = [Evidence(evidence['report_id'], evidence['content']) 
                    for evidence in item['reports']]
        
        # Create Claim object
        claim = Claim(
            claim_id=str(item.get('event_id', '')),
            content=item['claim'],
            label=item.get('label', ''),
            explanation=item.get('explain', ''),
            evidences=evidences
        )
        claims.append(claim)
    
    return claims

# Load evaluation data
eval_claims = load_evaluation_data()
print(f"Loaded {len(eval_claims)} claims for evaluation")

Loaded 1451 claims for evaluation


In [4]:
evidence_texts = [ev.content for ev in eval_claims[0].evidences]

with torch.no_grad():
    embeddings = model(evidence_texts)

print(f"Embeddings shape: {embeddings.shape}")
embeddings

Embeddings shape: torch.Size([26, 128])


tensor([[-0.1395, -0.0268,  0.0509,  ..., -0.1060, -0.0308, -0.0174],
        [-0.0055,  0.0271, -0.1331,  ..., -0.0378, -0.1289,  0.0627],
        [ 0.1287,  0.0716, -0.0469,  ...,  0.0583,  0.0131,  0.0339],
        ...,
        [ 0.0084,  0.0929,  0.1319,  ...,  0.0017,  0.0391,  0.0515],
        [-0.1156,  0.1822, -0.0038,  ..., -0.0488,  0.0004, -0.0187],
        [ 0.0466,  0.1041, -0.0269,  ...,  0.0179, -0.1794,  0.1510]],
       device='cuda:0')

## Run Inference

Run the model on evaluation data and get clustering results.

In [14]:
def run_inference(model, claims, save_dir='../eval_outputs'):
    """Run inference on claims and save results"""
    # Create output directory with timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_dir = os.path.join(save_dir, timestamp)
    os.makedirs(run_dir, exist_ok=True)

    if claims is not list:
        claims = [claims]
    
    results = []
    for claim in claims:
        # Get evidence texts
        evidence_texts = [ev.content for ev in claim.evidences]
        
        # Get embeddings
        with torch.no_grad():
            embeddings = model(evidence_texts)
        
        # Perform clustering
        labels, metrics = cluster_and_evaluate(embeddings)
        
        # Visualize clusters
        visualize_clusters(
            embeddings=embeddings,
            labels=labels,
            title=f'Claim_{claim.claim_id}',
            save_path=run_dir
        )
        
        # Store results
        result = {
            'claim_id': claim.claim_id,
            'n_clusters': metrics['n_clusters'],
            'eps': metrics['eps'],
            'min_samples': metrics['min_samples'],
            'pca_components': metrics['pca_components'],
            'explained_variance_ratio': metrics['explained_variance_ratio']
        }
        results.append(result)
    
    # Save results
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(run_dir, 'inference_results.csv'), index=False)
    
    return results_df, run_dir

# Run inference
results_df, run_dir = run_inference(model, eval_claims[0])
print(f"\nResults saved to {run_dir}")
results_df.head()


Results saved to ../eval_outputs/20250131_174842


Unnamed: 0,claim_id,n_clusters,eps,min_samples,pca_components,explained_variance_ratio
0,11972.json,1,0.878845,3,7,0.518979


## Analyze Results

Analyze and visualize the inference results.

In [None]:
def plot_cluster_distribution(results_df):
    """Plot distribution of number of clusters"""
    plt.figure(figsize=(10, 6))
    results_df['n_clusters'].hist(bins=20)
    plt.title('Distribution of Number of Clusters')
    plt.xlabel('Number of Clusters')
    plt.ylabel('Frequency')
    plt.show()

def plot_metrics_correlation(results_df):
    """Plot correlation between different metrics"""
    plt.figure(figsize=(10, 6))
    plt.scatter(results_df['eps'], results_df['n_clusters'])
    plt.title('Epsilon vs Number of Clusters')
    plt.xlabel('Epsilon')
    plt.ylabel('Number of Clusters')
    plt.show()

# Plot analysis
plot_cluster_distribution(results_df)
plot_metrics_correlation(results_df)

# Print summary statistics
print("\nSummary Statistics:")
print(results_df.describe())