In [1]:
from transformers import set_seed
set_seed(916)

In [2]:
from bert_reduced import BertReducedForSequenceClassification

# Load the model
mlm_model = BertReducedForSequenceClassification.from_pretrained("cayjobla/bert-base-uncased-reduced", revision="main")
nsp_model = BertReducedForSequenceClassification.from_pretrained("cayjobla/bert-base-uncased-reduced", revision="pretrain")

Some weights of BertReducedForSequenceClassification were not initialized from the model checkpoint at cayjobla/bert-base-uncased-reduced and are newly initialized: ['bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertReducedForSequenceClassification were not initialized from the model checkpoint at cayjobla/bert-base-uncased-reduced and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
# Freze base model parameters
for param in mlm_model.base_model.parameters():
    param.requires_grad = False
for param in nsp_model.base_model.parameters():
    param.requires_grad = False

In [4]:
from transformers import AutoTokenizer

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("cayjobla/bert-base-uncased-reduced", revision="main")

### COLA

In [5]:
from datasets import load_dataset

# Load the raw data
task = "cola"
raw_datasets = load_dataset("glue", task)

Found cached dataset glue (/home/cayjobla/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
# Tokenize the raw data
tokenize = lambda batch: tokenizer(batch["sentence"], padding="max_length", truncation=True)
raw_datasets = raw_datasets.map(tokenize, batched=True, desc="Running tokenizer on dataset")

Loading cached processed dataset at /home/cayjobla/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-0ee71e98608de039.arrow


Running tokenizer on dataset:   0%|          | 0/1043 [00:00<?, ? examples/s]

Loading cached processed dataset at /home/cayjobla/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-94ff257a8be70cdb.arrow


In [7]:
import evaluate
import numpy as np

metric = evaluate.load("glue", task)

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    result = metric.compute(predictions=preds, references=p.label_ids)
    if len(result) > 1:
        result["combined_score"] = np.mean(list(result.values())).item()
    return result

In [8]:
from transformers import TrainingArguments, Trainer, default_data_collator

training_args = TrainingArguments(
    output_dir="bert-base-uncased-reduced",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=2e-4,
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    push_to_hub=False,
    logging_steps=10,
    run_name="glue-" + task,
)

mlm_trainer = Trainer(
    model=mlm_model,
    args=training_args,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["validation"],
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
)

nsp_trainer = Trainer(
    model=nsp_model,
    args=training_args,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["validation"],
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
)

In [9]:
mlm_predictions = mlm_trainer.predict(raw_datasets["validation"])
mlm_predictions.metrics



{'test_loss': 0.6926127672195435,
 'test_matthews_correlation': 0.03755820756538691,
 'test_runtime': 21.3736,
 'test_samples_per_second': 48.798,
 'test_steps_per_second': 1.029}

In [10]:
nsp_predictions = nsp_trainer.predict(raw_datasets["validation"])
nsp_predictions.metrics



{'test_loss': 0.6931502223014832,
 'test_matthews_correlation': -0.0463559874942472,
 'test_runtime': 11.8901,
 'test_samples_per_second': 87.72,
 'test_steps_per_second': 1.85}

In [11]:
mlm_trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcayjobla[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Matthews Correlation
1,0.599,0.616445,0.0
2,0.5179,0.529401,0.319413
3,0.56,0.514348,0.347756
4,0.5354,0.529312,0.349696
5,0.5386,0.516906,0.366382
6,0.5166,0.536658,0.347968
7,0.529,0.508242,0.388052
8,0.5041,0.529188,0.368743
9,0.5379,0.525705,0.364825
10,0.4909,0.509662,0.374359




TrainOutput(global_step=1790, training_loss=0.5386588861156443, metrics={'train_runtime': 1403.9236, 'train_samples_per_second': 60.908, 'train_steps_per_second': 1.275, 'total_flos': 2.264780297988096e+16, 'train_loss': 0.5386588861156443, 'epoch': 10.0})

In [12]:
mlm_predictions = mlm_trainer.predict(raw_datasets["validation"])
mlm_predictions.metrics



{'test_loss': 0.5096619725227356,
 'test_matthews_correlation': 0.3743591779398503,
 'test_runtime': 11.3268,
 'test_samples_per_second': 92.083,
 'test_steps_per_second': 1.942}

In [13]:
nsp_trainer.train()



Epoch,Training Loss,Validation Loss,Matthews Correlation
1,0.5948,0.618346,0.0
2,0.5848,0.607151,0.0
3,0.5923,0.576861,0.13847
4,0.5667,0.588046,0.161881
5,0.5191,0.584617,0.227931
6,0.5273,0.601997,0.227936
7,0.5859,0.573096,0.249421
8,0.5682,0.55921,0.245942
9,0.5483,0.579262,0.253767
10,0.6002,0.561268,0.265572




TrainOutput(global_step=1790, training_loss=0.5761748817379915, metrics={'train_runtime': 1367.2836, 'train_samples_per_second': 62.54, 'train_steps_per_second': 1.309, 'total_flos': 2.264780297988096e+16, 'train_loss': 0.5761748817379915, 'epoch': 10.0})

In [14]:
nsp_predictions = nsp_trainer.predict(raw_datasets["validation"])
nsp_predictions.metrics



{'test_loss': 0.5612677335739136,
 'test_matthews_correlation': 0.2655724830811597,
 'test_runtime': 10.8295,
 'test_samples_per_second': 96.311,
 'test_steps_per_second': 2.031}