In [1]:
import torch
print(torch.__version__)
torch.cuda.is_available()

2.2.0.dev20230916+cu121


True

In [2]:
import sentencepiece as spm
print(spm.__version__)

0.1.99


In [3]:
def train_model(fname, prefix):
    spm.SentencePieceTrainer.train(input=fname, model_prefix=prefix, vocab_size=16000)
    
corpus = "bird_shooter.txt"
prefix = "bird_shooter"
train_model(corpus, prefix)

In [5]:
def load_tokenizer(model_file):
    sp = spm.SentencePieceProcessor()
    if not sp.load(model_file=model_file):
        return False, None
    else:
        return True, sp

def load_file_into_splits(text_file, split_ratio):
    with open(text_file, 'r') as file:
        data = file.read()
    split_idx = int(len(data) * split_ratio)
    return data[:split_idx], data[split_idx:]

import numpy as np
def encode_and_save(sp, content, prefix):
    token_ids = sp.encode(content, out_type=int)
    print(f"data split of {prefix} has {len(token_ids)} tokens")
    token_ids = np.array(token_ids, dtype=np.int32)
    token_ids.tofile("{}.dat".format(prefix))
    
import sys
def gen_dataset(text_file, model_file):
    flag, sp = load_tokenizer(model_file)
    if not flag:
        print(f"load tokenizer model from: {model_file} failed")
        sys.exit(1)
    split_ratio = 0.9
    train_text, test_text = load_file_into_splits(text_file, split_ratio)
    encode_and_save(sp, train_text, "train")
    encode_and_save(sp, test_text, "test")
    
gen_dataset(corpus, prefix+".model")

data split of train has 505009 tokens
data split of test has 58877 tokens


In [7]:
def get_batch(data, batch_size=4):
    win_len = 10
    ix = torch.randint(len(data)-win_len, (batch_size,))
    x = np.stack([data[i:i+win_len] for i in ix])
    y = np.stack([data[i+1:i+1+win_len] for i in ix])
    return x, y

model_file = prefix + ".model"

def gen_samples(fname):
    train_data = np.memmap(fname, dtype=np.int32, mode='r')
    x, y = get_batch(train_data)
    
    flag, sp = load_tokenizer(model_file)
    if not flag:
        print(f"load tokenizer model from: {model_file} failed")
        sys.exit(1)
        
    for features, targets in zip(x, y):
        print("features: ", sp.decode(features.tolist()))
        print("targets: ", sp.decode(targets.tolist()))

gen_samples("train.dat")

features:  他不会不给六王爷的面子。”完颜洪烈道
targets:  不会不给六王爷的面子。”完颜洪烈道:“
features:  女儿家的痴情呆想,这人哪里
targets:  家的痴情呆想,这人哪里是甚么
features:  得更低了。完颜康心中一荡,伸出左臂
targets:  更低了。完颜康心中一荡,伸出左臂去
features:  呢”郭靖道,“我接她到桃花岛上
targets:  ”郭靖道,“我接她到桃花岛上住
