# Entity Embeddi# ## Package Setup and Environment

In [None]:
using Pkg
Pkg.activate(@__DIR__);
Pkg.instantiate();

**Julia version** is assumed to be 1.10.*

## Required Packages

We'll need several packages for this tutorial:
- **MLJ ecosystem**: Core machine learning framework and MLJFlux for neural networks
- **Flux**: Deep learning framework for building the embedding models
- **Data handling**: CSV, DataFrames, CategoricalArrays for data manipulation
- **Visualization**: Plots for visualizing the learned embeddings
- **Utilities**: Random, Tables, ProgressMeter, StatsBase for various helper functionsedder

This demonstration is available as a Jupyter notebook or julia script
[here](https://github.com/FluxML/MLJFlux.jl/tree/dev/docs/src/common_workflows/entity_embeddings).

Entity embedding is a newer deep learning approach for categorical encoding introduced in 2016 by Cheng Guo and Felix Berkhahn.
It employs a set of embedding layers to map each categorical feature into a dense continuous vector in a similar fashion to how they are employed in NLP architectures.

In MLJFlux, the `EntityEmbedder` provides a high-level interface to learn entity embeddings using any supervised MLJFlux model as the underlying learner.
The embedder can be used as a transformer in MLJ pipelines to encode categorical features with learned embeddings, which can then be used as features in downstream machine learning models.

In this tutorial, we will explore how to use the `EntityEmbedder` to learn and apply entity embeddings on the Google Play Store dataset.

## Learning Objectives
- Understand the concept of entity embeddings for categorical encoding
- Learn how to use `EntityEmbedder` from MLJFlux
- Apply entity embeddings to a real-world dataset
- Visualize the learned embedding spaces
- Build pipelines combining embeddings with downstream models

In [None]:
using Pkg
Pkg.activate(@__DIR__);
Pkg.instantiate();



# Import all required packages
using MLJ
using Flux
using Optimisers
using CategoricalArrays
using DataFrames
using Random
using Tables
using ProgressMeter
using Plots
using ScientificTypes
using CSV
using StatsBase  ## For countmap
import Plots: mm  ## For margin units

## Data Loading and Preprocessing

We'll use the Google Play Store dataset which contains information about mobile applications.
This dataset has several categorical features that are perfect for demonstrating entity embeddings:
- **Category**: App category (e.g., Games, Social, Tools)
- **Content Rating**: Age rating (e.g., Everyone, Teen, Mature)
- **Genres**: Primary genre of the app
- **Android Ver**: Required Android version
- **Type**: Free or Paid

In [None]:
# Load the Google Play Store dataset
df = CSV.read("./googleplaystore.csv", DataFrame)

### Data Cleaning and Type Conversion

The raw dataset requires significant cleaning. We'll handle:
1. **Reviews**: Convert to integers
2. **Size**: Parse size strings like "14M", "512k" to numeric values
3. **Installs**: Remove formatting characters and convert to integers
4. **Price**: Remove dollar signs and convert to numeric
5. **Genres**: Extract primary genre only

In [None]:
# Custom parsing function that returns missing instead of nothing
safe_parse(T, s) = something(tryparse(T, s), missing)

# Reviews: ensure integer
df.Reviews = safe_parse.(Int, string.(df.Reviews))

# Size: "14M", "512k", or "Varies with device"
function parse_size(s)
    if s == "Varies with device"
        return missing
    elseif occursin('M', s)
        return safe_parse(Float64, replace(s, "M" => "")) * 1_000_000
    elseif occursin('k', s)
        return safe_parse(Float64, replace(s, "k" => "")) * 1_000
    else
        return safe_parse(Float64, s)
    end
end
df.Size = parse_size.(string.(df.Size))

# Installs: strip '+' and ',' then parse
clean_installs = replace.(string.(df.Installs), r"[+,]" => "")
df.Installs = safe_parse.(Int, clean_installs)

# Price: strip leading '$'
df.Price = safe_parse.(Float64, replace.(string.(df.Price), r"^\$" => ""))

# Genres: take only the primary genre
df.Genres = first.(split.(string.(df.Genres), ';'))

### Storing Category Information for Visualization

We'll store the unique values of each categorical feature to use later when visualizing the embeddings.

In [None]:
# Store unique category names for visualization later
category_names = Dict(
    :Category => sort(unique(df.Category)),
    Symbol("Content Rating") => sort(unique(df[!, Symbol("Content Rating")])),
    :Genres => sort(unique(df.Genres)),
    Symbol("Android Ver") => sort(unique(df[!, Symbol("Android Ver")])),
)

println("Category names extracted:")
for (feature, names) in category_names
    println("$feature: $(length(names)) categories")
end

### Feature Selection and Missing Value Handling

We'll select the most relevant features and remove any rows with missing values to ensure clean data for our embedding model.

In [None]:
select!(
    df,
    [
        :Category, :Reviews, :Size, :Installs, :Type,
        :Price, Symbol("Content Rating"), :Genres, Symbol("Android Ver"), :Rating,
    ],
)
dropmissing!(df)

### Creating Categorical Target Variable

For this tutorial, we'll convert the continuous rating into a categorical classification problem.
This will allow us to use a classification model that can learn meaningful embeddings.

We'll create 10 rating categories by rounding to the nearest 0.5 (e.g., 0.0, 0.5, 1.0, ..., 4.5, 5.0).

In [None]:
# Create 10 classes: 0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5
function rating_to_categorical(rating)
    # Clamp rating to valid range and round to nearest 0.5
    clamped_rating = clamp(rating, 0.0, 5.0)
    rounded_rating = round(clamped_rating * 2) / 2  ## Round to nearest 0.5
    return string(rounded_rating)
end

# Apply the transformation
df.RatingCategory = categorical([rating_to_categorical(r) for r in df.Rating])

# Check the distribution of categorical rating labels
println("Distribution of categorical rating labels:")
println(sort(countmap(df.RatingCategory)))
println("\nUnique rating categories: $(sort(unique(df.RatingCategory)))")

### Type Coercion for MLJ

MLJ requires explicit type coercion to understand which columns are categorical vs continuous.
This step is crucial for the `EntityEmbedder` to identify which features need embedding layers.

In [None]:
# Coerce types for MLJ compatibility
df = coerce(df,
    :Category => Multiclass,
    :Reviews => Continuous,
    :Size => Continuous,
    :Installs => Continuous,
    :Type => Multiclass,
    :Price => Continuous,
    Symbol("Content Rating") => Multiclass,
    :Genres => Multiclass,
    Symbol("Android Ver") => Multiclass,
    :Rating => Continuous,  ## Keep original for reference
    :RatingCategory => Multiclass,  ## New categorical target
)
schema(df)

### Data Splitting

We'll split our data into training and testing sets using stratified sampling to ensure balanced representation of rating categories.

In [None]:
# Split into features and target
y = df[!, :RatingCategory]  ## Use categorical rating as target
X = select(df, Not([:Rating, :RatingCategory]))  ## Exclude both rating columns from features

# Split the data with stratification
(X_train, X_test), (y_train, y_test) = partition(
    (X, y),
    0.8,
    multi = true,
    shuffle = true,
    stratify = y,
    rng = Random.Xoshiro(41),
);

using MLJFlux

## Building the EntityEmbedder Model

The `EntityEmbedder` works by wrapping a supervised learning model that will learn embeddings as part of its training process.

### Key Components:
1. **Base Model**: A neural network classifier that learns to predict our target
2. **Embedding Dimensions**: We specify how many dimensions each categorical feature should be embedded into
3. **Architecture**: The embeddings are learned jointly with the prediction task

### Why Entity Embeddings Work:
- Similar categories get mapped to similar vectors in the embedding space
- The embedding captures semantic relationships between categories
- Dimensionality reduction helps with the curse of dimensionality
- Learned representations often generalize better than one-hot encoding

In [None]:
# Load the neural network classifier
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg = MLJFlux

### Configuring the Base Neural Network

We'll create a neural network classifier with custom embedding dimensions for each categorical feature.
Setting smaller embedding dimensions (like 2D) makes it easier to visualize the learned representations.

In [None]:
# Create the underlying supervised model that will learn the embeddings
base_clf = NeuralNetworkClassifier(
    builder = MLJFlux.Short(n_hidden = 14),
    optimiser = Optimisers.Adam(10e-2),
    batch_size = 20,
    epochs = 5,
    acceleration = CUDALibs(),
    embedding_dims = Dict(
        :Category => 2,
        :Type => 2,
        Symbol("Content Rating") => 2,
        :Genres => 2,
        Symbol("Android Ver") => 2,
    ),
    rng = 39,
)

### Creating the EntityEmbedder

The `EntityEmbedder` wraps our neural network and can be used as a transformer in MLJ pipelines.
By default, it uses `min(n_categories - 1, 10)` dimensions for any categorical feature not explicitly specified.

In [None]:
# Create the EntityEmbedder using the neural network
embedder = EntityEmbedder(base_clf)

## Training the EntityEmbedder

Now we'll train the embedder on our training data. The model learns to predict app ratings while simultaneously learning meaningful embeddings for categorical features.

### What Happens During Training:
1. Each categorical value gets mapped to a learnable embedding vector
2. The neural network learns to predict ratings using these embeddings + continuous features
3. Similar categories that lead to similar predictions get similar embedding vectors
4. The embeddings capture semantic relationships in the data

In [None]:
# Create and train the machine
mach = machine(embedder, X_train, y_train)
MLJ.fit!(mach, force = true, verbosity = 1)

### Transforming Data with Learned Embeddings

After training, we can use the embedder as a transformer to convert categorical features into their learned embedding representations.

In [None]:
# Transform the data using the learned embeddings
X_train_embedded = MLJFlux.transform(mach, X_train)
X_test_embedded = MLJFlux.transform(mach, X_test)

# Check the schema transformation
println("Original schema:")
schema(X_train)
println("\nEmbedded schema:")
schema(X_train_embedded)
X_train_embedded

## Using Embeddings in ML Pipelines

One of the key advantages of entity embeddings is that they can be used as features in any downstream machine learning model.
Let's create a pipeline that combines our `EntityEmbedder` with a k-nearest neighbors classifier.

### Pipeline Benefits:
- **Modular**: Easy to swap out different downstream models
- **Reusable**: Embeddings learned once can be used with multiple models
- **Interpretable**: Can analyze embedding spaces separately from final predictions

In [None]:
# Load KNN classifier
KNNClassifier = @load KNNClassifier pkg = NearestNeighborModels

# Create a pipeline: EntityEmbedder -> KNNClassifier
pipe = embedder |> KNNClassifier(K = 5)

# Train the pipeline
pipe_mach = machine(pipe, X_train, y_train)
MLJ.fit!(pipe_mach, verbosity = 0)

## Visualizing the Learned Embedding Spaces

One of the most powerful aspects of entity embeddings is their interpretability. Since we used 2D embeddings, we can visualize how the model has organized different categories in the embedding space.

### What to Look For:
- **Clustering**: Similar categories should be close together
- **Separation**: Different types of categories should be well-separated
- **Meaningful patterns**: The spatial arrangement should reflect semantic relationships

In [None]:
# Extract the learned embedding matrices from the fitted model
mapping_matrices = fitted_params(mach)[4]

### Creating Embedding Visualization Function

We'll create a helper function to plot the 2D embedding space for each categorical feature.
Each point represents a category, and its position shows how the model learned to represent it.

In [None]:
# Function to create and display scatter plot for categorical embeddings
function plot_categorical_embeddings(feature_name, feature_categories, embedding_matrix)
    # Create scatter plot for this feature's embeddings
    p = scatter(embedding_matrix[1, :], embedding_matrix[2, :],
        title = "$(feature_name) Embeddings",
        xlabel = "Dimension 1",
        ylabel = "Dimension 2",
        label = "$(feature_name)",
        legend = :topright,
        markersize = 8,
        size = (1200, 600))

    # Annotate each point with the actual category name
    for (i, col) in enumerate(eachcol(embedding_matrix))
        if i <= length(feature_categories)
            cat_name = string(feature_categories[i])
            # Truncate long category names for readability
            display_name = length(cat_name) > 10 ? cat_name[1:10] * "..." : cat_name
            annotate!(p, col[1] + 0.02, col[2] + 0.02, text(display_name, :black, 5))
        end
    end

    # Display the plot
    display(p)
    println("Displayed embedding plot for: $(feature_name)")
    return p
end

### Generating Embedding Plots for Each Categorical Feature

Let's visualize the embedding space for each of our categorical features to understand what patterns the model learned.

In [None]:
# Create separate plots for each categorical feature's embeddings

# Plot 1: Category embeddings
if haskey(mapping_matrices, :Category)
    plot_categorical_embeddings(
        :Category,
        category_names[:Category],
        mapping_matrices[:Category],
    )
end

Notice that pairs such as social and entertainment, shopping and finance, and comics and art are closer together than others.

In [None]:
# Plot 2: Content Rating embeddings
if haskey(mapping_matrices, Symbol("Content Rating"))
    plot_categorical_embeddings(
        Symbol("Content Rating"),
        category_names[Symbol("Content Rating")],
        mapping_matrices[Symbol("Content Rating")],
    )
end

The `Everyone` category is positioned far from all others.

In [None]:
# Plot 3: Genres embeddings
if haskey(mapping_matrices, :Genres)
    plot_categorical_embeddings(:Genres, category_names[:Genres], mapping_matrices[:Genres])
end

Here the results may be less interpretable; the idea is that for purposes of indetifying the rating, the model considered categories closer together as more similar.

In [None]:
# Plot 4: Android Ver embeddings
if haskey(mapping_matrices, Symbol("Android Ver"))
    plot_categorical_embeddings(
        Symbol("Android Ver"),
        category_names[Symbol("Android Ver")],
        mapping_matrices[Symbol("Android Ver")],
    )
end

Clear patterns like close proximity between (7.1 and up) and, 7.0-7.1

In [None]:
# Plot 5: Type embeddings (if it exists in the mapping)
if haskey(mapping_matrices, :Type)
    plot_categorical_embeddings(:Type, sort(unique(df.Type)), mapping_matrices[:Type])
end

Indeed, `Free` and `Paid` are too dissimilar.

This demonstrates the power of entity embeddings as a modern approach to categorical feature encoding that goes beyond traditional methods like one-hot encoding or label encoding.

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*