Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 92 additions & 42 deletions chebai/result/generate_class_properties.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
from pathlib import Path
from typing import Literal

import torch
import torchmetrics
from jsonargparse import CLI
from sklearn.metrics import multilabel_confusion_matrix
from torchmetrics.classification import MultilabelConfusionMatrix, MultilabelF1Score

from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.result.utils import (
Expand Down Expand Up @@ -37,55 +38,85 @@ def load_class_labels(path: Path) -> list[str]:

@staticmethod
def compute_classwise_scores(
y_true: list[torch.Tensor],
y_pred: list[torch.Tensor],
raw_preds: torch.Tensor,
metrics_obj_dict: dict[str, torchmetrics.Metric],
class_names: list[str],
) -> dict[str, dict[str, float]]:
"""
Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class
in a multi-label setting.
Compute per-class evaluation metrics for a multi-label classification task.

This method uses torchmetrics objects (MultilabelConfusionMatrix, F1 scores, etc.)
to compute the following metrics for each class:
- PPV (Positive Predictive Value or Precision)
- NPV (Negative Predictive Value)
- True Positives (TP)
- False Positives (FP)
- True Negatives (TN)
- False Negatives (FN)
- F1 score

Args:
y_true: List of binary ground-truth label tensors, one tensor per sample.
y_pred: List of binary prediction tensors, one tensor per sample.
class_names: Ordered list of class names corresponding to class indices.
metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects:
{
"cm": MultilabelConfusionMatrix,
"micro-f1": MultilabelF1Score (average=None)
}
class_names: List of class names in the same order as class indices.

Returns:
Dictionary mapping each class name to its PPV and NPV metrics:
Dictionary mapping each class name to a sub-dictionary of computed metrics:
{
"class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int},
"class_name_1": {
"PPV": float,
"NPV": float,
"TN": int,
"FP": int,
"FN": int,
"TP": int,
"f1": float,
},
...
}
"""
# Stack per-sample tensors into (n_samples, n_classes) numpy arrays
true_np = torch.stack(y_true).cpu().numpy().astype(int)
pred_np = torch.stack(y_pred).cpu().numpy().astype(int)
cm_tensor = metrics_obj_dict["cm"].compute() # Shape: (num_classes, 2, 2)
f1_tensor = metrics_obj_dict["f1"].compute() # shape: (num_classes,)

# Compute confusion matrix for each class
cm = multilabel_confusion_matrix(true_np, pred_np)
assert len(class_names) == cm_tensor.shape[0] == f1_tensor.shape[0], (
f"Mismatch between number of class names ({len(class_names)}) and metric tensor sizes: "
f"confusion matrix has {cm_tensor.shape[0]}, "
f"F1 has {f1_tensor.shape[0]}, "
)

results: dict[str, dict[str, float]] = {}
for idx, cls_name in enumerate(class_names):
tn, fp, fn, tp = cm[idx].ravel()
tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
tn = cm_tensor[idx][0][0].item()
fp = cm_tensor[idx][0][1].item()
fn = cm_tensor[idx][1][0].item()
tp = cm_tensor[idx][1][1].item()

ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 # Precision
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 # Negative predictive value

# positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
# negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]

f1 = f1_tensor[idx]

results[cls_name] = {
"PPV": round(tpv, 4),
"PPV": round(ppv, 4),
"NPV": round(npv, 4),
"TN": int(tn),
"FP": int(fp),
"FN": int(fn),
"TP": int(tp),
"f1": round(f1.item(), 4),
# "positive_preds": positive_raw,
# "negative_preds": negative_raw,
}
return results

def generate_props(
self,
data_partition: Literal["train", "val", "test"],
model_ckpt_path: str,
model_config_file_path: str,
data_config_file_path: str,
Expand All @@ -95,14 +126,13 @@ def generate_props(
Run inference on validation set, compute TPV/NPV per class, and save to JSON.

Args:
data_partition: Partition of the dataset to use to generate class properties.
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
model_config_file_path: Path to yaml config file of the model.
data_config_file_path: Path to yaml config file of the data.
output_path: Optional path where to write the JSON metrics file.
Defaults to '<processed_dir_main>/classes.json'.
"""
print("Extracting validation data for computation...")

data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path)
data_module: XYBaseDataModule = load_data_instance(
data_cls_path, data_cls_kwargs
Expand All @@ -128,32 +158,43 @@ def generate_props(
model_ckpt_path, model_class_path, model_kwargs
)

val_loader = data_module.val_dataloader()
print("Running inference on validation data...")
if data_partition == "train":
data_loader = data_module.train_dataloader()
elif data_partition == "val":
data_loader = data_module.val_dataloader()
elif data_partition == "test":
data_loader = data_module.test_dataloader()
else:
raise ValueError(f"Unknown data partition: {data_partition}")
print(f"Running inference on {data_partition} data...")

y_true, y_pred = [], []
raw_preds = []
for batch_idx, batch in enumerate(val_loader):
data = model._process_batch( # pylint: disable=W0212
batch, batch_idx=batch_idx
)
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
class_names = self.load_class_labels(classes_file)
num_classes = len(class_names)
metrics_obj_dict: dict[str, torchmetrics.Metric] = {
"cm": MultilabelConfusionMatrix(num_labels=num_classes),
"f1": MultilabelF1Score(num_labels=num_classes, average=None),
}

for batch_idx, batch in enumerate(data_loader):
data = model._process_batch(batch, batch_idx=batch_idx)
labels = data["labels"]
outputs = model(data, **data.get("model_kwargs", {}))
logits = outputs["logits"] if isinstance(outputs, dict) else outputs
preds = torch.sigmoid(logits) > 0.5
y_pred.extend(preds)
y_true.extend(labels)
raw_preds.extend(torch.sigmoid(logits))
raw_preds = torch.stack(raw_preds)
model_output = model(data, **data.get("model_kwargs", {}))
preds, targets = model._get_prediction_and_labels(
data, labels, model_output
)
for metric_obj in metrics_obj_dict.values():
metric_obj.update(preds, targets)

print("Computing metrics...")
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
if output_path is None:
output_file = Path(data_module.processed_dir_main) / "classes.json"
output_file = (
Path(data_module.processed_dir_main) / f"classes_{data_partition}.json"
)
else:
output_file = Path(output_path)

class_names = self.load_class_labels(classes_file)
metrics = self.compute_classwise_scores(y_true, y_pred, raw_preds, class_names)
metrics = self.compute_classwise_scores(metrics_obj_dict, class_names)

with output_file.open("w") as f:
json.dump(metrics, f, indent=2)
Expand All @@ -167,6 +208,7 @@ class Main:

def generate(
self,
data_partition: Literal["train", "val", "test"],
model_ckpt_path: str,
model_config_file_path: str,
data_config_file_path: str,
Expand All @@ -176,14 +218,21 @@ def generate(
CLI command to generate JSON with metrics on validation set.

Args:
data_partition: Partition of dataset to use to generate class properties.
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
model_config_file_path: Path to yaml config file of the model.
data_config_file_path: Path to yaml config file of the data.
output_path: Optional path where to write the JSON metrics file.
Defaults to '<processed_dir_main>/classes.json'.
"""
assert data_partition in [
"train",
"val",
"test",
], f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` "
generator = ClassesPropertiesGenerator()
generator.generate_props(
data_partition,
model_ckpt_path,
model_config_file_path,
data_config_file_path,
Expand All @@ -193,6 +242,7 @@ def generate(

if __name__ == "__main__":
# _generate_classes_props_json.py generate \
# --data_partition "val" \
# --model_ckpt_path "model/ckpt/path" \
# --model_config_file_path "model/config/file/path" \
# --data_config_file_path "data/config/file/path" \
Expand Down