In [1]:
from transformers import BertTokenizer, BertForMaskedLM, BertTokenizer, BertForSequenceClassification

def load_model(name):
    model_path = name
    model = BertForSequenceClassification.from_pretrained(
        model_path,
        output_attentions=False
    )
    return model.cuda()
tokenizer_path = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(tokenizer_path, do_lower_case=True)

models = [f"./finetuned-scrumbled-wikitext2/checkpoint-{i}" for i in range(500, 13501, 500)]

In [2]:
import pandas as pd
import torch
import datasets


def preprocess_function(examples):
    res = tokenizer(examples["text"], truncation=True, return_tensors="pt", padding=True)
    for elem in res:
        res[elem] = res[elem].cuda()
    return res

train_ds = datasets.Dataset.from_csv("target_problem/train.csv").map(preprocess_function, batched=True)
test_ds = datasets.Dataset.from_csv("target_problem/test.csv").map(preprocess_function, batched=True)
train_ds

Using custom data configuration default-893c33a0074ecc88
Found cached dataset csv (/home/sha43/.cache/huggingface/datasets/csv/default-893c33a0074ecc88/0.0.0)
Loading cached processed dataset at /home/sha43/.cache/huggingface/datasets/csv/default-893c33a0074ecc88/0.0.0/cache-e8426cbdf3a0f8a0.arrow
Using custom data configuration default-5194cde6cade5dda
Found cached dataset csv (/home/sha43/.cache/huggingface/datasets/csv/default-5194cde6cade5dda/0.0.0)
Loading cached processed dataset at /home/sha43/.cache/huggingface/datasets/csv/default-5194cde6cade5dda/0.0.0/cache-d9f72de92ccf7b02.arrow


Dataset({
    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 3468
})

In [3]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [4]:
import evaluate

accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    print(eval_pred)
    predictions, labels = eval_pred
    predictions = predictions[0].argmax(axis=1)
    print(predictions)
    res = accuracy.compute(predictions=predictions, references=labels)
    print(res)
    return res

In [5]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

for model_name in models:
    model = load_model(model_name).cuda()

    training_args = TrainingArguments(
        output_dir=f"res/{model_name}",
        learning_rate=2e-5,
        num_train_epochs=2,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        eval_accumulation_steps=2,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=train_ds,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    res = trainer.train()
    break
print(res)

Some weights of the model checkpoint at ./finetuned-scrumbled-wikitext2/checkpoint-500 were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint a

Epoch,Training Loss,Validation Loss


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3468
  Batch size = 16


KeyboardInterrupt: 