In [1]:
import torch
import datasets
import os

BASE_DIR = os.getcwd()
print(BASE_DIR)


/Users/cxq/code/py_cloud/huggingface_exp


In [2]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = datasets.load_dataset(
            "json",
            data_files=f"{BASE_DIR}/data/ChnSentiCorp.json",
            split=split)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label


In [3]:
ds_data = MyDataset("train")
print(len(ds_data))
print(ds_data[0])

Found cached dataset json (/Users/cxq/.cache/huggingface/datasets/json/default-abe2463bdbc7aa6b/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


9600
('选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般', 1)


In [4]:
from transformers import BertTokenizer
token = BertTokenizer.from_pretrained('bert-base-chinese')

def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sents,
        truncation=True,
        padding='max_length',
        max_length=500,
        return_tensors='pt',
        return_length=True
    )

    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, labels

In [5]:
ds_loader = torch.utils.data.DataLoader(
    dataset=ds_data,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True
)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(ds_loader):
    print(len(ds_loader))
    print(input_ids.shape)
    print(attention_mask.shape)
    print(token_type_ids.shape)
    print(labels)
    break

600
torch.Size([16, 500])
torch.Size([16, 500])
torch.Size([16, 500])
tensor([1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0])


In [7]:
from transformers import BertModel


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = BertModel.from_pretrained('bert-base-chinese')
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = self.pretrained(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

        out = self.fc(out.last_hidden_state[:, 0])

        return out.softmax(dim=1)

In [8]:
from transformers import AdamW

model = MyModel()
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

model.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(ds_loader):
    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
    )

    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)

        print(i, loss.item(), accuracy)

    if i == 30:
        break

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


0 0.7347027063369751 0.3125
5 0.6768629550933838 0.5
10 0.6672183871269226 0.5625
15 0.6262333393096924 0.75
20 0.6419984102249146 0.75
25 0.6250054240226746 0.8125
30 0.6562329530715942 0.625


In [9]:
def save_model(net, opt, ckpt_file):
    torch.save({
        'model_state_dict': net.state_dict(),
        'optimizer_state_dict': opt.state_dict()
    }, ckpt_file)


def load_model(net, opt, ckpt_file):
    ckpt = torch.load(ckpt_file)
    net.load_state_dict(ckpt['model_state_dict'])
    opt.load_state_dict(ckpt['optimizer_state_dict'])


def evaluation(net, ds_loader):
    net.eval()
    correct = 0
    total = 0

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(ds_loader):
        if i == 5:
            break

        with torch.no_grad():
            out = net(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

        print(i, out, out.argmax(dim=1))

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    print(correct / total)


def inference(ds_loader):
    net = MyModel()
    opt = AdamW(model.parameters(), lr=5e-4)
    load_model(net, opt, f"{BASE_DIR}/logs/checkpoint/bert_cls.bin")

    net.eval()
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(ds_loader):
        if i == 5:
            break

        with torch.no_grad():
            out = net(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

        print(i, out, out.argmax(dim=1))

In [10]:
save_model(model, optimizer, f"{BASE_DIR}/logs/checkpoint/bert_cls.bin")

In [13]:
ts0_loader = torch.utils.data.DataLoader(
    dataset=ds_data,
    batch_size=1,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True
)
evaluation(model, ts0_loader)
print("-" * 72)
ts1_loader = torch.utils.data.DataLoader(
    dataset=ds_data,
    batch_size=1,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True
)
inference(ts1_loader)

0 tensor([[0.2561, 0.7439]]) tensor([1])
1 tensor([[0.2111, 0.7889]]) tensor([1])
2 tensor([[0.8142, 0.1858]]) tensor([0])
3 tensor([[0.8196, 0.1804]]) tensor([0])
4 tensor([[0.7313, 0.2687]]) tensor([0])
0.8
------------------------------------------------------------------------


Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


0 tensor([[0.2561, 0.7439]]) tensor([1])
1 tensor([[0.2111, 0.7889]]) tensor([1])
2 tensor([[0.8142, 0.1858]]) tensor([0])
3 tensor([[0.8196, 0.1804]]) tensor([0])
4 tensor([[0.7313, 0.2687]]) tensor([0])


In [21]:
torch.save(model, f"{BASE_DIR}/logs/checkpoint/bert_cls2.bin")

In [22]:
# MyModel需要先被import
net2 = torch.load(f"{BASE_DIR}/logs/checkpoint/bert_cls2.bin") 

In [23]:
ts2_loader = torch.utils.data.DataLoader(
    dataset=ds_data,
    batch_size=1,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True
)

net2.eval()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(ts2_loader):
    if i == 5:
        break

    with torch.no_grad():
        out = net2(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

    print(i, out, out.argmax(dim=1))

0 tensor([[0.2561, 0.7439]]) tensor([1])
1 tensor([[0.2111, 0.7889]]) tensor([1])
2 tensor([[0.8142, 0.1858]]) tensor([0])
3 tensor([[0.8196, 0.1804]]) tensor([0])
4 tensor([[0.7313, 0.2687]]) tensor([0])
