In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [14]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(chars)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [5]:
s2i = {ch:i for i,ch in enumerate(chars)}
i2s = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [s2i[c] for c in s]
decode = lambda l: ''.join([i2s[i] for i in l])

In [None]:
# for education only
import tiktoken
encoder = tiktoken.get_encoding('gpt2') # Byte Pair Encoding BPE, size 50257
encoder.encode("hii there")

In [6]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [7]:
n = int(0.9 * len(data))
train = data[:n]
test = data[n:]

In [None]:
# for education only
block_size = 8
x = train[:block_size]
y = train[1: block_size+1]
for i in range(block_size):
    print('context: ', x[:i+1], ', target: ', y[i])

In [13]:
batch_size = 4
block_size = 8
torch.manual_seed(1337)

def get_batch(isTrain):
    data = train if isTrain else test
    randomIndex = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i: i + block_size] for i in randomIndex])
    y = torch.stack([data[i + 1: i + block_size + 1] for i in randomIndex])
    return x, y

xb, yb = get_batch(True)

for batch in range(batch_size): # batch dimension
    for block in range(block_size): # time dimension
        context = xb[batch, :block+1]
        target = yb[batch, block]
        print('context: ', context.tolist(), ', target: ', target)
        

context:  [24] , target:  tensor(43)
context:  [24, 43] , target:  tensor(58)
context:  [24, 43, 58] , target:  tensor(5)
context:  [24, 43, 58, 5] , target:  tensor(57)
context:  [24, 43, 58, 5, 57] , target:  tensor(1)
context:  [24, 43, 58, 5, 57, 1] , target:  tensor(46)
context:  [24, 43, 58, 5, 57, 1, 46] , target:  tensor(43)
context:  [24, 43, 58, 5, 57, 1, 46, 43] , target:  tensor(39)
context:  [44] , target:  tensor(53)
context:  [44, 53] , target:  tensor(56)
context:  [44, 53, 56] , target:  tensor(1)
context:  [44, 53, 56, 1] , target:  tensor(58)
context:  [44, 53, 56, 1, 58] , target:  tensor(46)
context:  [44, 53, 56, 1, 58, 46] , target:  tensor(39)
context:  [44, 53, 56, 1, 58, 46, 39] , target:  tensor(58)
context:  [44, 53, 56, 1, 58, 46, 39, 58] , target:  tensor(1)
context:  [52] , target:  tensor(58)
context:  [52, 58] , target:  tensor(1)
context:  [52, 58, 1] , target:  tensor(58)
context:  [52, 58, 1, 58] , target:  tensor(46)
context:  [52, 58, 1, 58, 46] , 

In [17]:
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, index, targets):
        # index 和 targets 都是 (batch, block)
        logits = self.token_embedding_table(index) # (batch, block, character)
        return logits
    
model = BigramLanguageModel(vocab_size)
output = model(xb, yb)
print(output.shape)

torch.Size([4, 8, 65])
