In [1]:
import numpy as np
import torch
from tqdm import tqdm
import torchtext
import numpy as np
import torch.optim as optim
import torch.utils.data as data
from torchtext.data import get_tokenizer

from tokenizers import Tokenizer # RWKV https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py#L11
tokenizer = Tokenizer.from_file("20B_tokenizer.json") # RWKV https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py#L11

In [2]:
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

#tokenizer = get_tokenizer("subword")

temp = tokenizer.encode(raw_text)
raw_text = tokenizer.encode(raw_text).ids
n_vocab = max(raw_text) + 1 # 50276 + 1 for 20B_tokenizer

print(str(temp.ids))

[41497, 187, 50274, 783, 2199, 305, 10284, 4978, 299, 3305, 273, 355, 547, 434, 27420, 275, 4282, 1373, 187, 50274, 187, 2520, 299, 3305, 310, 323, 253, 897, 273, 3780, 9825, 275, 253, 25536, 3054, 285, 209, 187, 2252, 643, 4243, 273, 253, 1533, 387, 642, 2105, 285, 342, 2761, 642, 13133, 209, 187, 5371, 23330, 15, 368, 778, 3491, 352, 13, 1918, 352, 1977, 390, 294, 14, 2327, 352, 762, 253, 2426, 209, 187, 1171, 253, 2199, 305, 10284, 4978, 7981, 2908, 342, 436, 299, 3305, 390, 3909, 209, 187, 255, 8280, 15, 72, 10284, 4978, 15, 2061, 15, 604, 368, 403, 417, 4441, 275, 253, 25536, 3054, 13, 209, 187, 5658, 588, 452, 281, 2451, 253, 5323, 273, 253, 2586, 835, 368, 403, 4441, 209, 187, 9131, 970, 436, 299, 3305, 15, 5429, 50274, 187, 50270, 5564, 27, 355, 547, 434, 27420, 275, 4282, 1373, 187, 50270, 187, 50270, 7582, 27, 458, 88, 261, 1113, 1811, 535, 50270, 187, 50270, 16690, 3522, 27, 480, 2517, 3435, 13, 4695, 544, 20328, 1852, 883, 62, 2252, 4102, 9300, 27, 14172, 1884, 13, 43425, 1

In [3]:
#chars = sorted(list(set(raw_text)))
#char_to_int = dict((c, i) for i, c in enumerate(chars))

In [4]:
n_tokens = len(raw_text)
#n_vocab = len(chars)
print("Total Tokens: ", n_tokens)

Total Tokens:  45138


In [5]:
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_tokens - seq_length, 1):
    seq_in = raw_text[i:i + seq_length]
    seq_out = raw_text[i + seq_length]
    dataX.append(seq_in) # [char_to_int[char] for char in seq_in]
    dataY.append(seq_out)
n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

Total Patterns:  45038


In [6]:
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1)
X = X / float(n_vocab)
y = torch.tensor(dataY)

In [7]:
lookback = 1
print(X.shape, y.shape)

torch.Size([45038, 100, 1]) torch.Size([45038])


In [8]:
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True)
        self.linear = nn.Linear(100, n_vocab)
    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.linear(x)
        return x

In [12]:
device = torch.device("cuda:0") # cpu - cpu / cuda:0 - gpu

model = LSTMModel().to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loader = data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=128)

best_model = None
best_loss = np.inf

In [15]:
current_epoch = 0

In [16]:
n_epochs = 200

In [None]:
for e in range(n_epochs):
    current_epoch += 1
    model.float()
    model.train()
    for X_batch, y_batch in tqdm(loader):
        y_pred = model(X_batch.float().to(device))
        loss = loss_fn(y_pred.to(device), y_batch.long().to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            y_pred = model(X_batch.float().to(device))
            loss += loss_fn(y_pred.to(device), y_batch.long().to(device))
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (current_epoch, loss))

torch.save(best_model, "single-rwkv-tokenizer.pth")

In [18]:
seq_length = 100
start = np.random.randint(0, len(raw_text)-seq_length)
prompt = raw_text[start:start+seq_length]

In [19]:
import numpy as np
import torch
import torch.nn as nn

best_model = torch.load("single-rwkv-tokenizer.pth")

model = LSTMModel()
model.load_state_dict(best_model)

filename = "wonderland.txt"
seq_length = 100
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
start = np.random.randint(0, len(raw_text)-seq_length)
prompt = 'an with almost no restrictions '
print("xd " + str(prompt) + " dx ")
pattern = tokenizer.encode(prompt).ids

model.eval()
print('Prompt: "%s"' % prompt)

tokens_to_decode = []

with torch.no_grad():
    for i in range(1000):
        # format input array of int into PyTorch tensor
        x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
        x = torch.tensor(x, dtype=torch.float32)
        # generate logits as output from the model
        prediction = model(x)
        # convert logits into one character
        #index = int(prediction.argmax())
        #print(str(tokenizer.decode([prediction.argmax()])))
        index = prediction.argmax()
        tokens_to_decode.append(index)
        result = tokenizer.decode(tokens_to_decode)
        if '\ufffd' not in result:
            print(result, end="", flush=True)
            tokens_to_decode = []
        # append the new character into the prompt for the next iteration
        pattern.append(index)
        pattern = pattern[1:]
print()
print("Done.")

xd an with almost no restrictions  dx 
Prompt: "an with almost no restrictions "
 king,
the melancholyor “’ washead
ugarouse ed al
thehess, “are washead
 eyelouse presents,
 king,
 turning of tea, found hedgeytone.pig“e mockouse heartsling
chapter sortii twice near pig cried. pig turtle.
 arm,
let bowed, “’ washead
 eyelouse presents,
 king,
 turning of tea, found hedgeytone.pig“e mockouse heartsling
chapter sortii twice near pig cried. pig turtle.
 arm,
let bowed, “’ washead
 eyelouse presents,
 king,
 turning of tea, found hedgeytone.pig“e mockouse heartsling
chapter sortii twice near pig cried. pig turtle.
 arm,
let bowed, “’ washead
 eyelouse presents,
 king,
 turning of tea, found hedgeytone.pig“e mockouse heartsling
chapter sortii twice near pig cried. pig turtle.
 arm,
let bowed, “’ washead
 eyelouse presents,
 king,
 turning of tea, found hedgeytone.pig“e mockouse heartsling
chapter sortii twice near pig cried. pig turtle.
 arm,
let bowed, “’ washead
 eyelouse presents,
 king,


In [None]:
print("Torch version:",torch.__version__)

print("Is CUDA enabled?",torch.cuda.is_available())