In [1]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/root/data/transformers/model_zoo'

In [2]:
import json
import torch
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import BertTokenizer, AlbertForSequenceClassification

In [3]:
data_path = Path('../../data/tnews/')

### 1. tokenizer

In [4]:
pretrained = 'voidful/albert_chinese_tiny'
tokenizer = BertTokenizer.from_pretrained(pretrained, mirror='tuna')

### 2. preprocess data

In [5]:
label_name2id = {}
label_id2name = {}
with open(data_path / 'labels.json', 'r') as h:
    for idx, line in enumerate(h):
        items = json.loads(line)
        label_name2id[items['label_desc']] = idx
        label_id2name[idx] = items['label_desc']
print(label_name2id)
print(label_id2name)

{'news_story': 0, 'news_culture': 1, 'news_entertainment': 2, 'news_sports': 3, 'news_finance': 4, 'news_house': 5, 'news_car': 6, 'news_edu': 7, 'news_tech': 8, 'news_military': 9, 'news_travel': 10, 'news_world': 11, 'news_stock': 12, 'news_agriculture': 13, 'news_game': 14}
{0: 'news_story', 1: 'news_culture', 2: 'news_entertainment', 3: 'news_sports', 4: 'news_finance', 5: 'news_house', 6: 'news_car', 7: 'news_edu', 8: 'news_tech', 9: 'news_military', 10: 'news_travel', 11: 'news_world', 12: 'news_stock', 13: 'news_agriculture', 14: 'news_game'}


In [6]:
def load_data(fname):
    texts = []
    labels = []
    with open(fname, 'r', encoding='utf8') as h:
        for line in h:
            items = json.loads(line)
            texts.append(items['sentence'])
            labels.append(label_name2id[items['label_desc']])
    return texts, labels

In [7]:
train_txt, train_label = load_data(data_path / 'train.json')
test_txt, test_label = load_data(data_path / 'dev.json')
print('train num:{0}\ntest num:{1}'.format(len(train_txt), len(test_txt)))

train num:53360
test num:10000


### 3. load model

In [8]:
num_labels = len(label_name2id)
model = AlbertForSequenceClassification.from_pretrained(pretrained, mirror='tuna', num_labels=num_labels)
model

Some weights of the model checkpoint at voidful/albert_chinese_tiny were not used when initializing AlbertForSequenceClassification: ['predictions.bias', 'predictions.LayerNorm.weight', 'predictions.LayerNorm.bias', 'predictions.dense.weight', 'predictions.dense.bias', 'predictions.decoder.weight', 'predictions.decoder.bias']
- This IS expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at voidful/albert_chinese_tiny and are newly initialized: ['classifier.weight', 

AlbertForSequenceClassification(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(21128, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=312, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=312, out_features=312, bias=True)
                (key): Linear(in_features=312, out_features=312, bias=True)
                (value): Linear(in_features=312, out_features=31

### 4. build dataset

In [9]:
class TNewsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [10]:
train_enc = tokenizer(train_txt, padding=True, truncation=True, max_length=512)
test_enc = tokenizer(test_txt, padding=True, truncation=True, max_length=512)
train_dataset = TNewsDataset(train_enc, train_label)
test_dataset = TNewsDataset(test_enc, test_label)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### 5. random predict

In [13]:
# 模型顶层全连接层权重为随机初始化，所以是分类结果是随机的
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.eval()
model.to(device)
test_cnt = 0
test_correct_cnt = 0
for batch in tqdm(test_loader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    with torch.no_grad():
        output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        test_correct_cnt += (output.logits.argmax(1) == labels).sum().cpu().item()
        test_cnt += len(labels)
print("test acc:{0:.4f}".format(test_correct_cnt / test_cnt))

100%|██████████| 157/157 [00:09<00:00, 15.90it/s]

test acc:0.0771





### 6. finetune

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
eval_steps = 500
num_epochs = 2
step_idx = 0
train_cnt = 0
train_correct_cnt = 0
train_loss_sum = 0.0
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = output.loss
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.cpu().item()
        train_correct_cnt += (output.logits.argmax(1) == labels).sum().cpu().item()
        train_cnt += len(labels)
        
        step_idx += 1
        if step_idx % eval_steps == 0:
            model.eval()
            test_cnt = 0
            test_correct_cnt = 0
            test_loss_sum = 0.0
            for batch in test_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                with torch.no_grad():
                    output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = output.loss
                    test_loss_sum += loss.cpu().item()
                    test_correct_cnt += (output.logits.argmax(1) == labels).sum().cpu().item()
                    test_cnt += len(labels)
            print("epoch {0}, step {1}, train loss:{2:.6f}, train acc:{3:.4f}, test loss:{4:.6f}, test acc:{5:.4f}".format(
                epoch,
                step_idx,
                train_loss_sum / train_cnt,
                train_correct_cnt / train_cnt,
                test_loss_sum / test_cnt,
                test_correct_cnt / test_cnt))
            train_cnt = 0
            train_correct_cnt = 0
            train_loss_sum = 0.0
            model.train()

 60%|██████    | 501/834 [01:49<12:31,  2.26s/it]

epoch 0, step 500, train loss:0.027055, train acc:0.4501, test loss:0.024696, test acc:0.4904


100%|██████████| 834/834 [02:54<00:00,  4.77it/s]
 20%|██        | 167/834 [00:42<25:07,  2.26s/it]

epoch 1, step 1000, train loss:0.023373, train acc:0.5155, test loss:0.024123, test acc:0.5016


 80%|███████▉  | 667/834 [02:31<06:17,  2.26s/it]

epoch 1, step 1500, train loss:0.022137, train acc:0.5321, test loss:0.023908, test acc:0.5110


100%|██████████| 834/834 [03:04<00:00,  4.52it/s]
