diff --git a/chebai/result/generate_class_properties.py b/chebai/result/generate_class_properties.py index 2f1668f8..8c8f96bf 100644 --- a/chebai/result/generate_class_properties.py +++ b/chebai/result/generate_class_properties.py @@ -1,9 +1,10 @@ import json from pathlib import Path +from typing import Literal -import torch +import torchmetrics from jsonargparse import CLI -from sklearn.metrics import multilabel_confusion_matrix +from torchmetrics.classification import MultilabelConfusionMatrix, MultilabelF1Score from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.result.utils import ( @@ -37,48 +38,77 @@ def load_class_labels(path: Path) -> list[str]: @staticmethod def compute_classwise_scores( - y_true: list[torch.Tensor], - y_pred: list[torch.Tensor], - raw_preds: torch.Tensor, + metrics_obj_dict: dict[str, torchmetrics.Metric], class_names: list[str], ) -> dict[str, dict[str, float]]: """ - Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class - in a multi-label setting. + Compute per-class evaluation metrics for a multi-label classification task. + + This method uses torchmetrics objects (MultilabelConfusionMatrix, F1 scores, etc.) + to compute the following metrics for each class: + - PPV (Positive Predictive Value or Precision) + - NPV (Negative Predictive Value) + - True Positives (TP) + - False Positives (FP) + - True Negatives (TN) + - False Negatives (FN) + - F1 score Args: - y_true: List of binary ground-truth label tensors, one tensor per sample. - y_pred: List of binary prediction tensors, one tensor per sample. - class_names: Ordered list of class names corresponding to class indices. + metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects: + { + "cm": MultilabelConfusionMatrix, + "micro-f1": MultilabelF1Score (average=None) + } + class_names: List of class names in the same order as class indices. Returns: - Dictionary mapping each class name to its PPV and NPV metrics: + Dictionary mapping each class name to a sub-dictionary of computed metrics: { - "class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int}, + "class_name_1": { + "PPV": float, + "NPV": float, + "TN": int, + "FP": int, + "FN": int, + "TP": int, + "f1": float, + }, ... } """ - # Stack per-sample tensors into (n_samples, n_classes) numpy arrays - true_np = torch.stack(y_true).cpu().numpy().astype(int) - pred_np = torch.stack(y_pred).cpu().numpy().astype(int) + cm_tensor = metrics_obj_dict["cm"].compute() # Shape: (num_classes, 2, 2) + f1_tensor = metrics_obj_dict["f1"].compute() # shape: (num_classes,) - # Compute confusion matrix for each class - cm = multilabel_confusion_matrix(true_np, pred_np) + assert len(class_names) == cm_tensor.shape[0] == f1_tensor.shape[0], ( + f"Mismatch between number of class names ({len(class_names)}) and metric tensor sizes: " + f"confusion matrix has {cm_tensor.shape[0]}, " + f"F1 has {f1_tensor.shape[0]}, " + ) results: dict[str, dict[str, float]] = {} for idx, cls_name in enumerate(class_names): - tn, fp, fn, tp = cm[idx].ravel() - tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 - npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 + tn = cm_tensor[idx][0][0].item() + fp = cm_tensor[idx][0][1].item() + fn = cm_tensor[idx][1][0].item() + tp = cm_tensor[idx][1][1].item() + + ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 # Precision + npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 # Negative predictive value + # positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]] # negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]] + + f1 = f1_tensor[idx] + results[cls_name] = { - "PPV": round(tpv, 4), + "PPV": round(ppv, 4), "NPV": round(npv, 4), "TN": int(tn), "FP": int(fp), "FN": int(fn), "TP": int(tp), + "f1": round(f1.item(), 4), # "positive_preds": positive_raw, # "negative_preds": negative_raw, } @@ -86,6 +116,7 @@ def compute_classwise_scores( def generate_props( self, + data_partition: Literal["train", "val", "test"], model_ckpt_path: str, model_config_file_path: str, data_config_file_path: str, @@ -95,14 +126,13 @@ def generate_props( Run inference on validation set, compute TPV/NPV per class, and save to JSON. Args: + data_partition: Partition of the dataset to use to generate class properties. model_ckpt_path: Path to the PyTorch Lightning checkpoint file. model_config_file_path: Path to yaml config file of the model. data_config_file_path: Path to yaml config file of the data. output_path: Optional path where to write the JSON metrics file. Defaults to '/classes.json'. """ - print("Extracting validation data for computation...") - data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path) data_module: XYBaseDataModule = load_data_instance( data_cls_path, data_cls_kwargs @@ -128,32 +158,43 @@ def generate_props( model_ckpt_path, model_class_path, model_kwargs ) - val_loader = data_module.val_dataloader() - print("Running inference on validation data...") + if data_partition == "train": + data_loader = data_module.train_dataloader() + elif data_partition == "val": + data_loader = data_module.val_dataloader() + elif data_partition == "test": + data_loader = data_module.test_dataloader() + else: + raise ValueError(f"Unknown data partition: {data_partition}") + print(f"Running inference on {data_partition} data...") - y_true, y_pred = [], [] - raw_preds = [] - for batch_idx, batch in enumerate(val_loader): - data = model._process_batch( # pylint: disable=W0212 - batch, batch_idx=batch_idx - ) + classes_file = Path(data_module.processed_dir_main) / "classes.txt" + class_names = self.load_class_labels(classes_file) + num_classes = len(class_names) + metrics_obj_dict: dict[str, torchmetrics.Metric] = { + "cm": MultilabelConfusionMatrix(num_labels=num_classes), + "f1": MultilabelF1Score(num_labels=num_classes, average=None), + } + + for batch_idx, batch in enumerate(data_loader): + data = model._process_batch(batch, batch_idx=batch_idx) labels = data["labels"] - outputs = model(data, **data.get("model_kwargs", {})) - logits = outputs["logits"] if isinstance(outputs, dict) else outputs - preds = torch.sigmoid(logits) > 0.5 - y_pred.extend(preds) - y_true.extend(labels) - raw_preds.extend(torch.sigmoid(logits)) - raw_preds = torch.stack(raw_preds) + model_output = model(data, **data.get("model_kwargs", {})) + preds, targets = model._get_prediction_and_labels( + data, labels, model_output + ) + for metric_obj in metrics_obj_dict.values(): + metric_obj.update(preds, targets) + print("Computing metrics...") - classes_file = Path(data_module.processed_dir_main) / "classes.txt" if output_path is None: - output_file = Path(data_module.processed_dir_main) / "classes.json" + output_file = ( + Path(data_module.processed_dir_main) / f"classes_{data_partition}.json" + ) else: output_file = Path(output_path) - class_names = self.load_class_labels(classes_file) - metrics = self.compute_classwise_scores(y_true, y_pred, raw_preds, class_names) + metrics = self.compute_classwise_scores(metrics_obj_dict, class_names) with output_file.open("w") as f: json.dump(metrics, f, indent=2) @@ -167,6 +208,7 @@ class Main: def generate( self, + data_partition: Literal["train", "val", "test"], model_ckpt_path: str, model_config_file_path: str, data_config_file_path: str, @@ -176,14 +218,21 @@ def generate( CLI command to generate JSON with metrics on validation set. Args: + data_partition: Partition of dataset to use to generate class properties. model_ckpt_path: Path to the PyTorch Lightning checkpoint file. model_config_file_path: Path to yaml config file of the model. data_config_file_path: Path to yaml config file of the data. output_path: Optional path where to write the JSON metrics file. Defaults to '/classes.json'. """ + assert data_partition in [ + "train", + "val", + "test", + ], f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` " generator = ClassesPropertiesGenerator() generator.generate_props( + data_partition, model_ckpt_path, model_config_file_path, data_config_file_path, @@ -193,6 +242,7 @@ def generate( if __name__ == "__main__": # _generate_classes_props_json.py generate \ + # --data_partition "val" \ # --model_ckpt_path "model/ckpt/path" \ # --model_config_file_path "model/config/file/path" \ # --data_config_file_path "data/config/file/path" \