![Choose a Supervised Learning Algorithm banner](./images/4_choose_a_supervised_learning_algorithm.png)

# 4. Choose an algorithm

Select an appropriate supervised learning algorithm based on the problem type (classification or regression), data characteristics, interpretability requirements, training time, and other practical considerations.

## 4.1. Consider algorithm categories

![Choose supervised learning algorithm cheat sheet](./images/4_choose_a_supervised_learning_algorithm_scikit_learn_cheat_sheet.png)

Machine learning algorithms can be broadly categorized into three main types:

**Supervised Learning**: These algorithms learn from labeled data, where the input data has corresponding output labels or target variables. The goal is to learn a mapping function from the input features to the output labels. Common supervised learning tasks include classification (predicting a categorical label) and regression (predicting a continuous value).

- **Some algorithms**: Linear Regression, Logistic Regression, Decision Trees, Random Forests, Support Vector Machines (SVMs), Neural Networks.

Examples by dataset:
  - House Prices (Regression) - Linear Regression for interpretability, XGBoost for competition performance
  - Email Spam (Classification) - Naive Bayes excels with text data, achieving 98% accuracy on Enron dataset
  - MNIST Digits - Neural Networks (CNNs) achieve 99.7% accuracy, SVMs reach 98.5%

**Unsupervised Learning**: These algorithms learn from unlabeled data, where there are no predefined output labels. The goal is to discover patterns, structures, or relationships within the data. Common unsupervised learning tasks include clustering (grouping similar data points), dimensionality reduction (reducing the number of features), and association rule mining.

- **Some algorithms**: K-Means Clustering, Hierarchical Clustering, Principal Component Analysis (PCA), Association Rule Mining.

Examples by dataset:
  - Customer Segmentation - K-Means on RFM (Recency, Frequency, Monetary) features to identify customer groups
  - Gene Expression Data - PCA reduces thousands of genes to principal components for visualization
  - Market Basket Analysis - Association rules find "beer and diapers" patterns in retail transaction data

**Semi-Supervised Learning**: These algorithms combine a small amount of labeled data with a large amount of unlabeled data. They leverage the strengths of both supervised and unsupervised learning techniques to improve model performance, especially when labeled data is scarce or expensive to obtain.

- **Some algorithms**: Self-Training, Co-Training, Generative Adversarial Networks (GANs).

Examples by dataset:
  - Medical Imaging - Only 100 labeled X-rays but 10,000 unlabeled; semi-supervised learning improves accuracy
  - Text Classification - Using 1,000 labeled documents + 100,000 unlabeled web pages for sentiment analysis

## 4.2. Evaluate algorithm characteristics

When evaluating different supervised learning algorithms, consider the following characteristics:

- **Interpretability**: Some algorithms, like linear models and decision trees, are more interpretable and provide insights into the relationship between input features and the target variable. Others, like neural networks, are more complex and can be treated as "black boxes."

  Examples:
    - Medical Diagnosis - Doctors require Decision Trees to explain why a patient is high-risk
    - Credit Scoring - Regulations require Linear Regression coefficients to justify loan denials
    - Image Recognition - Deep learning black box acceptable since accuracy matters more than explanation

- **Training Time**: Some algorithms, like linear models, are computationally efficient and can be trained quickly, even on large datasets. Others, like ensemble methods (e.g., Random Forests) and neural networks, may require more computational resources and longer training times.

  Examples:
    - Real-time Fraud Detection - Logistic Regression trains in seconds on millions of transactions
    - Kaggle Competitions - XGBoost may take hours but worth it for 1% accuracy gain
    - GPT Models - Months of training on thousands of GPUs for state-of-the-art NLP

- **Prediction Speed**: After training, some algorithms can make predictions very quickly (e.g., linear models), while others may be slower (e.g., instance-based methods like k-Nearest Neighbors).

  Examples:
    - High-Frequency Trading - Linear models make predictions in microseconds
    - Recommendation Systems - kNN too slow for real-time; use matrix factorization instead
    - Mobile Apps - Decision trees work offline; neural networks need cloud inference

- **Data Type Handling**: Some algorithms can handle different data types (e.g., categorical, numerical, text) natively, while others may require additional data preprocessing or feature engineering.

  Examples:
    - Mixed Data (Titanic) - Random Forests handle categorical/numerical without encoding
    - Text Classification - Naive Bayes naturally works with word counts
    - Tabular Data - XGBoost handles missing values without imputation

- **Robustness to Outliers**: Some algorithms, like decision trees and ensemble methods, are more robust to outliers in the data, while others, like linear models, can be heavily influenced by outliers.

  Examples:
    - Income Prediction - Tree-based models handle millionaire outliers better than Linear Regression
    - Sensor Data - Random Forest robust to occasional sensor malfunctions
    - Student Grades - One bad test score skews Linear Regression predictions

- **Scalability**: As the size of the dataset grows, some algorithms may become computationally expensive or require specialized techniques (e.g., online learning, distributed computing) to handle large-scale data.
 
  Examples:
    - Google Search - SGD (Stochastic Gradient Descent) scales to billions of examples
    - IoT Sensors - Online learning updates model with each new data point

Evaluating these characteristics can help narrow down the choices and select algorithms that are well-suited for your specific problem and data characteristics.

## 4.3. Try multiple algorithms

Since it's difficult to know the best algorithm upfront, it's recommended to try multiple algorithms from different families (e.g., linear models, tree-based models, instance-based models, etc.) and compare their performance on specific data. 

This process is often referred to as "model selection" or "algorithm selection".

Here are some common algorithm families and examples:

- **Linear Models**: Linear Regression, Logistic Regression, Support Vector Machines (SVMs).

  Best for:
    - Boston Housing - Linear Regression achieves R² of 0.74 with clear feature importance
    - Iris Classification - Logistic Regression gets 97% accuracy on linearly separable species
    - Text Classification - Linear SVM excels on high-dimensional sparse text features

- **Tree-Based Models**: Decision Trees, Random Forests, Gradient Boosting Machines.

  Best for:
    - Titanic Survival - Random Forest captures complex interactions (gender × class × age)
    - House Prices Advanced - XGBoost wins most Kaggle competitions with 0.12 RMSE
    - Credit Default - Decision Trees provide clear if-then rules for loan officers

- **Instance-Based Models**: k-Nearest Neighbors (kNN).

  Best for:
    - Digit Recognition - kNN achieves 97% on MNIST by comparing pixel similarities
    - Recommendation Systems - "Users who liked X also liked Y" based on similarity
    - Anomaly Detection - Points far from k neighbors are likely anomalies

- **Bayesian Models**: Naive Bayes, Gaussian Naive Bayes.

  Best for:
    - Spam Detection - Naive Bayes perfect for word probabilities in spam/ham emails
    - Document Classification - 20 Newsgroups dataset with 90% accuracy
    - Real-time Prediction - Extremely fast training and prediction

- **Neural Networks**: Feedforward Neural Networks, Convolutional Neural Networks (CNNs), Recurrent Neural Networks (RNNs).

  Best for:
    - ImageNet - CNNs achieve 90%+ accuracy on 1000 classes
    - Stock Price Prediction - LSTMs capture temporal patterns in time series
    - Natural Language - Transformers (BERT, GPT) revolutionized text understanding

- **Ensemble Methods**: Random Forests, Gradient Boosting Machines, Bagging, Boosting.

  Best for:
    - Any Kaggle Competition - Ensemble of models almost always wins
    - Imbalanced Datasets - Balanced Random Forest handles class imbalance
    - Production Systems - Combine multiple models for robustness

By trying multiple algorithms from different families, you can compare their performance metrics (e.g., accuracy, precision, recall, F1-score, mean squared error) and select the one that performs best on your data.