In [1]:
import syft as sy
import transformers
import torch
import datasets

from syft.core.plan.plan_builder import make_plan, ROOT_CLIENT

In [2]:
model_name = sy.lib.python.String("cardiffnlp/twitter-xlm-roberta-base-sentiment")

In [3]:
# Create client
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
remote_torch = ROOT_CLIENT.torch

batch_size = 5

In [4]:
train_set = datasets.load_dataset('imdb', split='train')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

Reusing dataset imdb (/home/eelco/.cache/huggingface/datasets/imdb/plain_text/1.0.0/4ea52f2e58a08dbc12c2bd52d0d92b30b88c00230b4522801b3636782f625c5b)


In [7]:
# Local
full_model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
classifier = full_model.classifier

dummy_batches = sy.lib.python.List([next(iter(train_loader))])

@make_plan
def train(classifier=classifier, batches=dummy_batches, base_model_name=model_name):
    """
    Train classifier on batches, and return updated classifier
    """

    roberta_ptr = ROOT_CLIENT.transformers.models.xlm_roberta.modeling_xlm_roberta.XLMRobertaModel.from_pretrained(base_model_name, add_pooling_layer=False)
    tokenizer_ptr = ROOT_CLIENT.transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast.from_pretrained(base_model_name)
    opt = remote_torch.optim.AdamW(classifier.parameters(), lr=1e-3)
    
    for batch in batches:
        classifier.train()
        opt.zero_grad()
        
        # Prepare data
        batch_x = tokenizer_ptr(batch['text'], padding=True, return_tensors='pt', truncation=True, max_length=512)
        batch_y = batch['label']
        
        # Remote roberta
        with remote_torch.no_grad():
            roberta_ptr.eval()
            out = roberta_ptr(batch_x["input_ids"], batch_x["attention_mask"], return_dict=False)
            hidden_state = out[0]
            
        # classifier forward + backward
        out = classifier(features=hidden_state)
        loss = remote_torch.nn.functional.cross_entropy(out, batch_y)
        loss.backward()
        opt.step()
    return [classifier]

[2021-06-18T14:39:19.296777+0200][CRITICAL][logger]][15826] __getattribute__ failed. If you are trying to access an EnumAttribute or a StaticAttribute, be sure they have been added to the AST. Falling back on__getattr__ to search in self.attrs for the requested field.
[2021-06-18T14:39:19.297382+0200][CRITICAL][logger]][15826] 'Class' object has no attribute 'from_pretrained'
[2021-06-18T14:39:24.118159+0200][CRITICAL][logger]][15826] __getattribute__ failed. If you are trying to access an EnumAttribute or a StaticAttribute, be sure they have been added to the AST. Falling back on__getattr__ to search in self.attrs for the requested field.
[2021-06-18T14:39:24.119351+0200][CRITICAL][logger]][15826] 'Class' object has no attribute 'from_pretrained'


In [8]:
dummy_batches = sy.lib.python.List([next(iter(train_loader))])
train_ptr = train.send(alice_client)
out_ptr = train_ptr(classifier=classifier, batches=dummy_batches, base_model_name=model_name)

In [9]:
out_ptr[0].get()

RobertaClassificationHead(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (out_proj): Linear(in_features=768, out_features=3, bias=True)
)