<a href="https://colab.research.google.com/github/George-spider/BERT-NER/blob/main/BERT/Custom_Named_Entity_Recognition_with_BERT_only_first_wordpiece.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Fine-tuning BERT for named-entity recognition**

In this notebook, we are going to use **BertForTokenClassification** which is included in the [Transformers library](https://github.com/huggingface/transformers) by HuggingFace. This model has BERT as its base architecture, with a token classification head on top, allowing it to make predictions at the token level, rather than the sequence level. Named entity recognition is typically treated as a token classification problem, so that's what we are going to use it for.

This tutorial uses the idea of **transfer learning**, i.e. first pretraining a large neural network in an unsupervised way, and then fine-tuning that neural network on a task of interest. In this case, BERT is a neural network pretrained on 2 tasks: masked language modeling and next sentence prediction. Now, we are going to fine-tune this network on a NER dataset. Fine-tuning is supervised learning, so this means we will need a labeled dataset.

If you want to know more about BERT, I suggest the following resources:
* the original [paper](https://arxiv.org/abs/1810.04805)
* Jay Allamar's [blog post](http://jalammar.github.io/illustrated-bert/) as well as his [tutorial](http://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/)
* Chris Mccormick's [Youtube channel](https://www.youtube.com/channel/UCoRX98PLOsaN8PtekB9kWrw)
* Abbishek Kumar Mishra's [Youtube channel](https://www.youtube.com/user/abhisheksvnit)

The following notebook largely follows the same structure as the tutorials by Abhishek Kumar Mishra. For his tutorials on the Transformers library, see his [Github repository](https://github.com/abhimishra91/transformers-tutorials).

NOTE: this notebook assumes basic knowledge about deep learning, BERT, and native PyTorch. If you want to learn more Python, deep learning and PyTorch, I highly recommend cs231n by Stanford University and the FastAI course by Jeremy Howard et al. Both are freely available on the web.  

Now, let's move on to the real stuff!

#### **Importing Python Libraries and preparing the environment**

This notebook assumes that you have the following libraries installed:
* pandas
* numpy
* sklearn
* pytorch
* transformers
* seqeval

As we are running this in Google Colab, the only libraries we need to additionally install are transformers and seqeval (GPU version):

In [63]:
!pip install transformers seqeval[gpu]



In [64]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertConfig, BertForTokenClassification

As deep learning can be accellerated a lot using a GPU instead of a CPU, make sure you can run this notebook in a GPU runtime (which Google Colab provides for free! - check "Runtime" - "Change runtime type" - and set the hardware accelerator to "GPU").

We can set the default device to GPU using the following code (if it prints "cuda", it means the GPU has been recognized):

In [65]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

cuda


#### **Downloading and preprocessing the data**
Named entity recognition (NER) uses a specific annotation scheme, which is defined (at least for European languages) at the *word* level. An annotation scheme that is widely used is called **[IOB-tagging](https://en.wikipedia.org/wiki/Inside%E2%80%93outside%E2%80%93beginning_(tagging)**, which stands for Inside-Outside-Beginning. Each tag indicates whether the corresponding word is *inside*, *outside* or at the *beginning* of a specific named entity. The reason this is used is because named entities usually comprise more than 1 word.

Let's have a look at an example. If you have a sentence like "Barack Obama was born in Hawaï", then the corresponding tags would be   [B-PERS, I-PERS, O, O, O, B-GEO]. B-PERS means that the word "Barack" is the beginning of a person, I-PERS means that the word "Obama" is inside a person, "O" means that the word "was" is outside a named entity, and so on. So one typically has as many tags as there are words in a sentence.

So if you want to train a deep learning model for NER, it requires that you have your data in this IOB format (or similar formats such as [BILOU](https://stackoverflow.com/questions/17116446/what-do-the-bilou-tags-mean-in-named-entity-recognition)). There exist many annotation tools which let you create these kind of annotations automatically (such as Spacy's [Prodigy](https://prodi.gy/), [Tagtog](https://docs.tagtog.net/) or [Doccano](https://github.com/doccano/doccano)). You can also use Spacy's [biluo_tags_from_offsets](https://spacy.io/api/goldparse#biluo_tags_from_offsets) function to convert annotations at the character level to IOB format.

Here, we will use a NER dataset from [Kaggle](https://www.kaggle.com/namanj27/ner-dataset) that is already in IOB format. One has to go to this web page, download the dataset, unzip it, and upload the csv file to this notebook. Let's print out the first few rows of this csv file:

In [66]:
data = pd.read_csv("ner_datasetreference.csv", encoding='unicode_escape')
data.head()

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,,of,IN,O
2,,demonstrators,NNS,O
3,,have,VBP,O
4,,marched,VBN,O


Let's check how many sentences and words (and corresponding tags) there are in this dataset:

In [67]:
data.count()

Unnamed: 0,0
Sentence #,47959
Word,1048565
POS,1048575
Tag,1048575


As we can see, there are approximately 48,000 sentences in the dataset, comprising more than 1 million words and tags (quite huge!). This corresponds to approximately 20 words per sentence.

Let's have a look at the different NER tags, and their frequency:

In [68]:
print("Number of tags: {}".format(len(data.Tag.unique())))
frequencies = data.Tag.value_counts()
frequencies

Number of tags: 17


Unnamed: 0_level_0,count
Tag,Unnamed: 1_level_1
O,887908
B-geo,37644
B-tim,20333
B-org,20143
I-per,17251
B-per,16990
I-org,16784
B-gpe,15870
I-geo,7414
I-tim,6528


There are 8 category tags, each with a "beginning" and "inside" variant, and the "outside" tag. It is not really clear what these tags mean - "geo" probably stands for geographical entity, "gpe" for geopolitical entity, and so on. They do not seem to correspond with what the publisher says on Kaggle. Some tags seem to be underrepresented. Let's print them by frequency (highest to lowest):

In [69]:
tags = {}
for tag, count in zip(frequencies.index, frequencies):
    if tag != "O":
        if tag[2:5] not in tags.keys():
            tags[tag[2:5]] = count
        else:
            tags[tag[2:5]] += count
    continue

print(sorted(tags.items(), key=lambda x: x[1], reverse=True))

[('geo', 45058), ('org', 36927), ('per', 34241), ('tim', 26861), ('gpe', 16068), ('art', 699), ('eve', 561), ('nat', 252)]


Let's remove "art", "eve" and "nat" named entities, as performance on them will probably be not comparable to the other named entities.

In [70]:
entities_to_remove = ["B-art", "I-art", "B-eve", "I-eve", "B-nat", "I-nat"]
data = data[~data.Tag.isin(entities_to_remove)]
data.head()

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,,of,IN,O
2,,demonstrators,NNS,O
3,,have,VBP,O
4,,marched,VBN,O


We create 2 dictionaries: one that maps individual tags to indices, and one that maps indices to their individual tags. This is necessary in order to create the labels (as computers work with numbers = indices, rather than words = tags) - see further in this notebook.

In [71]:
labels_to_ids = {k: v for v, k in enumerate(data.Tag.unique())}
ids_to_labels = {v: k for v, k in enumerate(data.Tag.unique())}
labels_to_ids

{'O': 0,
 'B-geo': 1,
 'B-gpe': 2,
 'B-per': 3,
 'I-geo': 4,
 'B-org': 5,
 'I-org': 6,
 'B-tim': 7,
 'I-per': 8,
 'I-gpe': 9,
 'I-tim': 10}

As we can see, there are now only 10 different NER tags.

Now, we have to ask ourself the question: what is a training example in the case of NER, which is provided in a single forward pass? A training example is typically a **sentence**, with corresponding IOB tags. Let's group the words and corresponding tags by sentence:

In [72]:
# pandas has a very handy "forward fill" function to fill missing values based on the last upper non-nan value
data = data.fillna(method='ffill')
data.head()

  data = data.fillna(method='ffill')


Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,Sentence: 1,of,IN,O
2,Sentence: 1,demonstrators,NNS,O
3,Sentence: 1,have,VBP,O
4,Sentence: 1,marched,VBN,O


In [73]:
# let's create a new column called "sentence" which groups the words by sentence
data['sentence'] = data[['Sentence #','Word','Tag']].groupby(['Sentence #'])['Word'].transform(lambda x: ' '.join(x))
# let's also create a new column called "word_labels" which groups the tags by sentence
data['word_labels'] = data[['Sentence #','Word','Tag']].groupby(['Sentence #'])['Tag'].transform(lambda x: ','.join(x))
data.head()

Unnamed: 0,Sentence #,Word,POS,Tag,sentence,word_labels
0,Sentence: 1,Thousands,NNS,O,Thousands of demonstrators have marched throug...,"O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-..."
1,Sentence: 1,of,IN,O,Thousands of demonstrators have marched throug...,"O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-..."
2,Sentence: 1,demonstrators,NNS,O,Thousands of demonstrators have marched throug...,"O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-..."
3,Sentence: 1,have,VBP,O,Thousands of demonstrators have marched throug...,"O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-..."
4,Sentence: 1,marched,VBN,O,Thousands of demonstrators have marched throug...,"O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-..."


Let's only keep the "sentence" and "word_labels" columns, and drop duplicates:

In [74]:
data = data[["sentence", "word_labels"]].drop_duplicates().reset_index(drop=True)
data.head()

Unnamed: 0,sentence,word_labels
0,Thousands of demonstrators have marched throug...,"O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-..."
1,Families of soldiers killed in the conflict jo...,"O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,B-per,O,O,..."
2,They marched from the Houses of Parliament to ...,"O,O,O,O,O,O,O,O,O,O,O,B-geo,I-geo,O"
3,"Police put the number of marchers at 10,000 wh...","O,O,O,O,O,O,O,O,O,O,O,O,O,O,O"
4,The protest comes on the eve of the annual con...,"O,O,O,O,O,O,O,O,O,O,O,B-geo,O,O,B-org,I-org,O,..."


In [75]:
len(data)

47571

Let's verify that a random sentence and its corresponding tags are correct:

In [76]:
data.iloc[41].sentence

'Bedfordshire police said Tuesday that Omar Khayam was arrested in Bedford for breaching the conditions of his parole .'

In [77]:
data.iloc[41].word_labels

'B-gpe,O,O,B-tim,O,B-per,I-per,O,O,O,B-geo,O,O,O,O,O,O,O,O'

#### **Preparing the dataset and dataloader**

Now that our data is preprocessed, we can turn it into PyTorch tensors such that we can provide it to the model. Let's start by defining some key variables that will be used later on in the training/evaluation process:

In [78]:
MAX_LEN = 128
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
EPOCHS = 1
LEARNING_RATE = 1e-05
MAX_GRAD_NORM = 10
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

A tricky part of NER with BERT is that BERT relies on **wordpiece tokenization**, rather than word tokenization. This means that we should also define the labels at the wordpiece-level, rather than the word-level!

For example, if you have word like "Washington" which is labeled as "b-gpe", but it gets tokenized to "Wash", "##ing", "##ton", then one approach could be to handle this by only train the model on the tag labels for the first word piece token of a word (i.e. only label "Wash" with "b-gpe"). This is what was done in the original BERT paper, see Github discussion [here](https://github.com/huggingface/transformers/issues/64#issuecomment-443703063).

Note that this is a **design decision**. You could also decide to propagate the original label of the word to all of its word pieces and let the model train on this. In that case, the model should be able to produce the correct labels for each individual wordpiece. This was done in [this NER tutorial with BERT](https://github.com/chambliss/Multilingual_NER/blob/master/python/utils/main_utils.py#L118). Another design decision could be to give the first wordpiece of each word the original word label, and then use the label “X” for all subsequent subwords of that word. All of them seem to lead to good performance.

Below, we define a regular PyTorch [dataset class](https://pytorch.org/docs/stable/data.html) (which transforms examples of a dataframe to PyTorch tensors). Here, each sentence gets tokenized, the special tokens that BERT expects are added, the tokens are padded or truncated based on the max length of the model, the attention mask is created and the labels are created based on the dictionary which we defined above. Word pieces that should be ignored have a label of -100 (which is the default `ignore_index` of PyTorch's [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)).

For more information about BERT's inputs, see [here](https://huggingface.co/transformers/glossary.html).








In [83]:
class dataset(Dataset):
  def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

  def __getitem__(self, index):
        # step 1: get the sentence and word labels
        sentence = self.data.sentence[index].strip().split()
        word_labels = self.data.word_labels[index].split(",")

        # step 2: use tokenizer to encode sentence (includes padding/truncation up to max length)
        # BertTokenizerFast provides a handy "return_offsets_mapping" functionality for individual tokens
        encoding = self.tokenizer(sentence,
                             # is_pretokenized=True,  # This argument is deprecated
                             is_split_into_words=True,  # Use this instead
                             return_offsets_mapping=True,
                             padding='max_length',
                             truncation=True,
                             max_length=self.max_len)

        # step 3: create token labels only for first word pieces of each tokenized word
        labels = [labels_to_ids[label] for label in word_labels]
        # code based on https://huggingface.co/transformers/custom_datasets.html#tok-ner
        # create an empty array of -100 of length max_length
        encoded_labels = np.ones(len(encoding["offset_mapping"]), dtype=int) * -100

        # set only labels whose first offset position is 0 and the second is not 0
        i = 0
        for idx, mapping in enumerate(encoding["offset_mapping"]):
          if mapping[0] == 0 and mapping[1] != 0:
            # overwrite label
            try:
                encoded_labels[idx] = labels[i]
            except IndexError:
                # Handle the case where 'i' is out of range
                # This could involve skipping the current token or assigning a default label
                # print(f"IndexError at index {i} for sentence: {sentence}")  # Optional: Print for debugging
                # print(f"word_labels: {word_labels}, labels: {labels}, len(labels): {len(labels)}")
                break  # Skip the current token and proceed with the loop
            i += 1

        # step 4: turn everything into PyTorch tensors
        item = {key: torch.as_tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.as_tensor(encoded_labels)

        return item

  def __len__(self):
        return self.len

Now, based on the class we defined above, we can create 2 datasets, one for training and one for testing. Let's use a 80/20 split:

In [84]:
train_size = 0.8
train_dataset = data.sample(frac=train_size,random_state=200)
test_dataset = data.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(data.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = dataset(train_dataset, tokenizer, MAX_LEN)
testing_set = dataset(test_dataset, tokenizer, MAX_LEN)

FULL Dataset: (47571, 2)
TRAIN Dataset: (38057, 2)
TEST Dataset: (9514, 2)


Let's have a look at the first training example:

In [85]:
training_set[0]

{'input_ids': tensor([  101, 23564, 21030,  2099,  4967,  2001,  9388,  1011,  6109,  2005,
          2634,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,   

Let's verify that the input ids and corresponding targets are correct:

In [86]:
for token, label in zip(tokenizer.convert_ids_to_tokens(training_set[0]["input_ids"]), training_set[0]["labels"]):
  print('{0:10}  {1}'.format(token, label))

[CLS]       -100
za          3
##hee       -100
##r         -100
khan        8
was         0
mar         0
-           -100
93          -100
for         0
india       1
.           0
[SEP]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[PAD]       -100
[

Now, let's define the corresponding PyTorch dataloaders:

In [87]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

#### **Defining the model**

Here we define the model, BertForTokenClassification, and load it with the pretrained weights of "bert-base-uncased". The only thing we need to additionally specify is the number of labels (as this will determine the architecture of the classification head).

Note that only the base layers are initialized with the pretrained weights. The token classification head of top has just randomly initialized weights, which we will train, together with the pretrained weights, using our labelled dataset. This is also printed as a warning when you run the code cell below.

Then, we move the model to the GPU.

In [88]:
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=len(labels_to_ids))
model.to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

#### **Training the model**

Before training the model, let's perform a sanity check, which I learned thanks to Andrej Karpathy's wonderful [cs231n course](http://cs231n.stanford.edu/) at Stanford (see also his [blog post about debugging neural networks](http://karpathy.github.io/2019/04/25/recipe/)). The initial loss of your model should be close to -ln(1/number of classes) = -ln(1/17) = 2.83.

Why? Because we are using cross entropy loss. The cross entropy loss is defined as -ln(probability score of the model for the correct class). In the beginning, the weights are random, so the probability distribution for all of the classes for a given token will be uniform, meaning that the probability for the correct class will be near 1/17. The loss for a given token will thus be -ln(1/17). As PyTorch's [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) (which is used by `BertForTokenClassification`) uses *mean reduction* by default, it will compute the mean loss for each of the tokens in the sequence for which a label is provided.

Let's verify this:



In [89]:
inputs = training_set[2]
input_ids = inputs["input_ids"].unsqueeze(0)
attention_mask = inputs["attention_mask"].unsqueeze(0)
labels = inputs["labels"].unsqueeze(0)

input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)

outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
initial_loss = outputs[0]
initial_loss

tensor(2.4312, device='cuda:0', grad_fn=<NllLossBackward0>)

This looks good. Let's also verify that the logits of the neural network have a shape of (batch_size, sequence_length, num_labels):

In [90]:
tr_logits = outputs[1]
tr_logits.shape

torch.Size([1, 128, 11])

Next, we define the optimizer. Here, we are just going to use Adam with a default learning rate. One can also decide to use more advanced ones such as AdamW (Adam with weight decay fix), which is [included](https://huggingface.co/transformers/main_classes/optimizer_schedules.html) in the Transformers repository, and a learning rate scheduler, but we are not going to do that here.

In [91]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

Now let's define a regular PyTorch training function. It is partly based on [a really good repository about multilingual NER](https://github.com/chambliss/Multilingual_NER/blob/master/python/utils/main_utils.py#L344).

In [97]:
!pip install tqdm
def train(epoch):
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    total_steps = 0
    processed_samples = 0

    # 添加进度条
    from tqdm import tqdm
    progress_bar = tqdm(training_loader, desc=f"Epoch {epoch + 1}", leave=False)

    for batch_idx, batch in enumerate(progress_bar):
        try:
            # 数据准备
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
                'labels': batch['labels'].to(device)
            }

            # 前向传播
            outputs = model(**inputs)
            loss = outputs.loss
            logits = outputs.logits

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()

            # 计算指标
            current_loss = loss.item()
            total_loss += current_loss
            total_steps += 1
            processed_samples += inputs['input_ids'].size(0)

            # 计算准确率
            active_labels = inputs['labels'].view(-1) != -100
            predictions = logits.argmax(dim=-1).view(-1)
            true_labels = inputs['labels'].view(-1)

            active_preds = predictions[active_labels]
            active_true = true_labels[active_labels]

            if len(active_true) > 0:  # 确保有有效标签
                batch_accuracy = (active_preds == active_true).float().mean().item()
                total_accuracy += batch_accuracy

                # 更新进度条信息
                progress_bar.set_postfix({
                    'loss': f"{current_loss:.4f}",
                    'acc': f"{batch_accuracy:.4f}",
                    'avg_loss': f"{total_loss/total_steps:.4f}",
                    'avg_acc': f"{total_accuracy/total_steps:.4f}"
                })

            # 每100步打印详细日志
            if batch_idx % 100 == 0:
                tqdm.write(
                    f"Step {batch_idx} | "
                    f"Loss: {current_loss:.4f} | "
                    f"Acc: {batch_accuracy:.4f} | "
                    f"Samples: {processed_samples}/{len(training_loader.dataset)}"
                )

        except Exception as e:
            print(f"\nError processing batch {batch_idx}: {str(e)}")
            continue  # 跳过错误批次

    # Epoch总结
    avg_loss = total_loss / total_steps
    avg_acc = total_accuracy / total_steps

    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"Average Accuracy: {avg_acc:.4f}")
    print(f"Processed Samples: {processed_samples}")
    print(f"Total Steps: {total_steps}\n")

    return avg_loss, avg_acc



And let's train the model!

In [98]:

for epoch in range(EPOCHS):
    train_loss, train_acc = train(epoch)
    # 可以将这些值保存下来用于绘制学习曲线

Epoch 1:   0%|          | 2/9515 [00:00<25:47,  6.15it/s, loss=2.3799, acc=0.1474, avg_loss=2.4200, avg_acc=0.0969]

Step 0 | Loss: 2.4600 | Acc: 0.0465 | Samples: 4/38057


Epoch 1:   1%|          | 102/9515 [00:14<25:35,  6.13it/s, loss=0.3955, acc=0.8898, avg_loss=0.8070, avg_acc=0.8148]

Step 100 | Loss: 0.1825 | Acc: 0.9556 | Samples: 404/38057


Epoch 1:   2%|▏         | 202/9515 [00:27<18:36,  8.34it/s, loss=0.3325, acc=0.9022, avg_loss=0.5946, avg_acc=0.8536]

Step 200 | Loss: 0.3522 | Acc: 0.9010 | Samples: 804/38057


Epoch 1:   3%|▎         | 302/9515 [00:39<18:09,  8.46it/s, loss=0.3344, acc=0.9200, avg_loss=0.4896, avg_acc=0.8763]

Step 300 | Loss: 0.2222 | Acc: 0.9688 | Samples: 1204/38057


Epoch 1:   4%|▍         | 402/9515 [00:50<18:03,  8.41it/s, loss=0.1604, acc=0.9412, avg_loss=0.4232, avg_acc=0.8920]

Step 400 | Loss: 0.2969 | Acc: 0.8952 | Samples: 1604/38057


Epoch 1:   5%|▌         | 502/9515 [01:03<20:02,  7.50it/s, loss=0.1065, acc=0.9700, avg_loss=0.3765, avg_acc=0.9034]

Step 500 | Loss: 0.2000 | Acc: 0.9211 | Samples: 2004/38057


Epoch 1:   6%|▋         | 602/9515 [01:16<17:47,  8.35it/s, loss=0.1078, acc=0.9423, avg_loss=0.3446, avg_acc=0.9111]

Step 600 | Loss: 0.0952 | Acc: 0.9625 | Samples: 2404/38057


Epoch 1:   7%|▋         | 702/9515 [01:28<17:38,  8.32it/s, loss=0.4064, acc=0.9062, avg_loss=0.3218, avg_acc=0.9165]

Step 700 | Loss: 0.1245 | Acc: 0.9681 | Samples: 2804/38057


Epoch 1:   8%|▊         | 802/9515 [01:40<17:14,  8.42it/s, loss=0.0894, acc=0.9700, avg_loss=0.3037, avg_acc=0.9208]

Step 800 | Loss: 0.1150 | Acc: 0.9886 | Samples: 3204/38057


Epoch 1:   9%|▉         | 902/9515 [01:52<16:58,  8.45it/s, loss=0.0985, acc=0.9756, avg_loss=0.2881, avg_acc=0.9245]

Step 900 | Loss: 0.3181 | Acc: 0.9231 | Samples: 3604/38057


Epoch 1:  11%|█         | 1002/9515 [02:03<16:46,  8.46it/s, loss=0.1801, acc=0.9474, avg_loss=0.2739, avg_acc=0.9277]

Step 1000 | Loss: 0.1799 | Acc: 0.9492 | Samples: 4004/38057


Epoch 1:  12%|█▏        | 1102/9515 [02:16<19:20,  7.25it/s, loss=0.0619, acc=0.9868, avg_loss=0.2623, avg_acc=0.9305]

Step 1100 | Loss: 0.1152 | Acc: 0.9556 | Samples: 4404/38057


Epoch 1:  13%|█▎        | 1202/9515 [02:27<16:28,  8.41it/s, loss=0.0663, acc=0.9747, avg_loss=0.2543, avg_acc=0.9323]

Step 1200 | Loss: 0.3107 | Acc: 0.9167 | Samples: 4804/38057


Epoch 1:  14%|█▎        | 1302/9515 [02:39<16:23,  8.35it/s, loss=0.2013, acc=0.9470, avg_loss=0.2464, avg_acc=0.9340]

Step 1300 | Loss: 0.2802 | Acc: 0.9145 | Samples: 5204/38057


Epoch 1:  15%|█▍        | 1402/9515 [02:51<16:05,  8.40it/s, loss=0.2215, acc=0.9359, avg_loss=0.2391, avg_acc=0.9356]

Step 1400 | Loss: 0.1190 | Acc: 0.9688 | Samples: 5604/38057


Epoch 1:  16%|█▌        | 1502/9515 [03:03<15:55,  8.38it/s, loss=0.0460, acc=0.9911, avg_loss=0.2327, avg_acc=0.9371]

Step 1500 | Loss: 0.1240 | Acc: 0.9565 | Samples: 6004/38057


Epoch 1:  17%|█▋        | 1602/9515 [03:15<15:38,  8.44it/s, loss=0.2712, acc=0.9315, avg_loss=0.2273, avg_acc=0.9385]

Step 1600 | Loss: 0.3690 | Acc: 0.9367 | Samples: 6404/38057


Epoch 1:  18%|█▊        | 1702/9515 [03:27<15:24,  8.46it/s, loss=0.0174, acc=1.0000, avg_loss=0.2226, avg_acc=0.9397]

Step 1700 | Loss: 0.1911 | Acc: 0.9111 | Samples: 6804/38057


Epoch 1:  19%|█▉        | 1802/9515 [03:39<15:13,  8.44it/s, loss=0.0563, acc=0.9740, avg_loss=0.2178, avg_acc=0.9408]

Step 1800 | Loss: 0.2068 | Acc: 0.9367 | Samples: 7204/38057


Epoch 1:  20%|█▉        | 1902/9515 [03:51<14:59,  8.46it/s, loss=0.3724, acc=0.9359, avg_loss=0.2139, avg_acc=0.9418]

Step 1900 | Loss: 0.0558 | Acc: 0.9815 | Samples: 7604/38057


Epoch 1:  21%|██        | 2002/9515 [04:02<14:53,  8.41it/s, loss=0.2507, acc=0.9412, avg_loss=0.2094, avg_acc=0.9430]

Step 2000 | Loss: 0.1848 | Acc: 0.9451 | Samples: 8004/38057


Epoch 1:  22%|██▏       | 2102/9515 [04:14<14:41,  8.41it/s, loss=0.2059, acc=0.9273, avg_loss=0.2062, avg_acc=0.9437]

Step 2100 | Loss: 0.1173 | Acc: 0.9490 | Samples: 8404/38057


Epoch 1:  23%|██▎       | 2202/9515 [04:26<14:36,  8.34it/s, loss=0.1515, acc=0.9670, avg_loss=0.2033, avg_acc=0.9444]

Step 2200 | Loss: 0.1065 | Acc: 0.9643 | Samples: 8804/38057


Epoch 1:  24%|██▍       | 2302/9515 [04:38<14:12,  8.46it/s, loss=0.0144, acc=1.0000, avg_loss=0.2004, avg_acc=0.9451]

Step 2300 | Loss: 0.1260 | Acc: 0.9315 | Samples: 9204/38057


Epoch 1:  25%|██▌       | 2402/9515 [04:50<14:04,  8.42it/s, loss=0.0094, acc=1.0000, avg_loss=0.1969, avg_acc=0.9459]

Step 2400 | Loss: 0.0217 | Acc: 0.9853 | Samples: 9604/38057


Epoch 1:  26%|██▋       | 2502/9515 [05:02<13:52,  8.43it/s, loss=0.1269, acc=0.9186, avg_loss=0.1943, avg_acc=0.9466]

Step 2500 | Loss: 0.1351 | Acc: 0.9787 | Samples: 10004/38057


Epoch 1:  27%|██▋       | 2602/9515 [05:14<13:38,  8.45it/s, loss=0.1188, acc=0.9451, avg_loss=0.1921, avg_acc=0.9471]

Step 2600 | Loss: 0.3031 | Acc: 0.9425 | Samples: 10404/38057


Epoch 1:  28%|██▊       | 2702/9515 [05:26<13:29,  8.42it/s, loss=0.1208, acc=0.9541, avg_loss=0.1895, avg_acc=0.9477]

Step 2700 | Loss: 0.0548 | Acc: 0.9722 | Samples: 10804/38057


Epoch 1:  29%|██▉       | 2802/9515 [05:37<13:21,  8.37it/s, loss=0.0722, acc=0.9877, avg_loss=0.1872, avg_acc=0.9482]

Step 2800 | Loss: 0.0261 | Acc: 1.0000 | Samples: 11204/38057


Epoch 1:  30%|███       | 2902/9515 [05:49<13:01,  8.46it/s, loss=0.0297, acc=1.0000, avg_loss=0.1851, avg_acc=0.9487]

Step 2900 | Loss: 0.1209 | Acc: 0.9818 | Samples: 11604/38057


Epoch 1:  32%|███▏      | 3002/9515 [06:01<12:58,  8.37it/s, loss=0.0241, acc=1.0000, avg_loss=0.1831, avg_acc=0.9492]

Step 3000 | Loss: 0.0774 | Acc: 0.9714 | Samples: 12004/38057


Epoch 1:  33%|███▎      | 3102/9515 [06:13<12:39,  8.45it/s, loss=0.1356, acc=0.9796, avg_loss=0.1811, avg_acc=0.9497]

Step 3100 | Loss: 0.0537 | Acc: 0.9802 | Samples: 12404/38057


Epoch 1:  34%|███▎      | 3202/9515 [06:25<12:27,  8.45it/s, loss=0.1593, acc=0.9600, avg_loss=0.1791, avg_acc=0.9502]

Step 3200 | Loss: 0.0881 | Acc: 0.9753 | Samples: 12804/38057


Epoch 1:  35%|███▍      | 3302/9515 [06:37<12:15,  8.45it/s, loss=0.0348, acc=1.0000, avg_loss=0.1774, avg_acc=0.9506]

Step 3300 | Loss: 0.1132 | Acc: 0.9643 | Samples: 13204/38057


Epoch 1:  36%|███▌      | 3402/9515 [06:49<12:09,  8.38it/s, loss=0.0288, acc=0.9908, avg_loss=0.1760, avg_acc=0.9509]

Step 3400 | Loss: 0.1268 | Acc: 0.9855 | Samples: 13604/38057


Epoch 1:  37%|███▋      | 3502/9515 [07:01<11:55,  8.40it/s, loss=0.0923, acc=0.9700, avg_loss=0.1746, avg_acc=0.9512]

Step 3500 | Loss: 0.0784 | Acc: 0.9706 | Samples: 14004/38057


Epoch 1:  38%|███▊      | 3602/9515 [07:13<11:37,  8.48it/s, loss=0.0438, acc=1.0000, avg_loss=0.1728, avg_acc=0.9518]

Step 3600 | Loss: 0.0207 | Acc: 1.0000 | Samples: 14404/38057


Epoch 1:  39%|███▉      | 3702/9515 [07:24<11:27,  8.45it/s, loss=0.0164, acc=1.0000, avg_loss=0.1713, avg_acc=0.9521]

Step 3700 | Loss: 0.0317 | Acc: 1.0000 | Samples: 14804/38057


Epoch 1:  40%|███▉      | 3802/9515 [07:37<11:17,  8.44it/s, loss=0.1204, acc=0.9775, avg_loss=0.1701, avg_acc=0.9524]

Step 3800 | Loss: 0.1345 | Acc: 0.9725 | Samples: 15204/38057


Epoch 1:  41%|████      | 3902/9515 [07:49<11:08,  8.40it/s, loss=0.0918, acc=0.9643, avg_loss=0.1686, avg_acc=0.9527]

Step 3900 | Loss: 0.0728 | Acc: 0.9725 | Samples: 15604/38057


Epoch 1:  42%|████▏     | 4002/9515 [08:01<10:55,  8.41it/s, loss=0.1859, acc=0.9307, avg_loss=0.1676, avg_acc=0.9530]

Step 4000 | Loss: 0.1575 | Acc: 0.9324 | Samples: 16004/38057


Epoch 1:  43%|████▎     | 4102/9515 [08:13<10:41,  8.44it/s, loss=0.1081, acc=0.9429, avg_loss=0.1666, avg_acc=0.9532]

Step 4100 | Loss: 0.0188 | Acc: 1.0000 | Samples: 16404/38057


Epoch 1:  44%|████▍     | 4202/9515 [08:25<10:33,  8.38it/s, loss=0.0373, acc=0.9912, avg_loss=0.1655, avg_acc=0.9534]

Step 4200 | Loss: 0.0483 | Acc: 0.9868 | Samples: 16804/38057


Epoch 1:  45%|████▌     | 4302/9515 [08:37<10:16,  8.45it/s, loss=0.2041, acc=0.9130, avg_loss=0.1645, avg_acc=0.9536]

Step 4300 | Loss: 0.1026 | Acc: 0.9694 | Samples: 17204/38057


Epoch 1:  46%|████▋     | 4402/9515 [08:49<10:13,  8.33it/s, loss=0.1195, acc=0.9571, avg_loss=0.1633, avg_acc=0.9538]

Step 4400 | Loss: 0.0413 | Acc: 0.9811 | Samples: 17604/38057


Epoch 1:  47%|████▋     | 4502/9515 [09:01<09:57,  8.39it/s, loss=0.0397, acc=0.9706, avg_loss=0.1621, avg_acc=0.9541]

Step 4500 | Loss: 0.0618 | Acc: 0.9681 | Samples: 18004/38057


Epoch 1:  48%|████▊     | 4602/9515 [09:13<09:41,  8.45it/s, loss=0.1913, acc=0.9333, avg_loss=0.1608, avg_acc=0.9544]

Step 4600 | Loss: 0.0758 | Acc: 0.9626 | Samples: 18404/38057


Epoch 1:  49%|████▉     | 4702/9515 [09:25<09:30,  8.44it/s, loss=0.0909, acc=0.9762, avg_loss=0.1597, avg_acc=0.9546]

Step 4700 | Loss: 0.0238 | Acc: 1.0000 | Samples: 18804/38057


Epoch 1:  50%|█████     | 4802/9515 [09:36<09:25,  8.34it/s, loss=0.0200, acc=1.0000, avg_loss=0.1586, avg_acc=0.9549]

Step 4800 | Loss: 0.1256 | Acc: 0.9556 | Samples: 19204/38057


Epoch 1:  52%|█████▏    | 4902/9515 [09:48<09:03,  8.48it/s, loss=0.0929, acc=0.9886, avg_loss=0.1578, avg_acc=0.9550]

Step 4900 | Loss: 0.3520 | Acc: 0.8778 | Samples: 19604/38057


Epoch 1:  53%|█████▎    | 5002/9515 [10:00<08:56,  8.42it/s, loss=0.2600, acc=0.9652, avg_loss=0.1571, avg_acc=0.9552]

Step 5000 | Loss: 0.1697 | Acc: 0.9407 | Samples: 20004/38057


Epoch 1:  54%|█████▎    | 5102/9515 [10:12<08:44,  8.42it/s, loss=0.0456, acc=0.9804, avg_loss=0.1564, avg_acc=0.9554]

Step 5100 | Loss: 0.1727 | Acc: 0.9587 | Samples: 20404/38057


Epoch 1:  55%|█████▍    | 5202/9515 [10:24<08:29,  8.46it/s, loss=0.1161, acc=0.9714, avg_loss=0.1556, avg_acc=0.9555]

Step 5200 | Loss: 0.2097 | Acc: 0.9275 | Samples: 20804/38057


Epoch 1:  56%|█████▌    | 5302/9515 [10:36<08:20,  8.42it/s, loss=0.3248, acc=0.9529, avg_loss=0.1548, avg_acc=0.9556]

Step 5300 | Loss: 0.0595 | Acc: 0.9844 | Samples: 21204/38057


Epoch 1:  57%|█████▋    | 5402/9515 [10:48<08:10,  8.39it/s, loss=0.0572, acc=0.9825, avg_loss=0.1541, avg_acc=0.9558]

Step 5400 | Loss: 0.0534 | Acc: 0.9863 | Samples: 21604/38057


Epoch 1:  58%|█████▊    | 5502/9515 [11:00<07:56,  8.43it/s, loss=0.1322, acc=0.9474, avg_loss=0.1535, avg_acc=0.9559]

Step 5500 | Loss: 0.1663 | Acc: 0.9271 | Samples: 22004/38057


Epoch 1:  59%|█████▉    | 5602/9515 [11:12<07:43,  8.44it/s, loss=0.1410, acc=0.9655, avg_loss=0.1532, avg_acc=0.9560]

Step 5600 | Loss: 0.0696 | Acc: 0.9780 | Samples: 22404/38057


Epoch 1:  60%|█████▉    | 5702/9515 [11:24<07:34,  8.39it/s, loss=0.2095, acc=0.9434, avg_loss=0.1526, avg_acc=0.9561]

Step 5700 | Loss: 0.0397 | Acc: 0.9865 | Samples: 22804/38057


Epoch 1:  61%|██████    | 5802/9515 [11:36<07:21,  8.40it/s, loss=0.0954, acc=0.9646, avg_loss=0.1519, avg_acc=0.9563]

Step 5800 | Loss: 0.0162 | Acc: 1.0000 | Samples: 23204/38057


Epoch 1:  62%|██████▏   | 5902/9515 [11:48<07:09,  8.40it/s, loss=0.0831, acc=0.9600, avg_loss=0.1513, avg_acc=0.9564]

Step 5900 | Loss: 0.2226 | Acc: 0.9579 | Samples: 23604/38057


Epoch 1:  63%|██████▎   | 6002/9515 [11:59<06:59,  8.38it/s, loss=0.1689, acc=0.9510, avg_loss=0.1506, avg_acc=0.9565]

Step 6000 | Loss: 0.0675 | Acc: 0.9659 | Samples: 24004/38057


Epoch 1:  64%|██████▍   | 6102/9515 [12:11<06:49,  8.33it/s, loss=0.0875, acc=0.9589, avg_loss=0.1501, avg_acc=0.9566]

Step 6100 | Loss: 0.0243 | Acc: 1.0000 | Samples: 24404/38057


Epoch 1:  65%|██████▌   | 6202/9515 [12:23<06:32,  8.44it/s, loss=0.0766, acc=0.9775, avg_loss=0.1494, avg_acc=0.9568]

Step 6200 | Loss: 0.1541 | Acc: 0.9500 | Samples: 24804/38057


Epoch 1:  66%|██████▌   | 6302/9515 [12:35<06:29,  8.25it/s, loss=0.1245, acc=0.9659, avg_loss=0.1488, avg_acc=0.9570]

Step 6300 | Loss: 0.0328 | Acc: 1.0000 | Samples: 25204/38057


Epoch 1:  67%|██████▋   | 6402/9515 [12:47<06:22,  8.13it/s, loss=0.0360, acc=1.0000, avg_loss=0.1482, avg_acc=0.9571]

Step 6400 | Loss: 0.0748 | Acc: 0.9667 | Samples: 25604/38057


Epoch 1:  68%|██████▊   | 6502/9515 [12:59<06:08,  8.17it/s, loss=0.0782, acc=0.9717, avg_loss=0.1477, avg_acc=0.9572]

Step 6500 | Loss: 0.1428 | Acc: 0.9528 | Samples: 26004/38057


Epoch 1:  69%|██████▉   | 6602/9515 [13:11<05:54,  8.22it/s, loss=0.1529, acc=0.9574, avg_loss=0.1472, avg_acc=0.9573]

Step 6600 | Loss: 0.0498 | Acc: 0.9775 | Samples: 26404/38057


Epoch 1:  70%|███████   | 6702/9515 [13:23<05:41,  8.23it/s, loss=0.1849, acc=0.9286, avg_loss=0.1466, avg_acc=0.9574]

Step 6700 | Loss: 0.1389 | Acc: 0.9626 | Samples: 26804/38057


Epoch 1:  71%|███████▏  | 6802/9515 [13:35<05:27,  8.29it/s, loss=0.2271, acc=0.9296, avg_loss=0.1460, avg_acc=0.9576]

Step 6800 | Loss: 0.0033 | Acc: 1.0000 | Samples: 27204/38057


Epoch 1:  73%|███████▎  | 6902/9515 [13:46<05:14,  8.31it/s, loss=0.1635, acc=0.9600, avg_loss=0.1452, avg_acc=0.9578]

Step 6900 | Loss: 0.0116 | Acc: 1.0000 | Samples: 27604/38057


Epoch 1:  74%|███████▎  | 7002/9515 [13:58<04:58,  8.41it/s, loss=0.0280, acc=0.9785, avg_loss=0.1447, avg_acc=0.9579]

Step 7000 | Loss: 0.0443 | Acc: 0.9851 | Samples: 28004/38057


Epoch 1:  75%|███████▍  | 7102/9515 [14:10<04:49,  8.32it/s, loss=0.0419, acc=0.9868, avg_loss=0.1442, avg_acc=0.9580]

Step 7100 | Loss: 0.1398 | Acc: 0.9510 | Samples: 28404/38057


Epoch 1:  76%|███████▌  | 7202/9515 [14:22<04:37,  8.34it/s, loss=0.0045, acc=1.0000, avg_loss=0.1438, avg_acc=0.9581]

Step 7200 | Loss: 0.0661 | Acc: 0.9605 | Samples: 28804/38057


Epoch 1:  77%|███████▋  | 7302/9515 [14:34<04:24,  8.38it/s, loss=0.0877, acc=0.9670, avg_loss=0.1432, avg_acc=0.9582]

Step 7300 | Loss: 0.0855 | Acc: 0.9583 | Samples: 29204/38057


Epoch 1:  78%|███████▊  | 7402/9515 [14:46<04:10,  8.43it/s, loss=0.4674, acc=0.8305, avg_loss=0.1427, avg_acc=0.9583]

Step 7400 | Loss: 0.0659 | Acc: 0.9762 | Samples: 29604/38057


Epoch 1:  79%|███████▉  | 7502/9515 [14:58<03:59,  8.40it/s, loss=0.2076, acc=0.9109, avg_loss=0.1424, avg_acc=0.9584]

Step 7500 | Loss: 0.1523 | Acc: 0.9457 | Samples: 30004/38057


Epoch 1:  80%|███████▉  | 7602/9515 [15:10<03:47,  8.39it/s, loss=0.0489, acc=0.9820, avg_loss=0.1419, avg_acc=0.9585]

Step 7600 | Loss: 0.0945 | Acc: 0.9718 | Samples: 30404/38057


Epoch 1:  81%|████████  | 7702/9515 [15:22<03:35,  8.43it/s, loss=0.0818, acc=0.9747, avg_loss=0.1415, avg_acc=0.9586]

Step 7700 | Loss: 0.0983 | Acc: 0.9821 | Samples: 30804/38057


Epoch 1:  82%|████████▏ | 7802/9515 [15:33<03:25,  8.35it/s, loss=0.1320, acc=0.9592, avg_loss=0.1411, avg_acc=0.9587]

Step 7800 | Loss: 0.1847 | Acc: 0.9600 | Samples: 31204/38057


Epoch 1:  83%|████████▎ | 7902/9515 [15:45<03:11,  8.40it/s, loss=0.1968, acc=0.9054, avg_loss=0.1406, avg_acc=0.9588]

Step 7900 | Loss: 0.0026 | Acc: 1.0000 | Samples: 31604/38057


Epoch 1:  84%|████████▍ | 8002/9515 [15:57<02:59,  8.45it/s, loss=0.0125, acc=1.0000, avg_loss=0.1401, avg_acc=0.9589]

Step 8000 | Loss: 0.5697 | Acc: 0.8816 | Samples: 32004/38057


Epoch 1:  85%|████████▌ | 8102/9515 [16:09<02:48,  8.38it/s, loss=0.0823, acc=0.9815, avg_loss=0.1396, avg_acc=0.9590]

Step 8100 | Loss: 0.0775 | Acc: 0.9643 | Samples: 32404/38057


Epoch 1:  86%|████████▌ | 8202/9515 [16:21<02:36,  8.38it/s, loss=0.3692, acc=0.9091, avg_loss=0.1392, avg_acc=0.9591]

Step 8200 | Loss: 0.0366 | Acc: 1.0000 | Samples: 32804/38057


Epoch 1:  87%|████████▋ | 8302/9515 [16:33<02:25,  8.33it/s, loss=0.7453, acc=0.8143, avg_loss=0.1391, avg_acc=0.9591]

Step 8300 | Loss: 0.1856 | Acc: 0.9518 | Samples: 33204/38057


Epoch 1:  88%|████████▊ | 8402/9515 [16:45<02:12,  8.40it/s, loss=0.2420, acc=0.9204, avg_loss=0.1388, avg_acc=0.9592]

Step 8400 | Loss: 0.1833 | Acc: 0.9506 | Samples: 33604/38057


Epoch 1:  89%|████████▉ | 8502/9515 [16:57<02:00,  8.41it/s, loss=0.0323, acc=0.9909, avg_loss=0.1383, avg_acc=0.9593]

Step 8500 | Loss: 0.0117 | Acc: 1.0000 | Samples: 34004/38057


Epoch 1:  90%|█████████ | 8602/9515 [17:09<01:48,  8.40it/s, loss=0.1976, acc=0.9355, avg_loss=0.1380, avg_acc=0.9593]

Step 8600 | Loss: 0.0620 | Acc: 1.0000 | Samples: 34404/38057


Epoch 1:  91%|█████████▏| 8702/9515 [17:21<01:36,  8.41it/s, loss=0.1510, acc=0.9551, avg_loss=0.1377, avg_acc=0.9594]

Step 8700 | Loss: 0.1980 | Acc: 0.9605 | Samples: 34804/38057


Epoch 1:  93%|█████████▎| 8802/9515 [17:33<01:24,  8.39it/s, loss=0.0934, acc=0.9762, avg_loss=0.1373, avg_acc=0.9595]

Step 8800 | Loss: 0.1019 | Acc: 0.9895 | Samples: 35204/38057


Epoch 1:  94%|█████████▎| 8902/9515 [17:44<01:13,  8.38it/s, loss=0.0389, acc=0.9901, avg_loss=0.1370, avg_acc=0.9596]

Step 8900 | Loss: 0.0548 | Acc: 0.9853 | Samples: 35604/38057


Epoch 1:  95%|█████████▍| 9002/9515 [17:56<01:00,  8.43it/s, loss=0.1225, acc=0.9759, avg_loss=0.1366, avg_acc=0.9597]

Step 9000 | Loss: 0.0749 | Acc: 0.9692 | Samples: 36004/38057


Epoch 1:  96%|█████████▌| 9102/9515 [18:08<00:49,  8.42it/s, loss=0.3937, acc=0.8778, avg_loss=0.1362, avg_acc=0.9598]

Step 9100 | Loss: 0.2035 | Acc: 0.9355 | Samples: 36404/38057


Epoch 1:  97%|█████████▋| 9202/9515 [18:20<00:37,  8.41it/s, loss=0.0314, acc=0.9792, avg_loss=0.1358, avg_acc=0.9599]

Step 9200 | Loss: 0.0652 | Acc: 0.9677 | Samples: 36804/38057


Epoch 1:  98%|█████████▊| 9302/9515 [18:32<00:25,  8.38it/s, loss=0.0785, acc=0.9688, avg_loss=0.1356, avg_acc=0.9599]

Step 9300 | Loss: 0.0276 | Acc: 1.0000 | Samples: 37204/38057


Epoch 1:  99%|█████████▉| 9402/9515 [18:44<00:13,  8.40it/s, loss=0.0982, acc=0.9855, avg_loss=0.1351, avg_acc=0.9600]

Step 9400 | Loss: 0.1592 | Acc: 0.9579 | Samples: 37604/38057


Epoch 1: 100%|█████████▉| 9502/9515 [18:56<00:01,  8.41it/s, loss=0.0448, acc=0.9891, avg_loss=0.1347, avg_acc=0.9601]

Step 9500 | Loss: 0.1368 | Acc: 0.9701 | Samples: 38004/38057


                                                                                                                      


Epoch 1 Summary:
Average Loss: 0.1347
Average Accuracy: 0.9601
Processed Samples: 38057
Total Steps: 9515





#### **Evaluating the model**

Now that we've trained our model, we can evaluate its performance on the held-out test set (which is 20% of the data). Note that here, no gradient updates are performed, the model just outputs its logits.

In [101]:
def valid(model, testing_loader):
    """模型验证函数，返回真实标签和预测结果"""
    model.eval()
    total_loss = 0.0
    total_accuracy = 0.0
    total_steps = 0
    all_labels = []
    all_preds = []

    # 添加进度条和指标记录
    from tqdm import tqdm
    progress_bar = tqdm(testing_loader, desc="Validating", leave=False)

    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            try:
                # 数据准备
                inputs = {
                    'input_ids': batch['input_ids'].to(device),
                    'attention_mask': batch['attention_mask'].to(device),
                    'labels': batch['labels'].to(device)
                }

                # 前向传播
                outputs = model(**inputs)
                loss = outputs.loss
                logits = outputs.logits

                # 计算指标
                total_loss += loss.item()
                total_steps += 1

                # 计算准确率
                active_labels = inputs['labels'].view(-1) != -100
                predictions = logits.argmax(dim=-1).view(-1)
                true_labels = inputs['labels'].view(-1)

                active_preds = predictions[active_labels]
                active_true = true_labels[active_labels]

                if len(active_true) > 0:
                    batch_accuracy = (active_preds == active_true).float().mean().item()
                    total_accuracy += batch_accuracy

                    # 收集非填充token的预测结果
                    all_labels.extend(active_true.cpu().numpy())
                    all_preds.extend(active_preds.cpu().numpy())

                # 更新进度条
                progress_bar.set_postfix({
                    'val_loss': f"{loss.item():.4f}",
                    'val_acc': f"{batch_accuracy:.4f}",
                    'avg_loss': f"{total_loss/total_steps:.4f}",
                    'avg_acc': f"{total_accuracy/total_steps:.4f}"
                })

                # 每50步打印详细日志
                if batch_idx % 50 == 0:
                    tqdm.write(
                        f"Validation Step {batch_idx} | "
                        f"Loss: {loss.item():.4f} | "
                        f"Acc: {batch_accuracy:.4f}"
                    )

            except Exception as e:
                print(f"\nValidation Error at batch {batch_idx}: {str(e)}")
                continue

    # 计算最终指标
    avg_loss = total_loss / total_steps
    avg_accuracy = total_accuracy / total_steps

    # 转换ID到标签文本
    label_texts = [ids_to_labels[id] for id in all_labels]
    pred_texts = [ids_to_labels[id] for id in all_preds]

    # 打印最终结果
    print("\nValidation Complete:")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"Average Accuracy: {avg_accuracy:.4f}")
    print(f"Processed Samples: {len(all_labels)}")
    print(f"Total Steps: {total_steps}\n")

    return label_texts, pred_texts

As we can see below, performance is quite good! Accuracy on the test test is > 93%.

In [102]:
true_labels, pred_labels = valid(model, testing_loader)

Validating:   0%|          | 4/4757 [00:00<02:41, 29.38it/s, val_loss=0.1775, val_acc=0.9405, avg_loss=0.0793, avg_acc=0.9728]

Validation Step 0 | Loss: 0.0075 | Acc: 1.0000


Validating:   1%|          | 51/4757 [00:01<01:38, 47.83it/s, val_loss=0.0055, val_acc=1.0000, avg_loss=0.1115, avg_acc=0.9679]

Validation Step 50 | Loss: 0.2788 | Acc: 0.9375


Validating:   2%|▏         | 111/4757 [00:02<01:25, 54.40it/s, val_loss=0.0007, val_acc=1.0000, avg_loss=0.0986, avg_acc=0.9703]

Validation Step 100 | Loss: 0.2526 | Acc: 0.9348


Validating:   3%|▎         | 159/4757 [00:03<01:19, 58.02it/s, val_loss=0.0275, val_acc=0.9756, avg_loss=0.1012, avg_acc=0.9701]

Validation Step 150 | Loss: 0.5871 | Acc: 0.8723


Validating:   4%|▍         | 207/4757 [00:04<01:17, 58.66it/s, val_loss=0.0813, val_acc=0.9348, avg_loss=0.0978, avg_acc=0.9700]

Validation Step 200 | Loss: 0.0617 | Acc: 1.0000


Validating:   5%|▌         | 261/4757 [00:05<01:16, 58.97it/s, val_loss=0.1101, val_acc=0.9672, avg_loss=0.0967, avg_acc=0.9704]

Validation Step 250 | Loss: 0.1357 | Acc: 0.9559


Validating:   7%|▋         | 311/4757 [00:06<01:16, 58.37it/s, val_loss=0.0530, val_acc=0.9730, avg_loss=0.0970, avg_acc=0.9690]

Validation Step 300 | Loss: 0.0424 | Acc: 0.9778


Validating:   8%|▊         | 359/4757 [00:06<01:14, 58.78it/s, val_loss=0.0677, val_acc=0.9600, avg_loss=0.0966, avg_acc=0.9695]

Validation Step 350 | Loss: 0.1360 | Acc: 0.9556


Validating:   9%|▊         | 407/4757 [00:07<01:15, 57.79it/s, val_loss=0.0443, val_acc=0.9714, avg_loss=0.0957, avg_acc=0.9698]

Validation Step 400 | Loss: 0.0016 | Acc: 1.0000


Validating:  10%|▉         | 461/4757 [00:08<01:14, 57.76it/s, val_loss=0.0140, val_acc=1.0000, avg_loss=0.0984, avg_acc=0.9693]

Validation Step 450 | Loss: 0.0122 | Acc: 1.0000


Validating:  11%|█         | 503/4757 [00:09<01:31, 46.45it/s, val_loss=0.2888, val_acc=0.8974, avg_loss=0.0979, avg_acc=0.9691]

Validation Step 500 | Loss: 0.0577 | Acc: 0.9714


Validating:  12%|█▏        | 556/4757 [00:10<01:23, 50.53it/s, val_loss=0.2353, val_acc=0.9592, avg_loss=0.0975, avg_acc=0.9689]

Validation Step 550 | Loss: 0.0049 | Acc: 1.0000


Validating:  13%|█▎        | 610/4757 [00:11<01:17, 53.29it/s, val_loss=0.1727, val_acc=0.9302, avg_loss=0.0977, avg_acc=0.9689]

Validation Step 600 | Loss: 0.0444 | Acc: 0.9836


Validating:  14%|█▍        | 658/4757 [00:12<01:11, 57.68it/s, val_loss=0.0279, val_acc=1.0000, avg_loss=0.0974, avg_acc=0.9694]

Validation Step 650 | Loss: 0.1162 | Acc: 0.9615


Validating:  15%|█▍        | 712/4757 [00:13<01:08, 58.69it/s, val_loss=0.0126, val_acc=1.0000, avg_loss=0.0962, avg_acc=0.9699]

Validation Step 700 | Loss: 0.0004 | Acc: 1.0000


Validating:  16%|█▌        | 760/4757 [00:14<01:09, 57.68it/s, val_loss=0.1357, val_acc=0.9630, avg_loss=0.0971, avg_acc=0.9693]

Validation Step 750 | Loss: 0.0682 | Acc: 0.9592


Validating:  17%|█▋        | 808/4757 [00:15<01:08, 57.99it/s, val_loss=0.0871, val_acc=0.9583, avg_loss=0.0967, avg_acc=0.9691]

Validation Step 800 | Loss: 0.0625 | Acc: 0.9636


Validating:  18%|█▊        | 862/4757 [00:16<01:07, 58.04it/s, val_loss=0.0034, val_acc=1.0000, avg_loss=0.0975, avg_acc=0.9691]

Validation Step 850 | Loss: 0.0052 | Acc: 1.0000


Validating:  19%|█▉        | 910/4757 [00:16<01:06, 57.83it/s, val_loss=0.5928, val_acc=0.8222, avg_loss=0.0989, avg_acc=0.9685]

Validation Step 900 | Loss: 0.0586 | Acc: 0.9697


Validating:  20%|██        | 959/4757 [00:17<01:05, 57.80it/s, val_loss=0.0089, val_acc=1.0000, avg_loss=0.0996, avg_acc=0.9683]

Validation Step 950 | Loss: 0.0169 | Acc: 1.0000


Validating:  21%|██        | 1007/4757 [00:18<01:04, 57.80it/s, val_loss=0.0853, val_acc=0.9286, avg_loss=0.0996, avg_acc=0.9682]

Validation Step 1000 | Loss: 0.2976 | Acc: 0.9524


Validating:  22%|██▏       | 1061/4757 [00:19<01:04, 57.51it/s, val_loss=0.1787, val_acc=0.9259, avg_loss=0.0986, avg_acc=0.9686]

Validation Step 1050 | Loss: 0.0072 | Acc: 1.0000


Validating:  23%|██▎       | 1109/4757 [00:20<01:03, 57.46it/s, val_loss=0.2814, val_acc=0.9535, avg_loss=0.0980, avg_acc=0.9686]

Validation Step 1100 | Loss: 0.1139 | Acc: 0.9565


Validating:  24%|██▍       | 1157/4757 [00:21<01:03, 56.63it/s, val_loss=0.0003, val_acc=1.0000, avg_loss=0.0983, avg_acc=0.9684]

Validation Step 1150 | Loss: 0.0098 | Acc: 1.0000


Validating:  25%|██▌       | 1211/4757 [00:22<01:05, 54.33it/s, val_loss=0.0074, val_acc=1.0000, avg_loss=0.0991, avg_acc=0.9683]

Validation Step 1200 | Loss: 0.1442 | Acc: 0.9459


Validating:  26%|██▋       | 1259/4757 [00:23<01:05, 53.01it/s, val_loss=0.0850, val_acc=0.9655, avg_loss=0.0989, avg_acc=0.9683]

Validation Step 1250 | Loss: 0.0586 | Acc: 0.9744


Validating:  27%|██▋       | 1307/4757 [00:24<01:07, 50.94it/s, val_loss=0.1193, val_acc=0.9623, avg_loss=0.0981, avg_acc=0.9686]

Validation Step 1300 | Loss: 0.0149 | Acc: 1.0000


Validating:  29%|██▊       | 1361/4757 [00:24<00:59, 56.99it/s, val_loss=0.0213, val_acc=1.0000, avg_loss=0.0974, avg_acc=0.9686]

Validation Step 1350 | Loss: 0.0143 | Acc: 1.0000


Validating:  30%|██▉       | 1409/4757 [00:25<00:58, 57.20it/s, val_loss=0.1363, val_acc=0.9459, avg_loss=0.0971, avg_acc=0.9688]

Validation Step 1400 | Loss: 0.0181 | Acc: 1.0000


Validating:  31%|███       | 1458/4757 [00:26<00:57, 57.29it/s, val_loss=0.0254, val_acc=0.9773, avg_loss=0.0970, avg_acc=0.9687]

Validation Step 1450 | Loss: 0.3662 | Acc: 0.8222


Validating:  32%|███▏      | 1512/4757 [00:27<00:57, 56.60it/s, val_loss=0.1575, val_acc=0.9286, avg_loss=0.0976, avg_acc=0.9685]

Validation Step 1500 | Loss: 0.1951 | Acc: 0.9231


Validating:  33%|███▎      | 1560/4757 [00:28<00:55, 57.35it/s, val_loss=0.0466, val_acc=0.9851, avg_loss=0.0975, avg_acc=0.9686]

Validation Step 1550 | Loss: 0.0447 | Acc: 1.0000


Validating:  34%|███▍      | 1608/4757 [00:29<00:55, 57.24it/s, val_loss=0.0259, val_acc=1.0000, avg_loss=0.0984, avg_acc=0.9682]

Validation Step 1600 | Loss: 0.2951 | Acc: 0.9608


Validating:  35%|███▍      | 1662/4757 [00:30<00:54, 56.93it/s, val_loss=0.0042, val_acc=1.0000, avg_loss=0.0986, avg_acc=0.9684]

Validation Step 1650 | Loss: 0.1928 | Acc: 0.9623


Validating:  36%|███▌      | 1710/4757 [00:31<00:53, 57.24it/s, val_loss=0.3597, val_acc=0.9000, avg_loss=0.0989, avg_acc=0.9682]

Validation Step 1700 | Loss: 0.2318 | Acc: 0.8966


Validating:  37%|███▋      | 1758/4757 [00:31<00:52, 56.84it/s, val_loss=0.0021, val_acc=1.0000, avg_loss=0.0987, avg_acc=0.9682]

Validation Step 1750 | Loss: 0.1350 | Acc: 0.9773


Validating:  38%|███▊      | 1812/4757 [00:32<00:51, 56.99it/s, val_loss=0.0139, val_acc=1.0000, avg_loss=0.0996, avg_acc=0.9680]

Validation Step 1800 | Loss: 0.0614 | Acc: 0.9778


Validating:  39%|███▉      | 1860/4757 [00:33<00:51, 56.32it/s, val_loss=0.3147, val_acc=0.8837, avg_loss=0.0989, avg_acc=0.9681]

Validation Step 1850 | Loss: 0.1225 | Acc: 0.9643


Validating:  40%|████      | 1908/4757 [00:34<00:52, 54.71it/s, val_loss=0.1018, val_acc=0.9677, avg_loss=0.0989, avg_acc=0.9681]

Validation Step 1900 | Loss: 0.1033 | Acc: 0.9355


Validating:  41%|████      | 1956/4757 [00:35<00:52, 53.35it/s, val_loss=0.0230, val_acc=1.0000, avg_loss=0.0995, avg_acc=0.9680]

Validation Step 1950 | Loss: 0.1528 | Acc: 0.9535


Validating:  42%|████▏     | 2010/4757 [00:36<00:51, 53.66it/s, val_loss=0.0017, val_acc=1.0000, avg_loss=0.0988, avg_acc=0.9682]

Validation Step 2000 | Loss: 0.0023 | Acc: 1.0000


Validating:  43%|████▎     | 2058/4757 [00:37<00:48, 55.76it/s, val_loss=0.0734, val_acc=0.9667, avg_loss=0.0983, avg_acc=0.9685]

Validation Step 2050 | Loss: 0.1600 | Acc: 0.9362


Validating:  44%|████▍     | 2112/4757 [00:38<00:47, 56.27it/s, val_loss=0.0353, val_acc=1.0000, avg_loss=0.0979, avg_acc=0.9684]

Validation Step 2100 | Loss: 0.2137 | Acc: 0.9200


Validating:  45%|████▌     | 2160/4757 [00:39<00:46, 56.16it/s, val_loss=0.2374, val_acc=0.9348, avg_loss=0.0983, avg_acc=0.9683]

Validation Step 2150 | Loss: 0.1111 | Acc: 0.9500


Validating:  46%|████▋     | 2208/4757 [00:40<00:45, 56.33it/s, val_loss=0.1168, val_acc=0.9474, avg_loss=0.0990, avg_acc=0.9683]

Validation Step 2200 | Loss: 0.2406 | Acc: 0.9574


Validating:  48%|████▊     | 2262/4757 [00:40<00:44, 55.67it/s, val_loss=0.0956, val_acc=0.9828, avg_loss=0.0991, avg_acc=0.9683]

Validation Step 2250 | Loss: 0.1621 | Acc: 0.9184


Validating:  49%|████▊     | 2310/4757 [00:41<00:43, 56.79it/s, val_loss=0.0069, val_acc=1.0000, avg_loss=0.0998, avg_acc=0.9681]

Validation Step 2300 | Loss: 0.4317 | Acc: 0.8800


Validating:  50%|████▉     | 2358/4757 [00:42<00:42, 55.90it/s, val_loss=0.2847, val_acc=0.8833, avg_loss=0.0996, avg_acc=0.9682]

Validation Step 2350 | Loss: 0.0972 | Acc: 0.9375


Validating:  51%|█████     | 2412/4757 [00:43<00:41, 56.14it/s, val_loss=0.0002, val_acc=1.0000, avg_loss=0.0997, avg_acc=0.9682]

Validation Step 2400 | Loss: 0.0484 | Acc: 1.0000


Validating:  52%|█████▏    | 2460/4757 [00:44<00:40, 56.09it/s, val_loss=0.0053, val_acc=1.0000, avg_loss=0.0998, avg_acc=0.9682]

Validation Step 2450 | Loss: 0.5877 | Acc: 0.7750


Validating:  53%|█████▎    | 2508/4757 [00:45<00:39, 56.38it/s, val_loss=0.0662, val_acc=0.9556, avg_loss=0.0993, avg_acc=0.9683]

Validation Step 2500 | Loss: 0.0038 | Acc: 1.0000


Validating:  54%|█████▍    | 2562/4757 [00:46<00:39, 55.84it/s, val_loss=0.0088, val_acc=1.0000, avg_loss=0.0991, avg_acc=0.9684]

Validation Step 2550 | Loss: 0.1727 | Acc: 0.9714


Validating:  55%|█████▍    | 2610/4757 [00:47<00:40, 53.18it/s, val_loss=0.0014, val_acc=1.0000, avg_loss=0.0987, avg_acc=0.9686]

Validation Step 2600 | Loss: 0.0521 | Acc: 0.9744


Validating:  56%|█████▌    | 2658/4757 [00:48<00:39, 52.84it/s, val_loss=0.0796, val_acc=0.9811, avg_loss=0.0993, avg_acc=0.9684]

Validation Step 2650 | Loss: 0.0496 | Acc: 0.9545


Validating:  57%|█████▋    | 2712/4757 [00:49<00:37, 54.39it/s, val_loss=0.0130, val_acc=1.0000, avg_loss=0.0993, avg_acc=0.9684]

Validation Step 2700 | Loss: 0.0249 | Acc: 1.0000


Validating:  58%|█████▊    | 2760/4757 [00:49<00:35, 56.09it/s, val_loss=0.0382, val_acc=0.9778, avg_loss=0.0990, avg_acc=0.9685]

Validation Step 2750 | Loss: 0.0416 | Acc: 0.9808


Validating:  59%|█████▉    | 2808/4757 [00:50<00:34, 56.11it/s, val_loss=0.4814, val_acc=0.9545, avg_loss=0.0992, avg_acc=0.9685]

Validation Step 2800 | Loss: 0.0382 | Acc: 1.0000


Validating:  60%|██████    | 2862/4757 [00:51<00:33, 56.15it/s, val_loss=0.0237, val_acc=1.0000, avg_loss=0.0991, avg_acc=0.9685]

Validation Step 2850 | Loss: 0.0833 | Acc: 0.9375


Validating:  61%|██████    | 2910/4757 [00:52<00:33, 55.75it/s, val_loss=0.0359, val_acc=0.9726, avg_loss=0.0991, avg_acc=0.9684]

Validation Step 2900 | Loss: 0.0525 | Acc: 0.9512


Validating:  62%|██████▏   | 2958/4757 [00:53<00:31, 56.23it/s, val_loss=0.0829, val_acc=0.9796, avg_loss=0.0990, avg_acc=0.9685]

Validation Step 2950 | Loss: 0.1072 | Acc: 0.9500


Validating:  63%|██████▎   | 3012/4757 [00:54<00:30, 56.48it/s, val_loss=0.2306, val_acc=0.9429, avg_loss=0.0994, avg_acc=0.9685]

Validation Step 3000 | Loss: 0.1875 | Acc: 0.9455


Validating:  64%|██████▍   | 3060/4757 [00:55<00:29, 56.66it/s, val_loss=0.0613, val_acc=0.9730, avg_loss=0.0995, avg_acc=0.9684]

Validation Step 3050 | Loss: 0.1735 | Acc: 0.9487


Validating:  65%|██████▌   | 3108/4757 [00:56<00:29, 56.39it/s, val_loss=0.0496, val_acc=0.9744, avg_loss=0.0991, avg_acc=0.9686]

Validation Step 3100 | Loss: 0.1533 | Acc: 0.9701


Validating:  66%|██████▋   | 3162/4757 [00:57<00:27, 57.13it/s, val_loss=0.1944, val_acc=0.9268, avg_loss=0.0992, avg_acc=0.9685]

Validation Step 3150 | Loss: 0.0585 | Acc: 0.9636


Validating:  67%|██████▋   | 3210/4757 [00:57<00:27, 56.75it/s, val_loss=0.0466, val_acc=0.9583, avg_loss=0.0990, avg_acc=0.9685]

Validation Step 3200 | Loss: 0.0373 | Acc: 1.0000


Validating:  68%|██████▊   | 3258/4757 [00:58<00:27, 53.57it/s, val_loss=0.6836, val_acc=0.9333, avg_loss=0.0996, avg_acc=0.9684]

Validation Step 3250 | Loss: 0.0006 | Acc: 1.0000


Validating:  69%|██████▉   | 3306/4757 [00:59<00:28, 51.50it/s, val_loss=0.0913, val_acc=0.9583, avg_loss=0.0993, avg_acc=0.9685]

Validation Step 3300 | Loss: 0.1970 | Acc: 0.9310


Validating:  71%|███████   | 3360/4757 [01:00<00:27, 49.93it/s, val_loss=0.0828, val_acc=0.9804, avg_loss=0.0993, avg_acc=0.9686]

Validation Step 3350 | Loss: 0.0532 | Acc: 0.9811


Validating:  72%|███████▏  | 3408/4757 [01:01<00:24, 56.15it/s, val_loss=0.0844, val_acc=0.9444, avg_loss=0.0995, avg_acc=0.9685]

Validation Step 3400 | Loss: 0.0595 | Acc: 0.9615


Validating:  73%|███████▎  | 3462/4757 [01:02<00:23, 55.57it/s, val_loss=0.1605, val_acc=0.9091, avg_loss=0.0996, avg_acc=0.9684]

Validation Step 3450 | Loss: 0.0423 | Acc: 0.9839


Validating:  74%|███████▍  | 3510/4757 [01:03<00:22, 55.91it/s, val_loss=0.4287, val_acc=0.8444, avg_loss=0.0994, avg_acc=0.9686]

Validation Step 3500 | Loss: 0.0021 | Acc: 1.0000


Validating:  75%|███████▍  | 3558/4757 [01:04<00:21, 55.58it/s, val_loss=0.0165, val_acc=1.0000, avg_loss=0.0992, avg_acc=0.9686]

Validation Step 3550 | Loss: 0.1229 | Acc: 0.9286


Validating:  76%|███████▌  | 3612/4757 [01:05<00:20, 56.03it/s, val_loss=0.1064, val_acc=0.9487, avg_loss=0.0991, avg_acc=0.9686]

Validation Step 3600 | Loss: 0.0352 | Acc: 1.0000


Validating:  77%|███████▋  | 3660/4757 [01:06<00:19, 55.44it/s, val_loss=0.0511, val_acc=0.9714, avg_loss=0.0994, avg_acc=0.9685]

Validation Step 3650 | Loss: 0.2557 | Acc: 0.9062


Validating:  78%|███████▊  | 3708/4757 [01:06<00:18, 55.95it/s, val_loss=0.0620, val_acc=1.0000, avg_loss=0.0992, avg_acc=0.9685]

Validation Step 3700 | Loss: 0.0009 | Acc: 1.0000


Validating:  79%|███████▉  | 3762/4757 [01:07<00:17, 56.45it/s, val_loss=0.1143, val_acc=0.9623, avg_loss=0.0993, avg_acc=0.9685]

Validation Step 3750 | Loss: 0.0002 | Acc: 1.0000


Validating:  80%|████████  | 3810/4757 [01:08<00:16, 57.35it/s, val_loss=0.0233, val_acc=1.0000, avg_loss=0.0991, avg_acc=0.9686]

Validation Step 3800 | Loss: 0.2070 | Acc: 0.9259


Validating:  81%|████████  | 3858/4757 [01:09<00:15, 57.60it/s, val_loss=0.2034, val_acc=0.9500, avg_loss=0.0993, avg_acc=0.9685]

Validation Step 3850 | Loss: 0.0503 | Acc: 0.9853


Validating:  82%|████████▏ | 3912/4757 [01:10<00:14, 57.78it/s, val_loss=0.1458, val_acc=0.9483, avg_loss=0.0993, avg_acc=0.9684]

Validation Step 3900 | Loss: 0.0012 | Acc: 1.0000


Validating:  83%|████████▎ | 3960/4757 [01:11<00:15, 52.69it/s, val_loss=0.1413, val_acc=0.9412, avg_loss=0.0992, avg_acc=0.9685]

Validation Step 3950 | Loss: 0.0315 | Acc: 1.0000


Validating:  84%|████████▍ | 4008/4757 [01:12<00:14, 50.53it/s, val_loss=0.0574, val_acc=0.9767, avg_loss=0.0991, avg_acc=0.9685]

Validation Step 4000 | Loss: 0.4715 | Acc: 0.8333


Validating:  85%|████████▌ | 4062/4757 [01:13<00:12, 54.82it/s, val_loss=0.1383, val_acc=0.9412, avg_loss=0.0989, avg_acc=0.9686]

Validation Step 4050 | Loss: 0.0025 | Acc: 1.0000


Validating:  86%|████████▋ | 4110/4757 [01:14<00:11, 56.51it/s, val_loss=0.0412, val_acc=0.9762, avg_loss=0.0987, avg_acc=0.9686]

Validation Step 4100 | Loss: 0.1075 | Acc: 0.9464


Validating:  87%|████████▋ | 4158/4757 [01:15<00:10, 57.33it/s, val_loss=0.0510, val_acc=1.0000, avg_loss=0.0987, avg_acc=0.9686]

Validation Step 4150 | Loss: 0.0230 | Acc: 1.0000


Validating:  89%|████████▊ | 4212/4757 [01:15<00:09, 56.92it/s, val_loss=0.0002, val_acc=1.0000, avg_loss=0.0986, avg_acc=0.9687]

Validation Step 4200 | Loss: 0.0345 | Acc: 0.9800


Validating:  90%|████████▉ | 4260/4757 [01:16<00:08, 57.37it/s, val_loss=0.0343, val_acc=1.0000, avg_loss=0.0986, avg_acc=0.9687]

Validation Step 4250 | Loss: 0.0030 | Acc: 1.0000


Validating:  91%|█████████ | 4308/4757 [01:17<00:07, 57.92it/s, val_loss=0.1053, val_acc=0.9722, avg_loss=0.0986, avg_acc=0.9687]

Validation Step 4300 | Loss: 0.0659 | Acc: 0.9623


Validating:  92%|█████████▏| 4362/4757 [01:18<00:06, 57.32it/s, val_loss=0.0002, val_acc=1.0000, avg_loss=0.0986, avg_acc=0.9687]

Validation Step 4350 | Loss: 0.1973 | Acc: 0.9483


Validating:  93%|█████████▎| 4410/4757 [01:19<00:06, 57.23it/s, val_loss=0.2790, val_acc=0.9048, avg_loss=0.0989, avg_acc=0.9687]

Validation Step 4400 | Loss: 0.0317 | Acc: 0.9831


Validating:  94%|█████████▎| 4458/4757 [01:20<00:05, 57.76it/s, val_loss=0.0611, val_acc=0.9836, avg_loss=0.0989, avg_acc=0.9686]

Validation Step 4450 | Loss: 0.1244 | Acc: 0.9388


Validating:  95%|█████████▍| 4512/4757 [01:21<00:04, 57.03it/s, val_loss=0.1062, val_acc=0.9355, avg_loss=0.0985, avg_acc=0.9687]

Validation Step 4500 | Loss: 0.0025 | Acc: 1.0000


Validating:  96%|█████████▌| 4560/4757 [01:22<00:03, 57.30it/s, val_loss=0.1462, val_acc=0.9348, avg_loss=0.0984, avg_acc=0.9687]

Validation Step 4550 | Loss: 0.0046 | Acc: 1.0000


Validating:  97%|█████████▋| 4608/4757 [01:22<00:02, 57.35it/s, val_loss=0.0033, val_acc=1.0000, avg_loss=0.0984, avg_acc=0.9687]

Validation Step 4600 | Loss: 0.0760 | Acc: 0.9667


Validating:  98%|█████████▊| 4656/4757 [01:23<00:01, 52.29it/s, val_loss=0.0355, val_acc=1.0000, avg_loss=0.0983, avg_acc=0.9688]

Validation Step 4650 | Loss: 0.0175 | Acc: 1.0000


Validating:  99%|█████████▉| 4710/4757 [01:24<00:00, 53.42it/s, val_loss=0.2187, val_acc=0.9492, avg_loss=0.0981, avg_acc=0.9688]

Validation Step 4700 | Loss: 0.0186 | Acc: 1.0000




Validation Step 4750 | Loss: 0.0013 | Acc: 1.0000

Validation Complete:
Average Loss: 0.0979
Average Accuracy: 0.9689
Processed Samples: 208214
Total Steps: 4757

              precision    recall  f1-score   support

       B-geo       0.86      0.89      0.88      7378
       B-gpe       0.97      0.93      0.95      3021
       B-org       0.75      0.70      0.73      3964
       B-per       0.82      0.85      0.84      3367
       B-tim       0.91      0.89      0.90      4070
       I-geo       0.80      0.74      0.77      1481
       I-gpe       1.00      0.35      0.52        37
       I-org       0.79      0.68      0.73      3385
       I-per       0.82      0.93      0.87      3425
       I-tim       0.87      0.72      0.79      1346
           O       0.99      0.99      0.99    176740

    accuracy                           0.97    208214
   macro avg       0.87      0.79      0.81    208214
weighted avg       0.97      0.97      0.97    208214



However, the accuracy metric is misleading, as a lot of labels are "outside" (O), even after omitting predictions on the [PAD] tokens. What is important is looking at the precision, recall and f1-score of the individual tags. For this, we use the seqeval Python library:

In [104]:
# 计算分类报告
from sklearn.metrics import classification_report
print(classification_report(true_labels, pred_labels))

              precision    recall  f1-score   support

       B-geo       0.86      0.89      0.88      7378
       B-gpe       0.97      0.93      0.95      3021
       B-org       0.75      0.70      0.73      3964
       B-per       0.82      0.85      0.84      3367
       B-tim       0.91      0.89      0.90      4070
       I-geo       0.80      0.74      0.77      1481
       I-gpe       1.00      0.35      0.52        37
       I-org       0.79      0.68      0.73      3385
       I-per       0.82      0.93      0.87      3425
       I-tim       0.87      0.72      0.79      1346
           O       0.99      0.99      0.99    176740

    accuracy                           0.97    208214
   macro avg       0.87      0.79      0.81    208214
weighted avg       0.97      0.97      0.97    208214



Performance already seems quite good, but note that we've only trained for 1 epoch. An optimal approach would be to perform evaluation on a validation set while training to improve generalization.

#### **Inference**

The fun part is when we can quickly test the model on new, unseen sentences.
Here, we use the prediction of the **first word piece of every word** (which is how the model was trained).

*In other words, the code below does not take into account when predictions of different word pieces that belong to the same word do not match.*

In [106]:
def predict_entities(sentence, model, tokenizer, device, max_len=128):
    """预测句子中的命名实体"""
    # 1. 预处理和分词
    words = sentence.split()

    # 2. 编码输入
    inputs = tokenizer(
        words,
        is_split_into_words=True,  # 使用新参数名替代is_pretokenized
        return_offsets_mapping=True,
        padding='max_length',
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )

    # 3. 移动到设备
    ids = inputs["input_ids"].to(device)
    mask = inputs["attention_mask"].to(device)

    # 4. 模型预测
    with torch.no_grad():
        outputs = model(ids, attention_mask=mask)
        logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]

    # 5. 处理预测结果
    predictions = torch.argmax(logits, dim=-1).squeeze().cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())

    # 6. 对齐单词和预测结果
    word_predictions = []
    current_word = None
    current_pred = None

    for token, pred_idx, (start, end) in zip(tokens, predictions, inputs["offset_mapping"].squeeze().tolist()):
        # 只处理单词的第一个子词
        if start == 0 and end != 0:
            pred_label = ids_to_labels[pred_idx]
            word_predictions.append(pred_label)

    # 7. 打印结果
    print("\nInput Sentence:", sentence)
    print("Tokenized Words:", words)
    print("Predictions:", word_predictions)

    # 8. 可视化展示
    print("\nEntity Recognition Results:")
    for word, pred in zip(words, word_predictions):
        if pred != "O":
            print(f"{word} ({pred})", end=" ")
        else:
            print(word, end=" ")
    print("\n")

    return word_predictions

# 使用示例
sentence = "@HuggingFace is a company based in New York, but is also has employees working in Paris"
predictions = predict_entities(sentence, model, tokenizer, device, MAX_LEN)


Input Sentence: @HuggingFace is a company based in New York, but is also has employees working in Paris
Tokenized Words: ['@HuggingFace', 'is', 'a', 'company', 'based', 'in', 'New', 'York,', 'but', 'is', 'also', 'has', 'employees', 'working', 'in', 'Paris']
Predictions: ['B-org', 'O', 'O', 'O', 'O', 'O', 'B-geo', 'I-geo', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-geo']

Entity Recognition Results:
@HuggingFace (B-org) is a company based in New (B-geo) York, (I-geo) but is also has employees working in Paris (B-geo) 



In [114]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### **Saving the model for future use**

Finally, let's save the vocabulary (.txt) file, model weights (.bin) and the model's configuration (.json) to a directory, so that both the tokenizer and model can be re-loaded using the `from_pretrained()` class method.


In [107]:
import os

directory = "./model"

if not os.path.exists(directory):
    os.makedirs(directory)

# save vocabulary of the tokenizer
tokenizer.save_vocabulary(directory)
# save the model weights and its configuration file
model.save_pretrained(directory)
print('All files saved')
print('This tutorial is completed')

All files saved
This tutorial is completed


## Legacy

The following code blocks were used during the development of this notebook, but are not included anymore.

In [113]:
def prepare_sentence(sentence, tokenizer, maxlen):
      # step 1: tokenize the sentence
      tokenized_sentence = tokenizer.tokenize(sentence)

      # step 2: add special tokens
      tokenized_sentence = ["[CLS]"] + tokenized_sentence + ["[SEP]"]

      # step 3: truncating/padding
      if (len(tokenized_sentence) > maxlen):
        # truncate
        tokenized_sentence = tokenized_sentence[:maxlen]
      else:
        # pad
        tokenized_sentence = tokenized_sentence + ['[PAD]'for _ in range(maxlen - len(tokenized_sentence))]

      # step 4: obtain the attention mask
      attn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]

      # step 5: convert tokens to input ids
      ids = tokenizer.convert_tokens_to_ids(tokenized_sentence)

      return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(attn_mask, dtype=torch.long),
            #'token_type_ids': torch.tensor(token_ids, dtype=torch.long),
      }

In [None]:
def tokenize_and_preserve_labels(sentence, text_labels, tokenizer):
    """
    Word piece tokenization makes it difficult to match word labels
    back up with individual word pieces. This function tokenizes each
    word one at a time so that it is easier to preserve the correct
    label for each subword. It is, of course, a bit slower in processing
    time, but it will help our model achieve higher accuracy.
    """

    tokenized_sentence = []
    labels = []

    sentence = sentence.strip()

    for word, label in zip(sentence.split(), text_labels.split(",")):

        # Tokenize the word and count # of subwords the word is broken into
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        # Add the tokenized word to the final tokenized word list
        tokenized_sentence.extend(tokenized_word)

        # Add the same label to the new list of labels `n_subwords` times
        labels.extend([label] * n_subwords)

    return tokenized_sentence, labels

In [None]:
class dataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, index):
        # step 1: tokenize (and adapt corresponding labels)
        sentence = self.data.sentence[index]
        word_labels = self.data.word_labels[index]
        tokenized_sentence, labels = tokenize_and_preserve_labels(sentence, word_labels, self.tokenizer)

        # step 2: add special tokens (and corresponding labels)
        tokenized_sentence = ["[CLS]"] + tokenized_sentence + ["[SEP]"] # add special tokens
        labels.insert(0, "O") # add outside label for [CLS] token
        labels.insert(-1, "O") # add outside label for [SEP] token

        # step 3: truncating/padding
        maxlen = self.max_len

        if (len(tokenized_sentence) > maxlen):
          # truncate
          tokenized_sentence = tokenized_sentence[:maxlen]
          labels = labels[:maxlen]
        else:
          # pad
          tokenized_sentence = tokenized_sentence + ['[PAD]'for _ in range(maxlen - len(tokenized_sentence))]
          labels = labels + ["O" for _ in range(maxlen - len(labels))]

        # step 4: obtain the attention mask
        attn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]

        # step 5: convert tokens to input ids
        ids = self.tokenizer.convert_tokens_to_ids(tokenized_sentence)

        label_ids = [labels_to_ids[label] for label in labels]
        # the following line is deprecated
        #label_ids = [label if label != 0 else -100 for label in label_ids]

        return {
              'ids': torch.tensor(ids, dtype=torch.long),
              'mask': torch.tensor(attn_mask, dtype=torch.long),
              #'token_type_ids': torch.tensor(token_ids, dtype=torch.long),
              'targets': torch.tensor(label_ids, dtype=torch.long)
        }

    def __len__(self):
        return self.len

In [None]:
sentence = "this is a test @huggingface".strip().split()

inputs = tokenizer(sentence, is_pretokenized=True, return_offsets_mapping=True, padding='max_length', truncation=True)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"])
token_offsets = inputs["offset_mapping"]
print(tokens)
print(token_offsets)

['[CLS]', 'this', 'is', 'a', 'test', '@', 'hugging', '##face', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '

In [None]:
word = "@huggingface"

inputs = tokenizer(word, return_offsets_mapping=True, padding='max_length', truncation=True)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"])
token_offsets = inputs["offset_mapping"]
print(tokens)
print(token_offsets)

['[CLS]', '@', 'hugging', '##face', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '

In [None]:
# now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)
        active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)
        targets = torch.masked_select(flattened_targets, active_accuracy)
        predictions = torch.masked_select(flattened_predictions, active_accuracy)