In [1]:
import torch
import numpy as np

# 1. Load your data
file_path = '/remote-home/share/lisj/Workspace/SOTA_NAS/datasets/core/MUSK-feature/1909791.pt'
try:
    # weights_only=False is correct here as you are loading data, not just model weights
    data = torch.load(file_path, map_location='cpu', weights_only=False) # Load to CPU first is often safer for data
except FileNotFoundError:
    print(f"Error: File not found at {file_path}")
    exit()
except Exception as e:
    print(f"Error loading file: {e}")
    exit()

# Check if keys exist
if 'bag_feats' not in data or 'coords' not in data:
    print("Error: Loaded dictionary does not contain 'bag_feats' or 'coords' keys.")
    exit()

feats = data['bag_feats']
coords = data['coords']

print(f"Original feats shape: {feats.shape}")
print(f"Original coords shape: {coords.shape}")
print(f"Original feats dtype: {feats.dtype}")
print(f"Original coords dtype: {coords.dtype}")

# Ensure coords has the expected shape
if coords.ndim != 2 or coords.shape[1] != 2:
    print(f"Error: Coords tensor has unexpected shape {coords.shape}. Expected [N, 2].")
    exit()
if feats.shape[0] != coords.shape[0]:
     print(f"Error: Mismatch between number of features ({feats.shape[0]}) and coordinates ({coords.shape[0]}).")
     exit()

# --- Sorting Logic ---
# 2. Convert coordinates to a NumPy array for lexsort
#    (Ensure it's on CPU if it wasn't already)
coords_np = coords.cpu().numpy()

# 3. Get the indices that would sort the coordinates lexicographically.
#    np.lexsort sorts by the last key first, so provide columns in reverse order: (col, row)
print("Calculating sorting indices using np.lexsort...")
sorted_indices_np = np.lexsort((coords_np[:, 1], coords_np[:, 0]))

# 4. Convert the NumPy indices back to a PyTorch tensor.
#    Move indices to the same device as the features tensor for indexing.
sorted_indices = torch.from_numpy(sorted_indices_np).to(feats.device)
print(f"Indices calculated. Shape: {sorted_indices.shape}")

# 5. Use the sorted indices to reorder both features and coordinates
print("Reordering features and coordinates...")
sorted_feats = feats[sorted_indices]
sorted_coords = coords[sorted_indices]
# --- End Sorting Logic ---

# 6. Verification (Optional)
print("\n--- Verification ---")
print(f"Sorted feats shape: {sorted_feats.shape}")
print(f"Sorted coords shape: {sorted_coords.shape}")

# Display the first few and last few sorted coordinates to check order
print("\nFirst 10 sorted coordinates:")
print(sorted_coords[:10])
print("\nLast 10 sorted coordinates:")
print(sorted_coords[-10:])

# You can now use sorted_feats and sorted_coords
# Example: Save the sorted data (optional)
# sorted_data = {'bag_feats': sorted_feats, 'coords': sorted_coords}
# torch.save(sorted_data, '/path/to/your/sorted_data.pt')
# print("\nSorted data ready (not saved in this example).")

Original feats shape: torch.Size([4546, 1024])
Original coords shape: torch.Size([4546, 2])
Original feats dtype: torch.float16
Original coords dtype: torch.int32
Calculating sorting indices using np.lexsort...
Indices calculated. Shape: torch.Size([4546])
Reordering features and coordinates...

--- Verification ---
Sorted feats shape: torch.Size([4546, 1024])
Sorted coords shape: torch.Size([4546, 2])

First 10 sorted coordinates:
tensor([[  0, 168],
        [  0, 169],
        [  0, 170],
        [  0, 171],
        [  0, 172],
        [  0, 173],
        [  0, 174],
        [  0, 175],
        [  0, 176],
        [  0, 177]], dtype=torch.int32)

Last 10 sorted coordinates:
tensor([[ 58, 214],
        [ 58, 215],
        [ 58, 216],
        [ 58, 217],
        [ 58, 218],
        [ 59, 149],
        [ 59, 150],
        [ 59, 151],
        [ 59, 152],
        [ 59, 153]], dtype=torch.int32)
