In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from sklearn.metrics import f1_score, classification_report, multilabel_confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

#custom imports
from mpp.ml.models.sequence.vecset_transformer import ARMSTD
from mpp.constants import PATHS, VOCAB, INV_VOCAB
from mpp.ml.models.classifier.cadtostepset import ProcessClassificationTrsfmEncoderModule
from mpp.ml.datasets.fabricad import Fabricad
from mpp.ml.datasets.tkms import TKMS_Process_Dataset
from mpp.ml.datasets.datamodules import collate_fn


In [None]:

#set random seed for reproducibility
random.seed(42)
# Set the random seed for PyTorch
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#checkpoint_path = Path("/home/michelkruse/repos/cadtoseq/outputs/cadtoseq/2025-05-06-09:50:25_None/lightning_logs/version_0/checkpoints/epoch=943-val_loss=0.4672.ckpt")
checkpoint_path = Path("/workspace/mpp/src/cadtoseq/ml/models/checkpoints/best_model/cadtostepset/cadtostepset-best-epoch=66-val_loss=0.4306.ckpt")

model = ProcessClassificationTrsfmEncoderModule.load_from_checkpoint(checkpoint_path=checkpoint_path.as_posix())
model = model.to("cuda")
model.eval();


In [None]:
dataset = TKMS_Process_Dataset(mode="test")
test_loader = torch.utils.data.DataLoader(dataset, batch_size=58, shuffle=False) #SHUFFLE False for eva with filenames
v, p = next(iter(test_loader))

print(f"Vector shape: {v.shape}")
print(f"Label shape: {p.shape}")
print(f"Label example: {p[0]}")


In [None]:
v, p = next(iter(test_loader))
v = v.to(device)
p = p.to(device)
with torch.no_grad():
    outputs = model(v)
    print(f"Model output shape: {outputs.shape}")
    print(f"Output example: {outputs[0]}")

In [None]:
# Define class names
class_names = ['Bohren', 'Drehen', 'Fräsen']

# Evaluation loop
all_predictions = []
all_labels = []
threshold = 0.5

model.eval()
with torch.no_grad():
    for batch in test_loader:
        vec_sets, labels = batch
        vec_sets = vec_sets.to(device)
        labels = labels.to(device)
        
        # Get model outputs
        outputs = model(vec_sets)
        
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(outputs)
        
        # Convert to binary predictions
        predictions = (probs > threshold).float()
        
        all_predictions.append(predictions.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

# Concatenate all batches
all_predictions = np.vstack(all_predictions)
all_labels = np.vstack(all_labels)

# Calculate F1 scores
f1_micro = f1_score(all_labels, all_predictions, average='micro')
f1_macro = f1_score(all_labels, all_predictions, average='macro')
f1_per_class = f1_score(all_labels, all_predictions, average=None)

print(f"F1 Score (Micro): {f1_micro:.4f}")
print(f"F1 Score (Macro): {f1_macro:.4f}")
print("\nF1 Score per class:")
for name, score in zip(class_names, f1_per_class):
    print(f"{name}: {score:.4f}")

# Classification report
print("\nDetailed Classification Report:")
print(classification_report(all_labels, all_predictions, 
                          target_names=class_names))


In [None]:
# Multi-label confusion matrices (one per class)
cm_per_class = multilabel_confusion_matrix(all_labels, all_predictions)

# Plot confusion matrix for each class
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for idx, (cm, class_name) in enumerate(zip(cm_per_class, class_names)):
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx])
    axes[idx].set_title(f'Confusion Matrix: {class_name}')
    axes[idx].set_xlabel('Predicted')
    axes[idx].set_ylabel('Actual')
    axes[idx].set_xticklabels(['No', 'Yes'])
    axes[idx].set_yticklabels(['No', 'Yes'])

plt.tight_layout()
plt.show()

# Print confusion matrix details
for class_name, cm in zip(class_names, cm_per_class):
    tn, fp, fn, tp = cm.ravel()
    print(f"\n{class_name}:")
    print(f"  True Positives: {tp}")
    print(f"  False Positives: {fp}")
    print(f"  True Negatives: {tn}")
    print(f"  False Negatives: {fn}")


In [None]:
# Analyze "no process" parts
no_process_labels = (all_labels.sum(axis=1) == 0).sum()
no_process_predictions = (all_predictions.sum(axis=1) == 0).sum()

print(f"\nParts with no processes:")
print(f"  Ground truth: {no_process_labels} ({no_process_labels/len(all_labels)*100:.1f}%)")
print(f"  Predicted: {no_process_predictions} ({no_process_predictions/len(all_predictions)*100:.1f}%)")

# How well does model identify "no process" parts?
both_empty = ((all_labels.sum(axis=1) == 0) & (all_predictions.sum(axis=1) == 0)).sum()
print(f"  Correctly identified as 'no process': {both_empty}")


In [None]:
# Re-run evaluation but also collect probabilities
all_predictions = []
all_labels = []
all_probabilities = []  # NEW: collect probabilities
threshold = 0.5

model.eval()
with torch.no_grad():
    for batch in test_loader:
        vec_sets, labels = batch
        vec_sets = vec_sets.to(device)
        labels = labels.to(device)
        
        outputs = model(vec_sets)
        probs = torch.sigmoid(outputs)
        
        predictions = (probs > threshold).float()
        
        all_predictions.append(predictions.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
        all_probabilities.append(probs.cpu().numpy())  # NEW

# Concatenate
all_predictions = np.vstack(all_predictions)
all_labels = np.vstack(all_labels)
all_probabilities = np.vstack(all_probabilities)  # NEW

# Now analyze "no process" parts
no_proc_indices = np.where(all_labels.sum(axis=1) == 0)[0]
print("\nProbabilities for 'no process' parts:")
for i, idx in enumerate(no_proc_indices):
    print(f"Sample {i+1}: {all_probabilities[idx]}")

In [None]:
# SHUFFLE = FALSE FOR REAL NAMES!!!

# Get part names from dataset
dataset = TKMS_Process_Dataset(mode="test", target_type="step-set")

# Map indices to part names
no_proc_indices = np.where(all_labels.sum(axis=1) == 0)[0]

print("\n'No process' parts with details:")
for i, idx in enumerate(no_proc_indices):
    part_name = dataset.samples[idx]  # Get part name from dataset
    probs = all_probabilities[idx]
    pred = all_predictions[idx]
    
    print(f"\nPart: {part_name}")
    print(f"  Probabilities: Bohren={probs[0]:.3f}, Drehen={probs[1]:.3f}, Fräsen={probs[2]:.3f}")
    print(f"  Predicted: {pred}")
    print(f"  Correct: {'✓' if pred.sum() == 0 else '✗'}")