## GOSTD

In [1]:
from gosdt import ThresholdGuessBinarizer, GOSDTClassifier
from sklearn.ensemble import GradientBoostingClassifier
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

data = pd.read_csv('average_expression_t_cluster.csv')
X = data.drop(columns=['cluster'])
y = data['cluster']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 1. Binarize continuous features
#    "ThresholdGuessBinarizer" uses a gradient boosting approach internally
#    to propose thresholds that transform continuous data into 0/1 features.
binarizer = ThresholdGuessBinarizer(
    n_estimators=40,   
    max_depth=1,       
    random_state=42
)
X_train_bin = binarizer.fit_transform(X_train, y_train)
X_test_bin = binarizer.transform(X_test)

# 2.Warm start for GOSDT
#    Train a simple GBDT on the binarized features to get "warm labels".
warm_start_clf = GradientBoostingClassifier(
    n_estimators=40,
    max_depth=1,
    random_state=42
)
warm_start_clf.fit(X_train_bin, y_train)
warm_labels = warm_start_clf.predict(X_train_bin)

# 3. Fit GOSDT
clf = GOSDTClassifier(
    regularization=0.001,
    similar_support=False,
    time_limit=60,    # in seconds
    depth_budget=6,   # maximum depth of the decision tree
    verbose=True
)

clf.fit(X_train_bin, y_train, y_ref=warm_labels)

# 4. Evaluate GOSDT
train_preds = clf.predict(X_train_bin)
test_preds = clf.predict(X_test_bin)

train_acc = accuracy_score(y_train, train_preds)
test_acc = accuracy_score(y_test, test_preds)

print(f"GOSDT Training Accuracy: {train_acc:.4f}")
print(f"GOSDT Test Accuracy:     {test_acc:.4f}")

# 7. (Optional) Inspect the learned rules/tree
#    GOSDTClassifier often provides ways to extract the tree/rules. For example:
if hasattr(clf, 'rules_'):
    print("\nExtracted Rules:")
    for rule in clf.rules_:
        print(rule)


ImportError: cannot import name 'check_X_y' from 'sklearn.base' (/Users/nicolaszhang/miniconda3/envs/data_science_env/lib/python3.12/site-packages/sklearn/base.py)