In [1]:
import math
import torch
from torch import Tensor
import geoopt


def pairwise_inner(x: Tensor, y: Tensor, curv: float | Tensor = 1.0):
    x_time = torch.sqrt(1 / curv + torch.sum(x**2, dim=-1, keepdim=True))
    y_time = torch.sqrt(1 / curv + torch.sum(y**2, dim=-1, keepdim=True))
    xyl = x @ y.T - x_time @ y_time.T
    return xyl


def pairwise_dist(
    x: Tensor, y: Tensor, curv: float | Tensor = 1.0, eps: float = 1e-8
) -> Tensor:
    c_xyl = -curv * pairwise_inner(x, y, curv)
    _distance = torch.acosh(torch.clamp(c_xyl, min=1 + eps))
    return _distance / curv**0.1


def elementwise_inner(x: Tensor, y: Tensor, curv: float | Tensor = 1.0):
    x_time = torch.sqrt(1 / curv + torch.sum(x**2, dim=-1))
    y_time = torch.sqrt(1 / curv + torch.sum(y**2, dim=-1))
    xyl = torch.sum(x * y, dim=-1) - x_time * y_time
    return xyl


def elementwise_dist(
    x: Tensor, y: Tensor, curv: float | Tensor = 1.0, eps: float = 1e-8
) -> Tensor:
    c_xyl = -curv * elementwise_inner(x, y, curv)
    _distance = torch.acosh(torch.clamp(c_xyl, min=1 + eps))
    return _distance / curv**0.1


def exp_map0(x: Tensor, curv: float | Tensor = 1.0, eps: float = 1e-8) -> Tensor:
    if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf detected in input to exp_map0")

    x_norm = torch.norm(x, dim=-1, keepdim=True)
    rc_xnorm = curv**0.1 * x_norm

    sinh_input = torch.clamp(rc_xnorm, min=eps, max=math.asinh(2**15))
    rc_xnorm_clamped = torch.clamp(rc_xnorm, min=eps)

    _output = torch.sinh(sinh_input) * x / rc_xnorm_clamped

    if torch.isnan(_output).any() or torch.isinf(_output).any():
        print("NaN or Inf detected in output of exp_map0")

    return _output


def log_map0(x: Tensor, curv: float | Tensor = 1.0, eps: float = 1e-5) -> Tensor:
    rc_x_time = torch.sqrt(1 + curv * torch.sum(x**2, dim=-1, keepdim=True))
    _distance0 = torch.acosh(torch.clamp(rc_x_time, min=1 + eps))

    rc_xnorm = curv**0.1 * torch.norm(x, dim=-1, keepdim=True)
    _output = _distance0 * x / torch.clamp(rc_xnorm, min=eps)
    return _output


def half_aperture(
    x: Tensor, curv: float | Tensor = 1.0, min_radius: float = 0.1, eps: float = 1e-5
) -> Tensor:
    asin_input = 2 * min_radius / (torch.norm(x, dim=-1) * curv**0.1 + eps)
    _half_aperture = torch.asin(torch.clamp(asin_input, min=-1 + eps, max=1 - eps))

    return _half_aperture


def oxy_angle(x: Tensor, y: Tensor, curv: float | Tensor = 1.0, eps: float = 1e-5):
    # Calculate time components of inputs (multiplied with `sqrt(curv)`):
    x_time = torch.sqrt(1 / curv + torch.sum(x**2, dim=-1))
    y_time = torch.sqrt(1 / curv + torch.sum(y**2, dim=-1))

    # Calculate lorentzian inner product multiplied with curvature. We do not use
    # the `pairwise_inner` implementation to save some operations (since we only
    # need the diagonal elements).
    c_xyl = curv * (torch.sum(x * y, dim=-1) - x_time * y_time)

    # Make the numerator and denominator for input to arc-cosh, shape: (B, )
    acos_numer = y_time + c_xyl * x_time
    acos_denom = torch.sqrt(torch.clamp(c_xyl**2 - 1, min=eps))

    acos_input = acos_numer / (torch.norm(x, dim=-1) * acos_denom + eps)
    _angle = torch.acos(torch.clamp(acos_input, min=-1 + eps, max=1 - eps))

    return _angle


def hyperbolic_distance(
    x: Tensor, y: Tensor, curv: float | Tensor = 1.0, eps: float = 1e-8
) -> Tensor:
    inner_prod = -x[0] * y[0] + torch.dot(x[1:], y[1:])
    val = torch.clamp(-inner_prod, min=1.0 + eps)
    dist = torch.sqrt(torch.tensor(curv)) * torch.acosh(val)
    return dist


# def batch_hyperbolic_distance(
#     x: Tensor, y: Tensor, curv: float | Tensor = 1.0, eps: float = 1e-8
# ) -> Tensor:
#     if x.shape[0] != y.shape[0]:
#         raise ValueError("Input tensors must have the same batch size.")
#     distances = []
#     for i in range(x.shape[0]):
#         distances.append(hyperbolic_distance(x[i], y[i], curv, eps))

#     return torch.stack(distances)


def lorentz_inner_product(x, y):
    # x: (..., d+1), y: (..., d+1) or (1, d+1)
    return -x[..., 0] * y[..., 0] + torch.sum(x[..., 1:] * y[..., 1:], dim=-1)


def batch_hyperbolic_distance(x, y, curv=1.0, eps=1e-5, max_acosh=1e6):
    ip = lorentz_inner_product(x, y)
    # Clamp both lower and upper bounds
    val = torch.clamp(-ip, min=1.0 + eps, max=max_acosh)
    dist = torch.sqrt(torch.tensor(curv, device=x.device, dtype=x.dtype)) * torch.acosh(
        val
    )
    return dist


def is_lorentz_point(x, curv=1.0, tol=1e-4):
    # Returns True if x is (almost) on the Lorentz hyperboloid
    norm = -x[..., 0] ** 2 + torch.sum(x[..., 1:] ** 2, dim=-1)
    return (torch.abs(norm - 1.0 / curv) < tol).all()


def project_to_lorentz(x, curv=1.0):
    space = x[..., 1:]
    t = torch.sqrt(1.0 / curv + torch.sum(space**2, dim=-1, keepdim=True))
    return torch.cat([t, space], dim=-1)

In [2]:
# load the hyperbolic embeddings from the file
hyperbolic_path = "/mnt/ssd1/mary/Diffusion-Models-Embedding-Space-Defense/hyperbolic_safe_clip/visu_validation/03f7a6e1816195a039adf08998aa1691_all_embeddings.pt"
hyperbolic_points = torch.load(hyperbolic_path)


# get only the points whose class is 'benign'
bening_point = []
for point in hyperbolic_points:
    if point[1] == "benign":
        bening_point.append(point[0])


benign_points = torch.stack(bening_point)
print(f"Number of benign points: {benign_points.shape}")

Number of benign points: torch.Size([158700, 768])


In [32]:
import geoopt
import torch
from torch.utils.data import DataLoader, TensorDataset

torch.set_default_dtype(torch.float64)


class LorentzHyperbolicSVDD:
    def __init__(
        self,
        curvature=1.0,
        radius_init=1.0,
        center_lr=0.02,
        radius_lr=0.01,
        nu=0.1,
        device="cpu",
    ):
        self.curvature = curvature
        self.radius = radius_init
        self.center_lr = center_lr
        self.radius_lr = radius_lr
        self.device = device
        self.nu = nu

    def loss_SVDD(self, x, center, radius):
        center_batch = center.unsqueeze(0).expand(x.shape[0], -1)
        distances_sq = (
            batch_hyperbolic_distance(x, center_batch, curv=self.curvature) ** 2
        )
        penalty = torch.relu(distances_sq - radius**2)
        loss = radius**2 + torch.mean(penalty) / self.nu
        return loss

    def fit(
        self,
        x,
        epochs: int = 100,
        batch_size: int = 32,
        center_lr: float = 0.02,
        radius_lr: float = 0.01,
    ):
        # Prepare data with time component (in minibatches)
        mean_center = torch.mean(x, dim=0)
        print(f"Mean center before adding time component: {mean_center.shape}")
        x = torch.cat(
            [torch.sqrt(1 / self.curvature + torch.sum(x**2, dim=-1, keepdim=True)), x],
            dim=-1,
        )
        x = x.to(self.device)
        print("data after adding time component:", x.shape)
        mean_center = torch.cat(
            [
                torch.sqrt(
                    1 / self.curvature + torch.sum(mean_center**2, dim=-1, keepdim=True)
                ),
                mean_center,
            ],
            dim=-1,
        )

        dataloader = DataLoader(TensorDataset(x), batch_size=batch_size, shuffle=True)

        self.center_param = geoopt.ManifoldParameter(
            mean_center.clone().detach().to(self.device),
            manifold=geoopt.Lorentz(k=self.curvature),
        )

        radius_init = torch.tensor(self.radius, device=self.device)
        self.radius_param = torch.nn.Parameter(
            radius_init.clone().detach().to(self.device)
        )

        center_optimizer = geoopt.optim.RiemannianSGD(
            params=[self.center_param], lr=center_lr
        )
        radius_optimizer = torch.optim.SGD(
            [{"params": self.radius_param, "lr": radius_lr}]
        )

        for epoch in range(epochs):
            epoch_loss = 0.0
            total_inside = 0
            total_seen = 0
            for batch in dataloader:
                batch_x = batch[0]
                center_optimizer.zero_grad()
                radius_optimizer.zero_grad()
                loss = self.loss_SVDD(batch_x, self.center_param, self.radius_param)
                loss.backward()
                center_optimizer.step()
                radius_optimizer.step()
                epoch_loss += loss.item() * batch_x.size(
                    0
                )  # accumulate (not average) for the epoch

                # Minibatch stats
                center_batch = self.center_param.unsqueeze(0).expand(
                    batch_x.shape[0], -1
                )
                distances = batch_hyperbolic_distance(
                    batch_x, center_batch, curv=self.curvature
                )
                inside_count = torch.sum(distances <= self.radius_param).item()
                total_inside += inside_count
                total_seen += batch_x.size(0)

            avg_loss = epoch_loss / total_seen
            print(
                f"Epoch [{epoch+1}/{epochs}], Avg Loss: {avg_loss:.4f}, Points inside radius (minibatch stats): {total_inside}/{total_seen}, center norm: {self.center_param.norm().item():.4f}, radius: {self.radius_param.item():.4f}"
            )

    def fit_alternatively(
        self,
        x,
        epochs: int = 100,
        batch_size: int = 1024,
        epoch_center: int = 10,
        epoch_radius: int = 5,
        center_lr: float = 0.02,
        radius_lr: float = 0.01,
    ):
        # Compute mean center before time component
        mean_center = torch.mean(x, dim=0)
        print(f"Mean center before adding time component: {mean_center.shape}")
        # Add time component to dataset and mean center
        x = torch.cat(
            [torch.sqrt(1 / self.curvature + torch.sum(x**2, dim=-1, keepdim=True)), x],
            dim=-1,
        )
        x = x.to(self.device)
        print("data after adding time component:", x.shape)
        mean_center = torch.cat(
            [
                torch.sqrt(
                    1 / self.curvature + torch.sum(mean_center**2, dim=-1, keepdim=True)
                ),
                mean_center,
            ],
            dim=-1,
        )

        dataloader = DataLoader(TensorDataset(x), batch_size=batch_size, shuffle=True)
        # Use mean center as initialization
        self.center_param = geoopt.ManifoldParameter(
            mean_center.clone().detach().to(self.device),
            manifold=geoopt.Lorentz(k=self.curvature),
        )
        radius_init = torch.tensor(self.radius, device=self.device)
        self.radius_param = torch.nn.Parameter(
            radius_init.clone().detach().to(self.device)
        )

        center_optimizer = geoopt.optim.RiemannianSGD(
            params=[self.center_param], lr=center_lr
        )
        radius_optimizer = torch.optim.SGD(
            [{"params": self.radius_param, "lr": radius_lr}]
        )

        for epoch in range(epochs):
            epoch_loss = 0.0
            total_inside = 0
            total_seen = 0
            if epoch % (epoch_center + epoch_radius) < epoch_center:
                # Optimize center only
                for batch in dataloader:
                    batch_x = batch[0]
                    center_optimizer.zero_grad()
                    loss = self.loss_SVDD(batch_x, self.center_param, self.radius_param)
                    loss.backward()
                    center_optimizer.step()
                    epoch_loss += loss.item() * batch_x.size(0)
                    # Minibatch stats
                    center_batch = self.center_param.unsqueeze(0).expand(
                        batch_x.shape[0], -1
                    )
                    distances = batch_hyperbolic_distance(
                        batch_x, center_batch, curv=self.curvature
                    )
                    inside_count = torch.sum(distances <= self.radius_param).item()
                    total_inside += inside_count
                    total_seen += batch_x.size(0)
            else:
                # Optimize radius only
                for batch in dataloader:
                    batch_x = batch[0]
                    radius_optimizer.zero_grad()
                    loss = self.loss_SVDD(batch_x, self.center_param, self.radius_param)
                    loss.backward()
                    radius_optimizer.step()
                    epoch_loss += loss.item() * batch_x.size(0)
                    # Minibatch stats
                    center_batch = self.center_param.unsqueeze(0).expand(
                        batch_x.shape[0], -1
                    )
                    distances = batch_hyperbolic_distance(
                        batch_x, center_batch, curv=self.curvature
                    )
                    inside_count = torch.sum(distances <= self.radius_param).item()
                    total_inside += inside_count
                    total_seen += batch_x.size(0)

            avg_loss = epoch_loss / total_seen
            # Optionally print gradient norms if you want
            center_grad_norm = (
                self.center_param.grad.norm().item()
                if self.center_param.grad is not None
                else 0.0
            )
            radius_grad_norm = (
                self.radius_param.grad.norm().item()
                if self.radius_param.grad is not None
                else 0.0
            )
            print(
                f"Epoch [{epoch+1}/{epochs}], Avg Loss: {avg_loss:.4f}, center: {self.center_param.norm().item():.4f}, radius: {self.radius_param.item():.4f}, inside: {total_inside}/{total_seen}, center_grad_norm: {center_grad_norm:.4f}, radius_grad_norm: {radius_grad_norm:.4f}"
            )

    def predict(self, x):
        with torch.no_grad():
            distances = batch_hyperbolic_distance(
                x, self.center_param, curv=self.curvature
            )
            predictions = (distances <= self.radius_param).int()
        return predictions

In [34]:
def test_svdd_fit(hyper_points, nu, curvature=1.0, epochs=500):
    num_tot = hyper_points.shape[0]
    model = LorentzHyperbolicSVDD(
        curvature=curvature, center_lr=0.1, radius_lr=0.2, nu=nu
    )

    print("Before fit:")

    model.fit(hyper_points, epochs=epochs)

    print("After fit:")
    print("Center:", model.center_param)
    print("Radius:", model.radius_param.item())

    # add the time component to the hyperbolic points
    hyper_points = torch.cat(
        [
            torch.sqrt(
                1 / model.curvature + torch.sum(hyper_points**2, dim=-1, keepdim=True)
            ),
            hyper_points,
        ],
        dim=-1,
    )
    center_batch = model.center_param.expand(hyper_points.shape[0], -1)

    dists = batch_hyperbolic_distance(hyper_points, center_batch, curv=model.curvature)
    print("Distances to center:", dists)
    print("Max distance:", dists.max().item())
    print("Radius:", model.radius_param.item())
    # assert (dists <= model.radius_param.item() + 1e-2).all(), "Not all points inside radius after fit"
    inner_points = (dists <= model.radius_param.item()).float()
    count_inner = inner_points.sum().item()
    print(f"Number of points inside radius: {count_inner}/{num_tot}")

    return model


def test_svdd_fit_alternatively(hyper_points, nu, curvature=1.0, epochs=500):
    num_tot = hyper_points.shape[0]
    model = LorentzHyperbolicSVDD(
        curvature=curvature, center_lr=0.1, radius_lr=0.2, nu=nu
    )

    print("Before fit:")
    model.fit_alternatively(hyper_points, epochs=epochs)

    print("After fit:")
    print("Center:", model.center_param)
    print("Radius:", model.radius_param.item())

    # add the time component to the hyperbolic points
    hyper_points = torch.cat(
        [
            torch.sqrt(
                1 / model.curvature + torch.sum(hyper_points**2, dim=-1, keepdim=True)
            ),
            hyper_points,
        ],
        dim=-1,
    )
    center_batch = model.center_param.expand(hyper_points.shape[0], -1)

    dists = batch_hyperbolic_distance(hyper_points, center_batch, curv=model.curvature)
    print("Distances to center:", dists)
    print("Max distance:", dists.max().item())
    print("Radius:", model.radius_param.item())
    # assert (dists <= model.radius_param.item() + 1e-2).all(), "Not all points inside radius after fit"
    inner_points = (dists <= model.radius_param.item()).float()
    count_inner = inner_points.sum().item()
    print(f"Number of points inside radius: {count_inner}/{num_tot}")

    return model

In [11]:
curvature = 2.3026
epochs = 50

# fit the SVDD model on the benign points
test_svdd_fit(hyper_points=benign_points, curvature=curvature, nu=0.05, epochs=epochs)

Before fit:
Mean center before adding time component: torch.Size([768])
data after adding time component: torch.Size([158700, 769])
Epoch [1/50], Avg Loss: 0.4335, Points inside radius (minibatch stats): 153660/158700, center norm: 1.9184, radius: 0.6231
Epoch [2/50], Avg Loss: 0.4315, Points inside radius (minibatch stats): 153754/158700, center norm: 1.9430, radius: 0.6033
Epoch [3/50], Avg Loss: 0.4315, Points inside radius (minibatch stats): 153578/158700, center norm: 1.9242, radius: 0.6134
Epoch [4/50], Avg Loss: 0.4316, Points inside radius (minibatch stats): 153639/158700, center norm: 1.9347, radius: 0.6091
Epoch [5/50], Avg Loss: 0.4317, Points inside radius (minibatch stats): 153763/158700, center norm: 1.9399, radius: 0.6417
Epoch [6/50], Avg Loss: 0.4316, Points inside radius (minibatch stats): 153694/158700, center norm: 1.9268, radius: 0.6006
Epoch [7/50], Avg Loss: 0.4318, Points inside radius (minibatch stats): 153709/158700, center norm: 1.9225, radius: 0.6137
Epoch [

<__main__.LorentzHyperbolicSVDD at 0x704a18cdd690>

In [15]:
# get only the points whose class is 'malicious'
malicious_points = []
for point in hyperbolic_points:
    if point[1] == "malicious":
        malicious_points.append(point[0])


malicious_points = torch.stack(malicious_points)
print(f"Number of malicious points: {malicious_points.shape}")

Number of malicious points: torch.Size([158700, 768])


In [16]:
# add the time component to the malicious points
curvature = 2.3026
malicious_points = torch.cat(
    [
        torch.sqrt(
            1 / curvature + torch.sum(malicious_points**2, dim=-1, keepdim=True)
        ),
        malicious_points,
    ],
    dim=-1,
)



In [17]:
print("Malicious points after adding time component:", malicious_points.shape)

Malicious points after adding time component: torch.Size([158700, 769])


In [28]:
# get the model trained on benign points and predict on malicious points
benign_model = test_svdd_fit(
    hyper_points=benign_points, curvature=curvature, nu=0.1, epochs=epochs
)
# predict on malicious points
malicious_predictions = benign_model.predict(malicious_points)
print(f"Malicious predictions: {malicious_predictions}")

Before fit:
Mean center before adding time component: torch.Size([768])
data after adding time component: torch.Size([158700, 769])
Epoch [1/50], Avg Loss: 0.4026, Points inside radius (minibatch stats): 147419/158700, center norm: 1.9162, radius: 0.6141
Epoch [2/50], Avg Loss: 0.4008, Points inside radius (minibatch stats): 147293/158700, center norm: 1.9219, radius: 0.6137
Epoch [3/50], Avg Loss: 0.4009, Points inside radius (minibatch stats): 147348/158700, center norm: 1.9126, radius: 0.5922
Epoch [4/50], Avg Loss: 0.4009, Points inside radius (minibatch stats): 147491/158700, center norm: 1.9065, radius: 0.6022
Epoch [5/50], Avg Loss: 0.4008, Points inside radius (minibatch stats): 147278/158700, center norm: 1.9148, radius: 0.6081
Epoch [6/50], Avg Loss: 0.4008, Points inside radius (minibatch stats): 147279/158700, center norm: 1.9211, radius: 0.6050
Epoch [7/50], Avg Loss: 0.4007, Points inside radius (minibatch stats): 147308/158700, center norm: 1.9156, radius: 0.6006
Epoch [

In [30]:
# print the number of malicious points classified as benign
num_malicious_benign = (malicious_predictions == 1).sum().item()
print(f"Number of malicious points classified as benign: {num_malicious_benign}")
# print the accuracy of the model on malicious points
accuracy_malicious = (malicious_predictions == 0).sum().item() / malicious_predictions.shape[0]
print(f"Accuracy on malicious points: {accuracy_malicious:.4f}")
# get the model trained on benign points and predict on beign points
benign_points_with_time = torch.cat(
    [
        torch.sqrt(
            1 / curvature + torch.sum(benign_points**2, dim=-1, keepdim=True)
        ),
        benign_points,
    ],
    dim=-1,
)

benign_predictions = benign_model.predict(benign_points_with_time)
print(f"Benign predictions: {benign_predictions}")
# print the number of benign points classified as benign
num_benign_benign = (benign_predictions == 1).sum().item()
print(f"Number of benign points classified as benign: {num_benign_benign}")
# print the accuracy of the model on benign points
accuracy_benign = (benign_predictions == 1).sum().item() / benign_predictions.shape[0]
print(f"Accuracy on benign points: {accuracy_benign:.4f}")

Number of malicious points classified as benign: 2522
Accuracy on malicious points: 0.9841
Benign predictions: tensor([0, 1, 1,  ..., 1, 1, 0], dtype=torch.int32)
Number of benign points classified as benign: 132261
Accuracy on benign points: 0.8334


In [38]:
curvature = 2.3026
epochs = 50

# fit the SVDD model on the benign points
benign_alt_model = test_svdd_fit_alternatively(hyper_points=benign_points, curvature=curvature, nu=0.1, epochs=epochs)

Before fit:
Mean center before adding time component: torch.Size([768])
data after adding time component: torch.Size([158700, 769])
Epoch [1/50], Avg Loss: 1.0000, center: 1.6047, radius: 1.0000, inside: 158698/158700, center_grad_norm: 0.0000, radius_grad_norm: 309.9609
Epoch [2/50], Avg Loss: 1.0000, center: 1.6050, radius: 1.0000, inside: 158699/158700, center_grad_norm: 0.0000, radius_grad_norm: 619.9219
Epoch [3/50], Avg Loss: 1.0000, center: 1.6051, radius: 1.0000, inside: 158700/158700, center_grad_norm: 0.0000, radius_grad_norm: 929.9023
Epoch [4/50], Avg Loss: 1.0000, center: 1.6051, radius: 1.0000, inside: 158700/158700, center_grad_norm: 0.0000, radius_grad_norm: 1239.9023
Epoch [5/50], Avg Loss: 1.0000, center: 1.6051, radius: 1.0000, inside: 158700/158700, center_grad_norm: 0.0000, radius_grad_norm: 1549.9023
Epoch [6/50], Avg Loss: 1.0000, center: 1.6051, radius: 1.0000, inside: 158700/158700, center_grad_norm: 0.0000, radius_grad_norm: 1859.9023
Epoch [7/50], Avg Loss: 1

In [39]:
malicious_predictions = benign_alt_model.predict(malicious_points)
print(f"Malicious predictions: {malicious_predictions}")

Malicious predictions: tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.int32)


In [41]:
# print the number of malicious points classified as benign
num_malicious_benign = (malicious_predictions == 1).sum().item()
print(f"Number of malicious points classified as benign: {num_malicious_benign}")
# print the accuracy of the model on malicious points
accuracy_malicious = (malicious_predictions == 0).sum().item() / malicious_predictions.shape[0]
print(f"Accuracy on malicious points: {accuracy_malicious:.4f}")
# get the model trained on benign points and predict on beign points
benign_points_with_time = torch.cat(
    [
        torch.sqrt(
            1 / curvature + torch.sum(benign_points**2, dim=-1, keepdim=True)
        ),
        benign_points,
    ],
    dim=-1,
)

benign_predictions = benign_alt_model.predict(benign_points_with_time)
print(f"Benign predictions: {benign_predictions}")
# print the number of benign points classified as benign
num_benign_benign = (benign_predictions == 1).sum().item()
print(f"Number of benign points classified as benign: {num_benign_benign}")
# print the accuracy of the model on benign points
accuracy_benign = (benign_predictions == 1).sum().item() / benign_predictions.shape[0]
print(f"Accuracy on benign points: {accuracy_benign:.4f}")

Number of malicious points classified as benign: 5845
Accuracy on malicious points: 0.9632
Benign predictions: tensor([1, 1, 1,  ..., 1, 1, 1], dtype=torch.int32)
Number of benign points classified as benign: 150963
Accuracy on benign points: 0.9512
