In [1]:
from pathlib import Path

import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt

import syft as sy
from syft import VirtualMachine
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft.lib.python.list import List
from syft import logger
from syft import SyModule, SySequential

from transformers.models.distilbert.modeling_distilbert import DistilBertConfig, create_sinusoidal_embeddings


In [10]:
# Create client
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
remote_torch = ROOT_CLIENT.torch

vocab_size = 10
dim = 256
max_length = 100
n_heads = 2
config = DistilBertConfig(vocab_size=vocab_size, dim=dim, max_position_embeddings=max_length, n_heads=2)

# DistilSyBERT

- [x] PretrainedTokenizerFast
- [ ] DistilBertModel
    - [x] Embedding
    - [ ] MultiHeadSelfAttention
    - [ ] FFN
    - [ ] TransformerBlock
    - [ ] Transformer
* Load pretrained weights



## Embedding
https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L82

In [3]:
# Implementation
from transformers.models.distilbert.modeling_distilbert import DistilBertConfig, create_sinusoidal_embeddings

class Embedding(SyModule):

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
        if config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        # TODO setting device from input_ids from remotely created tensor throws KeyError: UID <...>.
        position_ids = remote_torch.arange(seq_length)  # (max_seq_length)
        position_ids = remote_torch.unsqueeze(position_ids, 0).expand_as(input_ids)  # (bs, max_seq_length)
        
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(input_ids)  # (bs, max_seq_length, dim)

        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
        return embeddings

In [4]:
# Test

sy_embedding = Embedding(config, input_size=(10, 100), input_dtype=torch.long)

dummy_x = torch.LongTensor(5, 80).random_(10)

@make_plan
def train(model=sy_embedding, x=dummy_x):
    opt = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
    out = model(input_ids=x)
    loss = remote_torch.mean(out[0])
    loss.backward()
    opt.step()
    return [model]

train_ptr = train.send(alice_client)
model_ptr = train_ptr(model=sy_embedding, x=torch.LongTensor(100, 90).random_(10))
updated_model = model_ptr.get()[0]
print(updated_model)

[2021-05-28T15:03:38.154265+0200][CRITICAL][logger]][12992] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 665bc3a9eb894f24999fb4a7a0954284>.


RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
RECOMPILING
Embedding(
  (word_embeddings): Embedding(10, 256, padding_idx=0)
  (position_embeddings): Embedding(100, 256)
  (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


## MultiHeadSelfAttention

https://github.com/huggingface/transformers/blob/master/src/transformers/models/distilbert/modeling_distilbert.py#L116

In [11]:
class MultiHeadSelfAttention(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)

        self.pruned_heads = set()

    def prune_heads(self, heads):
        raise NotImplementedError
        
    def forward(self, query, key, value, mask):
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)
        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
        """
        
        bs, q_length, dim = query.size()
        k_length = key.size(1)


        dim_per_head = self.dim // self.n_heads

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x):
            """separate heads"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """group heads"""
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)

        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)

        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        return (context,)
    
MultiHeadSelfAttention(config)

ValueError: SyModule needs `input_size`: Tuple(Int) as kwarg to trace the forward plan.Also, make sure to call **super().__init__(**kwargs)** in ALL your SyModules

DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 256,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 100,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "transformers_version": "4.6.1",
  "vocab_size": 10
}