In [145]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [146]:
batch_size = 64
context_size = 8
embedding_dim = 128

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

print(len(text))
print(chars)
print(vocab_size)

char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [char_to_ix[c] for c in s]
decode = lambda x: ''.join([ix_to_char[i] for i in x])
decode_torch = lambda x: ''.join([ix_to_char[i.item()] for i in x])

In [148]:
device = torch.device('cpu')

data = torch.tensor(encode(text), dtype=torch.long).to(device)
n = int(len(data) * 0.9)

# Split data into train and test (test unused for now)
train_data = data[:n]
test_data = data[n:]

In [149]:
class LM(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(context_size * embedding_dim, 128*2),
            nn.Tanh(),
            nn.Linear(128*2, vocab_size)
        )

    def init_weights(self):
        for param in self.parameters():
            nn.init.kaiming_normal_(param)

    def forward(self, x):
        # x is a B x T tensor, where B is the batch size and T is the context size
        B, T = x.size()

        assert T == context_size

        # Embed the input
        x = self.embedding(x)

        # Flatten the input
        x = x.view(B, -1)

        # Pass through the MLP to get logits
        x = self.mlp(x)
        return x

In [None]:
model = LM().to(device)

def sample(model):
    with torch.no_grad():
        prompt = text[:context_size]
        og = prompt
        prompt = torch.tensor(encode(prompt), dtype=torch.long).to(device)
        for i in range(100):
            output = model(prompt.view(1, -1))
            pred = torch.argmax(output, dim=1)
            prompt = torch.cat([prompt, pred]).to(device)
            og += decode_torch(pred)
            prompt = prompt[-context_size:]

        print(og)

iters = 20000
lr = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for i in range(iters + 1):
    idx = torch.randint(0, len(train_data) - context_size, (batch_size,)).to(device)
    batch = torch.stack([train_data[idx:idx+context_size] for idx in idx]).to(device)
    target = train_data[idx+context_size]

    optimizer.zero_grad()
    output = model(batch)

    # Cross entropy loss manually
    soft_output = output.exp()
    soft_output = soft_output / soft_output.sum(dim=1, keepdim=True)
    loss2 = -soft_output.log().gather(1, target.view(-1, 1)).mean()

    loss2.backward()
    optimizer.step()

    if i % 100 == 0:
        print(f'iter {i} loss {loss2.item()}')

    if i % 1000 == 0:
        sample(model)