In [7]:
import random
import numpy as np
import torch
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from datasets import load_dataset
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import sklearn
from metatree.model_metatree import LlamaForMetaTree as MetaTree
from metatree.decision_tree_class import DecisionTree, DecisionTreeForest
from metatree.run_train import preprocess_dimension_patch
from transformers import AutoConfig
from sklearn.tree import DecisionTreeClassifier

# Load MetaTree model
model_name_or_path = "yzhuang/MetaTree"

config = AutoConfig.from_pretrained(model_name_or_path)
model = MetaTree.from_pretrained(
    model_name_or_path,
    config=config,
)
decision_tree_forest = DecisionTreeForest()

ensemble_size = 5
seed = 42
rng = np.random.default_rng(seed)

# Load datasets
dset_train = load_dataset('rotten_tomatoes', split='train')
dset_train = dset_train.select(rng.choice(len(dset_train), size=100, replace=False))
dset_test = load_dataset('rotten_tomatoes', split='test')
dset_test = dset_test.select(rng.choice(len(dset_test), size=100, replace=False))

# Convert text to TF-IDF features
tfidf_vectorizer = TfidfVectorizer(max_features=1000)
X_tfidf = tfidf_vectorizer.fit_transform(dset_train['text'] + dset_test['text']).toarray()
y = np.concatenate([dset_train['label'], dset_test['label']])

# Split into training and testing sets (70% train, 30% test)
train_X_tfidf, test_X_tfidf, train_y, test_y = train_test_split(X_tfidf, y, test_size=0.3, random_state=seed)

for i in range(ensemble_size):
    # Sample Train Data
    random.seed(seed + i + 1)
    sample_size = min(256, len(train_X_tfidf))
    subset_idx = random.sample(range(len(train_X_tfidf)), sample_size)
    train_X_subset, train_y_subset = train_X_tfidf[subset_idx], train_y[subset_idx]

    input_x = torch.tensor(train_X_subset, dtype=torch.float32)  # Convert numpy array to tensor
    input_y = F.one_hot(torch.tensor(train_y_subset, dtype=torch.long), num_classes=3).float()  # Convert numpy array to tensor and one-hot encode

    batch = {"input_x": input_x, "input_y": input_y, "input_y_clean": input_y}
    batch = preprocess_dimension_patch(batch, n_feature=10, n_class=3)
    model.depth = 3
    outputs = model.generate_decision_tree(batch['input_x'], batch['input_y'], depth=model.depth)
    decision_tree_forest.add_tree(DecisionTree(auto_dims=outputs.metatree_dimensions, auto_thresholds=outputs.tentative_splits, input_x=batch['input_x'], input_y=batch['input_y'], depth=model.depth))

    print("Decision Tree Features: ", [x.argmax(dim=-1) for x in outputs.metatree_dimensions])
    print("Decision Tree Thresholds: ", outputs.tentative_splits)

# Evaluate the MetaTree model
test_X_tensor = torch.tensor(test_X_tfidf, dtype=torch.float32)
tree_pred = decision_tree_forest.predict(test_X_tensor)

# Calculate accuracy
metatree_accuracy = accuracy_score(test_y, tree_pred.argmax(dim=-1).squeeze(0))
print("MetaTree Test Accuracy: ", metatree_accuracy)

# CART model for comparison
cart_ensemble = []
for i in range(ensemble_size):
    random.seed(seed + i + 1)
    subset_idx = random.sample(range(len(train_X_tfidf)), sample_size)
    train_X_subset, train_y_subset = train_X_tfidf[subset_idx], train_y[subset_idx]

    clf = DecisionTreeClassifier(max_depth=3, random_state=seed + i + 1)
    clf.fit(train_X_subset, train_y_subset)
    cart_ensemble.append(clf)

overall_pred = np.zeros((test_X_tfidf.shape[0], len(set(test_y))))
for clf in cart_ensemble:
    overall_pred += clf.predict_proba(test_X_tfidf)
overall_pred = overall_pred / len(cart_ensemble)

accuracy = accuracy_score(test_y, overall_pred.argmax(axis=-1))
print("CART Test Accuracy: ", accuracy)


Decision Tree Features:  [tensor([8]), tensor([7]), tensor([8]), tensor([0]), tensor([7]), tensor([0]), tensor([8])]
Decision Tree Thresholds:  [tensor([[0.3316]]), tensor([[0.4207]]), tensor([[0.3316]]), tensor([[0.3480]]), tensor([[0.4207]]), tensor([[0.1740]]), tensor([[0.3316]])]
Decision Tree Features:  [tensor([8]), tensor([7]), tensor([8]), tensor([0]), tensor([7]), tensor([0]), tensor([8])]
Decision Tree Thresholds:  [tensor([[0.3316]]), tensor([[0.4207]]), tensor([[0.3316]]), tensor([[0.3480]]), tensor([[0.4207]]), tensor([[0.1740]]), tensor([[0.3316]])]
Decision Tree Features:  [tensor([8]), tensor([7]), tensor([8]), tensor([0]), tensor([7]), tensor([0]), tensor([8])]
Decision Tree Thresholds:  [tensor([[0.3316]]), tensor([[0.4207]]), tensor([[0.3316]]), tensor([[0.3480]]), tensor([[0.4207]]), tensor([[0.1740]]), tensor([[0.3316]])]
Decision Tree Features:  [tensor([8]), tensor([7]), tensor([8]), tensor([0]), tensor([7]), tensor([0]), tensor([8])]
Decision Tree Thresholds:  [

In [8]:
# Predict on test data
test_X_tensor = torch.tensor(test_X_tfidf, dtype=torch.float32)  # Convert test features to tensor
tree_pred = decision_tree_forest.predict(test_X_tensor)

# Compute accuracy
accuracy = accuracy_score(test_y, tree_pred.argmax(dim=-1).squeeze(0).numpy())
print("MetaTree Test Accuracy: ", accuracy)

MetaTree Test Accuracy:  0.4666666666666667
