In [2]:
from transformers import ElectraModel, ElectraConfig, ElectraForPreTraining, AutoTokenizer, BertForMaskedLM

# Set up Generator + Discriminator [First]
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

generator = BertForMaskedLM.from_pretrained('bert-base-uncased')

discriminator = ElectraForPreTraining.from_pretrained('google/electra-base-discriminator')
discriminator_config = discriminator.config
discriminator = ElectraForPreTraining(discriminator_config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
from transformers import ElectraModel, ElectraConfig, ElectraForPreTraining, AutoTokenizer, BertForMaskedLM

# Set up Generator + Discriminator [Continue Training]
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
generator = BertForMaskedLM.from_pretrained('bert-base-uncased')
discriminator = ElectraForPreTraining.from_pretrained('model/electra-5p')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
from datasets import (load_dataset)
from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm

import os
import html
import json
import torch.nn.functional as F
import copy

def mask_input(input_ids, attention_mask):
  
  masked = input_ids.clone()

  for i in range(input_ids.shape[0]):
    rand = torch.rand(input_ids[i].shape).to(input_ids.device)
    mask_arr = (rand < 0.15) * (input_ids[i] != 101) * (input_ids[i] != 102) * (attention_mask[i] == 1)
    #print(mask_arr)
    selection = torch.flatten((mask_arr).nonzero()).tolist()
    #print("selection is", selection)
    masked[i, selection] = 103
    #print("====")

  return masked

# Params:
accum_iter = 8
batch_size = 32

# Set up Device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Set up dataset
text_dataset = load_dataset("text", data_files="data/en-5.txt")['train']
tokenized_dataset = text_dataset.map(
    lambda examples: tokenizer(examples["text"], truncation=True, return_special_tokens_mask=True), 
    batched = True
)

data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, padding=True
)

# Train DataLoader
train_dataset = tokenized_dataset.remove_columns(['text'])
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, collate_fn = data_collator, shuffle=True)

# Move Model To Device
generator = generator.to(device)
discriminator = discriminator.to(device)

# Generator only used for inferrence, so, freezed
for params in generator.parameters():
  params.requires_grad = False

optimizer = AdamW(discriminator.parameters(), lr=5e-5)

# Set Up Progress bar + Gradient Accumulation, ignore "loss_step"
if accum_iter == None or accum_iter < 1:
    accum_iter = 1
num_training_steps = int(1 / accum_iter * len(train_dataloader))
progress_bar = tqdm(range(num_training_steps))
discriminator.train()
loss_step = 0
print(f"""Starting to train:
batch size      = {batch_size}
gradient accum  = {accum_iter}
dataloader size = {len(train_dataloader)}
""")
for epoch in range(1):
  for (i, batch) in enumerate(train_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}

    # Set a new batch (batch for discriminator)
    input_ids = batch['input_ids']
    attention_mask = batch["attention_mask"]
    token_type_ids = batch["token_type_ids"]
    labels = input_ids.clone()
    labels = torch.where(attention_mask == 1, labels, -100)

    masked_input = mask_input(input_ids, attention_mask)

    generator_output = generator(input_ids = masked_input, 
                                 attention_mask = attention_mask, 
                                 token_type_ids = token_type_ids) #Labels not needed, freezing it

    softmaxValue = torch.nn.functional.softmax(generator_output.logits, dim = 1)
    optToken = torch.argmax(softmaxValue, dim = 2)
    new_inputs = torch.where((masked_input == 103), optToken, input_ids)

    if labels is not None:
      labels = torch.where((labels != -100), (new_inputs != labels).type_as(labels), labels)

    discriminator_output = discriminator(input_ids = new_inputs, 
                                         attention_mask = attention_mask, 
                                         token_type_ids = token_type_ids, 
                                         labels=labels)
    loss = discriminator_output.loss
    loss.backward()

    # Gradient Accumulation
    if (((i % accum_iter) == 0) or (i + 1 == len(train_dataloader)) and (i != 0)):
        optimizer.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        loss_step += 1
        
        if (loss_step % 100 == 0 and loss_step != 0):
            print(f"loss at {loss_step} = {loss}")
        
discriminator.save_pretrained("model/electra-5p-epoch_2")

Using custom data configuration default-b215a6c16c929d49
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-b215a6c16c929d49/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-b215a6c16c929d49/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8/cache-4fe48c5930fdda04.arrow


  0%|          | 0/40258 [00:00<?, ?it/s]

Starting to train:
batch size      = 32
gradient accum  = 8
dataloader size = 322064

loss at 100 = 0.16511881351470947
loss at 200 = 0.15457402169704437
loss at 300 = 0.16783452033996582
loss at 400 = 0.14803257584571838
loss at 500 = 0.1818133145570755
loss at 600 = 0.17545990645885468
loss at 700 = 0.18486371636390686
loss at 800 = 0.16348466277122498
loss at 900 = 0.17506049573421478
loss at 1000 = 0.1657273769378662
loss at 1100 = 0.18806779384613037
loss at 1200 = 0.15661698579788208
