In [1]:
import torch
import torch.nn.functional as F

In [2]:
with open('./names.txt','r',encoding='utf-8') as f:
    words = f.read().splitlines()

In [34]:
len(words)

32033

In [3]:
words[:4]

['emma', 'olivia', 'ava', 'isabella']

In [4]:
list(zip(words[:4],words[1:5]))

[('emma', 'olivia'),
 ('olivia', 'ava'),
 ('ava', 'isabella'),
 ('isabella', 'sophia')]

In [5]:
chars = []
chars.append('.')
chars = chars+sorted(list(set(''.join(words))))

In [6]:
# 定义两个函数用以转换字符到数字和数字到字符
stoi = lambda x: chars.index(x)
itos = lambda x: chars[x]

print(stoi('a'))
print(itos(0))

1
.


In [7]:
for w in words[:3]:
    chs = '.'+ w + '.'
    for c1, c2 in zip(chs, chs[1:]):
        print(c1, c2)

. e
e m
m m
m a
a .
. o
o l
l i
i v
v i
i a
a .
. a
a v
v a
a .


In [8]:
F.one_hot(torch.tensor([0,1, 2, 3]), 4)

tensor([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]])

In [9]:
x = []
y = []
for w in words[:3]:
    tx = []
    ty = []
    chs = '.'+ w + '.'
    for c1, c2 in zip(chs, chs[1:]):
        tx.append(c1)
        ty.append(c2)
    x.append(tx)
    y.append(ty)
print(x)
print(y)

[['.', 'e', 'm', 'm', 'a'], ['.', 'o', 'l', 'i', 'v', 'i', 'a'], ['.', 'a', 'v', 'a']]
[['e', 'm', 'm', 'a', '.'], ['o', 'l', 'i', 'v', 'i', 'a', '.'], ['a', 'v', 'a', '.']]


In [10]:
xenc = [[stoi(x) for x in xi] for xi in x]
yenc = [[stoi(y) for y in yi] for yi in y]

In [11]:
g = torch.Generator().manual_seed(2147483647)

In [12]:
X = F.one_hot(torch.tensor(xenc[0]),num_classes=len(chars)).float()
Y = F.one_hot(torch.tensor(yenc[0]),num_classes=len(chars)).float()

In [13]:
w = torch.randn((27,27),generator=g,requires_grad=True)

In [14]:
b = torch.randn((1,27),generator=g)

In [15]:
logits = X @ w

In [16]:
logits

tensor([[ 1.5674e+00, -2.3729e-01, -2.7385e-02, -1.1008e+00,  2.8588e-01,
         -2.9643e-02, -1.5471e+00,  6.0489e-01,  7.9136e-02,  9.0462e-01,
         -4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
          2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01,
          9.8488e-01, -1.4837e-01, -1.4795e+00,  4.4830e-01, -7.0730e-02,
          2.4968e+00,  2.4448e+00],
        [ 4.7236e-01,  1.4830e+00,  3.1748e-01,  1.0588e+00,  2.3982e+00,
          4.6827e-01, -6.5650e-01,  6.1662e-01, -6.2197e-01,  5.1007e-01,
          1.3563e+00,  2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01,
          5.5570e-01,  4.7458e-01, -1.3867e+00,  1.6229e+00,  1.7197e-01,
          9.8846e-01,  5.0657e-01,  1.0198e+00, -1.9062e+00, -4.2753e-01,
         -2.1259e+00,  9.6041e-01],
        [ 1.9359e-01,  1.0532e+00,  6.3393e-01,  2.5786e-01,  9.6408e-01,
         -2.4855e-01,  2.4756e-02, -3.0404e-02,  1.5622e+00, -4.4852e-01,
         -1.2345e+00,  1.1220e+00, -6.73

In [17]:
logits.shape

torch.Size([5, 27])

In [18]:
prob = torch.exp(logits) /(torch.sum(torch.exp(logits),dim=1,keepdim=True))

In [19]:
torch.sum(prob,dim=1,keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]], grad_fn=<SumBackward1>)

In [20]:
prob[0]

tensor([0.0607, 0.0100, 0.0123, 0.0042, 0.0168, 0.0123, 0.0027, 0.0232, 0.0137,
        0.0313, 0.0079, 0.0278, 0.0091, 0.0082, 0.0500, 0.2378, 0.0603, 0.0025,
        0.0249, 0.0055, 0.0339, 0.0109, 0.0029, 0.0198, 0.0118, 0.1537, 0.1459],
       grad_fn=<SelectBackward0>)

In [21]:
Y[0]

tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [22]:
prob[0] * Y[0]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0123, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       grad_fn=<MulBackward0>)

In [23]:
prob[0,yenc[0][0]]

tensor(0.0123, grad_fn=<SelectBackward0>)

In [24]:
r = prob * Y

In [25]:
rs = torch.sum(r, dim=1)

In [26]:
nll =-torch.log(rs)

In [27]:
nll.mean()

tensor(3.7693, grad_fn=<MeanBackward0>)

In [31]:
# all the data
g = torch.Generator().manual_seed(2147483647)
x = []
y = []
for w in words:
    tx = []
    ty = []
    chs = '.'+ w + '.'
    for c1, c2 in zip(chs, chs[1:]):
        tx.append(c1)
        ty.append(c2)
    x.append(tx)
    y.append(ty)
print(len(x))
print(len(y))

32033
32033


In [32]:
xenc = [[stoi(x) for x in xi] for xi in x]
yenc = [[stoi(y) for y in yi] for yi in y]

In [35]:
w = torch.randn((27,27),generator=g,requires_grad=True)

# 这里采用的是将每个单词作为一个训练数据用来计算loss，然后更新w的梯度。

for i in range(len(words)):
    X = F.one_hot(torch.tensor(xenc[i]),num_classes=len(chars)).float()
    Y = F.one_hot(torch.tensor(yenc[i]),num_classes=len(chars)).float()
    logits = X @ w
    prob = torch.exp(logits) /(torch.sum(torch.exp(logits),dim=1,keepdim=True))
    r_prob = prob * Y
    r_prob_sum = torch.sum(r_prob, dim=1)
    nll =-torch.log(r_prob_sum)
    loss = torch.mean(nll)
    print(loss)
    w.grad = None
    loss.backward()
    w.grad.data = w.grad.data - 0.1*w.grad.data

tensor(2.8476, grad_fn=<MeanBackward0>)
tensor(3.6510, grad_fn=<MeanBackward0>)
tensor(3.0100, grad_fn=<MeanBackward0>)
tensor(3.3325, grad_fn=<MeanBackward0>)
tensor(3.9500, grad_fn=<MeanBackward0>)
tensor(4.0501, grad_fn=<MeanBackward0>)
tensor(4.1169, grad_fn=<MeanBackward0>)
tensor(3.6715, grad_fn=<MeanBackward0>)
tensor(4.4472, grad_fn=<MeanBackward0>)
tensor(3.9716, grad_fn=<MeanBackward0>)
tensor(3.7524, grad_fn=<MeanBackward0>)
tensor(3.9505, grad_fn=<MeanBackward0>)
tensor(3.3415, grad_fn=<MeanBackward0>)
tensor(4.0999, grad_fn=<MeanBackward0>)
tensor(3.4673, grad_fn=<MeanBackward0>)
tensor(3.7433, grad_fn=<MeanBackward0>)
tensor(3.9199, grad_fn=<MeanBackward0>)
tensor(4.1724, grad_fn=<MeanBackward0>)
tensor(3.4524, grad_fn=<MeanBackward0>)
tensor(4.0665, grad_fn=<MeanBackward0>)
tensor(3.7990, grad_fn=<MeanBackward0>)
tensor(3.5193, grad_fn=<MeanBackward0>)
tensor(2.9589, grad_fn=<MeanBackward0>)
tensor(4.2547, grad_fn=<MeanBackward0>)
tensor(3.9794, grad_fn=<MeanBackward0>)


In [None]:
# all the data
g = torch.Generator().manual_seed(2147483647)
x = []
y = []
for w in words:
    chs = '.'+ w + '.'
    for c1, c2 in zip(chs, chs[1:]):
        x.append(stoi(c1))
        y.append(stoi(c2))
   
print(x[:5])
print(y[:5])

[0, 5, 13, 13, 1]
[5, 13, 13, 1, 0]


In [None]:
len(x)

228146

In [None]:
X = F.one_hot(torch.tensor(x),num_classes=len(chars)).float()
Y = F.one_hot(torch.tensor(y),num_classes=len(chars)).float()

In [None]:
X.shape

torch.Size([228146, 27])

In [None]:
w = torch.randn((27,27),generator=g,requires_grad=True)

In [None]:
for i in range(100):
    logits = X @ w
    prob = torch.exp(logits) /(torch.sum(torch.exp(logits),dim=1,keepdim=True))
    r_prob = prob * Y
    r_prob_sum = torch.sum(r_prob, dim=1)
    nll =-torch.log(r_prob_sum)
    loss = torch.mean(nll)
    print(loss)
    w.grad = None
    loss.backward()
    w.data = w.data - 50*w.grad.data


tensor(2.4622, grad_fn=<MeanBackward0>)
tensor(2.4621, grad_fn=<MeanBackward0>)
tensor(2.4621, grad_fn=<MeanBackward0>)
tensor(2.4620, grad_fn=<MeanBackward0>)
tensor(2.4620, grad_fn=<MeanBackward0>)
tensor(2.4619, grad_fn=<MeanBackward0>)
tensor(2.4619, grad_fn=<MeanBackward0>)
tensor(2.4618, grad_fn=<MeanBackward0>)
tensor(2.4618, grad_fn=<MeanBackward0>)
tensor(2.4617, grad_fn=<MeanBackward0>)
tensor(2.4617, grad_fn=<MeanBackward0>)
tensor(2.4617, grad_fn=<MeanBackward0>)
tensor(2.4616, grad_fn=<MeanBackward0>)
tensor(2.4616, grad_fn=<MeanBackward0>)
tensor(2.4615, grad_fn=<MeanBackward0>)
tensor(2.4615, grad_fn=<MeanBackward0>)
tensor(2.4615, grad_fn=<MeanBackward0>)
tensor(2.4614, grad_fn=<MeanBackward0>)
tensor(2.4614, grad_fn=<MeanBackward0>)
tensor(2.4613, grad_fn=<MeanBackward0>)
tensor(2.4613, grad_fn=<MeanBackward0>)
tensor(2.4613, grad_fn=<MeanBackward0>)
tensor(2.4612, grad_fn=<MeanBackward0>)
tensor(2.4612, grad_fn=<MeanBackward0>)
tensor(2.4611, grad_fn=<MeanBackward0>)


In [None]:
# sample
x = stoi('.')
x_input = F.one_hot(torch.tensor(x), len(chars)).float()
y_pred_logits = x_input @ w
y_pred_logits.shape

torch.Size([27])

In [None]:
y_pred_prob = torch.exp(y_pred_logits)/torch.sum(torch.exp(y_pred_logits))

In [None]:
y_pred_prob

tensor([0.1960, 0.0164, 0.0160, 0.0139, 0.0308, 0.0204, 0.0040, 0.0050, 0.0688,
        0.0487, 0.0052, 0.0168, 0.0746, 0.0482, 0.1605, 0.0019, 0.0024, 0.0018,
        0.0963, 0.0330, 0.0203, 0.0112, 0.0246, 0.0047, 0.0054, 0.0605, 0.0128],
       grad_fn=<DivBackward0>)

In [None]:
ix = torch.multinomial(y_pred_prob,num_samples=1,replacement=True,generator=g).item()
ix

0

In [None]:
# code for sample
for i in range(10):
    ix = stoi('.')
    str = []
    while True:        
        x_input = F.one_hot(torch.tensor(ix), len(chars)).float()
        y_pred_logits = x_input @ w
        y_pred_prob = torch.exp(y_pred_logits)/torch.sum(torch.exp(y_pred_logits))
        ix = torch.multinomial(y_pred_prob,num_samples=1,replacement=True).item()
        if ix == 0:
            print(''.join(str))
            break
        else:
            x = itos(ix)
            str.append(x)

jesh
ddeslahaleabilbemmah
wharsadheloueri
iagtowa
hn
kiomartirex
rtyrolianyanieriavahalerie
adh
ke
isicannaynaslukaia


In [51]:
# scale up the model
w1 = torch.randn(27,100,requires_grad=True)
b1 = torch.randn(100,requires_grad=True)
w2 = torch.randn(100,27,requires_grad=True)
b2 = torch.randn(27,requires_grad=True)

In [37]:
# all the data
g = torch.Generator().manual_seed(2147483647)
x = []
y = []
for w in words:
    chs = '.'+ w + '.'
    for c1, c2 in zip(chs, chs[1:]):
        x.append(stoi(c1))
        y.append(stoi(c2))
   
print(x[:5])
print(y[:5])

[0, 5, 13, 13, 1]
[5, 13, 13, 1, 0]


In [38]:
X = F.one_hot(torch.tensor(x[0]),num_classes=len(chars)).float()
h1 = X @ w1+b1
a1 =F.relu(h1)

In [39]:
logits = a1 @ w2+b2
logits

tensor([  2.4759,  10.0417,   2.9225,  -1.2142,   6.1475,   0.9338, -10.9739,
         -4.5203,   2.3280,  -5.2843, -19.3769,   5.9894,  11.0459,  -1.7935,
          6.6703,  -2.5579,   4.3409,   1.6304,   8.3766,   1.2052,  12.8617,
          3.6009,  -9.3809,   2.3046,   4.3582,  -2.2076,  -0.6693],
       grad_fn=<AddBackward0>)

In [40]:
prob = torch.exp(logits) /(torch.sum(torch.exp(logits)))

-torch.log(torch.sum(prob * Y[0]))

tensor(13.7450, grad_fn=<NegBackward0>)

In [41]:
# 分割训练集和测试集
split_size = 0.8
train_size = int(split_size * len(x))
train_data = x[:train_size]
train_target = y[:train_size]
test_data = x[train_size:]
test_target = y[train_size:]

In [42]:
X = F.one_hot(torch.tensor(x),num_classes=len(chars)).float()
Y = F.one_hot(torch.tensor(y),num_classes=len(chars)).float()
X.shape

torch.Size([228146, 27])

In [43]:
train_X = F.one_hot(torch.tensor(train_data),num_classes=len(chars)).float()
train_Y = F.one_hot(torch.tensor(train_target),num_classes=len(chars)).float()

test_X = F.one_hot(torch.tensor(test_data),num_classes=len(chars)).float()
test_Y = F.one_hot(torch.tensor(test_target),num_classes=len(chars)).float()

In [44]:
h1 = X @ w1+b1
a1 = F.relu(h1)
logits = a1 @ w2+b2
prob = torch.exp(logits) /(torch.sum(torch.exp(logits),dim=1,keepdim=True))
r_prob = prob * Y
r_prob_sum = torch.sum(r_prob, dim=1)
nll = -torch.log(r_prob_sum)
loss = torch.mean(nll)
print(loss)

tensor(16.2311, grad_fn=<MeanBackward0>)


In [52]:
lr = 1

In [53]:
for i in range(1001):
    h1 = train_X @ w1+b1
    a1 = F.relu(h1)
    logits = a1 @ w2+b2
    prob = torch.exp(logits) /(torch.sum(torch.exp(logits),dim=1,keepdim=True))
    r_prob = prob * train_Y
    r_prob_sum = torch.sum(r_prob, dim=1)
    nll = -torch.log(r_prob_sum)
    loss = torch.mean(nll)
    
    w1.grad = None
    w2.grad = None
    b1.grad = None
    b2.grad = None
    loss.backward()
    w1.data = w1.data - lr*w1.grad.data
    b1.data = b1.data - lr*b1.grad.data
    w2.data = w2.data - lr*w2.grad.data
    b2.data = b2.data - lr*b2.grad.data
    
    # 测试集上验证
    val_h1 = test_X @ w1+b1
    val_a1 = F.relu(val_h1)
    val_logits = val_a1 @ w2+b2
    val_prob = torch.exp(val_logits) /(torch.sum(torch.exp(val_logits),dim=1,keepdim=True))
    val_r_prob = val_prob * test_Y
    val_r_prob_sum = torch.sum(val_r_prob, dim=1)
    val_nll = -torch.log(val_r_prob_sum)
    val_loss = torch.mean(val_nll)
    
    if i%100==0:
        print('epoch:',i,"loss:",loss,"val:",val_loss)

epoch: 0 loss: tensor(17.1601, grad_fn=<MeanBackward0>) val: tensor(14.4430, grad_fn=<MeanBackward0>)
epoch: 100 loss: tensor(2.5366, grad_fn=<MeanBackward0>) val: tensor(2.7490, grad_fn=<MeanBackward0>)
epoch: 200 loss: tensor(2.4707, grad_fn=<MeanBackward0>) val: tensor(2.6639, grad_fn=<MeanBackward0>)
epoch: 300 loss: tensor(2.4540, grad_fn=<MeanBackward0>) val: tensor(2.6397, grad_fn=<MeanBackward0>)
epoch: 400 loss: tensor(2.4454, grad_fn=<MeanBackward0>) val: tensor(2.6273, grad_fn=<MeanBackward0>)
epoch: 500 loss: tensor(2.4404, grad_fn=<MeanBackward0>) val: tensor(2.6204, grad_fn=<MeanBackward0>)
epoch: 600 loss: tensor(2.4372, grad_fn=<MeanBackward0>) val: tensor(2.6161, grad_fn=<MeanBackward0>)
epoch: 700 loss: tensor(2.4350, grad_fn=<MeanBackward0>) val: tensor(2.6130, grad_fn=<MeanBackward0>)
epoch: 800 loss: tensor(2.4333, grad_fn=<MeanBackward0>) val: tensor(2.6107, grad_fn=<MeanBackward0>)
epoch: 900 loss: tensor(2.4320, grad_fn=<MeanBackward0>) val: tensor(2.6088, grad_

In [65]:
for i in range(20):
    ix = stoi('.')
    str = []
    while True:        
        x_input = F.one_hot(torch.tensor(ix), len(chars)).float()
        h1 = x_input @ w1 + b1
        a1 = F.relu(h1)
        y_pred_logits = a1 @ w2 + b2
        y_pred_prob = torch.exp(y_pred_logits)/torch.sum(torch.exp(y_pred_logits))
        ix = torch.multinomial(y_pred_prob,num_samples=1,replacement=True).item()
        if ix == 0:
            print(''.join(str))
            break
        else:
            x = itos(ix)
            str.append(x)

seledaann
kieliss
caisyslen
de
arin
ele
jadaie
kah
jassorynackayugr
riemiealmana
zemi
sinnn
n
ceezy
ilyn
vi
vi
kaspa
lilysabezi
a


### Use 3 words to predict the next word 使用3个词预测下一个词

In [89]:
pre_chs_count= 3

for w in words[:1]:
    chs = '.'+ w + '.'
    pre_chs = '.'*pre_chs_count+ w + '.'
    index = 0
    for c1, c2 in zip(chs, chs[1:]):
        print(pre_chs[index:index+pre_chs_count],c2)
        index += 1

... e
..e m
.em m
emm a
mma .


In [94]:
# all the data
pre_chs_count= 3

x = []
y = []
for w in words:
    chs = '.'+ w + '.'
    pre_chs = '.'*pre_chs_count+ w + '.'
    index = 0
    for c1, c2 in zip(chs, chs[1:]):
        x.append([stoi(x) for x in pre_chs[index:index+pre_chs_count]])
        y.append(stoi(c2))
        index += 1
print(x[:5])
print(y[:5])

[[0, 0, 0], [0, 0, 5], [0, 5, 13], [5, 13, 13], [13, 13, 1]]
[5, 13, 13, 1, 0]


In [117]:
X = [torch.flatten(F.one_hot(torch.tensor(ix), num_classes=27)).float() for ix in x]

In [167]:
Y = y

In [151]:
X = torch.stack(X,dim=0)

In [153]:
X.shape

torch.Size([228146, 81])

In [118]:
X[0].dtype

torch.float32

In [150]:
w = torch.randn([pre_chs_count * 27,27],requires_grad=True,dtype=torch.float)
b = torch.randn(27,requires_grad=True,dtype=torch.float)

In [134]:
print(f'w shape:{w.shape}')
print(f'b shape:{b.shape}')

w shape:torch.Size([81, 27])
b shape:torch.Size([27])


In [154]:
logits = X @ w + b

In [155]:
logits.shape

torch.Size([228146, 27])

In [156]:
prob = torch.exp(logits)/torch.sum(torch.exp(logits), dim=1, keepdim=True)

In [177]:
prob.shape

torch.Size([228146, 27])

In [185]:
len(Y)

228146

In [186]:
Y[0]

5

In [191]:
row_indices = torch.arange(X.shape[0])

In [192]:
prob[row_indices,Y]

tensor([0.0021, 0.0088, 0.0001,  ..., 0.0030, 0.0516, 0.0018],
       grad_fn=<IndexBackward0>)

In [193]:
prob[0][Y[0]]

tensor(0.0021, grad_fn=<SelectBackward0>)

In [194]:
loss = -torch.log(prob[row_indices,Y])

In [197]:
loss

tensor([6.1428, 4.7275, 9.0175,  ..., 5.8259, 2.9640, 6.2942],
       grad_fn=<NegBackward0>)

In [198]:
epoch = 10
lr = 1

In [200]:
row_indices = torch.arange(X.shape[0])
for i in range(1000):
    logits = X @ w + b
    prob = torch.exp(logits)/torch.sum(torch.exp(logits), dim=1, keepdim=True)
    pred = prob[row_indices,Y]
    loss = torch.mean(-torch.log(pred))
    w.grad = None
    b.grad = None
    loss.backward()
    w.data = w.data + -lr * w.grad.data
    b.data = b.data + -lr * b.grad.data
    print(loss.item())

4.092034816741943
4.050347805023193
4.011786460876465
3.9759325981140137
3.942457914352417
3.9111006259918213
3.881643533706665
3.85390567779541
3.8277299404144287
3.802978038787842
3.779527187347412
3.757265567779541
3.7360939979553223
3.715919256210327
3.696659803390503
3.6782379150390625
3.660586357116699
3.6436431407928467
3.627352476119995
3.611664295196533
3.5965335369110107
3.5819196701049805
3.567786693572998
3.554102659225464
3.5408384799957275
3.5279672145843506
3.5154662132263184
3.5033137798309326
3.491490125656128
3.4799787998199463
3.4687633514404297
3.4578285217285156
3.447161912918091
3.4367504119873047
3.426582098007202
3.416646957397461
3.406934976577759
3.3974363803863525
3.3881425857543945
3.3790464401245117
3.3701391220092773
3.3614139556884766
3.3528642654418945
3.3444831371307373
3.3362655639648438
3.32820463180542
3.3202950954437256
3.312532901763916
3.304912805557251
3.2974295616149902
3.2900795936584473
3.282858371734619
3.275761365890503
3.2687861919403076
3.

In [216]:
# sample
start_chs = '...'
x_start = [F.one_hot(torch.tensor(stoi(ix)),num_classes=27) for ix in start_chs]
x_start = torch.cat(x_start).float()
x_start

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [236]:
start_chs = '...'
x_start = [F.one_hot(torch.tensor(stoi(ix)),num_classes=27) for ix in start_chs]
x_start = torch.cat(x_start).float()
logits = x_start @ w + b
prob = torch.exp(logits)/torch.sum(torch.exp(logits))
print(prob)

tensor([0.0009, 0.1382, 0.0385, 0.0473, 0.0524, 0.0476, 0.0120, 0.0210, 0.0266,
        0.0223, 0.0723, 0.0898, 0.0494, 0.0783, 0.0341, 0.0181, 0.0137, 0.0029,
        0.0514, 0.0648, 0.0411, 0.0071, 0.0153, 0.0049, 0.0044, 0.0179, 0.0277],
       grad_fn=<DivBackward0>)


In [261]:
for i in range(10):
    str = ''
    start_chs = '...'
    while True:      
        x_start = [F.one_hot(torch.tensor(stoi(ix)),num_classes=27) for ix in start_chs]
        x_start = torch.cat(x_start).float()
        logits = x_start @ w + b
        prob = torch.exp(logits)/torch.sum(torch.exp(logits))
        ix = torch.multinomial(prob,num_samples=1,replacement=True)
        if ix == 0:
            print(str)
            break
        else:
            start_chs = start_chs[1:] + itos(ix)
        str += itos(ix)

dalyna
maica
arivi
dee
thila
olen
ojalie
ary
gey
cel
