In [1]:
from adaptnlp import EasyWordEmbeddings, EasyStackedEmbeddings, EasyDocumentEmbeddings

# Easy Embeddings

### Pretrained keys are available in Transformer's documentation or Flair's tutorials

## Example of producing embeddings using NovettaWordEmbeddings

In [2]:
example_text = "This is Albert.  My last name is Einstein.  I like physics and atoms."

In [3]:
# Instantiate embeddings tagger
embeddings = EasyWordEmbeddings()

In [4]:
# Get GPT2 embeddings of example text
# A list of flair Sentence objects are generated
sentences = embeddings.embed_text(example_text, model_name_or_path="gpt2")
# Iterate through first Sentence to access the embeddings
for token in sentences[0]:
    print(token.get_embedding())
    break

tensor([-0.1524, -0.0703,  0.5778,  ..., -0.3797, -0.3565,  2.4139],
       device='cuda:0')


In [5]:
# Same thing but with BERT embeddings
sentences = embeddings.embed_text(example_text, model_name_or_path="bert-base-cased")
# Iterate through first Sentence to access the embeddings
for token in sentences[0]:
    print(token.get_embedding())
    break

tensor([ 0.5918, -0.4142,  1.0203,  ...,  0.4004, -0.1586,  1.0107],
       device='cuda:0')


In [6]:
# Same thing but roBERTa embeddings
sentences = embeddings.embed_text(example_text, model_name_or_path="roberta-base")
# Iterate through first Sentence to access the embeddings
for token in sentences[0]:
    print(token.get_embedding())
    break

tensor([ 3.3757e-02,  5.2783e-01, -6.2026e-02, -1.0129e-01,  8.1527e-01,
         5.0778e-01, -6.9540e-02,  7.9886e-02, -3.3002e-01,  8.0280e-02,
        -9.5056e-02, -3.7590e-01, -2.4488e-01,  1.2541e-02,  1.4148e-01,
         2.9701e-01, -4.6033e-01, -1.7297e-01,  9.4156e-03, -1.6250e-01,
        -2.4242e-01,  3.7214e-01,  8.8796e-03,  2.1160e-01, -2.4286e-01,
         1.3693e-01,  2.5863e-01, -2.0122e-01, -4.2645e-02,  1.0488e-01,
        -3.9044e-02, -1.1922e-01,  1.6976e-01,  3.5874e-01,  3.6035e-02,
        -1.4893e-02,  3.6945e-01,  1.3586e-01,  1.5267e-01,  1.8039e-02,
        -1.1443e-02, -3.8058e-01, -9.9154e-02,  2.1226e-01, -4.2931e-03,
         2.0853e-01, -4.1712e-01,  8.6358e-03,  2.5362e-02, -1.2906e-02,
        -2.4237e-01,  6.7392e-02, -8.2088e-02, -1.0465e-01, -2.8965e-01,
         4.4008e-01, -2.1734e-01, -5.6794e-02,  1.9137e-01, -1.8721e-01,
        -4.6830e-02,  1.7417e-01, -2.8660e-01,  9.6234e-02,  4.2202e-01,
        -1.3933e-01,  7.8239e-02,  2.0238e-01,  2.2

## Producing stacked embeddings with NovettaStackedEmbeddings

In [7]:
# Instantiate stacked embeddings with a variable number of language models
embeddings = EasyStackedEmbeddings("bert-base-cased", "xlnet-base-cased")

May need a couple moments to instantiate...


In [8]:
# Get stacked/concatenated word embeddings
sentences = embeddings.embed_text(example_text)
# Iterate through first Sentence to access the embeddings
for token in sentences[0]:
    print(token.get_embedding())
    break

tensor([ 0.5918, -0.4142,  1.0203,  ..., -0.1045, -1.2841,  0.0192],
       device='cuda:0')


## Document Embeddings with NovettaDocumentEmbeddings

In [9]:
# Instantiate with variable number of language models
embeddings = EasyDocumentEmbeddings("bert-base-cased", "xlnet-base-cased")

May need a couple moments to instantiate...
Pooled embedding loaded
RNN embeddings loaded


In [10]:
# Document Pool embedding
sentences = embeddings.embed_pool(example_text)
# Get the text/document embedding
for sentence in sentences:
    print(sentence.get_embedding())

tensor([ 0.4216,  0.0123,  0.3136,  ..., -0.0683, -0.3761, -0.0974],
       device='cuda:0', grad_fn=<CatBackward>)


In [11]:
# Now again but with Document RNN embedding
sentences = embeddings.embed_rnn(example_text)
# Get the text/document embedding
for sentence in sentences:
    print(sentence.get_embedding())

tensor([-0.7386, -0.0982,  0.5763,  0.5601, -0.2424,  0.5826,  0.2657, -0.2311,
        -0.1804, -0.0853,  0.1215, -0.2035, -0.1547,  0.7205, -0.0699, -0.2292,
         0.5321,  0.4236,  0.1684, -0.2774, -0.5438, -0.5345,  0.4484,  0.2757,
        -0.4520, -0.6414,  0.6564, -0.2088, -0.1627,  0.0736, -0.5191, -0.7877,
        -0.0214,  0.2722, -0.5574,  0.7462, -0.1740,  0.4776,  0.7506,  0.1386,
        -0.2258,  0.4511, -0.0340,  0.2455, -0.3466, -0.2498, -0.5203, -0.1051,
        -0.1231,  0.1517,  0.1652, -0.3973, -0.6091,  0.0515,  0.6699,  0.3800,
         0.0970,  0.6693, -0.0496,  0.0326, -0.6084, -0.5501, -0.0866, -0.4065,
        -0.0866,  0.1535, -0.3795, -0.2964,  0.3861, -0.0735, -0.0082, -0.5337,
         0.5857,  0.2368,  0.1144, -0.1635, -0.2295, -0.2913, -0.1221,  0.1203,
        -0.6785, -0.2888,  0.1972,  0.3157,  0.4167,  0.1574,  0.0319,  0.3106,
         0.2389, -0.2796,  0.2057, -0.3664,  0.4119,  0.6192, -0.0444, -0.2788,
         0.7154,  0.3747,  0.4265,  0.07