In [1]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import joblib

# Loading preprocessed data
X_train = pd.read_csv('X_train.csv')
X_test = pd.read_csv('X_test.csv')
y_train = pd.read_csv('y_train.csv').values.ravel()
y_test = pd.read_csv('y_test.csv').values.ravel()

# Defining and tuning Random Forest
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 10, 20],
    'min_samples_split': [2, 5]
}
rf = RandomForestClassifier(random_state=42)
grid_search = GridSearchCV(rf, param_grid, cv=5, scoring='accuracy', n_jobs=-1)
grid_search.fit(X_train, y_train)

# Best model
best_rf = grid_search.best_estimator_
print("Best Random Forest Parameters:", grid_search.best_params_)
print("Best Random Forest CV Accuracy:", grid_search.best_score_)

# Saving the model
joblib.dump(best_rf, 'random_forest_model.pkl')

# Making predictions
y_pred = best_rf.predict(X_test)

# Calculating metrics
metrics = {
    'Accuracy': accuracy_score(y_test, y_pred),
    'Precision': precision_score(y_test, y_pred, average='weighted'),
    'Recall': recall_score(y_test, y_pred, average='weighted'),
    'F1-Score': f1_score(y_test, y_pred, average='weighted')
}
print("Random Forest Metrics:\n", metrics)

# Saving metrics
with open('random_forest_metrics.txt', 'w') as f:
    for metric, value in metrics.items():
        f.write(f"{metric}: {value:.4f}\n")

# Plotting confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Random Forest Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.savefig('random_forest_cm.png')
plt.close()

# Plotting ROC curve for multiclass
y_test_bin = label_binarize(y_test, classes=[0, 1, 2])
y_score = best_rf.predict_proba(X_test)
fpr, tpr, roc_auc = {}, {}, {}
for i in range(3):
    fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure(figsize=(8, 6))
for i in range(3):
    plt.plot(fpr[i], tpr[i], label=f'ROC curve (class {i}, AUC = {roc_auc[i]:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Random Forest ROC Curve')
plt.legend(loc='lower right')
plt.savefig('random_forest_roc.png')
plt.close()

# Saving ROC AUC scores
with open('random_forest_roc_auc.txt', 'w') as f:
    for i in range(3):
        f.write(f"Class {i} AUC: {roc_auc[i]:.4f}\n")

# Feature importance
feature_importance = pd.DataFrame({
    'Feature': X_train.columns,
    'Importance': best_rf.feature_importances_
}).sort_values(by='Importance', ascending=False)
print("Random Forest Feature Importance:\n", feature_importance)
feature_importance.to_csv('feature_importance.csv', index=False)

# Plotting feature importance
plt.figure(figsize=(10, 6))
sns.barplot(x='Importance', y='Feature', data=feature_importance)
plt.title('Random Forest Feature Importance')
plt.savefig('random_forest_feature_importance.png')
plt.close()

print("Random Forest training and evaluation completed.")

Best Random Forest Parameters: {'max_depth': 10, 'min_samples_split': 2, 'n_estimators': 100}
Best Random Forest CV Accuracy: 0.5959987985767756


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Random Forest Metrics:
 {'Accuracy': 0.5950877192982457, 'Precision': 0.3554635006349324, 'Recall': 0.5950877192982457, 'F1-Score': 0.4450722053704433}
Random Forest Feature Importance:
                            Feature  Importance
17                 birth_weight_kg    0.148768
9                      apgar_score    0.119267
0                              age    0.113534
13  gestational_age_at_first_visit    0.094361
5       number_of_antenatal_visits    0.079213
8      gestational_age_at_delivery    0.065881
7                       occupation    0.062613
2                        gravidity    0.056245
1                           parity    0.050379
4                  education_level    0.044535
3                   marital_status    0.035815
6           household_income_level    0.032334
12                       residence    0.020003
14          previous_complications    0.018615
11             birth_complications    0.017638
16                has_hypertension    0.016524
15            