In [1]:
import numpy as np

def top_k_precision(y_true, y_scores, k):
    """
    Calculate top-k precision for multilabel classification.

    Parameters:
    y_true (numpy array): Binary matrix of true labels (shape: n_samples x n_classes).
    y_scores (numpy array): Matrix of predicted scores (shape: n_samples x n_classes).
    k (int): Number of top elements to consider for calculating precision.

    Returns:
    float: Mean top-k precision across all samples.
    """
    n_samples = y_true.shape[0]
    top_k_precisions = []

    for i in range(n_samples):
        # Get the indices of the top-k predictions
        top_k_indices = np.argsort(y_scores[i])[-k:]
        
        # Calculate precision for this sample
        precision = np.sum(y_true[i, top_k_indices]) / k
        top_k_precisions.append(precision)
    
    return np.mean(top_k_precisions)

# Example: Ground truth binary matrix
y_true = np.array([
    [1, 0, 0, 1, 0],
    [0, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])

# Example: Predicted scores from the model
y_scores = np.array([
    [0.8, 0.3, 0.2, 0.9, 0.1],
    [0.1, 0.7, 0.6, 0.3, 0.2],
    [0.9, 0.8, 0.1, 0.4, 0.3]
])

# Calculate top-2 precision
k = 2
precision_at_k = top_k_precision(y_true, y_scores, k)
print("Top-2 Precision:", precision_at_k)


Top-2 Precision: 1.0


In [2]:
indices = np.argsort(y_scores[0])[-2:]
print(indices)

[0 3]


In [3]:
import torch
Y = torch.load('/home/almusawiaf/MyDocuments/PhD_Projects/PSG_SURVIVAL_ANALYSIS/Data/203_Diagnoses/PathCount_Only/33333/HGNN_data/Y.pt')
Y

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [7]:
row_sums = np.sum(Y, axis=1)

# Step 2: Calculate the mean of the row sums
mean_row_sum = np.mean(row_sums)

# Step 3: Calculate the standard deviation of the row sums
std_row_sum = np.std(row_sums)

print(f"Mean of the row sums: {mean_row_sum}")
print(f"Standard deviation of the row sums: {std_row_sum}")

Mean of the row sums: 3.347220572122556
Standard deviation of the row sums: 4.967134281514312
