In [31]:
from transformers import ElectraModel, ElectraConfig, ElectraForPreTraining, AutoTokenizer, BertForMaskedLM

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
generator = BertForMaskedLM.from_pretrained('prajjwal1/bert-small')
discriminator = ElectraForPreTraining.from_pretrained('google/electra-base-discriminator')

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [32]:
from datasets import (load_dataset)
from transformers import DataCollatorForLanguageModeling
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm

import os
import html
import json
import torch.nn.functional as F
import copy

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

text_dataset = load_dataset("text", data_files="data/en-1.txt")
tokenized_dataset = text_dataset.map(lambda examples: tokenizer(examples["text"], truncation=True, return_special_tokens_mask=True), batched = True)
tokenized_dataset = tokenized_dataset['train']
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

train_dataset = tokenized_dataset.remove_columns(['text'])
train_dataloader = DataLoader(train_dataset, batch_size = 32, collate_fn = data_collator, shuffle=True)

generator = generator.to(device)
discriminator = discriminator.to(device)

# Generator only used for inferrence
for params in generator.parameters():
  params.requires_grad = False

optimizer = AdamW(discriminator.parameters(), lr=5e-5)

num_training_steps = int(1 / 8 * len(train_dataloader))
progress_bar = tqdm(range(num_training_steps))

accum_iter = 8
discriminator.train()
loss_step = 0
for epoch in range(1):
  for (i, batch) in enumerate(train_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}

    # Get Discriminator Labels (input_ids)
    real_sentence_batch = {}
    real_sentence_batch['input_ids'] = []
    real_sentence_batch['attention_mask'] = copy.deepcopy(batch['attention_mask'])
    real_sentence_batch['token_type_ids'] = copy.deepcopy(batch['token_type_ids'])
    for m in range(len(batch['input_ids'])):
      real_tokenized = copy.deepcopy(batch['input_ids'][m])
      for n in range(len(real_tokenized)):
        if (real_tokenized[n] == 103):
          real_tokenized[n] = batch['labels'][m][n]
      real_sentence_batch['input_ids'].append(real_tokenized)
    real_sentence_batch['input_ids'] = torch.stack(real_sentence_batch['input_ids'])

    logits = generator(**batch).logits
    mask_token_index = (batch['input_ids'] == tokenizer.mask_token_id).nonzero(as_tuple=True)
    entry, index = mask_token_index[0].tolist(), mask_token_index[1].tolist()
    for (x, y) in zip(entry, index):
      batch['input_ids'][x][y] = logits[x][y].argmax(axis=-1)

    real_sentence_batch['labels'] = []
    for x in range(len(batch['input_ids'])):
      label = [0] * len(batch['input_ids'][x])
      for y in range(len(label)):
        if (batch['input_ids'][x][y] != 0):
          label[y] = 1 if batch['input_ids'][x][y] != real_sentence_batch['input_ids'][x][y] else 0
      real_sentence_batch['labels'].append(label)
    real_sentence_batch['labels'] = torch.tensor(real_sentence_batch['labels'])
    real_sentence_batch = {k: v.to(device) for k, v in real_sentence_batch.items()}
    discriminator_output = discriminator(input_ids = batch['input_ids'].to(device), 
                                         attention_mask = batch['attention_mask'].to(device),
                                         token_type_ids = batch['token_type_ids'].to(device), 
                                         labels = real_sentence_batch['labels'].to(device))
    loss = discriminator_output.loss
    loss.backward()
    print(f"loss {i} = {loss}", end='\r')

    # Gradient Accumulation
    if (((i % accum_iter) == 0) or (i + 1 == len(train_dataloader)) and (i != 0)):
        optimizer.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        loss_step += 1
        
discriminator.save_pretrained("model/electra-1p")

Using custom data configuration default-a8c60a5cc6356686
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-a8c60a5cc6356686/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-a8c60a5cc6356686/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8/cache-bdaec421163bbd5c.arrow


  0%|          | 0/9711 [00:00<?, ?it/s]

loss 629 = 0.12290767580270767

KeyboardInterrupt: 

In [27]:
import torch

real_sentence = "I like pie."
corrupted_sentence = "I [MASK] pie."
tokenized_corrupted_sentence = tokenizer(corrupted_sentence, return_tensors="pt")

with torch.no_grad():
  logits = generator(**tokenized_corrupted_sentence.to(device)).logits

mask_token_index = (tokenized_corrupted_sentence.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
tokenizer.decode(predicted_token_id)

corrupted_sentence = corrupted_sentence.replace('[MASK]', tokenizer.decode(predicted_token_id))
# corrupted_sentence = "The capital of Indonesia is USA."
tokenized_real_sentence =  tokenizer(real_sentence, return_tensors="pt")
tokenized_corrupted_sentence = tokenizer(corrupted_sentence, return_tensors="pt")

discriminator_outputs = discriminator(tokenized_corrupted_sentence.input_ids.to(device))
predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
print("Tokens      : ", tokenizer.tokenize(corrupted_sentence, add_special_tokens=True))
print("Predictions : ", predictions.squeeze().tolist(), "Predictions: ", (torch.sign(discriminator_outputs[0]) + 1) / 2)
print("Labels      : ", [float(0) if(i == j) else float(1) for (i,j) in zip(tokenized_real_sentence.input_ids.squeeze().tolist(), tokenized_corrupted_sentence.input_ids.squeeze().tolist())])

Tokens      :  ['[CLS]', 'i', 'like', 'pie', '.', '[SEP]']
Predictions :  [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] Predictions:  tensor([[1., 1., 1., 1., 1., 1.]], device='cuda:0', grad_fn=<DivBackward0>)
Labels      :  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
