In [None]:
from functools import partial
import json
from multiprocessing import Pool
import os
import re

import pandas as pd

In [None]:
# to remove url pattern
pattern = r"[(https?:\/\/)|(www\.)]+[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)"


def get_dataset(data_path):
    """ Load data and return dataset for training and validating.

    Args:
        data_path (str): Path to the data.
    Return:
        output (list of dict): [dict, dict, dict ...]
    """
    dataset = pd.read_csv(data_path, dtype=str)
    
    formatData = []
    for (idx, data) in dataset.iterrows():
        """
        processed: {
            'Abstract': [[4,5,6],[3,4,2],...],
            'Label': [[0,0,0,1,1,0],[1,0,0,0,1,0],...]
        }
        
        -> if test data
        processed: {
            'Abstract': [[4,5,6],[3,4,2],...],
            'Id': T0001
        }
        """
        processed = {}
        processed["Abstract"] = [re.sub(pattern, " ", sent) for sent in data["Abstract"].split("$$$")]
        if "Task 1" in data:
            processed["Label"] = [label_to_onehot(label) for label in data["Task 1"].split(" ")]
        else:
            processed["Id"] = data["Id"]
        formatData.append(processed)
    return formatData


def label_to_onehot(labels):
    """ Convert label to onehot .
        Args:
            labels (string): sentence's labels.
        Return:
            outputs (onehot list): sentence's onehot label.
    """
    label_dict = {"BACKGROUND": 0, "OBJECTIVES": 1, "METHODS": 2, "RESULTS": 3, "CONCLUSIONS": 4, "OTHERS": 5}
    onehot = [0] * 6
    for lable in labels.split("/"):
        onehot[label_dict[lable]] = 1
    return onehot

In [None]:
from tqdm import tqdm
import torch
from torch.utils.data import (
    DataLoader,
    RandomSampler,
    SequentialSampler,
    TensorDataset
)
from transformers import (
    AlbertTokenizer,
    BertTokenizer, 
    RobertaTokenizer, 
    XLNetTokenizer
)

In [None]:
def convert_examples_to_features(
    examples,
    max_length,
    tokenizer,
    pad_token_segment_id=0,
    pad_on_left=False,
    pad_token=0,
    mask_padding_with_zero=True,
    do_lower=False,
    is_test_data=False,
    extra_datas=None,
    model_mode="bert"
):
    """
    Loads a data file into a list of `InputFeatures`
    """
    
    def process(sentence, abstract):
        inputs = tokenizer.encode_plus(
            sentence,
            abstract,
            add_special_tokens=True,
            max_length=max_length,
        )
        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
            token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
        else:
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        assert len(input_ids) == max_length
        assert len(attention_mask) == max_length
        assert len(token_type_ids) == max_length
        
        return input_ids, attention_mask, token_type_ids
            

    pad_on_left=True if model_mode == "xlnet" else False # pad on the left for xlnet
    pad_token_segment_id=4 if model_mode == "xlnet" else 0
    features = []
    if extra_datas:
        for extra_data in tqdm(extra_datas, desc="convert examples to features"):
                abstract = extra_data.get("Abstract", [])
                sentence = extra_data.get("Sent", [])
                label = extra_data.get("Label", [])
                
                if do_lower:
                    abstract = abstract.lower()
                    sentence = sentence.lower()

                input_ids, attention_mask, token_type_ids = process(sentence, abstract)
                features.append(
                    {
                        "net_inputs": {
                            "input_ids": input_ids,
                            "attention_mask": attention_mask, 
                            "token_type_ids": token_type_ids
                        },
                        "label": label
                    }
                )       
    for example in tqdm(examples, desc="convert examples to features"):
        abstract = example.get("Abstract", [])
        assert len(abstract) > 0, "no abstract data!"
        if do_lower:
            abstract = [sent.lower() for sent in abstract]

        sentences = abstract
        abstract = "".join(abstract)

        if is_test_data:
            abstract_id = example.get("Id", None)
            assert abstract_id is not None, "No abstract_id!"

            for idx, sentence in enumerate(sentences):
                input_ids, attention_mask, token_type_ids = process(sentence, abstract)
                features.append(
                    {
                        "abstract_id": abstract_id,
                        "seq_id": idx+1, # start from 1
                        "net_inputs": {
                            "input_ids": input_ids,
                            "attention_mask": attention_mask, 
                            "token_type_ids": token_type_ids
                        },
                    }
                )
        else:
            labels = example.get("Label", None)
            assert labels is not None, "No label data!"

            for sentence, label in zip(sentences, labels):
                input_ids, attention_mask, token_type_ids = process(sentence, abstract)
                features.append(
                    {
                        "net_inputs": {
                            "input_ids": input_ids,
                            "attention_mask": attention_mask, 
                            "token_type_ids": token_type_ids
                        },
                        "label": label
                    }
                )
    return features

In [None]:
def convert_data_to_feature(k_id, data_dir, pretrained_weights, save_dir="data_bin", do_lower=False, max_length=512):
    tokenizer_dict = {
        "albert": AlbertTokenizer,
        "bert": BertTokenizer,
        "roberta": RobertaTokenizer,
        "xlnet": XLNetTokenizer,
    }
    
    mode_dict = {
        "albert": "albert",
        "bert": "bert",
        "roberta": "roberta",
        "xlnet": "xlnet",
    }

    # load tokeinzer
    model_tag = pretrained_weights.split("-")[0]
    tokenizer = tokenizer_dict[model_tag].from_pretrained(pretrained_weights)
    
    # load dataset
    train = get_dataset(os.path.join(data_dir, "trainset_{}.csv".format(k_id)))
    valid = get_dataset(os.path.join(data_dir, "validset_{}.csv".format(k_id)))
    
    # process
    train_features = convert_examples_to_features(examples=train, max_length=max_length, tokenizer=tokenizer, model_mode=mode_dict[model_tag], do_lower=do_lower)
    valid_features = convert_examples_to_features(examples=valid, max_length=max_length, tokenizer=tokenizer, model_mode=mode_dict[model_tag], do_lower=do_lower)
    
    
    save_feature_to_bin(train_features, save_dir, tag=pretrained_weights + "_{}".format(k_id), version="v2")
    save_feature_to_bin(valid_features, save_dir, split="valid", tag=pretrained_weights + "_{}".format(k_id), version="v2")


def save_feature_to_bin(features, save_dir, split="train", tag="", version="v1"):
    torch.save(features, os.path.join(save_dir, "{}_{}_{}.pt".format(split, tag, version)))

In [None]:
data_dir = "datasets/k10"
pretrained_weights = "roberta-large" 

func = partial(convert_data_to_feature, data_dir=data_dir, pretrained_weights=pretrained_weights)
with Pool(10) as pool:
    pool.map(func, range(10))