[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Datacompintensive/WignerCamp2025/blob/master/ALT_intro/FeatureExtraction/feature_extraction_solution.ipynb)

# Feature Extraction with PyTorch

This exercise is designed to help students become familiar with basic tensor operations in PyTorch, focusing on statistical feature extraction.

## Task

You need to implement **two functions** in `feature_extraction.py`:

1. `mean_all(x: torch.Tensor) -> float`:  
   Computes the mean of **all elements** in a 2D tensor.  
   Hint: [torch.mean](https://pytorch.org/docs/stable/generated/torch.mean.html).

2. `mean_percentile(x: torch.Tensor, q: float) -> float`:  
   For each row in the 2D tensor `x`, compute the `q`-th percentile (e.g., 0.05 for 5%) using `torch.quantile`, then return the mean of these percentile values.  
   Hint: [torch.quantile](https://pytorch.org/docs/stable/generated/torch.quantile.html).

### Files

- `feature_extraction.py`: Script to implement your functions.
- `test_feature_extraction.py`: Pytest-based test suite.
- `solution.py`: Contains a correct solution for reference.

In [1]:
import torch

In [14]:
def mean_all(matrix: torch.Tensor) -> float:
    """
    Compute the mean of all elements in a 2D tensor.

    Args:
        x (torch.Tensor): A 2D tensor.

    Returns:
        float: The mean of all elements.
    """
    return torch.mean(matrix).item()

In [20]:
def mean_percentile(matrix: torch.Tensor, q: float) -> float:
    """
    For each row of x, compute the q-th percentile, then return the mean of these percentiles.

    Args:
        x (torch.Tensor): A 2D tensor.
        q (float): The quantile to compute (between 0 and 1).

    Returns:
        float: Mean of row-wise q-th percentiles.
    """
    percentiles = torch.quantile(matrix, q, dim=0)
    return torch.mean(percentiles)

In [22]:
matrix = torch.tensor([[2.0, 5.0], [5.0, -3.0]])
q = 0.15
print(f"Mean of all elements: {mean_all(matrix)}")

print(f"Mean of q={q} percentiles: {mean_percentile(matrix, q)}")

Mean of all elements: 2.25
Mean of q=0.15 percentiles: 0.3250000476837158
