In [None]:
%pip install torch scikit-learn

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score


In [None]:
class DecisionTree(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(DecisionTree, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class RandomForest(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_trees):
        super(RandomForest, self).__init__()
        self.trees = nn.ModuleList([DecisionTree(input_dim, hidden_dim, num_classes) for _ in range(num_trees)])
    
    def forward(self, x):
        tree_outputs = [tree(x) for tree in self.trees]
        return torch.mean(torch.stack(tree_outputs), dim=0)


In [None]:
# generate a random dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=14)
X = torch.FloatTensor(X)
y = torch.LongTensor(y)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=14)


In [None]:
# set hyperparameters
input_dim = X.shape[1]
hidden_dim = 100
num_classes = 2
num_trees = 10
lr = 0.01
epochs = 200

# initialize the model
model = RandomForest(input_dim, hidden_dim, num_classes, num_trees)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr)


In [None]:
# training loop
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = loss_fn(outputs, y_train)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# evaluate the model
model.eval()

with torch.no_grad():
    y_pred = model(X_test).argmax(dim=1)
    accuracy = accuracy_score(y_test, y_pred)
    print(f'Test Accuracy: {accuracy:.4f}')

# feature importance
feature_importance = torch.zeros(input_dim)

for tree in model.trees:
    feature_importance += torch.abs(tree.fc1.weight).sum(dim=0)

feature_importance /= num_trees


In [None]:
# print top 5 important features
top_features = torch.argsort(feature_importance, descending=True)[:5]

print("Top 5 most important features:")
for idx in top_features:
    print(f"Feature {idx}: {feature_importance[idx].item():.4f}")