In [None]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/root/data/transformers/model_zoo'

In [None]:
import json
import torch
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import BertTokenizer, AlbertForSequenceClassification

In [None]:
data_path = Path('../../data/tnews/')

### 1. tokenizer

In [None]:
pretrained = 'voidful/albert_chinese_tiny'
tokenizer = BertTokenizer.from_pretrained(pretrained, mirror='tuna')

### 2. preprocess data

In [None]:
label_name2id = {}
label_id2name = {}
with open(data_path / 'labels.json', 'r') as h:
    for idx, line in enumerate(h):
        items = json.loads(line)
        label_name2id[items['label_desc']] = idx
        label_id2name[idx] = items['label_desc']
print(label_name2id)
print(label_id2name)

In [None]:
def load_data(fname):
    texts = []
    labels = []
    with open(fname, 'r', encoding='utf8') as h:
        for line in h:
            items = json.loads(line)
            texts.append(items['sentence'])
            labels.append(label_name2id[items['label_desc']])
    return texts, labels

In [None]:
train_txt, train_label = load_data(data_path / 'train.json')
test_txt, test_label = load_data(data_path / 'dev.json')
print('train num:{0}\ntest num:{1}'.format(len(train_txt), len(test_txt)))

### 3. load model

In [None]:
num_labels = len(label_name2id)
model = AlbertForSequenceClassification.from_pretrained(pretrained, mirror='tuna', num_labels=num_labels)
model

### 4. build dataset

In [None]:
class TNewsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
train_enc = tokenizer(train_txt, padding=True, truncation=True, max_length=512)
test_enc = tokenizer(test_txt, padding=True, truncation=True, max_length=512)
train_dataset = TNewsDataset(train_enc, train_label)
test_dataset = TNewsDataset(test_enc, test_label)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### 5. random predict

In [None]:
# 模型顶层全连接层权重为随机初始化，所以是分类结果是随机的
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.eval()
model.to(device)
test_cnt = 0
test_correct_cnt = 0
for batch in tqdm(test_loader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    with torch.no_grad():
        output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        test_correct_cnt += (output.logits.argmax(1) == labels).sum().cpu().item()
        test_cnt += len(labels)
print("test acc:{0:.4f}".format(test_correct_cnt / test_cnt))

### 6. finetune

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
eval_steps = 500
num_epochs = 2
step_idx = 0
train_cnt = 0
train_correct_cnt = 0
train_loss_sum = 0.0
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = output.loss
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.cpu().item()
        train_correct_cnt += (output.logits.argmax(1) == labels).sum().cpu().item()
        train_cnt += len(labels)
        
        step_idx += 1
        if step_idx % eval_steps == 0:
            model.eval()
            test_cnt = 0
            test_correct_cnt = 0
            test_loss_sum = 0.0
            for batch in test_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                with torch.no_grad():
                    output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = output.loss
                    test_loss_sum += loss.cpu().item()
                    test_correct_cnt += (output.logits.argmax(1) == labels).sum().cpu().item()
                    test_cnt += len(labels)
            print("epoch {0}, step {1}, train loss:{2:.6f}, train acc:{3:.4f}, test loss:{4:.6f}, test acc:{5:.4f}".format(
                epoch,
                step_idx,
                train_loss_sum / train_cnt,
                train_correct_cnt / train_cnt,
                test_loss_sum / test_cnt,
                test_correct_cnt / test_cnt))
            train_cnt = 0
            train_correct_cnt = 0
            train_loss_sum = 0.0
            model.train()