In [7]:
import torch

device = "cpu"

pos_threshold = 100.0  # Max position deviation in meters
vel_threshold = 50.0   # Max velocity in m/s
att_threshold = 6.0   # Max attitude (rad)
ang_vel_threshold = 2.0  # Max angular velocity (rad/s)

weights = torch.tensor([
    pos_threshold, pos_threshold, pos_threshold,     # Position (xyz)
    vel_threshold, vel_threshold, vel_threshold,        # Velocity (uvw)
    att_threshold, att_threshold, att_threshold,        # Attitude (euler angles)
    ang_vel_threshold, ang_vel_threshold, ang_vel_threshold         # Angular velocity
], device=device)

def generate_points(num_samples):
    half_samples = num_samples // 2
    safe_points = torch.rand((half_samples, 12), device=device) * weights.unsqueeze(0) * 0.5 # All below threshold
    safe_labels = torch.ones((half_samples, 1), device=device)

    unsafe_points = torch.rand((half_samples, 12), device=device)
    for i in range(half_samples):
        dim = torch.randint(0, 12, (1,)).item()
        unsafe_points[i, dim] = weights[dim] * (1.2 + torch.rand(1).item())  # Ensure above threshold
    unsafe_labels = torch.zeros((half_samples, 1), device=device)

    x = torch.cat([safe_points, unsafe_points], dim=0)
    y = torch.cat([safe_labels, unsafe_labels], dim=0)
    indices = torch.randperm(num_samples)
    x = x[indices]
    y = y[indices]

    return x, y

torch.set_printoptions(precision=2, linewidth=120, sci_mode=False)
x_train, y_train = generate_points(10)
x_test, y_test = generate_points(3000)
print(f"X Train: {x_train}")

X Train: tensor([[    0.42,     0.09,     0.48,     0.20,     0.00,    67.66,     0.68,     1.00,     0.94,     0.07,     0.11,
             0.58],
        [   36.17,    21.04,    11.91,    18.18,    23.36,    16.78,     0.90,     0.04,     1.56,     0.94,     0.84,
             0.66],
        [   45.51,    42.20,    26.25,    11.43,    13.79,    24.46,     2.59,     0.41,     2.79,     0.08,     0.34,
             0.79],
        [   46.10,    19.56,    42.73,    15.67,    13.05,     7.52,     2.79,     1.86,     1.49,     0.78,     0.99,
             0.58],
        [    6.44,    45.64,    36.17,    16.70,    11.74,     7.99,     0.66,     2.29,     1.07,     0.23,     0.50,
             0.67],
        [    0.02,    29.75,    46.84,    19.91,    17.90,     2.55,     0.44,     2.82,     2.91,     0.41,     0.23,
             0.70],
        [    0.10,     0.42,     0.89,     0.91,     0.73,    96.36,     0.15,     0.43,     0.03,     0.97,     0.04,
             0.05],
        [    0.87,

In [3]:
def normalize_data(x, thresholds):
    norm = torch.norm(thresholds)
    norm_thresh = thresholds / norm
    norm_x = x / norm

    return norm_x * 2 * torch.pi, norm_thresh * 2 * torch.pi

norm_x, norm_thresh = normalize_data(x_train, weights)
print("done")

done


In [9]:
def diff_flag(x, thresholds, beta=10.0):
    """
    x: tensor of shape (batch_size, num_features)
    thresholds: tensor of shape (num_features,) with the threshold for each feature
    beta: sharpness parameter; higher beta -> closer to hard max
    """
    # Compute the difference between features and their thresholds
    diff = x - thresholds  # shape: (batch_size, num_features)
    print(f"Diff: {diff}")
    
    # Apply log-sum-exp along the features dimension as a smooth max
    smooth_max = (1.0 / beta) * torch.log(torch.sum(torch.exp(beta * diff), dim=1, keepdim=True))
    
    # Use a sigmoid to get an output between 0 and 1
    flag = torch.sigmoid(smooth_max)
    return flag

flag = diff_flag(x_train, weights)
print(f"Flag: {flag}")

Diff: tensor([[-99.58, -99.91, -99.52, -49.80, -50.00,  17.66,  -5.32,  -5.00,  -5.06,  -1.93,  -1.89,  -1.42],
        [-63.83, -78.96, -88.09, -31.82, -26.64, -33.22,  -5.10,  -5.96,  -4.44,  -1.06,  -1.16,  -1.34],
        [-54.49, -57.80, -73.75, -38.57, -36.21, -25.54,  -3.41,  -5.59,  -3.21,  -1.92,  -1.66,  -1.21],
        [-53.90, -80.44, -57.27, -34.33, -36.95, -42.48,  -3.21,  -4.14,  -4.51,  -1.22,  -1.01,  -1.42],
        [-93.56, -54.36, -63.83, -33.30, -38.26, -42.01,  -5.34,  -3.71,  -4.93,  -1.77,  -1.50,  -1.33],
        [-99.98, -70.25, -53.16, -30.09, -32.10, -47.45,  -5.56,  -3.18,  -3.09,  -1.59,  -1.77,  -1.30],
        [-99.90, -99.58, -99.11, -49.09, -49.27,  46.36,  -5.85,  -5.57,  -5.97,  -1.03,  -1.96,  -1.95],
        [-99.13, -99.95, -99.29, -49.31, -49.84,  32.55,  -5.53,  -5.05,  -5.14,  -1.12,  -1.09,  -1.84],
        [ 48.07, -99.77, -99.76, -49.85, -49.17, -49.95,  -5.55,  -5.03,  -5.36,  -1.30,  -1.22,  -1.91],
        [-99.19, -99.72, -99.07, -49.89,