In [3]:
import torch

import pickle

In [2]:

# Download a text file from a GitHub repository
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
# Open the downloaded file for reading with UTF-8 encoding
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
len(text)
print(text[:1000])

--2024-02-26 12:30:14--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8001::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8002::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1,1M) [text/plain]
Saving to: ‘input.txt’


2024-02-26 12:30:15 (1,37 MB/s) - ‘input.txt’ saved [1115394/1115394]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

F

Save raw text :

In [5]:
with open("./../data/shakespeare.txt", "wb") as f:
    pickle.dump(text, f)

Re-load :

In [113]:
with open("./../data/shakespeare.txt", "rb") as f:
    text = pickle.load(f)

In [114]:

# Create a sorted list of unique characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

In [115]:

# Create character-to-index and index-to-character mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Functions to encode and decode text
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[x] for x in l])

In [116]:

encode("hi there")

[46, 47, 1, 58, 46, 43, 56, 43]

In [117]:
decode([46, 47, 1, 58, 46, 43, 56, 43])

'hi there'

In [119]:
# Convert the text to a PyTorch tensor of character indices
data = torch.tensor(encode(text), dtype=torch.long)


In [120]:
data.shape

torch.Size([1115394])

In [121]:

# Split the data into training and validation sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [122]:

# Define the block size for context
block_size = 8
# Create training examples (x) and their corresponding targets (y)
train_data[:block_size+1]
x = train_data[:block_size]
y = train_data[1:block_size+1]

In [123]:
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print("ctx ", context, "target", target)

ctx  tensor([18]) target tensor(47)
ctx  tensor([18, 47]) target tensor(56)
ctx  tensor([18, 47, 56]) target tensor(57)
ctx  tensor([18, 47, 56, 57]) target tensor(58)
ctx  tensor([18, 47, 56, 57, 58]) target tensor(1)
ctx  tensor([18, 47, 56, 57, 58,  1]) target tensor(15)
ctx  tensor([18, 47, 56, 57, 58,  1, 15]) target tensor(47)
ctx  tensor([18, 47, 56, 57, 58,  1, 15, 47]) target tensor(58)


In [124]:

# Import PyTorch neural network modules
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1137)

<torch._C.Generator at 0x7f2fdcfdd8d0>

In [125]:

# Set batch size and block size
batch_size = 4
block_size = 8
torch.randint(6, (4,))

tensor([0, 5, 3, 1])

In [126]:

# Data loading function to get input (x) and target (y) batches
def get_batch(split):
    # Generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [127]:

xb, yb = get_batch('train')

In [128]:
xb.shape

torch.Size([4, 8])

In [129]:
yb.shape

torch.Size([4, 8])

In [130]:


for b in range(batch_size):
    for t in range(block_size):
        context = xb[b][:t+1]
        target = yb[b][t]
        print(context, "-----", target)

tensor([1]) ----- tensor(58)
tensor([ 1, 58]) ----- tensor(46)
tensor([ 1, 58, 46]) ----- tensor(43)
tensor([ 1, 58, 46, 43]) ----- tensor(1)
tensor([ 1, 58, 46, 43,  1]) ----- tensor(46)
tensor([ 1, 58, 46, 43,  1, 46]) ----- tensor(53)
tensor([ 1, 58, 46, 43,  1, 46, 53]) ----- tensor(54)
tensor([ 1, 58, 46, 43,  1, 46, 53, 54]) ----- tensor(43)
tensor([53]) ----- tensor(1)
tensor([53,  1]) ----- tensor(39)
tensor([53,  1, 39]) ----- tensor(1)
tensor([53,  1, 39,  1]) ----- tensor(61)
tensor([53,  1, 39,  1, 61]) ----- tensor(47)
tensor([53,  1, 39,  1, 61, 47]) ----- tensor(44)
tensor([53,  1, 39,  1, 61, 47, 44]) ----- tensor(43)
tensor([53,  1, 39,  1, 61, 47, 44, 43]) ----- tensor(0)
tensor([39]) ----- tensor(52)
tensor([39, 52]) ----- tensor(42)
tensor([39, 52, 42]) ----- tensor(56)
tensor([39, 52, 42, 56]) ----- tensor(63)
tensor([39, 52, 42, 56, 63]) ----- tensor(1)
tensor([39, 52, 42, 56, 63,  1]) ----- tensor(44)
tensor([39, 52, 42, 56, 63,  1, 44]) ----- tensor(53)
tensor([

In [131]:

# Function to create a decay matrix with a specified dimension and gamma values
def get_decay_matrix(dim, gamma):
    d = torch.ones(dim)
    d = torch.tril(d)
    for index, head in enumerate(d):
        g = gamma[index]
        for idx, x in enumerate(torch.tril(head)):
            for idy, y in enumerate(x):
                if idx >= idy:
                    head[idx][idy] = g ** (idx-idy)
    return d

In [30]:

# Install the 'einops' library
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [31]:

# Import the 'einops' library
import einops
from einops import rearrange, reduce, repeat

----

In [32]:

class ChunkwiseRetention(nn.Module):
    def __init__(self, chunk_size, num_head, block_size):
        super().__init__()
        self.key = nn.Linear(n_embed,  chunk_size * num_head, bias = False)
        self.query = nn.Linear(n_embed,  chunk_size * num_head, bias = False)
        self.value = nn.Linear(n_embed,  chunk_size * num_head, bias = False)
        self.gamma = 1.0-2.0**(-5-torch.arange(0,num_head))
        self.decay_mask = get_decay_matrix((num_head, block_size, block_size), self.gamma)
        self.chunk_decay = self.gamma
        self.gn = nn.GroupNorm(1, num_head)
        self.num_head = num_head
        self.chunk_size = chunk_size


    def forward(self, x, past_kv):


        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)


        k = rearrange(k, ('b t (h c) -> b h t c'), t=T, h=self.num_head, c =self.chunk_size)
        q = rearrange(q, ('b t (h c) -> b h t c'), t=T, h=self.num_head, c =self.chunk_size)
        v = rearrange(v, ('b t (h c) -> b h t c'), t=T, h=self.num_head, c =self.chunk_size)




        retention = q @ k.transpose(-1, -2)


        # b h t c , b h c t -> b h t t


        retention = retention  * self.decay_mask   # b h t t* h t t
        inner_retention = retention @ v
        past_kv = repeat(past_kv, 'n q v -> B n q v', B=B)
        pb, pn, pq, pv = past_kv.shape


        padding = torch.zeros(pb, pn, pq, self.chunk_size)
        past_kv = past_kv+ padding
        dm = repeat(self.decay_mask, 'h c d -> B h c d', B=B)
        pp = q @ past_kv
        cross_retention = pp.transpose(-1, -2) @ dm
        cross_retention = cross_retention.transpose(-1, -2)
        retention = inner_retention + cross_retention
        current_kv = self.gamma.view(self.num_head, 1, 1) * past_kv + (k.transpose(-1, -2) @ v)
        output = self.gn(retention.transpose(-1,-2))
        output = rearrange(output, 'b c h t -> b t (c h)')
        return output, current_kv.mean(dim=0)
class GatedMultiScaleRetention(nn.Module):
    def __init__(self, chunk_size, num_head, block_size):
        super().__init__()
        self.wg = nn.Linear(n_embed,  n_embed, bias = False)
        self.act = nn.SiLU()
        self.y= ChunkwiseRetention(num_head = n_head, chunk_size = n_embed//n_head, block_size=block_size)
        self.wo = nn.Linear(n_embed,  n_embed, bias = False)
        self.past = torch.zeros(num_head, chunk_size, chunk_size)
    def forward(self, x):
        wgx = self.wg(x)
        wgx = self.act(wgx)
        y, past = self.y(wgx, self.past)
        self.past = past.detach()
        y = wgx * y
        return self.wo(y)


class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.GELU(),
            nn.Linear(4 * n_embed, n_embed),
         nn.Dropout(dropout))


    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self, n_embed, n_head, block_size):
        super().__init__()
        self.sa_head= GatedMultiScaleRetention(num_head = n_head, chunk_size = n_embed//n_head, block_size=block_size)
        self.ffw=  FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)


    def forward(self, x):
        x = x + self.sa_head(self.ln1(x))
        x = x+self.ffw(self.ln2(x))
        return x
class RetNet(nn.Module):
    def __init__(self, block_size):
        super().__init__()


        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, block_size=block_size) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss


    def generate(self, idx, max_new_tokes):
        for _ in range(max_new_tokes):
            b, s = idx.shape
            bk = min(s, block_size)
            idx_cond =  torch.cat((torch.zeros(b, block_size-bk, dtype=int), idx), dim=1)[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

In [266]:
# Hyperparameters
batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_head = 4
n_layer = 4
dropout = 0.0
n_embed = 32

In [133]:

# Data loading function to get input (x) and target (y) batches
def get_batch(split, batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [134]:

# Initialize the RetNet model
model = RetNet(block_size=block_size)
# Get a batch of training data
xb, yb = get_batch('train', batch_size=batch_size)
# Initialize the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Forward pass and loss calculation
logits, loss = model(xb, yb)
loss.shape

torch.Size([])

In [141]:
logits.shape

torch.Size([512, 65])

In [137]:
xb.shape

torch.Size([16, 32])

In [142]:
# Function to estimate loss on train and validation sets
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size=batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


# Get a batch of training data
xb, yb = get_batch('train', batch_size=batch_size)


# Forward pass and loss calculation
logits, loss = model(xb, yb)

In [37]:

# Training loop
for iter in range(max_iters):
    # Every once in a while, evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


    # Sample a batch of data
    xb, yb = get_batch('train', batch_size=batch_size)


    # Forward pass, loss calculation, backpropagation, and optimization
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.5425, val loss 4.5270
step 100: train loss 2.7412, val loss 2.7844
step 200: train loss 2.7678, val loss 3.0618
step 300: train loss 2.4871, val loss 2.6357
step 400: train loss 2.4499, val loss 2.6510
step 500: train loss 2.6725, val loss 3.0618
step 600: train loss 3.4132, val loss 3.7535
step 700: train loss 2.4730, val loss 2.6705
step 800: train loss 2.8632, val loss 3.3417
step 900: train loss 3.0594, val loss 4.0267
step 1000: train loss 2.6181, val loss 2.9877
step 1100: train loss 2.6858, val loss 3.2692
step 1200: train loss 2.1505, val loss 2.2882
step 1300: train loss 2.4247, val loss 2.8718
step 1400: train loss 2.1464, val loss 2.4026
step 1500: train loss 2.1511, val loss 2.5579
step 1600: train loss 2.7613, val loss 3.3510
step 1700: train loss 2.5429, val loss 2.7941
step 1800: train loss 2.5993, val loss 3.0266
step 1900: train loss 2.5095, val loss 3.1794
step 2000: train loss 2.3715, val loss 2.7155
step 2100: train loss 2.4116, val loss 2.7164


In [38]:

# Create a context for text generation
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# Generate text using the model
print(decode(model.generate(context, max_new_tokes=200)[0].tolist()))


Thp; as mth dofowhiicaye dw RDUo,
Se thixghad pe tellldeaseadm f gad,
OWINANTh o, w Le E aps hpor ale,
An bl Bou by slor!
TIThandellatr gonghed ty Myoll cat tomillitu wiswingoblthithusferd win 'LDIUMY


In [39]:

# Create another context for text generation
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# Generate more text using the model
print(decode(model.generate(context, max_new_tokes=200)[0].tolist()))


A
L torn d s y, whr, rdend t Thaly, athaturOFr,
ICINRINGENatora lbst:
Hir ydorand.
The cheedustocngurgothiserisdr ttortrcoryof sped s mswithary ithoknwe l
Borifry lyolo foou;
ICHUCUWA beenomz,
Tin ies


----

In [40]:

# Install the 'tiktoken' library
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting regex>=2022.1.18 (from tiktoken)
  Downloading regex-2023.12.25-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests>=2.26.0 (from tiktoken)
  Downloading requests-2.31.0-py3-none-any.whl.metadata (4.6 kB)
Collecting charset-normalizer<4,>=2 (from requests>=2.26.0->tiktoken)
  Downloading charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (33 kB)
Collecting idna<4,>=2.5 (from requests>=2.26.0->tiktoken)
  Downloading idna-3.6-py3-none-any.whl.metadata (9.9 kB)
Collecting urllib3<3,>=1.21.1 (from requests>=2.26.0->tiktoken)
  Downloading urllib3-2.2.1-py3-none-any.whl.metadata (6.4 kB)
Collecting certifi>=2017.4.17 (from requests>=2.26.0->tiktok

In [245]:

# Import the 'tiktoken' library
import tiktoken
# Get the encoding for a specific model
enc = tiktoken.get_encoding("r50k_base")

In [246]:
enc

<Encoding 'r50k_base'>

In [247]:

# Assert that encoding and decoding work correctly
assert enc.decode(enc.encode("hello world")) == "hello world"

In [248]:
enc.encode("hello world")

[31373, 995]

In [147]:
"""
# To get the tokeniser corresponding to a specific model in the OpenAI API:
enc = tiktoken.encoding_for_model("gpt-4")"""

In [148]:
"""
# Assert that encoding and decoding work correctly for the new model
assert enc.decode(enc.encode("hello world")) == "hello world" """

In [249]:

# Encode "hello world" using the tokeniser
enc.encode("hello world")

[31373, 995]

In [250]:
len(text)

1115394

In [251]:
text_sub = text[:5000]

In [252]:

# Count the number of tokens in the text
text_tokens = enc.encode(text_sub)
len(text_tokens)

1393

In [253]:
text_tokens

[5962,
 22307,
 25,
 198,
 8421,
 356,
 5120,
 597,
 2252,
 11,
 3285,
 502,
 2740,
 13,
 198,
 198,
 3237,
 25,
 198,
 5248,
 461,
 11,
 2740,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 1639,
 389,
 477,
 12939,
 2138,
 284,
 4656,
 621,
 284,
 1145,
 680,
 30,
 198,
 198,
 3237,
 25,
 198,
 4965,
 5634,
 13,
 12939,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 5962,
 11,
 345,
 760,
 327,
 1872,
 385,
 1526,
 28599,
 318,
 4039,
 4472,
 284,
 262,
 661,
 13,
 198,
 198,
 3237,
 25,
 198,
 1135,
 760,
 470,
 11,
 356,
 760,
 470,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 5756,
 514,
 1494,
 683,
 11,
 290,
 356,
 1183,
 423,
 11676,
 379,
 674,
 898,
 2756,
 13,
 198,
 3792,
 470,
 257,
 15593,
 30,
 198,
 198,
 3237,
 25,
 198,
 2949,
 517,
 3375,
 319,
 470,
 26,
 1309,
 340,
 307,
 1760,
 25,
 1497,
 11,
 1497,
 0,
 198,
 198,
 12211,
 22307,
 25,
 198,
 3198,
 1573,
 11,
 922,
 4290,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 1135,
 389,
 17830,
 3595,
 4290,
 11,
 262,
 1458,


In [254]:

# Create a sorted list of unique characters in the text
chars = sorted(list(set(text_tokens)))
vocab_size = len(chars)
vocab_size

522

In [255]:

# Decode the first token in the text
enc.decode([text_tokens[0]])

'First'

In [256]:

data = torch.tensor(text_tokens, dtype=torch.long)
data.shape

torch.Size([1393])

In [258]:
learning_rate = 3e-4

In [54]:
"""
chars = sorted(list(set(text.split(' '))))
vocab_size = len(chars)"""

In [55]:
"""vocab_size"""

42197

In [59]:
"""
chars[100]"""

"'banished'?\n\nFRIAR"

In [66]:
"""
# Create word-to-index and index-to-word mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}


# Functions to encode and decode words
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: " ".join([itos[x] for x in l])
# Encode words using the mappings
data = torch.tensor(encode(text.split(' ')), dtype = torch.long)
# Display the first 10 tokens in the data
data[:10]"""

tensor([ 1455,   957, 39874, 29614,  5949, 16628, 18572, 24432, 34050, 34057])

In [73]:
"""
# Decode the first 10 tokens in the data
decode(encode(text.split("\n")[:2]))"""

KeyError: 'First Citizen:'

In [213]:
len(data)

134353

In [259]:

# Split the data into training and validation sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [260]:
len(val_data)

140

In [261]:

# Set hyperparameters for the model
batch_size = 8
block_size = 16

In [262]:
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32
n_embed = 32
n_head = 4
n_layer = 4
dropout = 0.0

In [267]:
# Initialize the RetNet model
model = RetNet(block_size=block_size)
# Get a batch of training data
xb, yb = get_batch('train', batch_size=batch_size)
# Initialize the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [268]:
xb.shape

torch.Size([16, 32])

In [269]:

# Forward pass and loss calculation
logits, loss = model(xb, yb)
loss.shape

IndexError: index out of range in self

In [234]:
xb[0:3]

tensor([[ 3957,   956,   264, 36543,  1980,  2460,   512,  2822,   810,  7556,
           389,   956,    26,  1095,   433,   387],
        [  382,  2460,   512, 96945,    11,  6604,   382,  5451, 47317,   512,
          2675,   527,   682, 20250,  4856,   311],
        [ 9354,   311,   279,  1274,   382,  2460,   512,  1687,  1440,   956,
            11,   584,  1440,   956,   382,  5451]])

In [102]:
yb[0:3]

tensor([[  280,  3112,   358,   656,  3987,  1695,  2919,   323,  1317,   311,
          1518,   382, 58163,    44,  3895,   512,    46, 28146,    11,  1778,
           264,  2324,    11,   449,  1778,   264,  7555,    11,  1051, 15234,
          4999,  4071],
        [46811,    11, 24613,  2277,   757,   198,  1962,   311,  5622,   279,
         96923,   382, 16041, 52483,   261,   512, 18293,   279, 38736,   304,
         26236,  4059,    11,   323, 48839,  1461,   539,    25,   568,   198,
         41450,  1672],
        [ 4648,    11,   719,  2547,   596,  9120, 16409,   382,  3442,  6903,
           512, 34042,    11,  9120, 16409,     0,   387, 16888,  5092,    11,
          2019,   364, 63007, 99419,  2520, 61087, 52677,   810,  8818,   304,
           813,  1427]])

In [108]:

# Get a batch of training data
xb, yb = get_batch('train', batch_size=batch_size)


# Forward pass and loss calculation
logits, loss = model(xb, yb)

IndexError: index out of range in self

In [98]:

# Forward pass and loss calculation
logits, loss = model(xb, yb)

IndexError: index out of range in self

In [None]:

# Training loop
for iter in range(max_iters):
    # Every once in a while, evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


    # Sample a batch of data
    xb, yb = get_batch('train', batch_size=batch_size)


    # Forward pass, loss calculation, backpropagation, and optimization
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:

# Generate text with different initial contexts
context1 = torch.tensor([encode("thou art kneel before king".split(' '))], dtype=torch.long)
context2 = torch.tensor([encode("Hermione".split(' '))], dtype=torch.long)
context3 = torch.tensor([encode("come".split(' '))], dtype=torch.long)

In [None]:

# Print generated text using different contexts
print(decode(model.generate(context1, max_new_tokes=200)[0].tolist()))
print(decode(model.generate(context2, max_new_tokes=200)[0].tolist()))
print(decode(model.generate(context3, max_new_tokes=200)[0].tolist()))