Name: Ryan Young

Date: 2024-05-15

# Molecular Scent Analysis
Hi!

Welcome to this exploratory analysis notebook where we aim to predict whether a molecule might smell like a flower

This notebook's goal is to demonstrate an approach to solving a problem related to molecular scent prediction, with a focus on exploratory data analysis, model evaluation, and visualization of results.

Towards the end, we will also explore some problems related to message passing and graph neural networks.

# Dependencies
The following cell clones a private GitHub repo called `gin`
 -- where I developed the code for this analysis.

Even with notebooks, I tend to modularize pieces into repos for
- reproducibility
- CI/CD
- testing

The "fine-grained" `access_token` token below grants permission to pull the private repo.

In [None]:
#  # Imports and Argparse
import importlib
import os
import shutil
import argparse
import datetime
import matplotlib.pyplot as plt

import gin

import numpy as np
import torch
from sklearn import ensemble as sklearn_ensemble
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE

parser = argparse.ArgumentParser(
    description='Predict the presence of a specific odor descriptor.')
parser.add_argument('--archive', 
                    type=str, 
                    default='leffingwell',
                    help='Name of the Pyrfume data archive to use.')
parser.add_argument('--descriptor', 
                    type=str, 
                    default='floral',
                    help='The odor descriptor to predict.')
args = parser.parse_args()
desc = args.descriptor
args.script = "Pyrfume_RF_GNN_singleOdor.py"

print(" ------  ARGS -------- ")
print(args)
print(" --------------------- ")

# import seaborn as sns
plt.rcParams['figure.dpi'] = 150

# Save the figure
figure_dir = os.path.join(os.path.dirname(gin.__file__), '..', 'figures', args.descriptor)
print("Figure directory:", figure_dir)
os.makedirs(figure_dir, exist_ok=True)  # Create the directory if it doesn't exist
figure_path = (lambda x="": 
                os.path.join(figure_dir, f'{plt.gcf().get_suptitle() if not
                                 x else x}.png'))
df_path = os.path.join(figure_dir, "..", "df.csv") # WARNING: in the face of more analyses, may have to split this dataframe
save_fig = lambda x="": plt.savefig(figure_path(x))

# MLFLOW
This section sets up experiment tracking

from gin.log.mlflow import start_run, log_params, log_metrics, log_artifacts, end_run

start_run(run_name="Pyrfume_RF_GNN_singleOdor")

Log model parameters and metrics
log_params(args.__dict__)
log_metrics({"metric1": score1, "metric2": score2}) # example of how would log this

## Clear

Check if the `gin` package is installed
module_spec = importlib.util.find_spec('gin')
module_spec

Check if the `gin` package is installed, if refresh is True, then we will refresh the package
refresh = False
if module_spec and refresh:
    shutil.rmtree(folder)
    gin_path = os.path.dirname(module_spec.origin)
    # NOTE: This is an access token fenced-off for this specific private repository - only usable to clone this single private repo.
    repo_url = f'https://github.com/synapticsage/gin.git'
    os.system(f'git clone {repo_url}')
    os.chdir('gin')
    # !pip install . 
    # pip install the package
    # os.chdir('..')

In [None]:
os.chdir(os.path.dirname(os.path.dirname(module_spec.origin)))
os.getcwd()

# Dataset

Here we will use data managed by [the Pyrfume project](https://pyrfume.org/) 

The 
[SMILES strings](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) 
representing the molecular structures and their corresponding binary labels are provided.

In [None]:
# Load the data
data_df = gin.data.pyrfume.get_join(args.archive, 
                                    types=["behavior", "molecules", "stimuli"])
data_df = pd.DataFrame(data_df.set_index('SMILES')[desc])

In [None]:
data_df


Now that we have the data loaded, what should we learn about this dataset?

In [None]:
data_df[desc].isnull().sum()


No missing values - reassuring!

Let's see the distribution of the labels in the dataset.

In [None]:
gin.explore.pyrfume.plot_desc_distribution(data_df, kind='pie', descriptor=desc)
save_fig(f'{desc}_distribution')
log_artifacts({"class_distribution": data_df[desc].value_counts().to_dict()})

👆 The large majority of the dataset is non-floral ❌💐. We should consider **class imbalance** downstream.

In [None]:
# Let's visualize some of the molecular structures in the dataset and see if we can spot any patterns.
gin.explore.pyrfume.plot_molecular_structures_w_label(data_df, num_samples=20, descriptor=desc)
save_fig(f'{desc}_molecular_structures_1')

And let's examine a few more samples.

In [None]:
gin.explore.pyrfume.plot_molecular_structures_w_label(data_df, num_samples=20, descriptor=desc)
save_fig(f'{desc}_molecular_structures_2')

## Hypotheses

Some things of note just from the visualization (hypotheses / possibillium / wild guesses):

- The 🪻 floral molecules nearly all have oxygen with free electron pair. Doubled-bond oxygen alone  seems less often associated with floral molecules.
- Nitrogen-containing rarely floral - though, devil's advocate, I also see fewer nitrogen-containing molecules to form an opinion.

# Molecule Featurization

In the next step, we will try to "digitize" each molecule by creating a 1D numpy array based on its molecular structure. Can you create a molecular fingerprint with `rdkit` ([documentation](https://www.rdkit.org/docs/GettingStartedInPython.html#fingerprinting-and-molecular-similarity))?

In [None]:
def featurize_smiles(smiles_str: str,
                     method: str = 'combined') -> np.ndarray:
  """Convert a molecule SMILES into a 1D feature vector."""
  if method == 'morgan':
    fingerprint = gin.features.get_morgan_fingerprint(smiles_str)
  elif method == 'maccs':
    fingerprint = gin.features.get_maccs_keys_fingerprint(smiles_str)
  elif method == 'combined':
    fingerprint = gin.features.get_combined_fingerprint(smiles_str)
  else:
    raise ValueError(f"Invalid method: {method}")
  return fingerprint

# Test the function
featurize_smiles('CC(C)CC(C)(O)C1CCCS1')

In [None]:
# Construct the features `x` and labels `y` for the model
x = np.array([featurize_smiles(v) for v in data_df.index])
from sklearn.preprocessing import OrdinalEncoder
label_encoder = OrdinalEncoder()
x = label_encoder.fit_transform(x)
y = data_df['floral'].values
gin.explore.pyrfume.plot_feature_heatmap(x)
save_fig(f'{desc}_feature_heatmap')

Having noticed the above, we should maybe be thinking about the following

- Feature scaling - less necessary for tree-based models
- High class cardinality - hopefully not an issue
- Feature imbalance - this is a possible issue, but we can address this
  - SMOTE is an option for increasing the minority class

## Splitting the data, cross-validation
we have to split the data into training and testing sets.

In [None]:
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

smote = SMOTE()
# Resampling before splitting the data can lead to data leakage
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)
X_test_res, y_test_res = smote.fit_resample(X_test, y_test)

## Train and evaluate a random forest (RF) model

We will use the RF implementation from [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html).

In [None]:
# What hyper-parameter should we use?
best_params = {'bootstrap': False, 
               'max_depth': None, 
               'max_features': 'log2', 
               'min_samples_leaf': 1, 
               'min_samples_split': 5, 
               'n_estimators': 300} # WARNING: tuned on Floral molecules -- may not apply to others

log_params({'rf_' + key:value for key,value in best_params.items()})

model = sklearn_ensemble.RandomForestClassifier(**best_params)
model_res = sklearn_ensemble.RandomForestClassifier(**best_params)

# How do we fit and inference with the model?
rf_y_pred = model.fit(X_train, y_train).predict(X_test)
rf_y_pred_res = model_res.fit(X_train_res, y_train_res).predict(X_test_res)
rf_y_pred_res2uns = model_res.predict(X_test)

And out of curiosity, let's also try an ensemble - even though for production-level models, this is likely overkill. A tiny performance boost often isn't worth the time and complexity.

## Scoring / Evaluation 📝

How do we evaluate the model performance? What metrics are relevant here?

This is binary classification - we care about precision, recall, F1, and AUC-ROC.

In [None]:
# What sort of visualization is needed here?
print("----------------")
print("Random Forest")
print("----------------")
suptitle = 'Random Forest'
gin.validate.evaluate_model(y_test, rf_y_pred)
gin.validate.plot_confusion_matrix(y_test, rf_y_pred, suptitle=suptitle)
save_fig(f'{desc}_confusion_matrix_rf')
log_metrics(gin.validate.get_metrics(y_test, rf_y_pred))

In [None]:
print("----------------")
print("Random Forest - Resampled")
print("----------------")
suptitle = 'Random Forest - Resampled'
gin.validate.evaluate_model(y_test, rf_y_pred_res2uns)
gin.validate.plot_confusion_matrix(y_test, rf_y_pred_res2uns, suptitle=suptitle)
save_fig(f'{desc}_confusion_matrix_rf_resampled')

In [None]:
start_time = time.time()
from sklearn.linear_model import LogisticRegression
clf1 = LogisticRegression(max_iter=1000)
clf2 = sklearn_ensemble.RandomForestClassifier(**best_params)
clf3 = sklearn_ensemble.GradientBoostingClassifier()

# VotingClassifier with hard voting
model_vote = sklearn_ensemble.VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), ('gb', clf3)], voting='hard')
model_vote_res = sklearn_ensemble.VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), ('gb', clf3)], voting='hard')

# Fit and predict
print("Fitting vote model")
model_vote.fit(X_train, y_train)
eclf_y_pred = model_vote.predict(X_test)

print("Fitting the res vote model")
model_vote_res.fit(X_train_res, y_train_res)
eclf_y_pred_res = model_vote_res.predict(X_test_res)
eclf_y_pred_res2uns = model_vote_res.predict(X_test)

print("Time taken: ", time.time() - start_time)

In [None]:
print("----------------")
print("Ensemble")
print("----------------")
suptitle = "Ensemble"
gin.validate.evaluate_model(y_test, eclf_y_pred)
gin.validate.plot_confusion_matrix(y_test, eclf_y_pred, suptitle=suptitle)
save_fig(f'{desc}_confusion_matrix_ensemble')

In [None]:
print("----------------")
print("Ensemble - Resampled")
print("----------------")
suptitle = "Ensemble - Resampled"
gin.validate.evaluate_model(y_test, eclf_y_pred_res2uns)
gin.validate.plot_confusion_matrix(y_test, eclf_y_pred_res2uns, suptitle=suptitle)
save_fig(f'{desc}_confusion_matrix_ensemble_resampled')

By default random forest sets a default, but perhaps that's not ideal. We have a great deal of choice for type I type II error, and situationally these change.

So let's examine how everything changes as a function of threshold.

> Note: for hyperparameter tuning, usually we want a train, test, and validation set. But since it already works well above with SMOTE, I'm going to forgo a validation set for this exercise 😈.

In [None]:
thresholds = np.arange(0, 1, 0.01)
results_df_res = gin.validate.evaluate_thresholds(model_res, X_test,
                                                  y_test, thresholds)
results_df = gin.validate.evaluate_thresholds(model, X_test, y_test,
                                              thresholds)
gin.validate.plot_threshold_results(results_df_res, model_name='Random Forest', suptitle='Random Forest - Resampled')
save_fig(f'{desc}_threshold_results_rf_resampled')
gin.validate.plot_threshold_results(results_df, model_name='Random Forest', suptitle='Random Forest')
save_fig(f'{desc}_threshold_results_rf')

## Conclusion

Class imbalance correction creates a modest improvement.

The correct threshold depends on what we're optimizing for: do we want to balance precision recall for floral molecules or non-floral?

Generally, we should pick a threshold somewhere in the goldilocks zone (shown in gray above).

# Multi-layer Perceptron 
Let's traina simple neural network.

Now that we have tried modeling with an RF, let's try modeling with a simple neural network: the [multilayer perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron).

This exercise aims to see whether we can train a PyTorch neural network from end to end, so a simple sanity check is adequate, and a thorough evaluation is **not** required.

## Build the MLP module and model API

## Setup a simple data loader and train the model

In [None]:
from gin.model import MLP
model = MLP(input_dim=X_train.shape[1])
model.fit(X_train, y_train)

model_res = MLP(input_dim=X_train_res.shape[1])
model_res.fit(X_train_res, y_train_res)

In [None]:
# Sanity check — how do we know the model has learned from the data?
mlp_y_pred = model.predict(X_test)
mlp_y_pred_res = model_res.predict(X_test)
mlp_y_pred_res2uns = model_res.predict(X_test)

fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].hist([mlp_y_pred, y_test], bins=20, label=['MLP', 'True'], color=['gray', 'red'])
axs[1].hist([mlp_y_pred_res2uns, y_test], bins=20, label=['MLP_res', 'True'], color=['black', 'red'])
axs[0].set_title('MLP')
axs[1].set_title('MLP - Resampled')
axs[0].legend()
axs[1].legend()
plt.show()

In [None]:
# As before, let's just explore the default threshold > 0.5
print("----------------")
print("MLP")
print("----------------")
suptitle = 'MLP'
gin.validate.evaluate_model(y_test, mlp_y_pred>0.5)
gin.validate.plot_confusion_matrix(y_test, mlp_y_pred>0.5, suptitle=suptitle)
save_fig(f'{desc}_confusion_matrix_mlp')

print("----------------")
print("MLP - Resampled")
print("----------------")
suptitle = 'MLP - Resampled'
gin.validate.evaluate_model(y_test, mlp_y_pred_res2uns > 0.5)
gin.validate.plot_confusion_matrix(y_test, mlp_y_pred_res2uns>0.5, suptitle=suptitle)
save_fig(f'{desc}_confusion_matrix_mlp_resampled')

Let's also try to examine the threshold for the MLP's final sigmoid output.

In [None]:
results_df_mlp = gin.validate.evaluate_thresholds(model,
                                                  X_test,
                                                  y_test,
                                                  thresholds,
                                                  y_proba=mlp_y_pred)
results_df_mlp_res = gin.validate.evaluate_thresholds(model_res,
                                                      X_test,
                                                      y_test,
                                                      thresholds,
                                                      y_proba=mlp_y_pred_res)
gin.validate.plot_threshold_results(results_df_mlp, model_name='MLP', suptitle='MLP')
save_fig(f'{desc}_threshold_results_mlp')
gin.validate.plot_threshold_results(results_df_mlp_res, model_name='MLP - Resampled', suptitle='MLP - Resampled')
save_fig(f'{desc}_threshold_results_mlp_resampled')

## Conclusions

The `MLP` achieves a very similar performance to the simpler method above.

Notably, the resampled model for the `MLP` does **not** perform any better, unlike the `RandomForest` above. In practice, we could try other methods of rebalancing and data augmentation techniques given sparse samples.

# Graph Neural Network, Bonus - Naive (👶) 

For fun, let's dovetail this section with a very naive message-passing GNN approach  🤖

*NOTES OF INTEREST*  📝
- We are doing this without resampling/data-augmentation -- so we _may not approach_ performance above.
- Instead, using a simpler class-imbalance reweighting function in the cross-entropy objective.
- We may not have enough samples to utilize the capacity of a bigger model.

In [None]:
# Convert the SMILES strings to graph data and split into train/test sets
from sklearn.model_selection import train_test_split
import torch
from torch_geometric.data import DataLoader
from gin.extra.features import smiles_to_graph
from gin.extra.gnn import train_gnn_model

# Convert the SMILES strings to graph data
data_list = []
for smile_string, floral in zip(data_df.index, data_df['floral']):
    data = smiles_to_graph(smile_string)
    if data is not None:
        data.y = torch.tensor([floral], dtype=torch.float)  # Assign target value
        data_list.append(data)
data_list = gin.extra.features.normalize_data_list(data_list) # Normalize features

if len(data_list) == 0:
    raise ValueError("No valid graph data could be generated from the provided SMILES strings.")

# Split the data into training and testing sets
train_data, test_data = train_test_split(data_list, test_size=0.2, random_state=42)

Train

In [None]:
# Train the GNN model on the training data with SMOTE applied
model = train_gnn_model(train_data, num_epochs=250)

And now, let's run inference

In [None]:
model.eval()
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        preds = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        all_preds.extend(preds.numpy().flatten())
        all_labels.extend(batch.y.numpy().flatten())

import matplotlib.pyplot as plt

plt.hist(all_preds, bins=20, alpha=0.75, label='Predictions')
plt.hist(all_labels, bins=20, alpha=0.75, label='True Labels')
plt.legend()
plt.title('Distribution of Predictions and True Labels')
plt.show()

In [None]:
# Evaluate performance
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
gin.validate.evaluate_model(all_labels, all_preds > 0.5)
gin.validate.plot_confusion_matrix(all_labels, all_preds > 0.5, suptitle='GNN Model')
save_fig(f'{desc}_confusion_matrix_gnn')

In [None]:
from gin.extra.validate import evaluate_thresholds_gnn
thresholds = np.arange(0.0,1.0,0.01)
results = evaluate_thresholds_gnn(model, test_data, thresholds)
results

In [None]:
gin.validate.plot_threshold_results(results, model_name="GNN")
save_fig(f'{desc}_threshold_results_gnn')

## Conclusions

The `MLP` achieves a very similar performance to the simpler method above.

Notably, the resampled model for the `MLP` does **not** perform any better, unlike the `RandomForest` above. In practice, we could try other methods of rebalancing and data augmentation techniques given sparse samples.

The GNN model, while an interesting exercise, does not perform as well as the simpler models. This is typical for neural networks with smaller datasets.

<h4> Better yet -- pull a model from HuggingFace 🤗  that has been pre-trained on other molecules to leverage the knowledge seen in other data.</h4>

# The End