In [1]:
# If needed, install huggingface datasets
#!pip install datasets

In [2]:
import time
import math
from pprint import pprint
from itertools import islice

import numpy as np
import torch
from torch import nn

import syft as sy
from syft import VirtualMachine
from syft.core.plan.plan_builder import make_plan, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft import logger
from syft import SyModule
from syft.lib.transformers.models.distilbert import SyDistilBert

from transformers.models.distilbert.modeling_distilbert import DistilBertConfig
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerFast
import datasets

logger.remove()
sy.load('transformers')

In [3]:
# Create client
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
remote_torch = ROOT_CLIENT.torch

In [4]:
# Setup

# Use distilbert weights distilled from `bert-base-uncased`, 
# other pretrained configurations can be found here:
# https://huggingface.co/transformers/pretrained_models.html

batch_size = 64
model_name = 'distilbert-base-uncased'
config = AutoConfig.from_pretrained(model_name)

# Data

We will use the IMDB dataset for binary sentence classification:

https://huggingface.co/datasets/imdb

In [5]:
train_set = datasets.load_dataset('imdb', split='train')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

num_labels = len(set(train_set['label']))
config.num_labels = num_labels

print(train_set)
print()
print(train_set[0]['text'])
print("label:", train_set[0]['label'])

Reusing dataset imdb (/home/eelco/.cache/huggingface/datasets/imdb/plain_text/1.0.0/4ea52f2e58a08dbc12c2bd52d0d92b30b88c00230b4522801b3636782f625c5b)


Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High's satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I'm here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn't!
label: 1


# Tokenizer

The `PretrainedTokenizerFast` tokenizer used by most transformer models can now be serialized by Syft, so we can just load one here

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Change tokenizer to a PreTrainedTokenizerFast,
# instead of the type returned by AutoTokenizer.
tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer._tokenizer,
    name_or_path = tokenizer.name_or_path,
    padding_side = tokenizer.padding_side,
    model_max_length = tokenizer.model_max_length,
    **tokenizer.special_tokens_map
)

# Set small model_max_length for faster testing
tokenizer.model_max_length = 128

print(tokenizer)

PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=128, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [7]:
# Test tokenizer on a single batch

bx = next(iter(train_loader))
tokenized = tokenizer(bx['text'], padding=True, return_tensors='pt', truncation=True)
pprint(tokenized)
print()
print("sizes:", {k: v.size() for k, v in tokenized.items()})

{'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]),
 'input_ids': tensor([[ 101, 1045, 1010,  ..., 3286, 1010,  102],
        [ 101, 2049, 4234,  ..., 4569, 2007,  102],
        [ 101, 1999, 1996,  ..., 1007, 1010,  102],
        ...,
        [ 101, 2073, 2000,  ..., 2431, 2126,  102],
        [ 101, 3462, 1997,  ..., 2036, 3625,  102],
        [ 101, 2054, 2019,  ..., 1012, 1026,  102]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])}

sizes: {'input_ids': torch.Size([64, 128]), 'token_type_ids': torch.Size([64, 128]), 'attention_mask': torch.Size([64, 128])}


# DistilBERT for sentiment analysis

A simple classifier using distilbert, copied from [ DistilBertForSequenceClassification](https://github.com/huggingface/transformers/blob/61c506349134db0a0a2fd6fb2eff8e29a2f84e79/src/transformers/models/distilbert/modeling_distilbert.py#L578)

In [8]:
class DistilBertClassifier(SyModule):
    def __init__(self, base_model: SyModule, config: DistilBertConfig, **kwargs):
        super().__init__(**kwargs)
        
        self.num_labels = config.num_labels
        self.config = config

        self.distilbert = base_model
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)
        self.activation = remote_torch.nn.functional.relu
        
    def forward(self, input_ids, attention_mask):
        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )[0]
        
        pooled_output = distilbert_output[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = self.activation(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, num_labels)

        return logits

In [9]:
# TEST CONFIG

# Until Issue #5627 gets resolved, we use a tiny bert config to make the model serializable.
# When fixed, remove this cell to use a full model with pretrained weights.

config = DistilBertConfig(
    vocab_size=tokenizer.vocab_size,
    dim=10,
    max_position_embeddings=129,
    n_heads=2,
    hidden_dim=10,
    n_layers=1,
)

In [10]:
# Small model
base_model = SyDistilBert.from_config(config)

# Large model with pretrained weights
# base_model = SyDistilBert.from_pretrained(model_name)

classifier = DistilBertClassifier(base_model, config, inputs=base_model.inputs)
print(classifier)

transformer init: 0.21 s
transformer forward: 0.10 s
DistilBertClassifier(
  (distilbert): SyDistilBert(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 10, padding_idx=0)
      (position_embeddings): Embedding(129, 10)
      (LayerNorm): LayerNorm((10,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=10, out_features=10, bias=True)
            (k_lin): Linear(in_features=10, out_features=10, bias=True)
            (v_lin): Linear(in_features=10, out_features=10, bias=True)
            (out_lin): Linear(in_features=10, out_features=10, bias=True)
          )
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=10, out_features=10, bias

In [11]:
# Test a local forward pass

batch = next(iter(train_loader))
batch_x = tokenizer(batch['text'], padding=True, 
                    return_tensors='pt', truncation=True)
batch_y = batch['label']
logits = classifier(**batch_x)[0]
print("logits shape:", logits.shape)

logits shape: torch.Size([64, 2])


  grad = getattr(obj, "grad", None)


# Train Plan

In [12]:
dummy_batches = sy.lib.python.List([next(iter(train_loader))])

@make_plan
def train(classifier=classifier, tokenizer=tokenizer, batches=dummy_batches):
    """
    Train classifier on batches, and return updated classifier
    """
    opt = remote_torch.optim.AdamW(classifier.parameters(), lr=1e-3)
    
    for batch in batches:
        classifier.train()
        opt.zero_grad()
        
        # Prepare data
        batch_x = tokenizer(batch['text'], padding=True, 
                            return_tensors='pt', truncation=True)
        batch_y = batch['label']
        
        # Forward, loss, backward
        out = classifier(input_ids=batch_x['input_ids'],
                         attention_mask=batch_x['attention_mask'])[0]
        loss = remote_torch.nn.functional.cross_entropy(out, batch_y)
        loss.backward()
        opt.step()
    return [classifier]

## Train

In [13]:
# Train on 10 batches and return the updated model

train_batches = sy.lib.python.List(islice(iter(train_loader), 10))
train_ptr = train.send(alice_client)
classifier_ptr = train_ptr(classifier=classifier, tokenizer=tokenizer, batches=train_batches)

classifier_updated = classifier_ptr.get()[0]

In [14]:
# Sanity check: check if the parameters have updated

def check_same_parameters(model1, model2) -> bool:
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        if p1.data.ne(p2.data).any():
            return False
    return True

print("models have same parameters:", check_same_parameters(classifier, classifier_updated))

models have same parameters: False
