In [1]:
import numpy as np
import matplotlib.pyplot as plt

def stat(x, y):
    print({
        "total": y.size,
        "background": (y == 0).sum(),
        "foreground": (y != 0).sum(),
    })
    return {i:c for i, c in zip(*np.unique(y, return_counts=True))}

In [2]:
# load the halo data
point_cloud = np.load("./halos/62.npy")
x, y = point_cloud[:, :-1], point_cloud[:, -1] + 1
# make mean of x = 0 and std of x = 1
x = (x - x.mean(axis=0)) / x.std(axis=0)

s0 = y == 0  # background
s1 = y != 0  # foreground
N, zeros, ones = len(y), s0.sum(), s1.sum()
assert zeros + ones == N, f"{zeros=} + {ones=} != {N=}"

print(x.shape, y.shape)

(18205, 7) (18205,)


In [3]:
stat(x, y)

{'total': 18205, 'background': 16263, 'foreground': 1942}


{0.0: 16263,
 1.0: 20,
 2.0: 1166,
 3.0: 184,
 4.0: 160,
 5.0: 146,
 6.0: 87,
 7.0: 62,
 8.0: 48,
 9.0: 27,
 10.0: 21,
 11.0: 21}

## Here we can see how the `0` class (background class) has a lot more samples than all the other classes combined. Now we will use `SMOTE` to balance the dataset.

In [4]:
majority_count = (y==0).sum()
minority_classes = np.delete(np.unique(y), 0)

np.random.seed(0)

while (y != 0).sum() < majority_count:  # as long as minority counts < majority counts
    probabilities = np.array([1/np.sum(y == i) for i in minority_classes])
    probabilities = probabilities / np.sum(probabilities)
    
    # chose one of the non-background classes
    chosen_class = np.random.choice(minority_classes, p=probabilities)
    # print(f"Chose {chosen_class}")
    
    # randomly choose two points of the chosen class
    chosen_class_points = x[y == chosen_class]
    # print(chosen_class_points.shape)
    rows = np.random.choice(chosen_class_points.shape[0], 2, replace=False)
    p1, p2 = chosen_class_points[rows]
    # print(p1.shape, p2.shape)
    
    # create a new point at the midpoint of p1 and p2
    new_point = (p1 + p2) / 2
    new_label = chosen_class
    
    # add the new point to the dataset
    x = np.vstack([x, new_point])
    y = np.append(y, new_label)

In [5]:
stat(x, y)

{'total': 32526, 'background': 16263, 'foreground': 16263}


{0.0: 16263,
 1.0: 1432,
 2.0: 1869,
 3.0: 1429,
 4.0: 1430,
 5.0: 1454,
 6.0: 1445,
 7.0: 1446,
 8.0: 1412,
 9.0: 1466,
 10.0: 1413,
 11.0: 1467}

## Here we can see how the foreground and background classes now have equal number of samples (16263). The dataset is now balanced. We can also see that all the minority classes have similar representation in the dataset (close to 1400).