In [1]:
import pandas as pd
import numpy as np
import random

In [2]:
random_seed = 42
np.random.seed(random_seed)
random.seed(random_seed)

In [3]:
CLASSES = ['EW', 'SR', 'EA', 'RRAB', 'EB', 'ROT', 'RRC', 'HADS', 'M', 'DSCT']

In [4]:
def calc_threshold(df, goal):
    counts = df['target'].value_counts(ascending=True)
    num_classes = len(counts)
    
    i = 1
    diff = (num_classes - i) * (counts.iloc[i] - counts.iloc[i - 1])
    count = num_classes * counts.iloc[i - 1]
    
    while count + diff < goal:
        count += diff
        i += 1
        diff = (num_classes - i) * (counts.iloc[i] - counts.iloc[i - 1])
    
    threshold = (goal - count) // (num_classes - i) + counts.iloc[i - 1]

    return threshold

In [5]:
def limit(df, threshold):
    value_counts = df['target'].value_counts()
    classes_to_limit = value_counts[value_counts > threshold].index
    
    for class_type in classes_to_limit:
        class_indices = df[df['target'] == class_type].index
        indices_to_keep = np.random.choice(class_indices, size=threshold, replace=False)
        df = df.drop(index=set(class_indices) - set(indices_to_keep))

    return df

## SUB50

In [6]:
train = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_train_norm.csv')
val = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_val_norm.csv')
test = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_test_norm.csv')

In [7]:
train = train[train['target'].isin(CLASSES)]
val = val[val['target'].isin(CLASSES)]
test = test[test['target'].isin(CLASSES)]

In [8]:
train['target'].value_counts(), len(train)

(target
 EW      4969
 SR      3652
 EA      2367
 RRAB    1947
 EB      1585
 ROT     1464
 RRC      640
 HADS     229
 M        216
 DSCT     206
 Name: count, dtype: int64,
 17275)

In [9]:
train_threshold = calc_threshold(train, len(train) // 2)
train = limit(train, train_threshold)

In [10]:
train['target'].value_counts(), len(train)

(target
 ROT     1224
 SR      1224
 EW      1224
 EA      1224
 RRAB    1224
 EB      1224
 RRC      640
 HADS     229
 M        216
 DSCT     206
 Name: count, dtype: int64,
 8635)

In [11]:
val['target'].value_counts(), len(val)

(target
 EW      604
 SR      480
 EA      275
 RRAB    238
 EB      210
 ROT     189
 RRC      95
 M        30
 HADS     29
 DSCT     25
 Name: count, dtype: int64,
 2175)

In [12]:
val_threshold = calc_threshold(val, len(val) // 2)
val = limit(val, val_threshold)

In [13]:
val['target'].value_counts(), len(val)

(target
 RRAB    151
 SR      151
 EB      151
 EW      151
 ROT     151
 EA      151
 RRC      95
 M        30
 HADS     29
 DSCT     25
 Name: count, dtype: int64,
 1085)

In [14]:
test['target'].value_counts(), len(test)

(target
 EW      693
 SR      465
 EA      301
 RRAB    239
 ROT     198
 EB      198
 RRC      81
 HADS     26
 DSCT     24
 M        22
 Name: count, dtype: int64,
 2247)

In [15]:
test_threshold = calc_threshold(test, len(test) // 2)
test = limit(test, test_threshold)

In [16]:
test['target'].value_counts(), len(test)

(target
 EB      161
 ROT     161
 EW      161
 EA      161
 SR      161
 RRAB    161
 RRC      81
 HADS     26
 DSCT     24
 M        22
 Name: count, dtype: int64,
 1119)

In [17]:
train.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub50/spectra_and_v_train_norm.csv', index=False)
val.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub50/spectra_and_v_val_norm.csv', index=False)
test.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub50/spectra_and_v_test_norm.csv', index=False)

## SUB25

In [18]:
train = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_train_norm.csv')
val = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_val_norm.csv')
test = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_test_norm.csv')

train = train[train['target'].isin(CLASSES)]
val = val[val['target'].isin(CLASSES)]
test = test[test['target'].isin(CLASSES)]

In [19]:
print(train['target'].value_counts(), len(train))

train_threshold = calc_threshold(train, len(train) // 4)
train = limit(train, train_threshold)

print(train['target'].value_counts(), len(train))

target
EW      4969
SR      3652
EA      2367
RRAB    1947
EB      1585
ROT     1464
RRC      640
HADS     229
M        216
DSCT     206
Name: count, dtype: int64 17275
target
EW      523
ROT     523
RRAB    523
SR      523
RRC     523
EB      523
EA      523
HADS    229
M       216
DSCT    206
Name: count, dtype: int64 4312


In [20]:
print(val['target'].value_counts(), len(val))

val_threshold = calc_threshold(val, len(val) // 4)
val = limit(val, val_threshold)

print(val['target'].value_counts(), len(val))

target
EW      604
SR      480
EA      275
RRAB    238
EB      210
ROT     189
RRC      95
M        30
HADS     29
DSCT     25
Name: count, dtype: int64 2175
target
SR      65
RRC     65
RRAB    65
EB      65
ROT     65
EW      65
EA      65
M       30
HADS    29
DSCT    25
Name: count, dtype: int64 539


In [21]:
print(test['target'].value_counts(), len(test))

test_threshold = calc_threshold(test, len(test) // 4)
test = limit(test, test_threshold)

print(test['target'].value_counts(), len(test))

target
EW      693
SR      465
EA      301
RRAB    239
ROT     198
EB      198
RRC      81
HADS     26
DSCT     24
M        22
Name: count, dtype: int64 2247
target
EB      69
EA      69
RRC     69
ROT     69
SR      69
EW      69
RRAB    69
HADS    26
DSCT    24
M       22
Name: count, dtype: int64 555


In [22]:
train.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub25/spectra_and_v_train_norm.csv', index=False)
val.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub25/spectra_and_v_val_norm.csv', index=False)
test.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub25/spectra_and_v_test_norm.csv', index=False)

## SUB10

In [23]:
train = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_train_norm.csv')
val = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_val_norm.csv')
test = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full/spectra_and_v_test_norm.csv')

train = train[train['target'].isin(CLASSES)]
val = val[val['target'].isin(CLASSES)]
test = test[test['target'].isin(CLASSES)]

In [24]:
print(train['target'].value_counts(), len(train))

train_threshold = calc_threshold(train, len(train) // 10)
train = limit(train, train_threshold)

print(train['target'].value_counts(), len(train))

target
EW      4969
SR      3652
EA      2367
RRAB    1947
EB      1585
ROT     1464
RRC      640
HADS     229
M        216
DSCT     206
Name: count, dtype: int64 17275
target
RRC     169
EB      169
DSCT    169
RRAB    169
EW      169
EA      169
HADS    169
ROT     169
M       169
SR      169
Name: count, dtype: int64 1690


In [25]:
print(val['target'].value_counts(), len(val))

val_threshold = calc_threshold(val, len(val) // 10)
val = limit(val, val_threshold)

print(val['target'].value_counts(), len(val))

target
EW      604
SR      480
EA      275
RRAB    238
EB      210
ROT     189
RRC      95
M        30
HADS     29
DSCT     25
Name: count, dtype: int64 2175
target
RRC     21
ROT     21
HADS    21
EB      21
EA      21
EW      21
SR      21
DSCT    21
RRAB    21
M       21
Name: count, dtype: int64 210


In [26]:
print(test['target'].value_counts(), len(test))

test_threshold = calc_threshold(test, len(test) // 10)
test = limit(test, test_threshold)

print(test['target'].value_counts(), len(test))

target
EW      693
SR      465
EA      301
RRAB    239
ROT     198
EB      198
RRC      81
HADS     26
DSCT     24
M        22
Name: count, dtype: int64 2247
target
RRAB    22
SR      22
RRC     22
EW      22
M       22
HADS    22
DSCT    22
EB      22
EA      22
ROT     22
Name: count, dtype: int64 220


In [27]:
train.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub10/spectra_and_v_train_norm.csv', index=False)
val.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub10/spectra_and_v_val_norm.csv', index=False)
test.to_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/sub10/spectra_and_v_test_norm.csv', index=False)