In [27]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, DataCollatorWithPadding, AdamW, get_scheduler

raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModel.from_pretrained(checkpoint)


def tokenize():
    # preprocess tokenized data
    def tokenize_function(example):
        return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

    return raw_datasets.map(tokenize_function, batched=True)


preprocessed_tokenized_datasets = tokenize()

tokenized_datasets = preprocessed_tokenized_datasets \
    .remove_columns(["sentence1", "sentence2", "idx"]) \
    .rename_column("label", "labels")
tokenized_datasets.set_format("torch")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

Found cached dataset glue (/home/phung/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

KeyboardInterrupt: 

In [24]:
# The manual way
# optimizer = AdamW(model.parameters(), lr=5e-5)
# num_epochs = 3
# num_training_steps = num_epochs * len(train_dataloader)
# lr_scheduler = get_scheduler(
#     'linear', optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
# )
# TODO: Add loops and training

Found cached dataset glue (/home/phung/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/phung/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6f2bdda7e82181f7.arrow
Loading cached processed dataset at /home/phung/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-748380ffc2a3cc21.arrow
Loading cached processed dataset at /home/phung/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fa46886d6a800015.arrow
