# 服务器离线环境：数据加载

In [1]:
from datasets import load_from_disk

dataset = load_from_disk("./yelp_review")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [3]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples<len(dataset)
    picks=[]
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [4]:
show_random_elements(dataset['train'],2)

Unnamed: 0,label,text
0,4 stars,"Had a fantastic lunch there. Enjoyed the braised cod -- a fantastic choice for anyone eating gluten or dairy free. Friendly staff, generous portions. Everybody in my work group liked what they had and we all agree we will go back! They should offer helmets with horns for anyone ordering the Viking Burger. Even enjoyed a couple of rounds of pool while waiting for the food to come out. Great rock music but a bit loud four lunch. If it's as loud next time I'll ask to get it turned down some."
1,1 star,"This place has an identity problem. The should not pretend to be a \""vintage\"" clothing store because they are not. Their idea of vintage is a one year old pair of True Religion jeans. I brought in some really cool actual vintage 1960's and 1970's coats and a beautiful old wool suit to sell them The twelve year old behind the counter looked at them for about 30 seconds and said \""no thanks\"". I asked her if she knew of any \""vintage\"" clothing stores in the area, not getting the dig she said, \"" I think there are some downtown on Main st.\"". I happened to spot another place further south on Maryland pkwy called Mustang Xchange at 4800 s. Maryland pkwy. I brought my items in and the girl there informed me the owners were not in but she looked at my items an told me the owner would definitely be interested in them and asked me to bring them back in. Buffalo Exchange needs to eliminate the \""vintage\"" in their description and get a location in the mall because thats really where they want to be. Stop pretending to be something you are not. If you want to see a cool little place go to Mustang Xchange instead."


# 预处理数据

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased", cache_dir='./models/')

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,4 stars,"They Do accept Credit Cards and Debit!!\nWould be 5 stars but $$$. 6/main item ie burrito, quesadilla, each dish etc.","[101, 1220, 2091, 4392, 14032, 10103, 1116, 1105, 3177, 9208, 106, 106, 165, 183, 2924, 6094, 5253, 1129, 126, 2940, 1133, 109, 109, 109, 119, 127, 120, 1514, 8926, 178, 1162, 171, 2149, 20376, 117, 15027, 23417, 5878, 117, 1296, 10478, 3576, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]"


# 微调训练配置


In [6]:
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5, cache_dir='./models/')


model_dir = "models/bert-base-cased"
training_args = TrainingArguments(output_dir=f"{model_dir}/",
                                  evaluation_strategy="epoch",
                                  logging_dir=f"{model_dir}/runs",
                                  per_device_train_batch_size=32,
                                  num_train_epochs=1,
                                  logging_steps=2000)
print(training_args)

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


TrainingArguments(
_n_gpu=2,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=epoch,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_

In [7]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"].shuffle(seed=42),
    eval_dataset=tokenized_datasets["test"].shuffle(seed=42),
    compute_metrics=compute_metrics,
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer.train()



Epoch,Training Loss,Validation Loss




# 看一下全量数据集训练的效果

In [10]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))
trainer.evaluate(small_test_dataset)



{'eval_loss': 0.799130380153656,
 'eval_accuracy': 0.68,
 'eval_runtime': 1.6428,
 'eval_samples_per_second': 60.872,
 'eval_steps_per_second': 4.261,
 'epoch': 1.0}