In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
from syft import make_plan
from syft.lib.python import String, List
from syft import logger
import torch
import transformers

sy.load('transformers')

# Create client and test string
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()

serde = lambda x: x.send(alice_client).get()

In [3]:
string = String("test")
string_ptr = string.send(alice_client)

batch = List([1])
batch_ptr = batch.send(alice_client)

batch = List(["test"])
batch_ptr = batch.send(alice_client)

In [11]:
# Load Tokenizer and Model
from transformers import AutoTokenizer

model_name = 'distilbert-base-uncased'
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

print(hf_tokenizer)

tok_serde = serde(hf_tokenizer)

PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [12]:
# Test batchencoding serde

be_list = hf_tokenizer(["test", "test2"], padding=True, truncation=True, max_length=100)
be_serde = serde(be_list)
assert be_serde == be_list

be_tensor = hf_tokenizer(["test", "test2"], padding=True, return_tensors='pt', truncation=True, max_length=100)
be_serde = serde(be_tensor)
for k, v in be_tensor.items():
    assert torch.eq(v, be_serde[k]).all()

In [7]:
# syft relative
from syft.proto.lib.transformers.tokenizerfast_pb2 import TokenizerFast as TokenizerFast_PB
from syft.generate_wrapper import GenerateWrapper
from syft.lib.python.primitive_factory import PrimitiveFactory
from syft import serialize, deserialize
from syft.lib.python.util import upcast

# Third party
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer


def object2proto(obj: PreTrainedTokenizerFast) -> TokenizerFast_PB:
    # TODO Tokenizer.to_str serializes to a json string.
    # Protobuf has protobuf.json_format, which might be better than sending a raw string.

    tokenizer_str = obj._tokenizer.to_str()
    tokenizer_str = PrimitiveFactory.generate_primitive(value=tokenizer_str)
    
    kwargs = obj.special_tokens_map
    kwargs["name_or_path"] = obj.name_or_path
    kwargs["padding_side"] = obj.padding_side
    kwargs["model_max_length"] = obj.model_max_length
    kwargs = PrimitiveFactory.generate_primitive(value=kwargs)

    protobuf_tokenizer = TokenizerFast_PB(
        id=kwargs.id._object2proto(),
        tokenizer=tokenizer_str._object2proto(),
        kwargs=kwargs._object2proto()
    )
    return protobuf_tokenizer

def proto2object(proto: TokenizerFast_PB) -> PreTrainedTokenizerFast:
    _tokenizer = Tokenizer.from_str(proto.tokenizer.data)
    kwargs = deserialize(proto.kwargs)
    kwargs = upcast(kwargs)
    
    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=_tokenizer,
        **kwargs
    )
    return tokenizer

In [8]:
proto = object2proto(hf_tokenizer)
proto2object(proto)

PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='left', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [9]:
hf_tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [20]:
dir(hf_tokenizer)

['SPECIAL_TOKENS_ATTRIBUTES',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_add_tokens',
 '_additional_special_tokens',
 '_batch_encode_plus',
 '_bos_token',
 '_cls_token',
 '_convert_encoding',
 '_convert_id_to_token',
 '_convert_token_to_id_with_added_voc',
 '_decode',
 '_decode_use_source_tokenizer',
 '_encode_plus',
 '_eos_token',
 '_eventual_warn_about_too_long_sequence',
 '_from_pretrained',
 '_get_padding_truncation_strategies',
 '_mask_token',
 '_pad',
 '_pad_token',
 '_pad_token_type_id',
 '_push_to_hub',
 '_save_pretrained',
 '_sep_token',
 '_tokenizer',
 '_unk_token',
 'add_special_tokens',
 'add_tokens',
 'ad

In [27]:
from syft.lib.python.util import upcast
string = String("test")
type(upcast(string))

str