## Random Forest Embeddings

random_forest_embed.ipynb

This script helps examine predictive power of MPNet vector embeddings for predicting
depression symptom severity (PHQ-8 scores) using Random Forest regression. It includes 
preprocessing steps, dimensionality reduction, and hyperparameter tuning.
Developed as part of a study on multilingual lexical markers and depression severity.


**Usage**:
- Place your input CSV file in the desired directory.
- Make sure the file includes columns starting with 'MPNet' containing MPNet embeddings.

**Author**: Anastasiia Tokareva


### Models tested:
1. Full embeddings
2. Embeddings + TSVD dimensionality reduction

In [1]:
## Load libraries
import pandas as pd
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler, FunctionTransformer
from sklearn.compose import ColumnTransformer

from sklearn.ensemble import RandomForestRegressor

from sklearn.metrics import make_scorer, mean_squared_error, r2_score, mean_absolute_error

from sklearn.model_selection import GroupKFold, cross_validate, GridSearchCV 
from sklearn.model_selection import StratifiedKFold

from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.decomposition import PCA
from sklearn.decomposition import TruncatedSVD


In [None]:
## 1. Clean your data
data = pd.read_csv("C:/Users/your/file/name/here.csv")  
data_cleaned = data.dropna(axis=0) 

## 2. Binarise COVID data
data_cleaned['Recording_Date'] = pd.to_datetime(data_cleaned['Recording_Date'])

# define COVID lockdown start and end dates (dates based on Leightley et al. (2021), https://pubmed.ncbi.nlm.nih.gov/34488697/)
covid_start = pd.to_datetime('2020-03-23')
covid_end = pd.to_datetime('2021-05-11')

data_cleaned['COVID'] = ((data_cleaned['Recording_Date'] >= covid_start) & (data_cleaned['Recording_Date'] <= covid_end)).astype(int)

# COVID now added as the last column (0/1)
data_cleaned.head(n=5)

#### 1. Full Embeddings

In [5]:
# Define MPNet columns  
mpnet = [col for col in data_cleaned.columns if col.startswith('Mpnet_')]  # extract MPNet column names

# Define column transformer (StandardScaler for numerical, FunctionTransformer for raw features)
preprocessor = ColumnTransformer([
    ('num_scaler', StandardScaler(), ['Age', 'Education_Years'] + mpnet),               # scale numerical features + embeddings
    ('num_raw', FunctionTransformer(lambda x: x, validate=False), ['Gender', 'COVID'])  # keeps dummy variables unscaled
])

# Define the pipeline
pipeline_1 = Pipeline([
    ('preprocessing', preprocessor),
    ('regressor', RandomForestRegressor())
])


#### Grid search parameters

In [6]:

param_grid = {
    'regressor__n_estimators': [50, 100, 200, 500],  
    'regressor__max_depth': [None, 20, 50, 70],
    'regressor__min_samples_split': [2, 5, 10]
}


#### 2. Embeddings + TSVD dimensionality reduction

In [7]:
# Define column transformer (StandardScaler for numerical, and FunctionTransformer for raw features)
preprocessor = ColumnTransformer([
    ('num_scaler', StandardScaler(), ['Age', 'Education_Years'] + mpnet),               # scale numerical features + embeddings
    ('num_raw', FunctionTransformer(lambda x: x, validate=False), ['Gender', 'COVID'])  # keeps dummy variables unscaled
])

# Define the pipeline
pipeline_2 = Pipeline([
    ('preprocessing', preprocessor),  
    ('pca', TruncatedSVD(n_components=100)), 
    ('regressor', RandomForestRegressor()) 
])


### Custom RMSE

In [8]:
# Custom RMSE scorer
def rmse(y_true, y_pred):
    return np.sqrt(mean_squared_error(y_true, y_pred))

# Define the scorers dictionary
scorers = {
    'rmse': make_scorer(rmse), # NC:I removed greater_is_better=False, this only needs setting if using scorer in a loss function 
    'r2': make_scorer(r2_score),
    'mae': make_scorer(mean_absolute_error),
}

### Set up nested CFV

In [9]:
# demographic variables and all embeddings
X = data_cleaned[['Age','Education_Years','Gender', 'COVID'] + mpnet]

# y = data_cleaned[['PHQ8']]
y = data_cleaned['PHQ8']
groups = data_cleaned['participant_ID']

In [10]:
# Define the outer cross-validation strategy (GroupKFold)
inner_cv = GroupKFold(n_splits=5)
outer_cv = GroupKFold(n_splits=5)

#### 1. Full embeddings

In [None]:
# Inner Loop
Inner_Grid = GridSearchCV(pipeline_1,
                          param_grid,
                          verbose = 1,
                          cv=inner_cv,
                          refit='rmse',
                          return_train_score=True  
                         )

# Outer Loop
nested_results = cross_validate(Inner_Grid, X, y, 
                                cv=outer_cv,
                                groups=groups,
                                params={'groups': groups},  # pass group information to inner split 
                                scoring=scorers,
                                return_train_score=True)    # optionally return train scores

print(f"Average Inner RMSE: {np.mean(nested_results['train_rmse']):.2f}")
print(f"Average Inner R²: {np.mean(nested_results['train_r2']):.2f}")
print(f"Average Outer RMSE: {np.mean(nested_results['test_rmse']):.2f}")
print(f"Average Outer R²: {np.mean(nested_results['test_r2']):.2f}")

#### 2. Embeddings + TSV dimensionality reduction

In [None]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      | # Inner Loop
Inner_Grid = GridSearchCV(pipeline_2,
                          param_grid,
                          verbose = 1,
                          cv=inner_cv,
                          refit='rmse',
                          return_train_score=True  
                         )

# Outer Loop
nested_results = cross_validate(Inner_Grid, X, y, 
                                cv=outer_cv,
                                groups=groups,
                                params={'groups': groups},  # pass group information to inner split 
                                scoring=scorers,
                                return_train_score=True)    # optionally return train scores

print(f"Average Inner RMSE: {np.mean(nested_results['train_rmse']):.2f}")
print(f"Average Inner R²: {np.mean(nested_results['train_r2']):.2f}")
print(f"Average Outer RMSE: {np.mean(nested_results['test_rmse']):.2f}")
print(f"Average Outer R²: {np.mean(nested_results['test_r2']):.2f}")