In [16]:
import torch
from torch import nn
import tiktoken

In [19]:
# Get data from the dataset

with open('../data/russian-novels/WAP.txt', 'r', encoding='utf-8-sig') as file:
    text = file.read()

enc = tiktoken.get_encoding("cl100k_base")
enc.encode(text)

words = sorted(list(set(text)))
# print(words)
# print(words)
# ctoi = {c: i+1 for i, c in enumerate(words)}
# ctoi['_'] = 0
# itoc = {i: c for c, i in ctoi.items()}

encode = enc.encode
decode = enc.decode

encode(text[:100])

# print(words)

[46558,
 3651,
 22557,
 5693,
 1432,
 1383,
 37848,
 59602,
 267,
 2303,
 17146,
 337,
 88341,
 1432,
 262,
 36962,
 271,
 262,
 48198,
 25002,
 25,
 220,
 5245,
 20,
 271,
 262,
 96424,
 358,
 271,
 262,
 6969,
 2599,
 2505]

In [4]:
# Build dataset
block_size = 8

X, Y = [], []
X_test, Y_test = [], []

test = text[int(len(text) * 0.8):]
text = text[:int(len(text) * 0.8)]

for i in range(0, len(text) - block_size):
    X.append([ctoi[c] for c in text[i:i+block_size]])
    Y.append(ctoi[text[i+block_size]])

for i in range(0, len(test) - block_size):
    X_test.append([ctoi[c] for c in test[i:i+block_size]])
    Y_test.append(ctoi[test[i+block_size]])

X = torch.tensor(X)
Y = torch.tensor(Y)

X_test = torch.tensor(X_test)
Y_test = torch.tensor(Y_test)

X.shape, Y.shape

(torch.Size([2566610, 8]), torch.Size([2566610]))

In [5]:
def get_batch(batch_size=32):
    idx = torch.randint(0, X.size(0), (batch_size,))
    return X[idx], Y[idx]

def get_test_batch(batch_size=32):
    idx = torch.randint(0, X_test.size(0), (batch_size,))
    return X_test[idx], Y_test[idx]

batch = get_batch(32)

In [6]:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.C = torch.randn(len(chars)+1, 2)
        self.W1 = torch.randn(block_size * 2, 200) * 0.2
        self.b1 = torch.randn(200) * 0.01
        self.W2 = torch.randn(200, len(chars)+1) * 0.01
        self.b2 = torch.randn(len(chars)+1) * 0
        self.parameters = [self.C, self.W1, self.b1, self.W2, self.b2] 
       
        
    def forward(self, x, y):
        
        emb = self.C[x].view(-1, 2*block_size)
        h = torch.relu(emb @ self.W1 + self.b1)
        logits = h @ self.W2 + self.b2
        preds = torch.softmax(logits, dim=1)
        loss = torch.nn.functional.cross_entropy(logits, y)
        
        return preds, loss

model = Model()

for p in model.parameters:
    p.requires_grad = True

In [7]:
x_train, y_train = batch

optimizer = torch.optim.SGD(model.parameters, lr=0.001, momentum=0.9)

for i in range(99999):
    
    optimizer.zero_grad()
    x, y = get_batch(32)
    
    outputs, loss = model(x, y)
    
    loss.backward()
    
    optimizer.step()
    
    # if i % 1000 == 0:
print(f'Loss: {loss.item()}')
        

Loss: 2.501145362854004


In [9]:
# Sample the model

result = encode("War and ")
print("War and ", end='')
while result[-1] != 0:
    x = result[-8:]
    x = x.view(1, -1)
    y = torch.tensor([result[-1]])
    preds, loss = model.forward(x, y)
    next_char = torch.multinomial(preds[0], 1)
    result = torch.cat([result, next_char])
    print(decode(next_char) , end='')

War and them
Pé they
ghey the osarper, dransyne breelh Pe H co leAteon aluritupfes. beyl onheestwasd fepe als ane,
itifhsgingst he Pas gas dal wingeng Alpat
w,e then
him loith! angand yep tawd oufite. paid him. and tott limsáou andards, 
“Wyichey agosllit’.”,” Tied ap neys cei loupesl

“Oep in letterow, le pastse paskhe?” Sv

 DHeouan, femaller of the wownou hisfeon and dhes Poellvnoun ind that therrent putúeoud.” deins did feribe of tunaley tat fouchter
mteated. gut lBaw
Iritte wob
and f2ale—
H clkith ind lelicsid his to coIr eebaise bet of anith wasród hiol Sar de, sacderard the
dourdy;d?” Din lool ob fean Icotelinco thar hud proy, dhoulantean, rhe clentesing denemesared
cobe bthe (roit bentav 
S her,”
beoothan. Bad, hee Mrinces Kothing dit,z Wurárr The tritcostr beadone Pny that themiding 
The 
nise thec Coyper of avunter,” Inad bl mund” Fedd an, an ouge betáow mitktin?” he sace
the
duddt Nounon’en
wreelhed to mef
vnl;—Iengter to woin he. beilder bat, jemh a mork
bOlHs, Snasd, the t

KeyboardInterrupt: 

In [366]:
# Test the model

x, y = get_test_batch(32)
outputs, loss = model(x, y)
print(loss.item())


2.2150869369506836


In [284]:
batch[0].shape

torch.Size([5, 3])