Skip to content
No description, website, or topics provided.
Python C++ Cuda
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Type Name Latest commit message Commit time
Failed to load latest commit information.


The FastEmbedding module is an alternative implementation of torch.nn.Embedding.

Modules differ in realization of the backpropagation. FastEmbedding avoids sorting and sequential reduction by key by adding gradients in a naive way using atomic addition. This approach is safe but gives in nondeterministic results, since addition of floats is not associative. However, in scenarios with many collisions (i.e. when multiple samples in batch refer to the same embedding) using CUDA atomicAdd() results in significant (up to 20 times) improvement in training execution time, especially if dimension of the embedding is small or moderate. In other tested scenarios GPU implementation of FastEmbedding is at least as fast as torch.nn.Embedding when batch size is larger than 128.

Check the performance benchmark results below for more details.


This module works with Python 3.6 and PyTorch 0.4.1 only. There's also a branch compatible with PyTorch 0.4.0.

In order to use FastEmbedding it is sufficient to execute:

python3 install

After that module can be imported as follows:

from fast_embedding import FastEmbedding


Class FastEmbedding is an ordinary torch.nn.Module and can be used in code almost exactly as the original torch.nn.Embedding module. The constructor of FastEmbedding takes two arguments - dictionary size and dimension of embedding.

FastEmbedding does not support:

  • padding,
  • renormalization,
  • scaling by frequency,
  • nor sparse weights.

One can construct FastEmbedding with their own (initial) weight by providing _weight argument.


FastMultiEmbedding is another nn.Module that can be used when there are multiple different embeddings of very small dimension. This module stores all the weights in a single tensor and performs multiple lookups at once. During tests, when dealing with 10 embeddings of dimensions 1-32 speedups ranging 2-35x were observed.


The constructor takes two arguments: list of dictionary sizes and list of embedding dimensions.

FastMultiEmbedding module performs lookup and concatenation of n embeddings of dimensions d1, d2, ..., dn. The input tensor has to be two dimensional with dimensions batchSize x n and the output tensor size will be batchSize x (d1 + d2 + ... + dn).

For example:

from fast_embedding import FastMultiEmbedding

emb = FastMultiEmbedding([100, 200, 300], [2, 4, 8])

is equivalent to:

import torch
import torch.nn as nn

embbeddings = [
    nn.Embedding(100, 2),
    nn.Embedding(200, 4),
    nn.Embedding(300, 8)

def emb(batch):
        embbeddings[0](batch[:, 0]),
        embbeddings[1](batch[:, 1]),
        embbeddings[2](batch[:, 2])
    ], 1)

Performance benchmarks


Plots below compare forward() + backward() execution times of nn.Embedding and FastEmbedding.

The y-axis represents nn.Embedding to FastEmbedding execution time ratio.

For more details please refer tests/

FastEmbedding benchmark


Plots below compare forward() + backward() execution times of nn.Embedding, FastEmbedding and FastMultiEmbedding.

The y-axes represent nn.Embedding to FastEmbedding and nn.Embedding to FastMultiEmbedding execution time ratios.

FastMultiEmbedding benchmark

You can’t perform that action at this time.