In [1]:
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer
import torch
import numpy as np
import random

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

In [3]:
# leverage checkpoints for Bert2Bert model...
# use BERT's cls token as BOS token and sep token as EOS token
encoder = BertGenerationEncoder.from_pretrained("bert-base-uncased", bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationEncoder: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.embeddings.token_type_embeddings.weight']
- This IS expected if you are initializing BertGenerationEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertGenerationEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationDecoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.

In [4]:
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)

In [5]:
import pandas as pd
from sklearn.model_selection import train_test_split

with open('./rt-polaritydata/rt-polarity.neg') as f:
    train_text = f.read().split('\n')
with open('./rt-polaritydata/rt-polarity.pos') as f:
    test_text = f.read().split('\n')

texts = train_text + test_text
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, texts, test_size=.2)

In [6]:
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
train_label_encodings = tokenizer(train_labels, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)
test_label_encodings = tokenizer(test_labels, truncation=True, padding=True)

In [7]:
class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        item = {'input_ids': torch.tensor(self.encodings['input_ids'][idx]),
                'decoder_input_ids': torch.tensor(self.encodings['input_ids'][idx]),
                'labels': torch.tensor(self.encodings['input_ids'][idx]),}
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])
    
train_dataset = IMDbDataset(train_encodings)
test_dataset = IMDbDataset(test_encodings)

In [8]:
from transformers import Trainer, TrainingArguments

In [9]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
)

trainer = Trainer(
    model=bert2bert,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

In [10]:
trainer.train()



Step,Training Loss
500,4.634454
1000,1.916913
1500,1.540519
2000,0.971506
2500,0.569443
3000,0.361985
3500,0.230035
4000,0.157719
4500,0.111135
5000,0.083153




TrainOutput(global_step=10670, training_loss=0.5096384979791472)

In [35]:
bert2bert.save_pretrained('./do-nothing_bert.pt')

In [11]:
import os
os.listdir('./results/')

['checkpoint-500',
 'checkpoint-1000',
 'checkpoint-1500',
 'checkpoint-2000',
 'checkpoint-2500',
 'checkpoint-3000',
 'checkpoint-3500',
 'checkpoint-4000',
 'checkpoint-4500',
 'checkpoint-5000',
 'checkpoint-5500',
 'checkpoint-6000',
 'checkpoint-6500',
 'checkpoint-7000',
 'checkpoint-7500',
 'checkpoint-8000',
 'checkpoint-8500',
 'checkpoint-9000',
 'checkpoint-9500',
 'checkpoint-10000',
 'checkpoint-10500']

In [17]:
sentence = 'to be or not to be, that is the question'
content_tokens = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0).cuda()

In [42]:
bert2bert.decoder.save_pretrained('./do-nothing_decoder')

In [18]:
content_tokens.device

device(type='cuda', index=0)

In [31]:
output = bert2bert.generate(content_tokens)

In [32]:
output

tensor([[ 101, 2000, 2022, 2030, 2025, 2030, 2025, 2000, 2022, 1010, 2008, 2008,
         2003, 2008, 2008, 2003, 2000, 2022, 1010, 2008]], device='cuda:0')

In [34]:
tokenizer.decode(output[0])

'[CLS] to be or not or not to be, that that is that that is to be, that'

In [37]:
bert2bert2 = EncoderDecoderModel.from_pretrained('./do-nothing_bert.pt/')

In [43]:
sentence = 'and thus all ends with the death of flame'
content_tokens = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
output = bert2bert2.generate(content_tokens)
tokenizer.decode(output[0])

'[CLS] and thus all all ends with the death ends with the death ends with the death ends with the'