# PR Curve Analysis for CNN Model

This notebook:
1. Loads the CNN model from August 2025
2. Computes the Precision-Recall curve
3. Exports the data for TikZ plotting


In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import precision_recall_curve, average_precision_score
import pickle
import os
from pathlib import Path
import sys
import pandas as pd
sys.path.append(os.path.abspath(".."))
from utils.project_classes import CNNClassifier
import utils.project_functions as pf

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define paths
MODEL_PATH = "../Models/cnn_2025-08-12.pth"
OUTPUT_PATH = "../Predictions/pr_curve_data.dat"

# Data parameters for testing
U = [40, 100]  # Different U values to test
L = [96]       # System size
exp = ["A"]    # Experiment type
frames = range(1, 11)  # Use first 10 frames for testing

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Initialize and load model using the CNNClassifier from project_classes.py
model = CNNClassifier()
model.load_state_dict(torch.load(MODEL_PATH))
model.to(device)
model.eval()  # Set to evaluation mode

print("Model loaded successfully!")
print(f"Model architecture:\n{model}")

# Create test dataset from VTK files
all_tensors = []
all_labels = []

print("\nProcessing VTK files to create test dataset...")
for u in U:
    for l in L:
        for e in exp:
            previous_defects = np.array([]).reshape(0,2)
            for frame in frames:
                try:
                    file_path = f"../Data/MAI_U{str(u).zfill(3)}_L_{str(l).zfill(3)}_{e}/mpcd_{frame}.vtk"
                    _eigen_vals, _eigen_vecs = pf.load_and_pad_vtk(file_path, pad_width=4)
                    frame_defects = pf.find_defects(_eigen_vals, _eigen_vecs, 0.3)
                    _defects = np.concatenate([previous_defects, frame_defects])
                    
                    # Get labeled data
                    samples = pf.predict_field(model, _eigen_vals, _eigen_vecs, _defects, device=device, filename=f"w")
                    
                    # Extract tensors and labels
                    for sample in samples:
                        if sample.label is not None:  # Only use samples with ground truth labels
                            all_tensors.append(sample.tensor)
                            all_labels.append(sample.label)
                    
                    previous_defects = frame_defects
                    
                except Exception as e:
                    print(f"Error processing {file_path}: {str(e)}")
                    continue

# Convert lists to tensors
X_test = torch.stack([torch.tensor(t, dtype=torch.float32) for t in all_tensors]).to(device)
y_test = torch.tensor(all_labels, dtype=torch.long).to(device)

print(f"\nTest dataset created with {len(all_tensors)} samples")


  model.load_state_dict(torch.load(MODEL_PATH))


Model loaded successfully!
Model architecture:
CNNClassifier(
  (conv1): Conv2d(5, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=1568, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=3, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)


In [None]:
# Get model predictions
with torch.no_grad():
    y_pred = torch.softmax(model(X_test), dim=1)
    
# Convert predictions to numpy for sklearn
y_pred_np = y_pred.cpu().numpy()
y_test_np = y_test.cpu().numpy()

# Compute PR curve for each class
n_classes = y_pred_np.shape[1]
precision = {}
recall = {}
average_precision = {}

print("\nComputing PR curves for each class...")
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(
        (y_test_np == i).astype(int),
        y_pred_np[:, i]
    )
    average_precision[i] = average_precision_score(
        (y_test_np == i).astype(int),
        y_pred_np[:, i]
    )

print("\nAverage Precision Scores:")
for i in range(n_classes):
    print(f"Class {i}: {average_precision[i]:.3f}")

# Print class distribution in test set
print("\nClass distribution in test set:")
unique, counts = np.unique(y_test_np, return_counts=True)
for class_idx, count in zip(unique, counts):
    print(f"Class {class_idx}: {count} samples ({count/len(y_test_np)*100:.1f}%)")


KeyError: 'test_data'

In [None]:
# Export PR curve data for TikZ
def export_pr_data(precision_dict, recall_dict, output_path):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w') as f:
        # Write header
        f.write("% Precision-Recall curve data for TikZ plotting\n")
        f.write("% Format: Class Recall Precision\n\n")
        
        # Write data for each class
        for class_idx in precision_dict.keys():
            # Combine precision and recall into pairs
            for p, r in zip(precision_dict[class_idx], recall_dict[class_idx]):
                f.write(f"{class_idx} {r:.6f} {p:.6f}\n")
            # Add a blank line between classes for easier reading
            f.write("\n")

# Export the data
export_pr_data(precision, recall, OUTPUT_PATH)
print(f"PR curve data exported to: {OUTPUT_PATH}")

# Preview the first few lines of the exported file
with open(OUTPUT_PATH, 'r') as f:
    print("\nPreview of exported data:")
    print("".join(f.readlines()[:10]))


# Using the Data with TikZ

The exported data file contains the PR curve points in the following format:
```
Class Recall Precision
```

Example TikZ code to plot the curves:
```latex
\begin{tikzpicture}
    \begin{axis}[
        xlabel=Recall,
        ylabel=Precision,
        grid=major,
        xmin=0, xmax=1,
        ymin=0, ymax=1,
    ]
    
    % Plot data for each class
    \addplot[blue] table[x index=1, y index=2] {pr_curve_data.dat};
    \addlegendentry{Class 0}
    
    % Add more \addplot commands for other classes
    
    \end{axis}
\end{tikzpicture}
```

Note: You may need to filter the data for each class separately when plotting, or use a tool like `awk` to split the data file by class.
