In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
device = 'cuda:5' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [15]:
train_data

tensor([18, 47, 56,  ..., 43, 56, 43])

In [18]:
data = train_data
x_index = torch.randint(len(data) - block_size, (batch_size,))
len(x_index)
x_index
x = torch.stack([data[i : i+block_size] for i in x_index])
x


tensor([[35, 43,  1, 39, 56, 43,  1, 52],
        [56, 61, 53, 56, 52,  1, 61, 47],
        [52,  1, 47, 58, 57, 43, 50, 44],
        [23, 17, 10,  0, 28, 39, 50, 43],
        [56, 57, 58,  1, 15, 47, 58, 47],
        [46, 58, 10,  0, 21,  1, 54, 56],
        [39, 52, 42,  8,  0,  0, 25, 13],
        [52, 42,  1, 57, 53, 60, 43, 56],
        [63,  1, 44, 39, 58, 46, 43, 56],
        [58, 57,  6,  1, 40, 59, 58,  1],
        [ 0, 32, 47, 40, 43, 56,  1, 47],
        [47, 41, 12,  1, 46, 39, 56, 49],
        [ 1, 41, 39, 56, 43,  0, 20, 39],
        [57,  1, 57, 46, 43,  1, 58, 53],
        [41, 50, 53, 57, 43, 42,  1, 47],
        [39, 49, 43,  1, 58, 46, 43, 51],
        [39, 44, 58, 43, 56, 52, 53, 53],
        [10,  0, 25, 63,  1, 50, 47, 44],
        [43, 39, 57, 53, 52, 10,  0, 32],
        [53, 52,  0, 27, 44,  1, 58, 46],
        [52, 58,  8,  0,  0, 15, 27, 30],
        [42, 43, 39, 56, 12,  0, 32, 46],
        [56, 47, 52, 41, 43,  6,  0, 13],
        [53, 61,  1, 52, 53, 58,  