# Create training dataset

In [1]:
from datasets import load_dataset

dataset = load_dataset("pythainlp/han-corf-dataset-v1.0")

Found cached dataset parquet (/home/poomphob/.cache/huggingface/datasets/pythainlp___parquet/pythainlp--han-corf-dataset-v1.0-cffcdc2a501e26f3/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
dataset["train"][0]["text"][0:1]

'ศ'

In [7]:
import json

def flatten_list_of_lists(lst):
    return [elem for sublst in lst for elem in sublst]

def change_index_to_word_level(example):
    tokens = []
    split_indexes = sorted(set(flatten_list_of_lists(flatten_list_of_lists(example["clusters"])) + [0, len(example["text"])]))
    cluster_map = {}
    for start_index, stop_index in zip(split_indexes, split_indexes[1:]):
        tokens.append(example["text"][start_index:stop_index])
        cluster_map[start_index] = len(tokens) - 1
    cluster_map[stop_index] = len(tokens)
    new_clusters = [[(cluster_map[start], cluster_map[stop]) for start, stop in cluster] for cluster in example["clusters"] ]
    return tokens, new_clusters


# Each line is a json object that represents a single example
# The json object contains the following fields:
#   doc_key: document key
#   sentences: list of sentences, each sentence is a list of words
#   clusters: list of clusters, each cluster is a list of mentions, each mention is a list of [sentence_idx, start_token_idx, end_token_idx]

with open("test.jsonl", 'w') as f:
    for doc_key, example in enumerate(dataset["test"]):
        tokens = []
        tokens, new_clusters = change_index_to_word_level(example)
        
        new_cluster_strings = sorted([tuple(sorted(["".join(tokens[start:stop]) for start, stop in cluster])) for cluster in new_clusters])
        old_cluster_strings = sorted([tuple(sorted(clusters_string)) for clusters_string in example["clusters_strings"]])

        assert new_cluster_strings == old_cluster_strings, f"new_cluster_strings: {new_cluster_strings}\nold_cluster_strings: {old_cluster_strings}"
        final_json = {
            "doc_key": doc_key,
            "sentences": [tokens],
            "clusters": new_clusters
        }
        f.write(json.dumps(final_json, ensure_ascii=False) + "\n")

In [None]:
dataset

In [50]:
old_cluster_strings

[('ตน',
  'ตน',
  'ตน',
  'ตน',
  'ตน',
  'ตน',
  'ตัวเอง',
  'นายอนุทิน',
  'นายอนุทิน',
  'นายอนุทิน',
  'นายอนุทิน ชาญวีรกูล',
  'ผม',
  'ผม',
  'ผม',
  'อนุทิน',
  'ไอ้หนู'),
 ('นายประเดิมชัย',
  'นายประเดิมชัย',
  'นายประเดิมชัย',
  'นายประเดิมชัย บุญช่วยเหลือ',
  'ประเดิมชัย บุญช่วยเหลือ',
  'เขา',
  'เขา',
  'เขา',
  'เขา',
  'เขา',
  'เขา',
  'เขา',
  'เขา',
  'เขา')]

In [47]:
new_cluster_strings

[('ตัวเอง', 'ศรี จันทน')]

In [29]:
example["text"]

'ศรี จันทน รองประธานสหพันธ์มวยกัมพูชา ออกมาโพสต์ผ่านเฟซบุ๊กของตัวเอง หลังจากที่ เซียน เลวี นักชกกุน แขมร์ ขวัญใจเจ้าถิ่น พ่ายน็อกยก 4 ให้กับ แซมมี่ บัญชาเมฆ นักชกไทย ในการแข่งขันกุน แขมร์ ที่เสียมราฐ ประเทศกัมพูชา เมื่อวันที่ 15 เมษายนที่ผ่านมา'

## Try to create coref dataset

In [2]:
from data import CorefDataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("pythainlp/han-coref-v1.0")

In [3]:
from collections import namedtuple
import json
import pickle
from tqdm.auto import tqdm

def flatten_list_of_lists(lst):
    return [elem for sublst in lst for elem in sublst]

def parse_jsonlines(file_path):
    """Parse the jsonlines file into a list of examples.
    Args:
        file_path: path to the jsonlines file
    Returns:
        examples: list of doc_key, input_words, clusters
        max_mention_num: maximum number of mentions in a single example
        max_cluster_size: maximum number of mentions in a single cluster
        max_num_clusters: maximum number of clusters in a single example
        """
    examples = []
    max_mention_num = -1
    max_cluster_size = -1
    max_num_clusters = -1
    with open(file_path, 'r') as f:
        for line in f:
            # Each line is a json object that represents a single example
            # The json object contains the following fields:
            #   doc_key: document key
            #   sentences: list of sentences, each sentence is a list of words
            #   clusters: list of clusters, each cluster is a list of mentions, each mention is a list of [sentence_idx, start_token_idx, end_token_idx]

            d = json.loads(line.strip())
            doc_key = d["doc_key"]
            input_words = flatten_list_of_lists(d["sentences"])
            clusters = d["clusters"]

            # Max mention num is the maximum number of mentions in a single example
            max_mention_num = max(max_mention_num, len(flatten_list_of_lists(clusters)))
            # Max cluster size is the maximum number of mentions in a single cluster
            max_cluster_size = max(max_cluster_size, max(len(cluster) for cluster in clusters) if clusters else 0)
            # Max num clusters is the maximum number of clusters in a single example
            max_num_clusters = max(max_num_clusters, len(clusters) if clusters else 0)
            examples.append((doc_key, input_words, clusters))
    return examples, max_mention_num, max_cluster_size, max_num_clusters

examples, max_mention_num, max_cluster_size, max_num_clusters = parse_jsonlines("/home/poomphob/Desktop/Thesis/s2e_coref/data/train.jsonl")

CorefExample = namedtuple("CorefExample", ["token_ids", "clusters"])
def tokenize(examples, tokenizer, max_seq_length):
        """Tokenize the examples."""
        coref_examples = []
        lengths = []
        num_examples_filtered = 0
        for doc_key, words, clusters in tqdm(examples):
            word_idx_to_start_token_idx = dict()
            word_idx_to_end_token_idx = dict()
            end_token_idx_to_word_idx = [0]  # for <s>

            token_ids = []
            for idx, word in enumerate(words):
                word_idx_to_start_token_idx[idx] = len(token_ids) + 1  # +1 for <s>
                tokenized = tokenizer.encode(" " + word, add_special_tokens=False)
                print("word: ", word)
                print("tokenized: ", tokenized)
                for _ in range(len(tokenized)):
                    end_token_idx_to_word_idx.append(idx)
                token_ids.extend(tokenized)
                word_idx_to_end_token_idx[idx] = len(token_ids)  # old_seq_len + 1 (for <s>) + len(tokenized_word) - 1 (we start counting from zero) = len(token_ids)

            if 0 < max_seq_length < len(token_ids):
                num_examples_filtered += 1
                continue
            
            print("clusters: ", clusters)
            print("word_idx_to_start_token_idx: ", word_idx_to_start_token_idx)
            print("word_idx_to_end_token_idx: ", word_idx_to_end_token_idx)
            new_clusters = [
                [(word_idx_to_start_token_idx[start], word_idx_to_end_token_idx[end-1]) for start, end in cluster] for
                cluster in clusters]
            lengths.append(len(token_ids))

            # CorefExample = namedtuple("CorefExample", ["token_ids", "clusters"])
            # Example: 
            # Text = "John Smith is a nice guy. He lives in London."
            # CorefExample = {
            #  token_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            #  clusters: [[(0, 1), (6, 6)]]
            coref_examples.append(((doc_key, end_token_idx_to_word_idx), CorefExample(token_ids=token_ids, clusters=new_clusters)))
        return coref_examples, lengths, num_examples_filtered

tokenize(examples, tokenizer, max_seq_length=512)

  0%|          | 0/1039 [00:00<?, ?it/s]

word:  ศรี จันทน
tokenized:  [401318, 465378, 382883]
word:   รองประธานสหพันธ์มวยกัมพูชา ออกมาโพสต์ผ่านเฟซบุ๊กของ
tokenized:  [399194, 409722, 421185, 403599, 416303, 442050, 443302, 401341, 385112, 453216, 382519]
word:  ตัวเอง
tokenized:  [6, 386675]
word:   หลังจากที่ เซียน เลวี นักชกกุน แขมร์ ขวัญใจเจ้าถิ่น พ่ายน็อกยก 4 ให้กับ แซมมี่ บัญชาเมฆ นักชกไทย ในการแข่งขันกุน แขมร์ ที่เสียมราฐ ประเทศกัมพูชา เมื่อวันที่ 15 เมษายนที่ผ่านมา
tokenized:  [416733, 6, 464674, 412415, 389515, 399517, 458370, 495355, 416946, 446496, 385866, 6, 512877, 521209, 6, 468383, 382883, 420720, 390064, 184, 466698, 6, 404685, 488550, 6, 500683, 435817, 399517, 458370, 383615, 6, 489864, 495355, 416946, 446496, 385866, 382993, 386250, 382875, 386364, 395124, 399580, 442050, 403018, 328, 400649, 394474]
clusters:  [[[0, 1], [2, 3]]]
word_idx_to_start_token_idx:  {0: 1, 1: 4, 2: 15, 3: 17}
word_idx_to_end_token_idx:  {0: 3, 1: 14, 2: 16, 3: 63}
word:  เพียร์ส มอร์แกน พิธีกรฝีปากกล้าชาวอังกฤษซึ่งสนิทสนมกับ
token

([((0,
    [0,
     0,
     0,
     0,
     1,
     1,
     1,
     1,
     1,
     1,
     1,
     1,
     1,
     1,
     1,
     2,
     2,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3,
     3]),
   CorefExample(token_ids=[401318, 465378, 382883, 399194, 409722, 421185, 403599, 416303, 442050, 443302, 401341, 385112, 453216, 382519, 6, 386675, 416733, 6, 464674, 412415, 389515, 399517, 458370, 495355, 416946, 446496, 385866, 6, 512877, 521209, 6, 468383, 382883, 420720, 390064, 184, 466698, 6, 404685, 488550, 6, 500683, 435817, 399517, 458370, 383615, 6, 489864, 495355, 416946, 446496, 385866, 382993, 386250, 382875, 386364, 395124, 399580, 442050, 403018, 328, 40064

In [7]:
examples[9]

(9,
 ['อีเจี๊ยบ เลียบด่วน เล่าเหตุ',
  'วัยรุ่นแถวบ้าน',
  'ถูกรุมทำร้าย หลังไปขโขมยหอมแก้ม',
  'ผู้หญิง',
  ' สุดท้ายผัว',
  'เขา',
  'และเพื่อนมารุมสั่งสอน บอกสงสาร',
  'จิ้งเหลนน้อย'],
 [[[1, 2], [7, 8]], [[3, 4], [5, 6]]])

In [None]:
class CorefDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_seq_length=-1):
        self.tokenizer = tokenizer
        logger.info(f"Reading dataset from {file_path}")
        examples, self.max_mention_num, self.max_cluster_size, self.max_num_clusters = self._parse_jsonlines(file_path)
        self.max_seq_length = max_seq_length
        self.examples, self.lengths, self.num_examples_filtered = self._tokenize(examples)
        logger.info(
            f"Finished preprocessing Coref dataset. {len(self.examples)} examples were extracted, {self.num_examples_filtered} were filtered due to sequence length.")



    def _tokenize(self, examples):
        """Tokenize the examples."""
        coref_examples = []
        lengths = []
        num_examples_filtered = 0
        for doc_key, words, clusters in examples:
            word_idx_to_start_token_idx = dict()
            word_idx_to_end_token_idx = dict()
            end_token_idx_to_word_idx = [0]  # for <s>

            token_ids = []
            for idx, word in enumerate(words):
                word_idx_to_start_token_idx[idx] = len(token_ids) + 1  # +1 for <s>
                tokenized = self.tokenizer.encode(" " + word, add_special_tokens=False)
                for _ in range(len(tokenized)):
                    end_token_idx_to_word_idx.append(idx)
                token_ids.extend(tokenized)
                word_idx_to_end_token_idx[idx] = len(token_ids)  # old_seq_len + 1 (for <s>) + len(tokenized_word) - 1 (we start counting from zero) = len(token_ids)

            if 0 < self.max_seq_length < len(token_ids):
                num_examples_filtered += 1
                continue

            new_clusters = [
                [(word_idx_to_start_token_idx[start], word_idx_to_end_token_idx[end]) for start, end in cluster] for
                cluster in clusters]
            lengths.append(len(token_ids))

            # CorefExample = namedtuple("CorefExample", ["token_ids", "clusters"])
            # Example: 
            # Text = "John Smith is a nice guy. He lives in London."
            # CorefExample = {
            #  token_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            #  clusters: [[(0, 1), (6, 6)]]
            coref_examples.append(((doc_key, end_token_idx_to_word_idx), CorefExample(token_ids=token_ids, clusters=new_clusters)))
        return coref_examples, lengths, num_examples_filtered

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return self.examples[item]

    def pad_clusters_inside(self, clusters):
        return [cluster + [(NULL_ID_FOR_COREF, NULL_ID_FOR_COREF)] * (self.max_cluster_size - len(cluster)) for cluster
                in clusters]

    def pad_clusters_outside(self, clusters):
        return clusters + [[]] * (self.max_num_clusters - len(clusters))

    def pad_clusters(self, clusters):
        clusters = self.pad_clusters_outside(clusters)
        clusters = self.pad_clusters_inside(clusters)
        return clusters

    def pad_batch(self, batch, max_length):
        max_length += 2  # we have additional two special tokens <s>, </s>
        padded_batch = []
        for example in batch:
            encoded_dict = self.tokenizer.encode_plus(example[0],
                                                      add_special_tokens=True,
                                                      pad_to_max_length=True,
                                                      max_length=max_length,
                                                      return_attention_mask=True,
                                                      return_tensors='pt')
            clusters = self.pad_clusters(example.clusters)
            example = (encoded_dict["input_ids"], encoded_dict["attention_mask"]) + (torch.tensor(clusters),)
            padded_batch.append(example)
        tensored_batch = tuple(torch.stack([example[i].squeeze() for example in padded_batch], dim=0) for i in range(len(example)))
        return tensored_batch


def get_dataset(args, tokenizer, evaluate=False):
    read_from_cache, file_path = False, ''
    if evaluate and os.path.exists(args.predict_file_cache):
        file_path = args.predict_file_cache
        read_from_cache = True
    elif (not evaluate) and os.path.exists(args.train_file_cache):
        file_path = args.train_file_cache
        read_from_cache = True

    if read_from_cache:
        logger.info(f"Reading dataset from {file_path}")
        with open(file_path, 'rb') as f:
            return pickle.load(f)

    file_path, cache_path = (args.predict_file, args.predict_file_cache) if evaluate else (args.train_file, args.train_file_cache)

    coref_dataset = CorefDataset(file_path, tokenizer, max_seq_length=args.max_seq_length)
    with open(cache_path, 'wb') as f:
        pickle.dump(coref_dataset, f)

    return coref_dataset




# Train data

In [None]:
export OUTPUT_DIR=/home/poomphob/Desktop/Thesis/s2e_coref/output
export CACHE_DIR=/home/poomphob/Desktop/Thesis/s2e_coref/cache
export DATA_DIR=/home/poomphob/Desktop/Thesis/s2e_coref/data
export MODEL_NAME=pythainlp/han-coref-v1.0

In [8]:
python3 run_coref.py \
        --output_dir=$OUTPUT_DIR \
        --cache_dir=$CACHE_DIR \
        --model_type=xlm-roberta \
        --model_name_or_path=$MODEL_NAME \
        --tokenizer_name=$MODEL_NAME \
        --config_name=$MODEL_NAME  \
        --train_file=$DATA_DIR/train.jsonl \
        --predict_file=$DATA_DIR/dev.jsonl \
        --do_train \
        --do_eval \
        --num_train_epochs=129 \
        --logging_steps=500 \
        --save_steps=3000 \
        --eval_steps=1000 \
        --max_seq_length=4096 \
        --train_file_cache=$DATA_DIR/train.4096.pkl \
        --predict_file_cache=$DATA_DIR/dev.4096.pkl \
        --gradient_accumulation_steps=1 \
        --normalise_loss \
        --max_total_seq_len=5000 \
        --experiment_name="s2e-model" \
        --warmup_steps=5600 \
        --adam_epsilon=1e-6 \
        --head_learning_rate=3e-4 \
        --learning_rate=1e-5 \
        --adam_beta2=0.98 \
        --weight_decay=0.01 \
        --dropout_prob=0.3 \
        --save_if_best \
        --top_lambda=0.4  \
        --tensorboard_dir=$OUTPUT_DIR/tb \
        --conll_path_for_eval=$DATA_DIR/dev.english.v4_gold_conll \
        --overwrite_output_dir

/bin/bash: line 1: python: command not found


In [4]:
with open("/home/poomphob/Desktop/Thesis/s2e_coref/data/train.4096.pkl", 'rb') as f:
    data = pickle.load(f)

In [6]:
data.examples

[((0,
   [0,
    0,
    0,
    0,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    2,
    2,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3,
    3]),
  CorefExample(token_ids=[401318, 465378, 382883, 399194, 409722, 421185, 403599, 416303, 442050, 443302, 401341, 385112, 453216, 382519, 6, 386675, 416733, 6, 464674, 412415, 389515, 399517, 458370, 495355, 416946, 446496, 385866, 6, 512877, 521209, 6, 468383, 382883, 420720, 390064, 184, 466698, 6, 404685, 488550, 6, 500683, 435817, 399517, 458370, 383615, 6, 489864, 495355, 416946, 446496, 385866, 382993, 386250, 382875, 386364, 395124, 399580, 442050, 403018, 328, 400649, 394474], clusters=[[(1, 3), (15, 16)]])),
 ((1,
   [0,
    0,
 