In [1]:
import time
import math
from pprint import pprint
from itertools import islice

import numpy as np
import torch
from torch import nn

import syft as sy
from syft import VirtualMachine
from syft.core.plan.plan_builder import make_plan, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft import logger

import transformers
from transformers.models.distilbert.modeling_distilbert import DistilBertConfig
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerFast
import datasets

In [2]:
# Create client
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
remote_torch = ROOT_CLIENT.torch

## Load DistilBert

In [3]:
# Load model
model = transformers.AutoModel.from_pretrained("distilbert-base-uncased")
tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased")
config = model.config

# Turn off dropout
model = model.eval()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
%%time
model_ptr = model.send(alice_client)

[2021-06-14T17:46:25.330900+0200][CRITICAL][logger]][26185] Path FFN not present in the AST.


AttributeError: 'FFN' object has no attribute '_forward_plan'

In [5]:
%%time
tokenizer_ptr = tokenizer.send(alice_client)

CPU times: user 43.7 ms, sys: 1.68 ms, total: 45.4 ms
Wall time: 44.5 ms


## Load some data

In [6]:
train_set = datasets.load_dataset('sst', split='train')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)

num_labels = len(set(train_set['label']))
config.num_labels = num_labels

bx = next(iter(train_loader))['sentence']
bx_ptr = sy.lib.python.List(bx).send(alice_client)

No config specified, defaulting to: sst/default
Reusing dataset sst (/home/eelco/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff)


In [7]:
# Tokenize on client
bt_ptr = tokenizer_ptr(bx_ptr, padding=True, return_tensors='pt', truncation=True)
print(bt_ptr)

<syft.proxy.transformers.tokenization_utils_base.BatchEncodingPointer object at 0x7f57441638e0>


In [8]:
# Forward on client
model_ptr.eval()
out_ptr = model_ptr(bt_ptr['input_ids'], bt_ptr['attention_mask'], return_dict=False)
client_out = out_ptr.get()

  grad = getattr(obj, "grad", None)


In [9]:
bt = tokenizer(bx, padding=True, return_tensors='pt', truncation=True)
local_out = model(bt['input_ids'], bt['attention_mask'], return_dict=False)

In [10]:
torch.mean(torch.abs(client_out[0] - local_out[0]))

tensor(0., grad_fn=<MeanBackward0>)