# BERT variant : ELECTRA

Efficiently Learning an Encoder that Classifies Token Replacements Accurately (ELECTRA) is another interesting variant of BERT. We pre-train BERT 
using the MLM and NSP tasks. We know that in the MLM task, we randomly mask 15% of the tokens and train BERT to predict the masked token. Instead of
using the MLM task as a pre-training objective, ELECTRA is pre-trained using a task called replaced token detection.

The replaced token detection task : 
The replaced token detection task is very similar to MLM, but instead of masking a token with the [MASK] token, we replace a token with a different 
token and train the model to classify whether the given tokens are actual or replaced tokens.

Why choose ELECTRA over BERT : 
One of the advantages of ELECTRA compared to BERT is that in BERT, we use MLM as a training objective where we mask only 15% of the tokens, so the 
training signal to the model is only those 15% of the tokens since it predicts only those masked tokens. But in ELECTRA, the training signal is
all the tokens because here, the model classifies whether all the given tokens are original or replaced.
    

In [1]:
# Suppressing "INFO" and "WARNING" messages by setting the verbosity of the Transformers library.
from transformers import logging
logging.set_verbosity_error()

# Import necessary libraries

In [2]:
from transformers import ElectraTokenizer, ElectraModel

# Load the pretrained model and tokenizer

In [3]:
model = ElectraModel.from_pretrained('google/electra-small-discriminator')

Downloading config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

In [4]:
model = ElectraModel.from_pretrained('google/electra-small-generator')

Downloading config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

In [5]:
model.config

ElectraConfig {
  "_name_or_path": "google/electra-small-generator",
  "architectures": [
    "ElectraForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "embedding_size": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "electra",
  "num_attention_heads": 4,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "summary_activation": "gelu",
  "summary_last_dropout": 0.1,
  "summary_type": "first",
  "summary_use_proj": true,
  "transformers_version": "4.30.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [6]:
tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-discriminator")

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

In [7]:
tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-generator")

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

# Generate Tokens

In [8]:
inputs = tokenizer("The dog is cute", return_tensors="pt")

In [9]:
print(inputs)

{'input_ids': tensor([[  101,  1996,  3899,  2003, 10140,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}


# Generate Embeddings

In [10]:
objects = model(**inputs)
hidden_rep = objects.last_hidden_state

In [12]:
print(hidden_rep.shape)
print(hidden_rep[0][0])

torch.Size([1, 6, 256])
tensor([-9.0066e-01,  1.6689e-01, -1.0052e+00, -2.0985e-01,  7.4740e-01,
         4.8050e-01,  8.9076e-01,  6.2886e-01, -1.7570e-01, -1.2542e+00,
        -1.5265e+00, -7.5520e-01,  4.2734e-01,  7.2674e-01,  1.0291e+00,
         4.0141e-01,  9.0369e-01, -2.0499e-01, -2.0861e-01, -4.8087e-01,
         6.7896e-01,  4.8444e-01,  1.0877e+00, -1.6838e-01,  2.2549e-01,
         4.9599e-02,  1.6649e+00, -5.5879e-01, -1.2861e+00,  1.4240e+00,
        -1.0735e+00, -2.0786e+00,  6.7480e-01,  1.6990e-01,  7.4186e-01,
        -8.1543e-01, -7.6766e+00,  3.8491e-01, -6.7032e-01, -8.5981e-01,
         7.3789e-01, -1.0038e-01,  2.8210e-01, -3.8917e-01, -1.0351e+00,
         9.6054e-01,  1.1270e-01,  2.5980e-01,  9.8965e-02, -1.9503e+00,
         4.4600e-01,  7.1412e-01,  6.2666e-01,  1.3863e-01, -5.6530e-03,
         1.7137e+00, -4.0496e-01, -4.6221e-01,  8.0276e-01,  2.6045e-01,
        -6.2693e-01, -8.1840e-01,  8.8280e-01, -1.2556e+00,  5.7909e-01,
         7.8505e-01,  1.150