Skip to content
Learn to do agglomerative clustering with reinforcement learning
Python Shell
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Type Name Latest commit message Commit time
Failed to load latest commit information.

Reinforced agglomerative clustering

To overcome the greediness of traditional linkage criteria in agglomerative clustering, we proposed a reinforcement learning approach to learn a non-greedy merge policy by modeling agglomerative clustering as Markov Decision Process.

Agglomerative clustering is a "bottom up" approach of hierarchical clustering, where each observation starts in its own cluster, and pairs of clusters are merged as one moves up the hierarchy. Agglomerative clustering is a sequential decision problem, which comes with the problem that a decision made earlier affects the later result. But traditional linkage criteria fail to handle this problem by simply measuring similarity of clusters in current phase. This motivated us to model the clustering as Markov Decision Process and solve it with reinforcement learning. The agent should learn a non-greedy merge policy so that each merge operation is chosen for a better long term discounted reward.

The state is defined as feature representation of current clustering. We use pooling to aggregate the feature of all clusters. The action is defined as merging cluster i and cluster j. We use Q-learning to compute the value of a state-action pair. In training, the reward is computed by the ground truth label of images. And at test time, we test the agent in a different domain to see how it can generalize.


  1. Download mnist dataset
cd dataset/ & bash
  1. Install all the dependencies
pip install -r requirements.txt


  1. Train
python --train
  1. Test
python --test [MODEL_DIR]
You can’t perform that action at this time.