In [82]:
# data

text = None
data_file = "input.txt"
with open(data_file, "r") as f:
    text = f.read()
print("Estimated tokens:", len(text))


Estimated tokens: 1115393


In [83]:
# Tokenizer

chars = sorted(list(set(''.join(text))))
vocab_size = len(chars)
print("Unique characters:", ''.join(chars), end="\n\n")

i2char = { x:chars[x] for x in range(len(chars)) }
char2i = { chars[x]:x for x in range(len(chars)) }


encode = lambda string_ : [ char2i[s] for s in string_  ]
decode = lambda tokens : ''.join([ i2char[t] for t in tokens  ])

test_str = 'hi there'
print(f"encoding '{test_str}' gives {encode(test_str)}")


Unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz

encoding 'hi there' gives [46, 47, 1, 58, 46, 43, 56, 43]


In [84]:
# train test split
import torch

split_ratio = .9
size = len(text)
n = int( size * split_ratio)

encoded_text = torch.tensor(encode(text))
train_dataset = encoded_text[:n]
val_dataset = encoded_text[n:]

print(f"train : val split ====> {len(train_dataset)/size :.2f}:{len(val_dataset)/size :.2f}")



train : val split ====> 0.90:0.10


In [95]:
import torch.nn as nn

torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 4
block_size = 8

def get_batch(split):
    # generate a small batch of data of inputs x and targets y

    assert split in ['train', 'val']
    data = train_dataset if split == 'train' else val_dataset

    ix = torch.randint(low = 0, high = len(data) - batch_size, size=(batch_size,) )
    x = torch.stack([ data[i: i+block_size] for i in ix])
    y = torch.stack([ data[i+1: i+block_size+1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y


class SimpleBigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) 

    def forward(self, idx, targets=None):
        # idx == (B, T)
        logits = self.token_embedding_table(idx) # B,T,C === Batch, Time, Channels === Rows, Columns, Channels

        if targets == None:
            loss = None
        else:
            B,T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = torch.nn.functional.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx,  max_new_tokens: int):
        # idx == (B, T)

        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :] # focus on last time step (col) ===> becomes (B, C)
            probs = torch.nn.functional.softmax(logits, dim=-1) # apply on C
            idx_next = torch.multinomial(probs, num_samples=1) # # sample from the distribution (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx
        pass



In [96]:
# training

# print(x)
# print(y)

model = SimpleBigramLanguageModel(vocab_size).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)
max_iters = 3000
eval_interval = 300

model.train()

for epoch in range(max_iters):
    x, y = get_batch('train')

    if epoch % eval_interval == 0:
        model.eval()
        x_train, y_train = get_batch('train')
        x_val, y_val = get_batch('val')
        
        _, train_loss = model(x_train, y_train)
        _, val_loss = model(x_val, y_val)

        print(f"epoch {epoch}: train loss {train_loss :.4f}, val loss {val_loss:.4f} ")
        model.train()

    logits, loss = model(x, y)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()



epoch 0: train loss 4.6291, val loss 4.6232 
epoch 300: train loss 3.1551, val loss 3.0098 
epoch 600: train loss 3.0026, val loss 2.4398 
epoch 900: train loss 3.0212, val loss 2.8308 
epoch 1200: train loss 2.3367, val loss 2.3471 
epoch 1500: train loss 2.3850, val loss 2.5813 
epoch 1800: train loss 2.5350, val loss 2.5518 
epoch 2100: train loss 2.3550, val loss 2.3588 
epoch 2400: train loss 2.6374, val loss 2.4671 
epoch 2700: train loss 2.2695, val loss 2.1183 


In [97]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))


JOLI ou walur paly houmy rde t die;
CHFigovigeeat ueju ly, cowiccond d hand wouce ds nlowisheno, Inds s tasthid thy, thaimond's k; Oun'ssearethe fosonsigrisaples, br, tronshe lveald Wesourar
gis?atr, l be-d,
ENGyoy Em n.
chodn mfe me! fodapas. habyomamalwous sushaglowiress; yFlwelo, oowityo hy, e was t ms,
Thmind in, d y s imatheaccarcano thidsto boinerougre; E:
Burs higrind blos.

Y:

DYoupe cofe'diea! helalpou,
&'sV&'sprals.


Cay'le taie OHese
BELond ale g.
BUzencust s t rd t
FhRI y thend:
SC
