In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
from dataclasses import dataclass, field
import json
from pathlib import Path
from typing import Union, List, Optional
from itertools import chain

from datasets import Dataset
from transformers import (
    AutoTokenizer
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.tokenization_utils import PaddingStrategy

data_dir = Path("../data")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = AutoTokenizer.from_pretrained(
    "bert-base-chinese",
    use_fast=True,
    model_revision="main"
)

In [4]:
@dataclass
class DataManager:

    tokenizer: PreTrainedTokenizerBase
    context_path: Union[str, Path]
    train_path: Optional[Union[str, Path]] = None
    valid_path: Optional[Union[str, Path]] = None
    test_path: Optional[Union[str, Path]] = None
    max_seq_length: int = 1024
    que_col: str = "question"
    para_col: str = "paragraphs"
    num_para: int = 4
    rel_col: str = "relevant"
    label_col: str = "labels"

    context: List[str] = field(init=False)
    train_dataset: Dataset = field(default=None, init=False)
    valid_dataset: Dataset = field(default=None, init=False)
    test_dataset: Dataset = field(default=None, init=False)

    def __post_init__(self):
        self.max_seq_length = min(self.tokenizer.model_max_length, self.max_seq_length)
        self.context = self.load_json(self.context_path)
        if self.train_path is not None:
            self.train_dataset = self.build_dataset(self.train_path)
        if self.valid_path is not None:
            self.valid_dataset = self.build_dataset(self.valid_path)
        if self.test_path is not None:
            self.test_dataset = self.build_dataset(self.test_path)

    def load_json(self, file_path: Union[str, Path]) -> List:
        result = []
        with open(file_path, "r", encoding="utf-8") as f:
            result = json.load(f)
        return result

    def preprocess(self, instances):
        questions = [
            [question] * self.num_para
            for question in instances[self.que_col]
        ]
        paragraphs = [
            [
                self.context[idx] if idx < len(self.context) else ""
                for idx in paragraphs
            ]
            for paragraphs in instances[self.para_col]
        ]

        # Flatten
        questions = list(chain.from_iterable(questions))
        paragraphs = list(chain.from_iterable(paragraphs))

        # Tokenize
        tokenized_examples = self.tokenizer(
            questions,
            paragraphs,
            truncation=True,
            max_length=self.max_seq_length,
            padding=False
        )

        # Un-flatten
        result = {
            k: [
                v[i: i+self.num_para]
                for i in range(0, len(v), self.num_para)
            ]
            for k, v in tokenized_examples.items()
        }

        if self.rel_col in instances:
            result[self.label_col] = [
                paragraphs.index(relevant)
                for paragraphs, relevant in zip(instances[self.para_col], instances[self.rel_col])
            ]

        return result


    def build_dataset(self, file_path: Union[str, Path]) -> Dataset:
        data = self.load_json(file_path)
        keys = data[0].keys()
        dataset = Dataset.from_dict({
            key: [item[key] for item in data]
            for key in keys
        })
        return dataset.map(
            self.preprocess,
            batched=True,
            num_proc=None,
            load_from_cache_file=True,
        )


In [6]:
data_manager = DataManager(
    tokenizer=tokenizer,
    context_path=data_dir / "context.json",
    train_path=data_dir / "train.json",
    valid_path=data_dir / "valid.json",
    # test_path=data_dir / "test.json"
)
print(f"Context Length: {len(data_manager.context)}")
print(f"Train Length: {data_manager.train_dataset.num_rows}")
print(f"Valid Length: {data_manager.valid_dataset.num_rows}")
# print(f"Test Length: {data_manager.test_dataset.num_rows}")

100%|██████████| 22/22 [00:25<00:00,  1.14s/ba]
100%|██████████| 4/4 [00:03<00:00,  1.14ba/s]

Context Length: 9013
Train Length: 21714
Valid Length: 3009





In [7]:
import torch

@dataclass
class DataCollator:

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    para_col: str = "paragraphs"
    label_col: str = "labels"

    def __call__(self, features):
        labels = [feature.pop(self.label_col) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])

        # Flatten
        flattened_features  = [
            [
                {k: v[i] for k, v in feature.items()}
                for i in range(num_choices)
            ]
            for feature in features
        ]
        flattened_features = list(chain.from_iterable(flattened_features))

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=None,
            return_tensors="pt",
        )

        # Un-flatten
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        # Add back labels
        batch[self.label_col] = torch.tensor(labels, dtype=torch.int64)
        return batch

In [8]:
data_collator = DataCollator(
    tokenizer=tokenizer
)

In [9]:
from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice
)

config = AutoConfig.from_pretrained(
    "bert-base-chinese",
    cache_dir=None,
    revision="main",
    use_auth_token=None,
)

model = AutoModelForMultipleChoice.from_config(config)

# model = AutoModelForMultipleChoice.from_pretrained(
#     "bert-base-chinese",
#     from_tf=False,
#     config=config,
#     cache_dir=None,
#     revision="main",
#     use_auth_token=None,
# )

In [10]:
import numpy as np
def compute_metrics(eval_predictions):
    predictions, label_ids = eval_predictions
    preds = np.argmax(predictions, axis=1)
    return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()}

In [13]:
from transformers import (
    TrainingArguments,
    Trainer
)

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir="./tmp",
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        gradient_accumulation_steps=1,
        learning_rate=5e-5,
        seed=1123,
        # no_cuda=True
    ),
    train_dataset=data_manager.train_dataset,
    eval_dataset=data_manager.valid_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [14]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BertForMultipleChoice.forward` and have been ignored: answer, relevant, id, question, paragraphs. If answer, relevant, id, question, paragraphs are not expected by `BertForMultipleChoice.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3009
  Batch size = 64


{'eval_loss': 1.3866825103759766,
 'eval_accuracy': 0.21103356778621674,
 'eval_runtime': 73.7698,
 'eval_samples_per_second': 40.789,
 'eval_steps_per_second': 0.651}