In [1]:
import torch

#### Sorting

In [2]:
def sort_matrix_by_nth_entry(matrix, n=0):
    sorted_values, sorted_indices = torch.sort(matrix[:, n])

    return  matrix[sorted_indices]

In [29]:
def group_ordering_based_on_sorting(matrix, grp_idx=0, val_idx=1):
    """
    matrix: tensor([[ 1.,  2.], [ 0.,  3.], [ 1.,  0.]])
    b: tensor([[1., 0.], [1., 1.], [1., 2.]])
    c: tensor([[ 0.,  3.], [ 1.,  0.], [ 1.,  2.]])
    d: tensor([1, 0, 2])
    """

    b = torch.empty(matrix.shape[0], 2)
    b[:, 0] = matrix[:, 0]
    b[:, 1] = torch.arange(matrix.shape[0])
    c = sort_matrix_by_nth_entry(b)
    d = c[:, 1]

    return d

In [3]:
def sort_matrix_by_group(matrix, grp_idx=0, val_idx=1):
    transposed = matrix.T
    sorted_indices = torch.argsort(transposed[grp_idx] * matrix.shape[grp_idx] + transposed[val_idx])
    
    return matrix[sorted_indices]

In [4]:
def sort_matrix_by_nth_and_mth_column(matrix, nth_col=0, mth_col=1):
    """
    First priority is nth-column, second priority is mth-column
    """

    matrix = sort_matrix_by_group(matrix, nth_col, mth_col)
    matrix = sort_matrix_by_nth_entry(matrix, nth_col)

    return matrix

In [18]:
matrix = torch.tensor([[ 0.,  2.],
        [ 0.,  3.],
        [ 1.,  0.],
        [ 1.,  3.],
        [ 1.,  4.],
        [ 0.,  7.],
        [ 1.,  0.],
        [ 1.,  3.],
        [ 1.,  7.],
        [ 0., 11.]])

# sort_matrix_by_nth_and_mth_column(matrix)
sort_matrix_by_nth_entry(matrix)

tensor([[ 0.,  2.],
        [ 0.,  3.],
        [ 0.,  7.],
        [ 0., 11.],
        [ 1.,  0.],
        [ 1.,  3.],
        [ 1.,  4.],
        [ 1.,  0.],
        [ 1.,  3.],
        [ 1.,  7.]])

In [20]:
matrix = torch.tensor([[ 1.,  2.], [ 0.,  3.], [ 1.,  0.]])

#### Maximum Values

In [14]:
def maximum_value_by_grp(matrix, grp_idx=0, val_idx=1):
    """
    matrix: torch.tensor([[ 0.,  2.], [ 0.,  3.], [ 1.,  0.], [ 1.,  3.], [ 1.,  4.], [ 2.,  23.], [ 3.,  4.], [ 0.,  7.], [ 1.,  0.], [ 1.,  3.], [ 1.,  7.], [ 0., 11.]])
    result: tensor([[ 0., 11.], [ 1.,  7.], [ 2., 23.], [ 3.,  4.]])
    """
    num_groups = int(matrix[:, grp_idx].max().item()) + 1

    result = torch.zeros(num_groups, 2)
    
    result[:, grp_idx] = torch.arange(num_groups)
    
    result[:, val_idx] = torch.scatter_reduce(
        input=torch.zeros(num_groups),
        dim=0,
        index=matrix[:, grp_idx].long(),
        src=matrix[:, val_idx],
        reduce='amax'
    )

    return result

In [15]:
matrix = torch.tensor([[ 0.,  2.], [ 0.,  3.], [ 1.,  0.], [ 1.,  3.], [ 1.,  4.], [ 2.,  23.], [ 3.,  4.], [ 0.,  7.], [ 1.,  0.], [ 1.,  3.], [ 1.,  7.], [ 0., 11.]])
maximum_value_by_grp(matrix)

tensor([[ 0., 11.],
        [ 1.,  7.],
        [ 2., 23.],
        [ 3.,  4.]])