In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
torch.manual_seed(69)

<torch._C.Generator at 0x13affdb1530>

In [3]:
with open("data/input.txt", "r") as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
CHAR_TO_INDEX = {ch: i for i, ch in enumerate(chars)}
INDEX_TO_CHAR = {i: ch for i, ch in enumerate(chars)}

def encode(text):
    return torch.tensor([CHAR_TO_INDEX[ch] for ch in text], dtype=torch.long)

def decode(tensor):
    return "".join([INDEX_TO_CHAR[int(i)] for i in tensor])

In [5]:
encoded_text = encode('Luke, I am your father.')
print(encoded_text)
print(decode(encoded_text))

tensor([24, 59, 49, 43,  6,  1, 21,  1, 39, 51,  1, 63, 53, 59, 56,  1, 44, 39,
        58, 46, 43, 56,  8])
Luke, I am your father.


In [6]:
data = torch.tensor(encode(text), dtype=torch.long)
pct = 0.9
train_data = data[:int(len(data)*pct)]
val_data = data[int(len(data)*pct):]

  data = torch.tensor(encode(text), dtype=torch.long)


In [7]:
SEQ_LEN = 100
BATCH_SIZE = 64

def get_batch(split):
    dt = train_data if split == "train" else val_data
    ix = torch.randint(len(dt) - SEQ_LEN, (BATCH_SIZE,))
    x = torch.stack([dt[i:i+SEQ_LEN] for i in ix])
    y = torch.stack([dt[i+1:i+SEQ_LEN+1] for i in ix])
    return x, y

In [8]:
EMBEDDING_DIM = 64
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.emb = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.fc = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)

    def forward(self, x, hidden=None):
        x = self.emb(x)
        x = self.fc(x)
        return x
    
    def generate(self, x, num_chars):
        for _ in range(num_chars):
            tmp = self(x)
            tmp = F.softmax(tmp[-1], dim=0)
            tmp = torch.multinomial(tmp, 1)
            x = torch.cat([x, tmp])
        return x

In [9]:
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [12]:
NUM_EPOCHS = 1000
for epoch in range(NUM_EPOCHS):
    model.train()
    x, y = get_batch("train")
    optimizer.zero_grad()
    y_pred = model(x)
    loss = F.cross_entropy(y_pred.view(-1, VOCAB_SIZE), y.view(-1))
    loss.backward()
    optimizer.step()
    model.eval()
    x, y = get_batch("val")
    y_pred = model(x)
    loss = F.cross_entropy(y_pred.view(-1, VOCAB_SIZE), y.view(-1))
    print(f"Epoch {epoch+1} Loss: {loss.item()}")

Epoch 1 Loss: 2.491325616836548
Epoch 2 Loss: 2.5031158924102783
Epoch 3 Loss: 2.4512479305267334
Epoch 4 Loss: 2.4868037700653076
Epoch 5 Loss: 2.4764151573181152
Epoch 6 Loss: 2.4607126712799072
Epoch 7 Loss: 2.472249984741211
Epoch 8 Loss: 2.474839448928833
Epoch 9 Loss: 2.457688808441162
Epoch 10 Loss: 2.4986538887023926
Epoch 11 Loss: 2.4976232051849365
Epoch 12 Loss: 2.4818215370178223
Epoch 13 Loss: 2.469834327697754
Epoch 14 Loss: 2.4705610275268555
Epoch 15 Loss: 2.4820897579193115
Epoch 16 Loss: 2.462399482727051
Epoch 17 Loss: 2.5267186164855957
Epoch 18 Loss: 2.509347915649414
Epoch 19 Loss: 2.458144187927246
Epoch 20 Loss: 2.4689748287200928
Epoch 21 Loss: 2.4895541667938232
Epoch 22 Loss: 2.4916417598724365
Epoch 23 Loss: 2.4854557514190674
Epoch 24 Loss: 2.509185552597046
Epoch 25 Loss: 2.480280876159668
Epoch 26 Loss: 2.5097341537475586
Epoch 27 Loss: 2.48793888092041
Epoch 28 Loss: 2.467266798019409
Epoch 29 Loss: 2.466728448867798
Epoch 30 Loss: 2.4775567054748535
Epo

In [17]:
context = encode("First citizen:")
print(decode(model.generate(context, 10000)))

First citizen:
Tonir ce me wou puct he;
Vingo, ERI aten
An fenopot ar cean t f t setupe gerear'sory CELO:
ES:
S: seamy onayore.
S:
But mys nd y llisereheas ould.

Buligrges we, wapofe t he;
Were qurooureane ar worine myodis sorad praverengncouches mesed byo anort s, t iflotoosure hithor bery D:
RGhas l n ate t meas; jerir;
Gos:
TKENous,

Rim bed muror, frous lt mmuremantorowis blllors lve
GLA: d Gr.
unouf givof iomyofor tld;
Be vakssthe t ang, wh t tharu hat arend cor wicanse I'llomer othou isindeper--y it
WINo thamur searet rar s Tou!

Wh wothe. af f t ant ced sthence, isn arthee ch son m, ar crratawourakeango m s ftyo ace s istos thallve, wo:
CA:
I avorcou in:
Wh!
MNI tipthis ceanaiee fatichofame shat ber huten
I y tht aliswefowe his.-f;
Ifatonteale, ne.
CHod she, hes ourerirenoromeeak andowin'ltonreedod, s, hendotenoor
ERDY lilif,
ADoon ilfow:

Maverichad:
I b!

S:
TORe h And ced bullleeane ll?
Yom turtouy; but?
Gunach Whinouthedeshille IZELLOLAnd thid he buthicompamareca gong outhe