In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import datasets
import tokenizers

import numpy as np
from pprint import pprint

# Minimal transformers pipeline for sentiment analysis

In [3]:
# Load Tokenizer and Model
model_name = 'distilbert-base-uncased'
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
hf_model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Load dataset
train_dataset = datasets.load_dataset('imdb', split='train')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifi

In [4]:
# Check out dataset
pprint(train_dataset)
print()
print("Unique labels:", np.unique(train_dataset['label']))
print()
pprint(train_dataset[0])

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

Unique labels: [0 1]

{'label': 1,
 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some '
         'other programs about school life, such as "Teachers". My 35 years in '
         "the teaching profession lead me to believe that Bromwell High's "
         'satire is much closer to reality than is "Teachers". The scramble to '
         'survive financially, the insightful students who can see right '
         "through their pathetic teachers' pomp, the pettiness of the whole "
         'situation, all remind me of the schools I knew and their students. '
         'When I saw the episode in which a student repeatedly tried to burn '
         'down the school, I immediately recalled ......... at .......... '
         "High. A classic line: INSPECTOR: I'm here to sack one of your "
         'teachers. STUDENT: Welcome to Bromwell High. I expect that many '
         'adults of my age think that Bromwe

In [5]:
hf_model.distilbert.transformer.layer[0].attention

MultiHeadSelfAttention(
  (dropout): Dropout(p=0.1, inplace=False)
  (q_lin): Linear(in_features=768, out_features=768, bias=True)
  (k_lin): Linear(in_features=768, out_features=768, bias=True)
  (v_lin): Linear(in_features=768, out_features=768, bias=True)
  (out_lin): Linear(in_features=768, out_features=768, bias=True)
)

## Train loop

In [6]:
# Train setup
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# lr from BERT paper
opt = torch.optim.AdamW([
        {'params': hf_model.distilbert.parameters(), 'lr': 2e-5},
        {'params': hf_model.pre_classifier.parameters(), 'lr': 1e-3},
        {'params': hf_model.classifier.parameters(), 'lr': 1e-3}
    ], lr=1e-2)

batch_size = 32
max_length = 100 # Keep text short for CPU. Should be 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [7]:
def train_iter(model, tokenizer, opt, batch, max_length):
    # Single train iter
    model.train()
    opt.zero_grad()
    bx = tokenizer(batch['text'], padding=True, return_tensors='pt', truncation=True, max_length=100)
    by = torch.LongTensor(batch['label'])
    out = model(**bx).logits
    loss = F.cross_entropy(out, by)

    loss.backward()
    opt.step()
    return loss
    
batch = next(iter(train_loader))
loss = train_iter(hf_model, hf_tokenizer, opt, batch, max_length)
print('loss', loss)

loss tensor(0.6874, grad_fn=<NllLossBackward>)


# Syft

TODO
- tokenizer, `tokenizer.__call__`
- BatchEncoding
- model
    - serializable DistilBERT

In [8]:
import syft as sy
from syft import make_plan
from syft.lib.python.string import String
from syft import logger

# sy.load('tokenizers')
# sy.load('transformers')

# Create client and test string
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()

text = "This is a test."
text_ptr = String(text).send(alice_client)

In [87]:
# Make tokenizer ptr
tokenizer = hf_tokenizer._tokenizer
encoding = tokenizer.encode(text)

tokenizer_ptr = tokenizer.send(alice_client)
encoding_ptr = tokenizer_ptr.encode(text_ptr)
encoding_2 = encoding_ptr.get()
assert encoding.__getstate__() == encoding_2.__getstate__()

In [88]:
# Reconstruct tokenizer from ptr
from transformers import PreTrainedTokenizerFast

hf_tokenizer_2 = PreTrainedTokenizerFast(tokenizer_object=tokenizer_ptr.get())
hf_tokenizer_2.add_special_tokens({'pad_token': '[PAD]'})

# Test if same
enc = hf_tokenizer(batch['text'], padding=True, return_tensors='pt', truncation=True, max_length=100)
enc_2 = hf_tokenizer_2(batch['text'], padding=True, return_tensors='pt', truncation=True, max_length=100)

# Issue: It forgets some params from wrapper (pad tokens, etc)

In [100]:
for k in enc.keys():
    print(torch.eq(enc[k], enc_2[k]).all())

tensor(True)
tensor(True)


In [11]:
import transformers
transformers.AutoTokenizer.from_pretrained

<bound method AutoTokenizer.from_pretrained of <class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>>

In [109]:
hf_tokenizer.save_pretrained()

PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [14]:
sy.load('transformers')

transformers.models.auto.tokenization_auto.AutoTokenizer