In [1]:
import pickle
import sentencepiece as spm
import json
from glob import glob
import os
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
from tensor2tensor.layers import modalities
import tensorflow as tf
from tqdm import tqdm

In [2]:
vocab = 'sp10m.cased.t5.model'
sp = spm.SentencePieceProcessor()
sp.Load(vocab)


class Encoder:
    def __init__(self, sp):
        self.sp = sp
        self.vocab_size = sp.GetPieceSize() + 100

    def encode(self, s):
        return self.sp.EncodeAsIds(s)

    def decode(self, ids, strip_extraneous = False):
        return self.sp.DecodeIds(list(ids))

In [3]:
d = [
    {'class': 0, 'Description': 'PAD', 'salah': '', 'betul': ''},
    {
        'class': 1,
        'Description': 'kesambungan subwords',
        'salah': '',
        'betul': '',
    },
    {
        'class': 2,
        'Description': 'tiada kesalahan',
        'salah': '',
        'betul': '',
    },
    {
        'class': 3,
        'Description': 'kesalahan frasa nama, Perkara yang diterangkan mesti mendahului "penerang"',
        'salah': 'Cili sos',
        'betul': 'sos cili',
    },
    {
        'class': 4,
        'Description': 'kesalahan kata jamak',
        'salah': 'mereka-mereka',
        'betul': 'mereka',
    },
    {
        'class': 5,
        'Description': 'kesalahan kata penguat',
        'salah': 'sangat tinggi sekali',
        'betul': 'sangat tinggi',
    },
    {
        'class': 6,
        'Description': 'kata adjektif dan imbuhan "ter" tanpa penguat.',
        'salah': 'Sani mendapat markah yang tertinggi sekali.',
        'betul': 'Sani mendapat markah yang tertinggi.',
    },
    {
        'class': 7,
        'Description': 'kesalahan kata hubung',
        'salah': 'Sally sedang membaca bila saya tiba di rumahnya.',
        'betul': 'Sally sedang membaca apabila saya tiba di rumahnya.',
    },
    {
        'class': 8,
        'Description': 'kesalahan kata bilangan',
        'salah': 'Beribu peniaga tidak membayar cukai pendapatan.',
        'betul': 'Beribu-ribu peniaga tidak membayar cukai pendapatan',
    },
    {
        'class': 9,
        'Description': 'kesalahan kata sendi',
        'salah': 'Umar telah berpindah daripada sekolah ini bulan lalu.',
        'betul': 'Umar telah berpindah dari sekolah ini bulan lalu.',
    },
    {
        'class': 10,
        'Description': 'kesalahan penjodoh bilangan',
        'salah': 'Setiap orang pelajar',
        'betul': 'Setiap pelajar.',
    },
    {
        'class': 11,
        'Description': 'kesalahan kata ganti diri',
        'salah': 'Pencuri itu telah ditangkap. Beliau dibawa ke balai polis.',
        'betul': 'Pencuri itu telah ditangkap. Dia dibawa ke balai polis.',
    },
    {
        'class': 12,
        'Description': 'kesalahan ayat pasif',
        'salah': 'Cerpen itu telah dikarang oleh saya.',
        'betul': 'Cerpen itu telah saya karang.',
    },
    {
        'class': 13,
        'Description': 'kesalahan kata tanya',
        'salah': 'Kamu berasal dari manakah ?',
        'betul': 'Kamu berasal dari mana ?',
    },
    {
        'class': 14,
        'Description': 'kesalahan tanda baca',
        'salah': 'Kamu berasal dari manakah .',
        'betul': 'Kamu berasal dari mana ?',
    },
    {
        'class': 15,
        'Description': 'kesalahan kata kerja tak transitif',
        'salah': 'Dia kata kepada saya',
        'betul': 'Dia berkata kepada saya',
    },
    {
        'class': 16,
        'Description': 'kesalahan kata kerja transitif',
        'salah': 'Dia suka baca buku',
        'betul': 'Dia suka membaca buku',
    },
    {
        'class': 17,
        'Description': 'penggunaan kata yang tidak tepat',
        'salah': 'Tembuk Besar negeri Cina dibina oleh Shih Huang Ti.',
        'betul': 'Tembok Besar negeri Cina dibina oleh Shih Huang Ti',
    },
]


class Tatabahasa:
    def __init__(self, d):
        self.d = d
        self.kesalahan = {i['Description']: no for no, i in enumerate(self.d)}
        self.reverse_kesalahan = {v: k for k, v in self.kesalahan.items()}
        self.vocab_size = len(self.d)

    def encode(self, s):
        return [self.kesalahan[i] for i in s]

    def decode(self, ids, strip_extraneous = False):
        return [self.reverse_kesalahan[i] for i in ids]

In [4]:
def get_xy(row, encoder):
    x, y, tag = [], [], []

    for i in range(len(row[0])):
        t = encoder.encode(row[0][i][0])
        tag.extend([row[1][i][1]] * len(t))
        y.extend(t)
        t = encoder.encode(row[1][i][0])
        x.extend(t)
        
    # EOS
    x.append(1)
    y.append(1)
    tag.append(0)
    
    return x, y, tag

In [5]:
tags = []

In [6]:
@registry.register_problem
class Grammar(text_problems.Text2TextProblem):
    """grammatical error correction."""

    def feature_encoders(self, data_dir):
        encoder = Encoder(sp)
        t = Tatabahasa(d)
        return {'inputs': encoder, 'targets': encoder, 'targets_error_tag': t}

    def hparams(self, defaults, model_hparams):
        super(Grammar, self).hparams(defaults, model_hparams)
        if 'use_error_tags' not in model_hparams:
            model_hparams.add_hparam('use_error_tags', True)
        if 'middle_prediction' not in model_hparams:
            model_hparams.add_hparam('middle_prediction', False)
        if 'middle_prediction_layer_factor' not in model_hparams:
            model_hparams.add_hparam('middle_prediction_layer_factor', 2)
        if 'ffn_in_prediction_cascade' not in model_hparams:
            model_hparams.add_hparam('ffn_in_prediction_cascade', 1)
        if 'error_tag_embed_size' not in model_hparams:
            model_hparams.add_hparam('error_tag_embed_size', 12)
        if model_hparams.use_error_tags:
            defaults.modality[
                'targets_error_tag'
            ] = modalities.ModalityType.SYMBOL
            error_tag_vocab_size = self._encoders[
                'targets_error_tag'
            ].vocab_size
            defaults.vocab_size['targets_error_tag'] = error_tag_vocab_size

    def example_reading_spec(self):
        data_fields, _ = super(Seq2edits, self).example_reading_spec()
        data_fields['targets_error_tag'] = tf.compat.v1.VarLenFeature(tf.compat.v1.int64)
        return data_fields, None

    @property
    def approx_vocab_size(self):
        return 32100

    @property
    def is_generate_per_split(self):
        return False

    @property
    def dataset_splits(self):
        return [
            {'split': problem.DatasetSplit.TRAIN, 'shards': 200},
            {'split': problem.DatasetSplit.EVAL, 'shards': 1},
        ]
    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        
        from glob import glob
        files = glob('/home/husein/pure-text/*tatabahasa-*.pkl')
        encoder = Encoder(sp)
        
        for file in files:
            
            with open(file, 'rb') as fopen:
                data = pickle.load(fopen)
        
            for row in tqdm(data):
                x, y, tag = get_xy(row, encoder)
                if len(y) != len(tag):
                    continue
                tags.extend(tag)
                yield {
                    'inputs': x,
                    'targets': y,
                    'targets_error_tag': tag,
                }

    def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):

        generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
        for sample in generator:
            yield sample

In [7]:
import os
import tensorflow as tf

os.system('rm -rf t2t-tatabahasa/data')
DATA_DIR = os.path.expanduser('t2t-tatabahasa/data')
TMP_DIR = os.path.expanduser('t2t-tatabahasa/tmp')

In [8]:
tf.compat.v1.io.gfile.mkdir(DATA_DIR)
tf.compat.v1.io.gfile.mkdir(TMP_DIR)

In [9]:
from tensor2tensor.utils import registry
from tensor2tensor import problems

In [10]:
PROBLEM = 'grammar'
t2t_problem = problems.problem(PROBLEM)
t2t_problem.generate_data(DATA_DIR, TMP_DIR)

  0%|          | 0/99939 [00:00<?, ?it/s]

INFO:tensorflow:Generating case 0.


INFO:tensorflow:Generating case 0.
100%|██████████| 99939/99939 [00:38<00:00, 2587.49it/s]
  0%|          | 0/99898 [00:00<?, ?it/s]

INFO:tensorflow:Generating case 100000.


INFO:tensorflow:Generating case 100000.
100%|██████████| 99898/99898 [00:43<00:00, 2299.05it/s]
  0%|          | 0/99949 [00:00<?, ?it/s]

INFO:tensorflow:Generating case 200000.


INFO:tensorflow:Generating case 200000.
100%|██████████| 99949/99949 [00:46<00:00, 2170.53it/s]
  0%|          | 0/37236 [00:00<?, ?it/s]

INFO:tensorflow:Generating case 300000.


INFO:tensorflow:Generating case 300000.
100%|██████████| 37236/37236 [00:17<00:00, 2111.24it/s]
 63%|██████▎   | 62880/99933 [00:25<00:16, 2293.53it/s]

INFO:tensorflow:Generating case 400000.


INFO:tensorflow:Generating case 400000.
100%|██████████| 99933/99933 [00:39<00:00, 2528.48it/s]
 63%|██████▎   | 62952/99938 [00:27<00:17, 2138.41it/s]

INFO:tensorflow:Generating case 500000.


INFO:tensorflow:Generating case 500000.
100%|██████████| 99938/99938 [00:42<00:00, 2341.33it/s]
 63%|██████▎   | 62884/99955 [00:23<00:12, 3063.73it/s]

INFO:tensorflow:Generating case 600000.


INFO:tensorflow:Generating case 600000.
100%|██████████| 99955/99955 [00:38<00:00, 2618.91it/s]
 63%|██████▎   | 63113/99907 [00:27<00:16, 2192.06it/s]

INFO:tensorflow:Generating case 700000.


INFO:tensorflow:Generating case 700000.
100%|██████████| 99907/99907 [00:44<00:00, 2241.94it/s]
 63%|██████▎   | 63198/99956 [00:25<00:15, 2381.59it/s]

INFO:tensorflow:Generating case 800000.


INFO:tensorflow:Generating case 800000.
100%|██████████| 99956/99956 [00:39<00:00, 2501.46it/s]
 63%|██████▎   | 63053/99954 [00:24<00:14, 2485.16it/s]

INFO:tensorflow:Generating case 900000.


INFO:tensorflow:Generating case 900000.
100%|██████████| 99954/99954 [00:38<00:00, 2592.45it/s]
 63%|██████▎   | 63291/99942 [00:27<00:15, 2305.21it/s]

INFO:tensorflow:Generating case 1000000.


INFO:tensorflow:Generating case 1000000.
100%|██████████| 99942/99942 [00:43<00:00, 2291.56it/s]
 63%|██████▎   | 63048/99898 [00:27<00:12, 3022.30it/s]

INFO:tensorflow:Generating case 1100000.


INFO:tensorflow:Generating case 1100000.
100%|██████████| 99898/99898 [00:43<00:00, 2306.03it/s]
 63%|██████▎   | 63303/99922 [00:26<00:14, 2496.53it/s]

INFO:tensorflow:Generating case 1200000.


INFO:tensorflow:Generating case 1200000.
100%|██████████| 99922/99922 [00:42<00:00, 2327.32it/s]
 63%|██████▎   | 63428/99955 [00:28<00:17, 2082.24it/s]

INFO:tensorflow:Generating case 1300000.


INFO:tensorflow:Generating case 1300000.
100%|██████████| 99955/99955 [00:43<00:00, 2292.16it/s]
 63%|██████▎   | 63453/99935 [00:27<00:15, 2354.98it/s]

INFO:tensorflow:Generating case 1400000.


INFO:tensorflow:Generating case 1400000.
100%|██████████| 99935/99935 [00:43<00:00, 2300.44it/s]
 64%|██████▎   | 63666/99919 [00:28<00:15, 2295.85it/s]

INFO:tensorflow:Generating case 1500000.


INFO:tensorflow:Generating case 1500000.
100%|██████████| 99919/99919 [00:43<00:00, 2274.73it/s]
 64%|██████▎   | 63620/99880 [00:28<00:16, 2218.53it/s]

INFO:tensorflow:Generating case 1600000.


INFO:tensorflow:Generating case 1600000.
100%|██████████| 99880/99880 [00:44<00:00, 2262.17it/s]
 64%|██████▍   | 63797/99968 [00:25<00:16, 2217.92it/s]

INFO:tensorflow:Generating case 1700000.


INFO:tensorflow:Generating case 1700000.
100%|██████████| 99968/99968 [00:38<00:00, 2566.83it/s]
 64%|██████▍   | 63776/99906 [00:27<00:18, 1976.94it/s]

INFO:tensorflow:Generating case 1800000.


INFO:tensorflow:Generating case 1800000.
100%|██████████| 99906/99906 [00:42<00:00, 2323.89it/s]
 64%|██████▍   | 63997/99965 [00:24<00:11, 3077.95it/s]

INFO:tensorflow:Generating case 1900000.


INFO:tensorflow:Generating case 1900000.
100%|██████████| 99965/99965 [00:37<00:00, 2654.81it/s]
 64%|██████▍   | 63844/99926 [00:26<00:16, 2247.96it/s]

INFO:tensorflow:Generating case 2000000.


INFO:tensorflow:Generating case 2000000.
100%|██████████| 99926/99926 [00:42<00:00, 2357.70it/s]


INFO:tensorflow:Generated 2035881 Examples


INFO:tensorflow:Generated 2035881 Examples


INFO:tensorflow:Shuffling data...


INFO:tensorflow:Shuffling data...


Instructions for updating:
Use eager execution and: 
`tf.compat.v1.data.TFRecordDataset(path)`


Instructions for updating:
Use eager execution and: 
`tf.compat.v1.data.TFRecordDataset(path)`


INFO:tensorflow:Data shuffled.


INFO:tensorflow:Data shuffled.


In [12]:
import numpy as np

np.unique(tags, return_counts = True)

(array([ 0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17]),
 array([ 2035881, 43535622,  1310695,   295299,    13388,   114688,
          251414,    11864,   651255,   132101,   161983,    80888,
           11598,  1210119,   175657,    32830,   255391]))