In [1]:
from flair.embeddings import TransformerDocumentEmbeddings
from flair.data import Sentence

In [2]:
embedding = TransformerDocumentEmbeddings('roberta-base')

In [3]:
sentence = Sentence('The grass is green')
embedding.embed(sentence)

[Sentence: "The grass is green"]

In [4]:
print(sentence.get_embedding().size())

torch.Size([768])


In [5]:
import json

In [7]:
with open('instruction_set.json', 'r') as handler:
    instructions = json.load(handler)

In [8]:
from collections import defaultdict


embeddings = {"train": defaultdict(list), "test": defaultdict(list)}
for mode, benchmarks in instructions.items():
    for key, val in benchmarks.items():
        print(f'processing: {key}')
        for i in val:
            print(f'\tinstruction: {i}')
            sentence = Sentence(i)
            embedding.embed(sentence)

            embeddings[mode][key].append(sentence.get_embedding())


processing: reach
	instruction: reach
	instruction: lift up claw
processing: push
	instruction: push
	instruction: push object forward
	instruction: push object forward with claw
processing: pick-place
	instruction: pick and place
	instruction: pick object then place
processing: door-open
	instruction: open door
	instruction: pull back door
	instruction: grab door handle and pull back
processing: drawer-close
	instruction: close drawer
	instruction: push drawer forward
	instruction: grab drawer handle and push forward
	instruction: claw drawer handle and push forward
processing: button-press-topdown
	instruction: press button
	instruction: press object down
processing: peg-insert-side
	instruction: pick and insert
	instruction: pick object and place from left
processing: window-open
	instruction: open window
	instruction: grab window handle and slide
	instruction: sweep window handle from left
	instruction: sweep object from left to right
processing: basketball
	instruction: pick and p

In [9]:
embeddings

{'train': defaultdict(list,
             {'reach': [tensor([-5.6498e-02,  9.1725e-02, -3.0494e-03, -9.9331e-02,  7.7074e-02,
                       -9.2331e-02, -3.9850e-02,  4.0786e-02,  5.8545e-02, -6.0423e-02,
                       -4.5134e-03,  4.3720e-02,  4.3269e-02, -3.2881e-02,  8.0648e-02,
                        2.8199e-02, -7.1061e-02,  1.6550e-02,  2.9523e-02, -2.5358e-02,
                       -1.1265e-01,  2.7606e-02, -4.5263e-02,  1.1176e-01, -2.9239e-03,
                        1.5195e-02,  6.1748e-02,  7.4846e-02, -6.9118e-02, -3.1089e-02,
                       -2.5063e-02, -3.9352e-02,  5.0516e-02, -3.0511e-02,  3.2795e-02,
                        6.2434e-02,  3.4659e-02, -4.7221e-03, -1.0785e-01,  1.7888e-02,
                       -2.0268e-02,  5.8397e-02,  1.2740e-02,  1.3831e-02,  8.1842e-02,
                        3.0441e-02,  2.3627e-02,  2.2674e-03, -3.7224e-02,  5.3170e-03,
                        2.2361e-02,  9.3582e-02, -4.0727e-02,  1.2557e-02, -8.9593e

In [13]:
import pickle
with open('embeddings.pkl', 'wb') as handler:
    pickle.dump(embeddings, handler)

In [14]:
with open('embeddings.pkl', 'rb') as handler:
    x = pickle.load(handler)

In [15]:
len(x['train']['reach'])

2