In [1]:
import time
import math
from pathlib import Path

import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
from pprint import pprint

import syft as sy
from syft import VirtualMachine
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft.lib.python.list import List
from syft import logger
from syft import SyModule

# transformers imports, not needed in AST
from transformers.models.distilbert.modeling_distilbert import DistilBertConfig, create_sinusoidal_embeddings
from transformers import AutoModel

# Add in AST
from transformers.activations import gelu

logger.remove()

In [2]:
# Create client
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
remote_torch = ROOT_CLIENT.torch

# DistilBERT

- [x] PretrainedTokenizerFast
- [ ] DistilBertModel
    - [x] Embedding
    - [x] MultiHeadSelfAttention
        - [ ] gelu in AST (torch._C issue)
    - [x] FFN
        - [ ] Chunking
    - [x] TransformerBlock
    - [X] Transformer
        - [X] SyModuleList
    - [ ] Shape issue in SyDistilBert when batch size changes
- [ ] Fix `nn.ModuleList.__iter__` in AST: incorrect Pointer type
- [X] Load pretrained weights
- [ ] Training example on real dataset
- [ ] Classifier (DistilBertForSequenceClassification)
- [ ] Schedulers
- [ ] Metrics

In [3]:
# Small config for testing
vocab_size = 10
dim = 256
max_length = 100
n_heads = 2
hidden_dim = 128
n_layers = 2

config = DistilBertConfig(
    vocab_size=vocab_size,
    dim=dim,
    max_position_embeddings=max_length,
    n_heads=n_heads,
    hidden_dim=hidden_dim,
    n_layers=n_layers,
)

## Embedding
https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L82

In [4]:
class Embeddings(SyModule):

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
        if config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids: torch.LongTensor):
        seq_length = input_ids.size(1)
        # TODO setting device from input_ids from remotely created tensor throws KeyError: UID <...>.
        position_ids = remote_torch.arange(seq_length)  # (max_seq_length)
        position_ids = remote_torch.unsqueeze(position_ids, 0).expand_as(input_ids)  # (bs, max_seq_length)
        
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(input_ids)  # (bs, max_seq_length, dim)

        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
        return embeddings

# dummy_x = torch.ones(10, 100).long()
# sy_embedding = Embeddings(config, inputs={'input_ids': dummy_x})

## MultiHeadSelfAttention

https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L116

In [5]:
from syft.lib.python.int import Int

class MultiHeadSelfAttention(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)

        self.pruned_heads = set()

    def prune_heads(self, heads):
        raise NotImplementedError()

    def shape(self, x, bs, dim_per_head):
            """separate heads for linear layers"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
        
    def unshape(self, x, bs, dim_per_head):
        """group heads"""
        return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
        
    def forward(self, query, key, value, mask):
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)
        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
        """
        bs = query.size(0)
        q_length = query.size(1)
        dim = query.size(2)
        k_length = key.size(1)
        
        dim_per_head = self.dim // self.n_heads
        mask_reshp = (bs, 1, 1, k_length)
        

        q = self.shape(self.q_lin(query), bs, dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        k = self.shape(self.k_lin(key), bs, dim_per_head)  # (bs, n_heads, k_length, dim_per_head)
        v = self.shape(self.v_lin(value), bs, dim_per_head)  # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = remote_torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(*mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
                
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)
        weights = remote_torch.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
        
        context = remote_torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = self.unshape(context, bs, dim_per_head)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        return context

# dummy inputs
# bs = 10
# seq_length = 15
# dim = config.dim

# dummy_q = torch.randn(bs, seq_length, dim)
# dummy_k = torch.randn(bs, seq_length, dim)
# dummy_v = torch.randn(bs, seq_length, dim)
# dummy_mask = torch.ones([bs, seq_length], dtype=torch.long)
# dummy_inputs = {
#     "query": dummy_q,
#     "key": dummy_k,
#     "value": dummy_v,
#     "mask": dummy_mask
# }

# sy_mhsa = MultiHeadSelfAttention(config, inputs=dummy_inputs)

## FFN

https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L203

In [6]:
class FFN(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.dropout = nn.Dropout(p=config.dropout)
        self.seq_len_dim = 1
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
        
        self.activation = nn.ReLU()
        # TODO GeLU torch._C issue, not in AST
#         if config.activation == "gelu":
#             self.activation = remote_torch.nn.functional.gelu
#         elif config.activation == "relu":
#             self.activation = remote_torch.nn.functional.relu
#         else:
#             raise ValueError(
#                 f"activation ({config.activation}) must be in ['relu', 'gelu']"
#             )

    def forward(self, input):
        # TODO Chunking
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

# bs = 10
# seq_length = 15
# dim = config.dim

# dummy_input = torch.randn(bs, seq_length, dim)
# ffn_inputs = {
#     "input": dummy_input
# }

# sy_FFN = FFN(config, inputs=ffn_inputs)
# print(sy_FFN)

## Transformer

https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L273

In [7]:
class TransformerBlock(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        assert config.dim % config.n_heads == 0
        
        attn_dummy_inputs = {
            "query": kwargs["inputs"]["x"],
            "key": kwargs["inputs"]["x"],
            "value": kwargs["inputs"]["x"],
            "mask": kwargs["inputs"]["attn_mask"],
        }
        self.attention = MultiHeadSelfAttention(config, inputs=attn_dummy_inputs)
        
        ffn_dummy_inputs = {
            "input": kwargs["inputs"]["x"]
        }
        self.ffn = FFN(config, inputs=ffn_dummy_inputs)
        
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
        
    def forward(self, x, attn_mask):
        # Self-Attention
        sa_output = self.attention(
            query=x,
            key=x,
            value=x,
            mask=attn_mask
        )[0]

        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(input=sa_output)[0]  # (bs, seq_length, dim)
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

        return ffn_output

class Transformer(SyModule):
    def __init__(self, config, **kwargs):
        
        t0 = time.time()
        
        super().__init__(**kwargs)
        self.n_layers = config.n_layers
        
        self.layer = nn.ModuleList([
            TransformerBlock(config, inputs=kwargs["inputs"]) for _ in range(self.n_layers)
        ])
            
        print("transformer init: {:.2f} s".format(
            time.time() - t0))
  

    def forward(self, x, attn_mask):
        """
        - SyModule does not work; multiple inputs
        - getattr does not work; 
        """
        t0 = time.time()
        hidden_state = x
        
        
        
        # TODO fix ModuleList.__iter__; items in iter need to be ModulePointer
        for i in range(self.n_layers):
            layer = self.layer[i]
            hidden_state = layer(
                x=hidden_state,
                attn_mask=attn_mask
            )[0]
            
        print("transformer forward: {:.2f} s".format(
            time.time() - t0))
        return hidden_state

# bs = 10
# seq_length = 15
# dim = config.dim

# dummy_x = torch.randn(bs, seq_length, dim)
# dummy_mask = torch.ones([bs, seq_length], dtype=torch.long)
# transformer_dummy_inputs = {
#     "x": dummy_x,
#     "attn_mask": dummy_mask
# }
# sy_transformer = Transformer(config, inputs=transformer_dummy_inputs)

## SyDistilBert

In [8]:
class SyDistilBert(SyModule):
    def __init__(self, config: DistilBertConfig, **kwargs):
        """
        SyDistilBert is a re-implementation of huggingface DistilBert in pysyft, 
        with all non-torch-native submodules rewritten as SyModules.
        
        Use the `from_pretrained` and `from_config` classmethods to instantiate this
        model from an existing HuggingFace pretrained model.
        """
        super().__init__(**kwargs)
        self.config = config
        
        # Embeddings
        embedding_inputs = {
            'input_ids': kwargs['inputs']['input_ids']
        }
        self.embeddings = Embeddings(config=config, inputs=embedding_inputs)
        
        # Transformer
        transformer_x = torch.rand(*kwargs['inputs']['input_ids'].size(), config.dim)
        transformer_mask = kwargs['inputs']['attention_mask']
        transformer_inputs = {
            "x": transformer_x,
            "attn_mask": transformer_mask
        }
        
        self.transformer = Transformer(config=config,
                                       inputs=transformer_inputs)


    def forward(self, input_ids, attention_mask):
        input_embeds = self.embeddings(input_ids=input_ids)[0]

        out = self.transformer(x=input_embeds,
                               attn_mask=attention_mask)[0]
        return input_embeds
    
    @classmethod
    def from_pretrained(cls, model_name: str) -> "SyDistilBert":
        # Make dummy inputs
        dummy_x = torch.ones(1, 1, dtype=torch.long)
        dummy_mask = torch.ones(1, 1, dtype=torch.long)

        dummy_inputs = {
            "input_ids": dummy_x,
            "attention_mask": dummy_mask
        }
        
        # Load huggingface model
        hf_model = AutoModel.from_pretrained(model_name)
        
        # Construct model
        model = cls(hf_model.config, inputs=dummy_inputs)
        
        # Load weights
        model.load_state_dict(hf_model.state_dict())
        
        return model
    
    @classmethod
    def from_config(cls, config: DistilBertConfig) -> "SyDistilBert":
        # Make dummy inputs
        dummy_x = torch.ones(1, 1, dtype=torch.long)
        dummy_mask = torch.ones(1, 1, dtype=torch.long)

        dummy_inputs = {
            "input_ids": dummy_x,
            "attention_mask": dummy_mask
        }
        
        # Construct model
        model = cls(config, inputs=dummy_inputs)
        
        return model

## Test with small config

In [10]:
sydistilbert = SyDistilBert.from_config(config)

dummy_x = torch.ones(2, 2, dtype=torch.long)
dummy_mask = torch.ones(2, 2, dtype=torch.long)

@make_plan
def train(model=sydistilbert, x=dummy_x, attn_mask=dummy_mask):
    """Single training iteration"""
    opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
    out = model(input_ids=x,
                attention_mask=dummy_mask)[0]
    loss = remote_torch.mean(out[0])
    loss.backward()
    opt.step()
    return [model]

dummy_x = torch.ones(2, 2, dtype=torch.long)
dummy_mask = torch.ones(2, 2, dtype=torch.long)

train_ptr = train.send(alice_client)
model_ptr = train_ptr(model=sydistilbert, x=dummy_x, attn_mask=dummy_mask)
updated_model = model_ptr.get()[0]
print(updated_model)

RECOMPILING Embeddings
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING Transformer
RECOMPILING Embeddings
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING M

## Load full DistilBERT

In [10]:
# This takes a while, loading state_dict sometimes crashes kernel
sydistilbert = SyDistilBert.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
transformer init: 44.28 s
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FF

## Memory issue

In [9]:
# Memory issue when sending plan with distilbert-sized model
from transformers import AutoConfig

# Fully sized distilbert without pretrained weights (takes too long)
distilbert_config = AutoConfig.from_pretrained("distilbert-base-uncased")
sydistilbert = SyDistilBert.from_config(distilbert_config)

dummy_x = torch.ones(2, 2, dtype=torch.long)
dummy_mask = torch.ones(2, 2, dtype=torch.long)

@make_plan
def train(model=sydistilbert, x=dummy_x, attn_mask=dummy_mask):
    opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
    out = model(input_ids=x,
                attention_mask=dummy_mask)[0]
    loss = remote_torch.mean(out[0])
    loss.backward()
    opt.step()
    return [model]

dummy_x = torch.ones(2, 2, dtype=torch.long)
dummy_mask = torch.ones(2, 2, dtype=torch.long)

train_ptr = train.send(alice_client)
model_ptr = train_ptr(model=sydistilbert, x=dummy_x, attn_mask=dummy_mask)
updated_model = model_ptr.get()[0]
print(updated_model)

RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING FFN
transformer init: 46.83 s
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING TransformerBlock
RECOMPILING MultiHeadSelfAttention
RECOMPILING FFN
RECOMPILING MultiHeadSelfAttention
RECOMPILING FF

ValueError: Message syft.lib.torch.Module exceeds maximum protobuf size of 2GB: 3104940090