In [3]:
from pathlib import Path

import math
import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
from pprint import pprint
import time

import syft as sy
from syft import VirtualMachine
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft.lib.python.list import List
from syft import logger
from syft import SyModule, SySequential

# transformers imports, not needed in AST
from transformers.models.distilbert.modeling_distilbert import DistilBertConfig, create_sinusoidal_embeddings

# Add in AST
from transformers.activations import gelu

logger.remove()

In [4]:
# Create client
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
remote_torch = ROOT_CLIENT.torch

# DistilBERT

- [x] PretrainedTokenizerFast
- [ ] DistilBertModel
    - [x] Embedding
    - [x] MultiHeadSelfAttention
    - [x] FFN
        - [ ] Chunking
    - [x] TransformerBlock
    - [X] Transformer
        - [ ] SyModuleList
- [ ] Load pretrained weights

### TODO
- Defining inputs isn't a great interface. 
    - Get inspiration from other libraries that pre-define inputs.
      Keras: https://keras.io/api/layers/core_layers/input/
    - Defining tensor: full control
    - defining dummy Input(shape, dtype, ..), easier use and closer to other APIs.
- support for output_shape or make_dummy_output in SyModule would be nice for nesting modules:
```
class Net(SyModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        # Simple case
        self.layer1 = SyLayer(input_shape=kwargs["input_shape"])
        self.layer2 = SyLayer(input_shape=self.layer1.output_shape)
        
        # Advanced case
        self.layer1 = SyLayer(inputs=kwargs['inputs])
        self.layer2 = SyLayer(inputs=self.layer1.make_dummy_output())
```


In [None]:
input_x = Input(
    shape=(-1, -1, 10),
    dtype=torch.long,
    sparse=True
)

In [8]:
# Small config for testing
vocab_size = 10
dim = 256
max_length = 100
n_heads = 2
hidden_dim = 128
n_layers = 2

config = DistilBertConfig(
    vocab_size=vocab_size,
    dim=dim,
    max_position_embeddings=max_length,
    n_heads=n_heads,
    hidden_dim=hidden_dim,
    n_layers=n_layers,
)

model = DistilBertModel(config)
type(model.get_input_embeddings())

torch.nn.modules.sparse.Embedding

## Embedding
https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L82

In [4]:
class Embeddings(SyModule):

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
        if config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids: torch.LongTensor):
        seq_length = input_ids.size(1)
        # TODO setting device from input_ids from remotely created tensor throws KeyError: UID <...>.
        position_ids = remote_torch.arange(seq_length)  # (max_seq_length)
        position_ids = remote_torch.unsqueeze(position_ids, 0).expand_as(input_ids)  # (bs, max_seq_length)
        
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(input_ids)  # (bs, max_seq_length, dim)

        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
        return embeddings


# dummy_x = torch.ones(10, 100).long()
# sy_embedding = Embeddings(config, inputs={'input_ids': dummy_x})

In [5]:
# Test

# @make_plan
# def train(model=sy_embedding, x=dummy_x):
#     opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
#     out = model(input_ids=x)
#     loss = remote_torch.mean(out[0])
#     loss.backward()
#     opt.step()
#     return [model]

# dummy_x = torch.LongTensor(5, 80).random_(10)


# train_ptr = train.send(alice_client)
# model_ptr = train_ptr(model=sy_embedding, x=torch.LongTensor(100, 90).random_(10))
# updated_model = model_ptr.get()[0]
# print(updated_model)

## MultiHeadSelfAttention

https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L116

In [6]:
from syft.lib.python.int import Int

class MultiHeadSelfAttention(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)

        self.pruned_heads = set()

    def prune_heads(self, heads):
        raise NotImplementedError()

    def shape(self, x, bs, dim_per_head):
            """separate heads for linear layers"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
        
    def unshape(self, x, bs, dim_per_head):
        """group heads"""
        return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
        
    def forward(self, query, key, value, mask):
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)
        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
        """
        bs = query.size(0)
        q_length = query.size(1)
        dim = query.size(2)
        k_length = key.size(1)
        
        dim_per_head = self.dim // self.n_heads
        mask_reshp = (bs, 1, 1, k_length)
        

        q = self.shape(self.q_lin(query), bs, dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        k = self.shape(self.k_lin(key), bs, dim_per_head)  # (bs, n_heads, k_length, dim_per_head)
        v = self.shape(self.v_lin(value), bs, dim_per_head)  # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = remote_torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(*mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
                
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)
        weights = remote_torch.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
        
        context = remote_torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = self.unshape(context, bs, dim_per_head)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        return context

# dummy inputs
bs = 10
seq_length = 15
dim = config.dim

dummy_q = torch.randn(bs, seq_length, dim)
dummy_k = torch.randn(bs, seq_length, dim)
dummy_v = torch.randn(bs, seq_length, dim)
dummy_mask = torch.ones([bs, seq_length], dtype=torch.long)
dummy_inputs = {
    "query": dummy_q,
    "key": dummy_k,
    "value": dummy_v,
    "mask": dummy_mask
}

# sy_mhsa = MultiHeadSelfAttention(config, inputs=dummy_inputs)


In [7]:
# Test

# @make_plan
# def train(model=sy_mhsa, query=dummy_q, key=dummy_k, value=dummy_v, mask=dummy_mask):
#     opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
#     out = model(query=query,
#                 key=key,
#                 value=value,
#                 mask=mask)
#     loss = remote_torch.mean(out[0])
#     loss.backward()
#     opt.step()
#     return [model]

# q = torch.randn(bs, seq_length, dim)
# k = torch.randn(bs, seq_length, dim)
# v = torch.randn(bs, seq_length, dim)
# mask = torch.ones([bs, seq_length], dtype=torch.long)
# batch = {
#     "query": q,
#     "key": k,
#     "value": v,
#     "mask": mask
# }

# train_ptr = train.send(alice_client)
# model_ptr = train_ptr(model=sy_mhsa, **batch)
# updated_model = model_ptr.get()[0]
# print(updated_model)

## FFN

In [8]:
class FFN(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.dropout = nn.Dropout(p=config.dropout)
        # self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
        # TODO transformers.activations.gelu in AST
        # NOTE torch F.gelu has a slightly different implementation, can't use that one.
        # assert config.activation in ["relu", "gelu"], f"activation ({config.activation}) must be in ['relu', 'gelu']"
        # self.activation = gelu if config.activation == "gelu" else nn.ReLU()
        self.activation = nn.ReLU()

    def forward(self, input):
        # TODO: Chunking
        # return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

# bs = 10
# seq_length = 15
# dim = config.dim

# dummy_input = torch.randn(bs, seq_length, dim)
# ffn_inputs = {
#     "input": dummy_input
# }

# sy_FFN = FFN(config, inputs=ffn_inputs)
# print(sy_FFN)

In [9]:
# Test

# @make_plan
# def train(model=sy_FFN, input=dummy_input):
#     opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
#     out = model(input=input)
#     loss = remote_torch.mean(out[0])
#     loss.backward()
#     opt.step()
#     return [model]

# bx = torch.randn(bs, seq_length, dim)

# train_ptr = train.send(alice_client)
# model_ptr = train_ptr(model=sy_FFN, input=bx)
# updated_model = model_ptr.get()[0]
# print(updated_model)

## Transformer

In [10]:
class TransformerBlock(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        assert config.dim % config.n_heads == 0
        
        attn_dummy_inputs = {
            "query": kwargs["inputs"]["x"],
            "key": kwargs["inputs"]["x"],
            "value": kwargs["inputs"]["x"],
            "mask": kwargs["inputs"]["attn_mask"],
        }
        self.attention = MultiHeadSelfAttention(config, inputs=attn_dummy_inputs)
        
        ffn_dummy_inputs = {
            "input": kwargs["inputs"]["x"]
        }
        self.ffn = FFN(config, inputs=ffn_dummy_inputs)
        
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
        
    def forward(self, x, attn_mask):
        # Self-Attention
        sa_output = self.attention(
            query=x,
            key=x,
            value=x,
            mask=attn_mask
        )[0]

        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(input=sa_output)[0]  # (bs, seq_length, dim)
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

        return ffn_output

class Transformer(SyModule):
    def __init__(self, config, **kwargs):
        
        t_init_0 = time.time()
        
        super().__init__(**kwargs)
        self.n_layers = config.n_layers
        
        # TODO replace this with SyModuleList
        for i in range(config.n_layers):
            setattr(self, f"layer_{i}", TransformerBlock(config, inputs=kwargs["inputs"]))
            
        print("transformer init:", time.time() - t_init_0)
            
            
    
    def forward(self, x, attn_mask):
        """
        - SyModule does not work; multiple inputs
        - getattr does not work; 
        """
        f_time = time.time()
        hidden_state = x
        
        # TODO SyModuleList
        for i in range(self.n_layers):
            layer = getattr(self, f"layer_{i}")
            hidden_state = layer(
                x=hidden_state,
                attn_mask=attn_mask
            )[0]
            
        print("forward plan time", time.time() - f_time)
        return hidden_state

bs = 10
seq_length = 15
dim = config.dim

dummy_x = torch.randn(bs, seq_length, dim)
dummy_mask = torch.ones([bs, seq_length], dtype=torch.long)
transformer_dummy_inputs = {
    "x": dummy_x,
    "attn_mask": dummy_mask
}
# sy_transformer = Transformer(config, inputs=transformer_dummy_inputs)

In [11]:
# Test

# @make_plan
# def train(model=sy_transformer, x=dummy_x, attn_mask=dummy_mask):
#     opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
#     out = model(x=x,
#                 attn_mask=attn_mask)
#     loss = remote_torch.mean(out[0][0])
#     loss.backward()
#     opt.step()
#     return [model]

# x = torch.randn(bs, seq_length, dim)
# mask = torch.ones([bs, seq_length], dtype=torch.long)
# batch = {
#     "x": x,
#     "attn_mask": mask
# }

# train_ptr = train.send(alice_client)
# model_ptr = train_ptr(model=sy_transformer, **batch)
# updated_model = model_ptr.get()[0]
# print(updated_model)

In [12]:


class SyDistilBert(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        
        # Embeddings
        embedding_inputs = {
            'input_ids': kwargs['inputs']['input_ids']
        }
        self.embeddings = Embeddings(config=config, inputs=embedding_inputs)
        
        # Transformer
        transformer_x = torch.rand(*kwargs['inputs']['input_ids'].size(), config.dim)
        transformer_mask = kwargs['inputs']['attention_mask']
        transformer_inputs = {
            "x": transformer_x,
            "attn_mask": transformer_mask
        }
        
        self.transformer = Transformer(config=config,
                                       inputs=transformer_inputs)


    def forward(self, input_ids, attention_mask):
        input_embeds = self.embeddings(input_ids=input_ids)[0]
        out = self.transformer(x=input_embeds,
                               attn_mask=attention_mask)[0]
        return out

# dummy_x = torch.ones(1, 1, dtype=torch.long)
# dummy_mask = torch.ones(1, 1, dtype=torch.long)
# sydistilbert = SyDistilBert(config, inputs={'input_ids': dummy_x, 'attention_mask': dummy_mask})

In [13]:
# Test

# @make_plan
# def train(model=sydistilbert, x=dummy_x, attn_mask=dummy_mask):
#     opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
#     out = model(input_ids=x,
#                 attention_mask=dummy_mask)
#     loss = remote_torch.mean(out[0][0])
#     loss.backward()
#     opt.step()
#     return [model]

# x = torch.ones([bs, seq_length], dtype=torch.long)
# mask = torch.ones([bs, seq_length], dtype=torch.long)
# batch = {
#     "x": x,
#     "attn_mask": mask
# }

# train_ptr = train.send(alice_client)
# model_ptr = train_ptr(model=sydistilbert, **batch)
# updated_model = model_ptr.get()[0]
# print(updated_model)

In [14]:
from transformers import AutoModel, AutoConfig

model_name = "distilbert-base-uncased"
distilbert_model = AutoModel.from_pretrained(model_name)
print(distilbert_model)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(i

In [None]:
dummy_x = torch.ones(1, 2, dtype=torch.long)
dummy_mask = torch.ones(1, 2, dtype=torch.long)
sydistilbert_model = SyDistilBert(distilbert_model.config, inputs={'input_ids': dummy_x, 'attention_mask': dummy_mask})
sydistilbert_model.load_state_dict(distilbert_model.load_state_dict())

RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
transformer init: 427.7230200767517
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
forward plan time 571.7766692638397


In [2]:
from transformers.models.distilbert.modeling_distilbert import Embeddings, Transformer, DistilBertModel