In [1]:
from abc import ABC

import os
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader

In [2]:
train_fullname = './data/train_data_public.csv'
test_fullname = './data/test_public.csv'

In [3]:
class AnnoDataSet(Dataset, ABC):
    def __init__(self,texts: list[str], labels: list[str] = None):
        super(Dataset, self).__init__()
        self.texts = texts
        self.labels = labels

    def __getitem__(self, item):
        text = self.texts[item]
        label = self.labels[item] if self.labels is not None else None
        return text, label

    def __len__(self):
        return len(self.labels)

In [4]:
def get_dataset() -> (AnnoDataSet, AnnoDataSet, AnnoDataSet):
    train_set, valid_set, test_set = None, None, None

    train_raw = pd.read_csv(train_fullname)
    train_texts: list[str] = train_raw['text'].to_list()
    train_labels: list[str] = train_raw['BIO_anno'].to_list()

    train_set = AnnoDataSet(train_texts, train_labels)

    test_raw = pd.read_csv(test_fullname)
    test_texts: list[str] = test_raw['text'].to_list()
    test_labels = None

    test_set = AnnoDataSet(test_texts, test_labels)

    return train_set, valid_set, test_set

In [5]:
# 编码labels
class SimpleVocab:
    def __init__(self):
        labels = pd.read_csv(train_fullname)['BIO_anno'].to_list()
        labels = [label_line.split() for label_line in labels]
        all_tokens = [token for label_line in labels for token in label_line] # 这个写法每次看一遍都觉得震撼
        self.token_dict = {}
        self.token_array = []

        # zero for none
        self.token_dict['None'] = len(self.token_array)
        self.token_array.append('None')

        for token in all_tokens:
            if token not in self.token_dict:
                self.token_dict[token] = len(self.token_array)
                self.token_array.append(token)

    def __call__(self, tokens):
        assert isinstance(tokens, (list, tuple, str))
        if isinstance(tokens, (list, tuple)):
            return [self(token) for token in tokens]
        else:
            return self.token_dict[tokens]

    def __len__(self):
        return len(self.token_array)

    def to_tokens(self, ids):
        assert isinstance(ids, (list, tuple, int))
        if isinstance(ids, (list, tuple)):
            return [self.to_tokens(idx) for idx in ids]
        else:
            return self.token_array[ids]

    def get_none(self):
        return self('None')

In [20]:
def collect(sample:list[tuple[str, str]], tokenizer, label_vocab) -> torch.Tensor:
    texts, labels = zip(*sample)
    tokens = [list(text) for text in texts]
    ret = tokenizer(tokens, padding=True, is_split_into_words=True, return_tensors='pt')

    if labels[0] is None:
        return ret


    labels = [label_line.split() for label_line in labels]

    # 对齐label
    none_anno = label_vocab.get_none()
    max_len = ret['input_ids'][0].shape[-1]
    print(max_len)
    print(len(ret['input_ids']))
    for idx in range(len(labels)):
        if len(labels[idx]) + 2 >= 89 :print(labels[idx])
        assert ret['attention_mask'][idx][len(labels[idx]) + 2] == 0  # 有效长度应该是label长度+2
        labels[idx] = [none_anno] + labels[idx] + [none_anno]
        labels[idx] = labels[idx] + [none_anno] * (max_len - len(labels[idx]))
        assert len(labels[idx]) == max_len

    ret['labels'] = torch.tensor(label_vocab(labels), dtype=torch.int32)
    return ret


In [7]:
model_name = 'bert-base-chinese'
bert_tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', model_name)

train_dataset, valid_dataset, test_dataset = get_dataset()

Using cache found in C:\Users\Justi/.cache\torch\hub\huggingface_pytorch-transformers_main


In [21]:
batch_size = 128
train_iter = DataLoader(
    train_dataset,
    batch_size,
    collate_fn=lambda x: collect(x, bert_tokenizer, SimpleVocab()),
    num_workers=0
)
next(iter(train_iter))

89
128
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-COMMENTS_ADJ', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-COMMENTS_N', 'I-COMMENTS_N', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-COMMENTS_N', 'I-COMMENTS_N', 'I-COMMENTS_N', 'B-COMMENTS_N', 'I-COMMENTS_N', 'O', 'O', 'O', 'O', 'O', 'O', 'B-COMMENTS_N', 'I-COMMENTS_N', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-COMMENTS_N', 'I-COMMENTS_N', 'O', 'O', 'O', 'O', 'B-COMMENTS_ADJ', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


IndexError: index 89 is out of bounds for dimension 0 with size 89

In [None]:
train_raw = pd.read_csv(train_fullname)