In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import datasets
import tokenizers

import numpy as np
from pprint import pprint

# Minimal transformers pipeline for sentiment analysis

In [3]:
# Load Tokenizer and Model
model_name = 'distilbert-base-uncased'
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
hf_model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Load dataset
train_dataset = datasets.load_dataset('imdb', split='train')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.bias', 'pre_classifi

In [4]:
# Check out dataset
pprint(train_dataset)
print()
print("Unique labels:", np.unique(train_dataset['label']))
print()
pprint(train_dataset[0])

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

Unique labels: [0 1]

{'label': 1,
 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some '
         'other programs about school life, such as "Teachers". My 35 years in '
         "the teaching profession lead me to believe that Bromwell High's "
         'satire is much closer to reality than is "Teachers". The scramble to '
         'survive financially, the insightful students who can see right '
         "through their pathetic teachers' pomp, the pettiness of the whole "
         'situation, all remind me of the schools I knew and their students. '
         'When I saw the episode in which a student repeatedly tried to burn '
         'down the school, I immediately recalled ......... at .......... '
         "High. A classic line: INSPECTOR: I'm here to sack one of your "
         'teachers. STUDENT: Welcome to Bromwell High. I expect that many '
         'adults of my age think that Bromwe

In [5]:
hf_model.distilbert

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(i

## Train loop

In [6]:
# Train setup
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# lr from BERT paper
opt = torch.optim.AdamW([
        {'params': hf_model.distilbert.parameters(), 'lr': 2e-5},
        {'params': hf_model.pre_classifier.parameters(), 'lr': 1e-3},
        {'params': hf_model.classifier.parameters(), 'lr': 1e-3}
    ], lr=1e-2)

batch_size = 32
max_length = 100 # Keep text short for CPU. Should be 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [7]:
def train_iter(model, tokenizer, opt, batch, max_length):
    # Single train iter
    model.train()
    opt.zero_grad()
    bx = tokenizer(batch['text'], padding=True, return_tensors='pt', truncation=True, max_length=100)
    by = torch.LongTensor(batch['label'])
    out = model(**bx).logits
    loss = F.cross_entropy(out, by)

    loss.backward()
    opt.step()
    return loss
    
batch = next(iter(train_loader))
loss = train_iter(hf_model, hf_tokenizer, opt, batch, max_length)
print('loss', loss)

loss tensor(0.6980, grad_fn=<NllLossBackward>)


In [87]:
# Make tokenizer ptr
tokenizer = hf_tokenizer._tokenizer
encoding = tokenizer.encode(text)

tokenizer_ptr = tokenizer.send(alice_client)
encoding_ptr = tokenizer_ptr.encode(text_ptr)
encoding_2 = encoding_ptr.get()
assert encoding.__getstate__() == encoding_2.__getstate__()

In [88]:
# Reconstruct tokenizer from ptr
from transformers import PreTrainedTokenizerFast

hf_tokenizer_2 = PreTrainedTokenizerFast(tokenizer_object=tokenizer_ptr.get())
hf_tokenizer_2.add_special_tokens({'pad_token': '[PAD]'})

# Test if same
enc = hf_tokenizer(batch['text'], padding=True, return_tensors='pt', truncation=True, max_length=100)
enc_2 = hf_tokenizer_2(batch['text'], padding=True, return_tensors='pt', truncation=True, max_length=100)

# Issue: It forgets some params from wrapper (pad tokens, etc)

In [22]:
type(th.long)

torch.dtype

# Syft

TODO
- tokenizer, `tokenizer.__call__`
- BatchEncoding
- model
    - serializable DistilBERT

In [11]:
from pathlib import Path

import numpy as np
import torch as th
from matplotlib import pyplot as plt

import syft as sy
from syft import VirtualMachine
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft.lib.python.list import List
from syft import logger
from syft import SyModule, SySequential

In [12]:
from syft.lib.python.string import String

# Create client and test string
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()

text = "This is a test."
text_ptr = String(text).send(alice_client)

In [21]:
from transformers.models.distilbert.modeling_distilbert import DistilBertConfig, create_sinusoidal_embeddings

class Embedding(SyModule):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
        if config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = th.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)

        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)

        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
        return embeddings

config = DistilBertConfig()
sy_embedding = Embedding(config, input_size=(10, 100), input_dtype=torch.long)

TypeError: __init__() got an unexpected keyword argument 'input_dtype'