In [89]:
import numpy as np
import torch

def discretize(data, bins):
   
    # Ensure the bins are a numpy array
    bins = np.array(bins)
    
    # Check if the input data is a single value
    if np.isscalar(data):
        bin_indices = np.digitize([data], bins)[0]
        bin_values = bins[bin_indices - 1] if bin_indices > 0 else bins[0]
        return bin_indices, bin_values
    else:
        # Use numpy.digitize to get the bin indices for array-like data
        bin_indices = np.digitize(data, bins)
        # Map indices to bin values
        bin_values = [bins[index - 1] if index > 0 else bins[0] for index in bin_indices]
        return bin_indices, bin_values


# Function to discretize a tensor slice
def discretize_tensor_slice(tensor_slice, bins):
    indices = []
    values = []
    for row in tensor_slice:
        row_indices = []
        row_values = []
        for item in row:
            bin_idx, bin_val = discretize(item.item(), bins)
            row_indices.append(bin_idx)
            row_values.append(bin_val)
        indices.append(row_indices)
        values.append(row_values)
    return torch.tensor(indices, device=tensor_slice.device), torch.tensor(values, device=tensor_slice.device)



def get_discrete(observation):

    # Define bin edges
    pos_bins = np.linspace(-1, 1, num=5)  # [-1, -0.5, 0, 0.5, 1]
    vel_bins = np.linspace(-1, 1, num=9)  # [-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1]
    lidar_bins = np.linspace(0, 1, num=21)  # [0, 0.05, 0.1, 0.15, .. , 1]

    # Split observation into components
    pos = observation[:, :2]
    vel = observation[:, 2:4]
    goal_pose = observation[:, 4:6]
    sensor_data = observation[:, 6:18]  # Assuming lidar measurements are 12-dimensional
    comms_data = observation[:, 18:]    # Assuming communication data is already discrete


    # Discretize each part of the observation
    discrete_pos_indices, discrete_pos_values = discretize_tensor_slice(pos, pos_bins)
    discrete_vel_indices, discrete_vel_values = discretize_tensor_slice(vel, vel_bins)
    discrete_goal_pose_indices, discrete_goal_pose_values = discretize_tensor_slice(goal_pose, pos_bins)
    discrete_sensor_data_indices, discrete_sensor_data_values = discretize_tensor_slice(sensor_data, lidar_bins)

    result = {
        "discrete_pos_indices": discrete_pos_indices,
        "discrete_pos_values": discrete_pos_values,
        "discrete_vel_indices": discrete_vel_indices,
        "discrete_vel_values": discrete_vel_values,
        "discrete_goal_pose_indices": discrete_goal_pose_indices,
        "discrete_goal_pose_values": discrete_goal_pose_values,
        "discrete_sensor_data_indices": discrete_sensor_data_indices,
        "discrete_sensor_data_values": discrete_sensor_data_values,
        "comms_data": comms_data
    }
    
    return result


# Example observation tensor
observation = torch.tensor([[ 0.2913, -0.2901,  0.0000,  0.0000,  0.8705,  0.5032,  0.0000,  0.0000,
                              0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                              0.33,  0.46,  0.0000,  0.0000,  0.0000,  0.0000]], device='cuda:0')

observation_discrete = get_discrete(observation)

print("Discrete Position Indices:", observation_discrete["discrete_pos_indices"])
print("Discrete Position Values:", observation_discrete["discrete_pos_values"])
print("Discrete Velocity Indices:", observation_discrete["discrete_vel_indices"])
print("Discrete Velocity Values:", observation_discrete["discrete_vel_values"])
print("Discrete Goal Pose Indices:", observation_discrete["discrete_goal_pose_indices"])
print("Discrete Goal Pose Values:", observation_discrete["discrete_goal_pose_values"])
print("Discrete Sensor Data Indices:", observation_discrete["discrete_sensor_data_indices"])
print("Discrete Sensor Data Values:", observation_discrete["discrete_sensor_data_values"])
print("Communication Data:", observation_discrete["comms_data"])


Discrete Position Indices: tensor([[3, 2]], device='cuda:0')
Discrete Position Values: tensor([[ 0.0000, -0.5000]], device='cuda:0', dtype=torch.float64)
Discrete Velocity Indices: tensor([[5, 5]], device='cuda:0')
Discrete Velocity Values: tensor([[0., 0.]], device='cuda:0', dtype=torch.float64)
Discrete Goal Pose Indices: tensor([[4, 4]], device='cuda:0')
Discrete Goal Pose Values: tensor([[0.5000, 0.5000]], device='cuda:0', dtype=torch.float64)
Discrete Sensor Data Indices: tensor([[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  7, 10]], device='cuda:0')
Discrete Sensor Data Values: tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.3000, 0.4500]], device='cuda:0', dtype=torch.float64)
Communication Data: tensor([[0., 0., 0., 0.]], device='cuda:0')
