diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index da8561f45c..872819f1a9 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -163,8 +163,8 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou batch_size, n_class = y_pred.shape[:2] # convert to [BNS], where S is the number of pixels for one sample. # As for classification tasks, S equals to 1. - y_pred = y_pred.view(batch_size, n_class, -1) - y = y.view(batch_size, n_class, -1) + y_pred = y_pred.reshape(batch_size, n_class, -1) + y = y.reshape(batch_size, n_class, -1) tp = ((y_pred + y) == 2).float() tn = ((y_pred + y) == 0).float()