In [1]:
import re
import random

punctuation = '!()&%{}[];:\'",./?\\<>'

def remove_punc(string):
    string = re.sub('[^A-Za-z0-9 ]+', ' ', string)
    string = re.sub(r'[ ]+', ' ', string).strip()
    return string

def remove_punc_random(string, threshold = 0.3):
    result = []
    for c in string:
        if c in punctuation and random.random() >= threshold:
            continue
            
        result.append(c)
    return ''.join(result)

def add_punc_random(string, threshold = 0.7):
    string = string.split()
    for i in range(len(string)):
        if random.random() >= threshold and string[i][-1] not in punctuation:
            string[i] = string[i] + random.choice(punctuation)
    return ' '.join(string)

def add_space_after_punc(string, threshold = 0.3):
    string = string.split()
    for i in range(len(string)):
        if random.random() >= threshold and string[i][-1] in punctuation:
            string[i] = string[i][:-1] + ' ' + string[i][-1]
    return ' '.join(string)

def replace_random_punc(string):
    string = string.split()
    for i in range(len(string)):
        if string[i][-1] in punctuation:
            string[i] = string[i][:-1] + random.choice(punctuation)
    return ' '.join(string)

def random_upper(string, threshold = 0.5):
    string = string.split()
    for i in range(len(string)):
        if random.random() >= threshold:
            string[i] = string[i].upper()
    return ' '.join(string)

def random_title(string, threshold = 0.5):
    string = string.split()
    for i in range(len(string)):
        if random.random() >= threshold:
            string[i] = string[i].title()
    return ' '.join(string)

def random_char_upper(string, threshold = 0.7):
    result = []
    for c in string:
        if random.random() >= threshold:
            c = c.upper()
        result.append(c)
    return ''.join(result)

def remove_space(string, threshold = 0.5):
    string = [s for s in string if not (s == ' ' and random.random() >= threshold)]
    return ''.join(string)

funcs = [remove_punc, remove_punc_random, add_punc_random, add_space_after_punc, 
         random_upper, random_title, random_char_upper, replace_random_punc]

chain_funcs = [remove_punc_random, add_punc_random, add_space_after_punc, 
         random_upper, random_title, random_char_upper, replace_random_punc]

In [2]:
def package(string, repeat = 2, repeat_chain = 5, threshold = 0.5, maintain_space = False):
    
    result = [string]
    result.append(string.lower())
    result.append(string.upper())
    result.append(string.title())
    
    for _ in range(repeat):
    
        for func in funcs:
            result.append(func(string))

        for func in funcs:
            result.append(func(string.lower()))

        for func in funcs:
            result.append(func(string.upper()))

        for func in funcs:
            result.append(func(string.title()))
    
    for _ in range(repeat_chain):
        s = string[:]
        for func in chain_funcs:
            if random.random() > threshold:
                s = func(s)
        result.append(s)
        
    result = list(set(result))
    if not maintain_space:
        result = [remove_space(r) if random.random() >= 0.6 else r for r in result]
        result = list(set(result))
        
    return [(r, string) for r in result]


In [3]:
package('saya suka makan ayam dan daging')

[('SAYA SUKA MAKAN AYAM DAN DAGING', 'saya suka makan ayam dan daging'),
 ('sayasuka MakanAyam dan daging', 'saya suka makan ayam dan daging'),
 ('SAYA SUKA& MAKAN/ AYAM. DAN} DAGING', 'saya suka makan ayam dan daging'),
 ('SAYASUKA< MAKANAYAM] DANDAGING', 'saya suka makan ayam dan daging'),
 ('SAYA SUKA MAKAN AYAM DAN Daging', 'saya suka makan ayam dan daging'),
 ('Saya! Suka makan Ayam dan Daging', 'saya suka makan ayam dan daging'),
 ('saya Sukamakanayam Dan Daging', 'saya suka makan ayam dan daging'),
 ('Saya suka maKan ayaM Dan daGing', 'saya suka makan ayam dan daging'),
 ('Saya Suka MAKaN AYAm Dan DaginG', 'saya suka makan ayam dan daging'),
 ('SayA SUKA Makan Ayam DAn DAGiNg', 'saya suka makan ayam dan daging'),
 ('saya suka makan ayam dan daging', 'saya suka makan ayam dan daging'),
 ('Saya suka MakanAyamdandaging', 'saya suka makan ayam dan daging'),
 ('saya suka MAKAN AYAM DAN DAGING', 'saya suka makan ayam dan daging'),
 ('SAya suKamakaN aYam dAndaGInG', 'saya suka makan ay

In [4]:
from tqdm import tqdm

def loop(strings):
    results = []
    for i in tqdm(range(len(strings))):
        p = package(strings[i])
        results.extend(p)
    return results

def slide(strings, n = 2):
    result = []
    for i in range(0, len(strings), len(strings) - (n - 1)):
        result.append(strings[i: i + n])
    return result

In [5]:
files = ['/home/husein/pure-text/filtered-dumping-wiki.txt',
        '/home/husein/pure-text/dumping-cleaned-news.txt']

In [6]:
with open(files[0]) as fopen:
    data = fopen.read().split('\n')
    
results, result = [], []
for i in data:
    if len(i) and i[-1] != '.':
        i = i + '.'
    if not len(i) and len(result):
        results.append(result)
        result = []
    else:
        if len(i):
            result.append(i)
        
if len(result):
    results.append(result)
    
len(results)

363578

In [7]:
from tqdm import tqdm

def loop(strings):
    results = []
    for i in tqdm(range(len(strings))):
        try:
            slided = slide(strings[i])
            slided = [s for s in slided if len(s) > 1]
            for s in slided:
                s = ' '.join(s)
                p = package(s)
                results.extend(p)
        except:
            pass
    return results

In [8]:
import cleaning

results1 = cleaning.multiprocessing(random.sample(results, 70000), loop)

 94%|█████████▍| 4124/4375 [00:05<00:00, 792.72it/s]
100%|██████████| 4375/4375 [00:05<00:00, 790.97it/s]
100%|██████████| 4375/4375 [00:05<00:00, 760.16it/s]
100%|██████████| 4375/4375 [00:05<00:00, 756.20it/s]
 98%|█████████▊| 4295/4375 [00:05<00:00, 853.56it/s]
 99%|█████████▉| 4334/4375 [00:05<00:00, 906.80it/s]
 95%|█████████▍| 4143/4375 [00:05<00:00, 783.28it/s]
100%|██████████| 4375/4375 [00:05<00:00, 762.32it/s]
100%|██████████| 4375/4375 [00:05<00:00, 770.92it/s]
100%|██████████| 4375/4375 [00:06<00:00, 724.72it/s]
100%|██████████| 4375/4375 [00:06<00:00, 721.45it/s]
100%|██████████| 4375/4375 [00:05<00:00, 746.25it/s]
100%|██████████| 4375/4375 [00:06<00:00, 706.22it/s]
100%|██████████| 4375/4375 [00:05<00:00, 749.44it/s]
100%|██████████| 4375/4375 [00:05<00:00, 732.38it/s]
100%|██████████| 4375/4375 [00:05<00:00, 735.83it/s]


In [9]:
results1[:10], len(results1)

([("Saint-Broing-Les-Moines . Saint-Broing-Les-Moines Ialah Komun Di Jabatan Cote-D'Or Di Timur Perancis .",
   "Saint-Broing-les-Moines. Saint-Broing-les-Moines ialah komun di jabatan Cote-d'Or di timur Perancis."),
  ("SAINT-BROING-LES-MOINES. SAINT-BROING-LES-MOINESIALAHKOMUN,DI JABATAN&COTE-D'OR DITIMUR' PERANCIS.",
   "Saint-Broing-les-Moines. Saint-Broing-les-Moines ialah komun di jabatan Cote-d'Or di timur Perancis."),
  ('saint-broing-les-moines saint-broing-les-moines ialah komun di jabatan cote-dor di timur perancis',
   "Saint-Broing-les-Moines. Saint-Broing-les-Moines ialah komun di jabatan Cote-d'Or di timur Perancis."),
  ("Saint-Broing-Les-Moines.saint-broing-les-moines Ialahkomundi Jabatan cote-d'orDi TimurPerancis.",
   "Saint-Broing-les-Moines. Saint-Broing-les-Moines ialah komun di jabatan Cote-d'Or di timur Perancis."),
  ('SAINT-BROING-LES-MOINES SAINT-BROING-LES-MOINES IALAH KOMUN DI JABATAN COTE-DOR DI TIMUR PERANCIS.',
   "Saint-Broing-les-Moines. Saint-Broing-l

In [10]:
def cleaning_row(string):
    string = string.replace('\n', ' ').replace('\t', ' ')
    string = re.sub(r'[ ]+', ' ', string).strip()
    return string

In [11]:
import tensorflow as tf

with tf.compat.v1.io.gfile.GFile('true-case-multisentences-wiki.tsv', "w") as outfile:
    for i in tqdm(range(len(results1))):
        l = cleaning_row(results1[i][0])
        r = cleaning_row(results1[i][1])
        outfile.write("%s\t%s\n" % (l, r))

100%|██████████| 2584138/2584138 [01:13<00:00, 35009.00it/s]


In [12]:
with open(files[1]) as fopen:
    data = fopen.read().split('\n')
    
len(data)

3656919

In [13]:
results, result = [], []
for i in data:
    if len(i) and i[-1] != '.':
        i = i + '.'
    if not len(i) and len(result):
        results.append(result)
        result = []
    else:
        if len(i):
            result.append(i)
        
if len(result):
    results.append(result)

In [14]:
results = random.sample(results, 30000)
results1 = cleaning.multiprocessing(results, loop)

100%|██████████| 1875/1875 [00:04<00:00, 403.89it/s]
100%|██████████| 1875/1875 [00:04<00:00, 408.04it/s]
100%|██████████| 1875/1875 [00:04<00:00, 385.80it/s]
100%|██████████| 1875/1875 [00:05<00:00, 370.15it/s]
100%|██████████| 1875/1875 [00:04<00:00, 393.01it/s]
100%|██████████| 1875/1875 [00:04<00:00, 408.61it/s]
100%|██████████| 1875/1875 [00:04<00:00, 408.48it/s]
100%|██████████| 1875/1875 [00:04<00:00, 407.65it/s]
100%|██████████| 1875/1875 [00:04<00:00, 378.81it/s]
100%|██████████| 1875/1875 [00:04<00:00, 394.50it/s]
100%|██████████| 1875/1875 [00:04<00:00, 396.54it/s]
100%|██████████| 1875/1875 [00:04<00:00, 391.33it/s]
100%|██████████| 1875/1875 [00:05<00:00, 369.01it/s]
100%|██████████| 1875/1875 [00:04<00:00, 390.46it/s]
100%|██████████| 1875/1875 [00:05<00:00, 355.89it/s]
100%|██████████| 1875/1875 [00:04<00:00, 376.53it/s]


In [15]:
with tf.compat.v1.io.gfile.GFile('true-case-multisentences-news.tsv', "w") as outfile:
    for i in tqdm(range(len(results1))):
        l = cleaning_row(results1[i][0])
        r = cleaning_row(results1[i][1])
        outfile.write("%s\t%s\n" % (l, r))

100%|██████████| 1798495/1798495 [00:58<00:00, 30947.92it/s]


In [16]:
def generate_short(string):
    splitted = string.split()
    random_length = random.randint(2, min(len(splitted), 20))
    end = random.randint(0 + random_length, len(splitted))
    return ' '.join(splitted[end - random_length: end])

In [17]:
with open(files[0]) as fopen:
    data = list(filter(None, fopen.read().split('\n')))
    
data = [i for i in data if len(i) >= 2]

len(data)


2037249

In [18]:
package(generate_short(data[0]))

[('"Klaas-Jan" Huntelaar (Lahir 12 Ogos 1983) Merupakan Pemain Bola Sepak Belanda Yang Bermain Di Posisi',
  '"Klaas-Jan" Huntelaar (lahir 12 Ogos 1983) merupakan pemain bola sepak Belanda yang bermain di posisi'),
 ('KLAAS JANHUNTELAARLAHIR12 OGOS1983 MERUPAKANPEMAINBOLA SEPAK BELANDAYANGBERMAINDIPOSISI',
  '"Klaas-Jan" Huntelaar (lahir 12 Ogos 1983) merupakan pemain bola sepak Belanda yang bermain di posisi'),
 ('"Klaas-Jan" HUNTELAAR (LAHIR 12 OGOS 1983) MERUPAKAN Pemain BOLA SEPAK BELANDA YANG BERMAIN Di POSISI',
  '"Klaas-Jan" Huntelaar (lahir 12 Ogos 1983) merupakan pemain bola sepak Belanda yang bermain di posisi'),
 ('Klaas Jan Huntelaar Lahir 12 Ogos 1983 Merupakan Pemain Bola Sepak Belanda Yang Bermain Di Posisi',
  '"Klaas-Jan" Huntelaar (lahir 12 Ogos 1983) merupakan pemain bola sepak Belanda yang bermain di posisi'),
 ('"klaas-jan" huntelaar (lahir 12 ogos 1983) merupakan pemain bola sepak belanda yang bermain di posisi',
  '"Klaas-Jan" Huntelaar (lahir 12 Ogos 1983) merup

In [19]:
data = random.sample(data, 200000)

In [20]:
def loop(strings):
    results = []
    for i in tqdm(range(len(strings))):
        try:
            p = package(generate_short(strings[i]))
            results.extend(p)
        except:
            pass
    return results

In [21]:
results1 = cleaning.multiprocessing(data, loop)

100%|██████████| 12500/12500 [00:09<00:00, 1296.80it/s]
100%|██████████| 12500/12500 [00:09<00:00, 1268.32it/s]
100%|██████████| 12500/12500 [00:09<00:00, 1271.07it/s]
100%|██████████| 12500/12500 [00:09<00:00, 1278.68it/s]
100%|██████████| 12500/12500 [00:09<00:00, 1261.42it/s]
100%|██████████| 12500/12500 [00:09<00:00, 1268.90it/s]
100%|██████████| 12500/12500 [00:09<00:00, 1263.46it/s]
100%|██████████| 12500/12500 [00:09<00:00, 1258.26it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1248.70it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1224.54it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1222.27it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1221.50it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1205.80it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1206.38it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1197.89it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1180.26it/s]


In [22]:
results1[:10], len(results1)

([('Mengadakan[ Pemilihan', 'mengadakan Pemilihan'),
  ('MENGADAKANPEmiLihAn', 'mengadakan Pemilihan'),
  ('MENGADAKANPEmIlihan', 'mengadakan Pemilihan'),
  ('mengadakan Pemilihan', 'mengadakan Pemilihan'),
  ('MENGADAKAN PEMILIHAN', 'mengadakan Pemilihan'),
  ('MengadakaN PemilIhaN', 'mengadakan Pemilihan'),
  ('mengadakan% Pemilihan', 'mengadakan Pemilihan'),
  ('MeNGADaKAn PeMiLihaN', 'mengadakan Pemilihan'),
  ('mengADakaN peMIlihaN', 'mengadakan Pemilihan'),
  ('MeNgaDAkAN pemIlIhAn', 'mengadakan Pemilihan')],
 8083198)

In [23]:
with tf.compat.v1.io.gfile.GFile('true-case-short-wiki.tsv', "w") as outfile:
    for i in tqdm(range(len(results1))):
        l = cleaning_row(results1[i][0])
        r = cleaning_row(results1[i][1])
        outfile.write("%s\t%s\n" % (l, r))

100%|██████████| 8083198/8083198 [01:55<00:00, 69714.75it/s]


In [24]:
with open(files[1]) as fopen:
    data = list(filter(None, fopen.read().split('\n')))
    
data = random.sample(data, 200000)

In [25]:
results1 = cleaning.multiprocessing(data, loop)

100%|██████████| 12500/12500 [00:10<00:00, 1225.42it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1233.36it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1227.01it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1219.79it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1219.64it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1204.49it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1214.88it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1194.62it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1188.04it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1186.63it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1188.94it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1186.01it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1169.45it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1169.36it/s]
100%|██████████| 12500/12500 [00:10<00:00, 1161.64it/s]
 99%|█████████▉| 12427/12500 [00:11<00:00, 1301.21it/s]
100%|██████████| 12500/12500 [00:11<00:00, 1117.72it/s]

In [26]:
len(results1)

8075957

In [27]:
with tf.compat.v1.io.gfile.GFile('true-case-short-news.tsv', "w") as outfile:
    for i in tqdm(range(len(results1))):
        l = cleaning_row(results1[i][0])
        r = cleaning_row(results1[i][1])
        outfile.write("%s\t%s\n" % (l, r))

100%|██████████| 8075957/8075957 [01:59<00:00, 67302.78it/s]


In [28]:
with open(files[0]) as fopen:
    data = list(filter(None, fopen.read().split('\n')))
    
data = [i for i in data if len(i) >= 2]

In [29]:
def loop(strings):
    results = []
    for i in tqdm(range(len(strings))):
        p = package(strings[i])
        results.extend(p)
    return results

In [30]:
data = random.sample(data, 70000)
results1 = cleaning.multiprocessing(data, loop)

100%|██████████| 4375/4375 [00:06<00:00, 701.37it/s]
100%|██████████| 4375/4375 [00:06<00:00, 703.23it/s]
100%|██████████| 4375/4375 [00:06<00:00, 685.55it/s]
100%|██████████| 4375/4375 [00:06<00:00, 697.34it/s]
100%|██████████| 4375/4375 [00:06<00:00, 657.95it/s]
100%|██████████| 4375/4375 [00:06<00:00, 656.78it/s]
100%|██████████| 4375/4375 [00:06<00:00, 670.20it/s]
100%|██████████| 4375/4375 [00:06<00:00, 658.90it/s]
100%|██████████| 4375/4375 [00:06<00:00, 673.16it/s]
100%|██████████| 4375/4375 [00:06<00:00, 648.82it/s]
100%|██████████| 4375/4375 [00:06<00:00, 642.41it/s]
100%|██████████| 4375/4375 [00:06<00:00, 655.86it/s]
100%|██████████| 4375/4375 [00:06<00:00, 628.27it/s]
100%|██████████| 4375/4375 [00:06<00:00, 643.76it/s]
100%|██████████| 4375/4375 [00:06<00:00, 641.13it/s]
100%|██████████| 4375/4375 [00:07<00:00, 614.44it/s]


In [31]:
with tf.compat.v1.io.gfile.GFile('true-case-wiki.tsv', "w") as outfile:
    for i in tqdm(range(len(results1))):
        l = cleaning_row(results1[i][0])
        r = cleaning_row(results1[i][1])
        outfile.write("%s\t%s\n" % (l, r))

100%|██████████| 3752297/3752297 [01:18<00:00, 47658.98it/s]


In [32]:
with open(files[1]) as fopen:
    data = list(filter(None, fopen.read().split('\n')))
    
data = random.sample(data, 70000)
results1 = cleaning.multiprocessing(data, loop)

100%|██████████| 4375/4375 [00:07<00:00, 602.73it/s]
100%|██████████| 4375/4375 [00:07<00:00, 588.61it/s]
100%|██████████| 4375/4375 [00:07<00:00, 585.87it/s]
100%|██████████| 4375/4375 [00:07<00:00, 583.82it/s]
100%|██████████| 4375/4375 [00:07<00:00, 578.28it/s]
100%|██████████| 4375/4375 [00:07<00:00, 566.07it/s]
100%|██████████| 4375/4375 [00:07<00:00, 560.36it/s]
100%|██████████| 4375/4375 [00:07<00:00, 561.83it/s]
100%|██████████| 4375/4375 [00:07<00:00, 549.89it/s]
100%|██████████| 4375/4375 [00:08<00:00, 544.98it/s]
100%|██████████| 4375/4375 [00:08<00:00, 544.07it/s]
100%|██████████| 4375/4375 [00:08<00:00, 544.38it/s]
100%|██████████| 4375/4375 [00:08<00:00, 544.64it/s]
100%|██████████| 4375/4375 [00:08<00:00, 538.31it/s]
100%|██████████| 4375/4375 [00:08<00:00, 534.94it/s]
100%|██████████| 4375/4375 [00:08<00:00, 537.18it/s]


In [33]:
with tf.compat.v1.io.gfile.GFile('true-case-news.tsv', "w") as outfile:
    for i in tqdm(range(len(results1))):
        l = cleaning_row(results1[i][0])
        r = cleaning_row(results1[i][1])
        outfile.write("%s\t%s\n" % (l, r))

100%|██████████| 3828676/3828676 [01:31<00:00, 41663.44it/s]


In [34]:
import os
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'mesolitica-tpu.json'

In [35]:
from google.cloud import storage

client = storage.Client()
bucket = client.bucket('mesolitica-tpu-general')

In [36]:
from glob import glob

files = glob('true-case*.tsv')
files

['true-case-multisentences-wiki.tsv',
 'true-case-news.tsv',
 'true-case-short-wiki.tsv',
 'true-case-wiki.tsv',
 'true-case-short-news.tsv',
 'true-case-multisentences-news.tsv']

In [37]:
for file in files:
    print(file)
    blob = bucket.blob(f't5-data-v2/{file}')
    blob.upload_from_filename(file)

true-case-multisentences-wiki.tsv
true-case-news.tsv
true-case-short-wiki.tsv
true-case-wiki.tsv
true-case-short-news.tsv
true-case-multisentences-news.tsv


In [38]:
b2_application_key_id = os.environ['b2_application_key_id']
b2_application_key = os.environ['b2_application_key']

In [39]:
from b2sdk.v1 import *
info = InMemoryAccountInfo()
b2_api = B2Api(info)
application_key_id = b2_application_key_id
application_key = b2_application_key
b2_api.authorize_account("production", application_key_id, application_key)
file_info = {'how': 'good-file'}
b2_bucket = b2_api.get_bucket_by_name('malay-dataset')

In [40]:
for file in files:
    print(file)
    b2_bucket.upload_local_file(
    local_file=file,
    file_name=f'true-case/{file}',
    file_infos=file_info,
)

true-case-multisentences-wiki.tsv
true-case-news.tsv
true-case-short-wiki.tsv
true-case-wiki.tsv
true-case-short-news.tsv
true-case-multisentences-news.tsv


In [41]:
for file in files:
    os.remove(file)

In [42]:
import tensorflow as tf
import tensorflow_datasets as tfds
from t5.data import preprocessors as prep
import functools
import t5
import gin
import sentencepiece as spm
from glob import glob
import os

gin.parse_config_file('pretrained_models_base_operative_config.gin')
vocab = 'sp10m.cased.ms-en.model'
sp = spm.SentencePieceProcessor()
sp.Load(vocab)

True

In [43]:
def segmentation_dataset(split, shuffle_files = False):
    del shuffle_files
    ds = tf.compat.v1.data.TextLineDataset(
        [
            'gs://mesolitica-tpu-general/t5-data-v2/true-case-short-wiki.tsv'
        ]
    )

    ds = ds.map(
        functools.partial(
            tf.compat.v1.io.decode_csv,
            record_defaults = ['', ''],
            field_delim = '\t',
            use_quote_delim = False,
        ),
        num_parallel_calls = tf.compat.v1.data.experimental.AUTOTUNE,
    )
    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))
    return ds

def segmentation_preprocessor(ds):
    def to_inputs_and_targets(ex):
        return {
            'inputs': tf.compat.v1.strings.join(['kes benar: ', ex['question']]),
            'targets': ex['answer'],
        }

    return ds.map(
        to_inputs_and_targets,
        num_parallel_calls = tf.compat.v1.data.experimental.AUTOTUNE,
    )

In [44]:
t5.data.TaskRegistry.remove('segmentation_dataset')
t5.data.TaskRegistry.add(
    'segmentation_dataset',
    dataset_fn = segmentation_dataset,
    splits = ['train'],
    text_preprocessor = [segmentation_preprocessor],
    sentencepiece_model_path = vocab,
    metric_fns = [t5.evaluation.metrics.accuracy],
)

In [45]:
nq_task = t5.data.TaskRegistry.get("segmentation_dataset")
ds = nq_task.get_dataset(split='knowledge-graph.tsv', sequence_length={"inputs": 256, "targets": 256})
r = tfds.as_numpy(ds)

In [46]:
next(r)

{'inputs_plaintext': b'kes benar: Baru Padang Serai merupakan sebuah kampung YANG terletak DALAM Mukim NAGA LILIT, DAERAH',
 'inputs': array([  744,   679,    31,  2128,  7353, 26241,   157,   136,  1350,
         4897,   282, 11031,  9937,   517, 21337,   876,  4510,  5094,
           14,    13,  4541,  4056,  5168,     1]),
 'targets_plaintext': b'Baru Padang Serai merupakan sebuah kampung yang terletak dalam Mukim Naga Lilit, Daerah',
 'targets': array([ 2128,  7353, 26241,   157,   136,  1350,    17,   282,    36,
         9937, 17366,  2015,  3813,    14,  2684,     1])}