In [1]:
import torch

import torch.nn as nn
from torch.nn import functional as F

import pickle

from retention import SimpleRetention, MultiScaleRetention
from retnet import RetNet


# Import the 'einops' library
import einops
from einops import rearrange, reduce, repeat


  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),


In [2]:

# Download a text file from a GitHub repository
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
# Open the downloaded file for reading with UTF-8 encoding
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
len(text)
print(text[:1000])

--2024-02-26 12:30:14--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8001::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8002::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1,1M) [text/plain]
Saving to: ‘input.txt’


2024-02-26 12:30:15 (1,37 MB/s) - ‘input.txt’ saved [1115394/1115394]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

F

Save raw text :

In [5]:
with open("./../data/shakespeare.txt", "wb") as f:
    pickle.dump(text, f)

Re-load :

In [2]:
with open("./../data/shakespeare.txt", "rb") as f:
    text = pickle.load(f)

In [3]:

# Function to create a decay matrix with a specified dimension and gamma values
def get_decay_matrix(dim, gamma):
    d = torch.ones(dim)
    d = torch.tril(d)
    for index, head in enumerate(d):
        g = gamma[index]
        for idx, x in enumerate(torch.tril(head)):
            for idy, y in enumerate(x):
                if idx >= idy:
                    head[idx][idy] = g ** (idx-idy)
    return d

In [4]:

# Data loading function to get input (x) and target (y) batches
def get_batch(split, batch_size, train_data, val_data, block_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

----

In [5]:
# Hyperparameters
batch_size = 16
seq_len = 20
max_iters = 100
eval_interval = 10
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 20
n_head = 4
n_layer = 2
dropout = 0.0
n_embed = 8

In [6]:
vocab_size = 65

In [7]:
xb_test = torch.tensor(torch.zeros([3, seq_len]), dtype=torch.long)
yb_test = torch.tensor(torch.zeros([3, seq_len]), dtype=torch.long)

  xb_test = torch.tensor(torch.zeros([3, seq_len]), dtype=torch.long)
  yb_test = torch.tensor(torch.zeros([3, seq_len]), dtype=torch.long)


In [8]:
class BigRetNet(nn.Module):
    def __init__(self, n_layer, n_embed, ffn_size, n_head, vocab_size, seq_len):
        super().__init__()

        self.retnet = RetNet(layers=n_layer, hidden_dim=n_embed, ffn_size=ffn_size, heads=n_head)

        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(seq_len, n_embed)
        self.blocks = nn.Sequential(*[RetNet(layers=n_layer, hidden_dim=n_embed, ffn_size=ffn_size, heads=n_head) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    def forward_recurrent(self, idx):
        pass
    
    """def train_parallelize(self, idx, targets):
        logits = self(idx)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss"""

In [9]:
# Initialize the RetNet model
model = BigRetNet(n_layer, n_embed, 2*n_embed, n_head, vocab_size, seq_len)

In [10]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [11]:
output, loss = model(xb_test, yb_test)

In [12]:
output.shape

torch.Size([60, 65])

In [13]:
loss

tensor(3.3119, grad_fn=<NllLossBackward0>)

----

In [14]:

# Create a sorted list of unique characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

In [15]:

# Create character-to-index and index-to-character mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Functions to encode and decode text
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[x] for x in l])

In [16]:

encode("hi there")

[46, 47, 1, 58, 46, 43, 56, 43]

In [77]:
decode([46, 47, 1, 58, 46, 43, 56, 43])

'hi there'

In [17]:
# Convert the text to a PyTorch tensor of character indices
data = torch.tensor(encode(text), dtype=torch.long)

In [18]:

# Split the data into training and validation sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [19]:
# Function to estimate loss on train and validation sets
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size=batch_size, train_data=train_data, val_data=val_data, block_size=seq_len)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


# Get a batch of training data
xb, yb = get_batch('train', batch_size=batch_size, train_data=train_data, val_data=val_data, block_size=seq_len)


# Forward pass and loss calculation
logits, loss = model(xb, yb)

In [20]:
batch_size * seq_len

320

In [21]:
logits.shape

torch.Size([320, 65])

In [22]:

# Training loop
for iter in range(max_iters):
    # Every once in a while, evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


    # Sample a batch of data
    xb, yb = get_batch('train', batch_size=batch_size, train_data=train_data, val_data=val_data, block_size=seq_len)


    # Forward pass, loss calculation, backpropagation, and optimization
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.6459, val loss 4.6237
step 99: train loss 3.7512, val loss 3.7784


STOP - TO DO NEXT

---

In [38]:

# Create a context for text generation
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# Generate text using the model
print(decode(model.generate(context, max_new_tokes=200)[0].tolist()))


Thp; as mth dofowhiicaye dw RDUo,
Se thixghad pe tellldeaseadm f gad,
OWINANTh o, w Le E aps hpor ale,
An bl Bou by slor!
TIThandellatr gonghed ty Myoll cat tomillitu wiswingoblthithusferd win 'LDIUMY


In [39]:

# Create another context for text generation
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# Generate more text using the model
print(decode(model.generate(context, max_new_tokes=200)[0].tolist()))


A
L torn d s y, whr, rdend t Thaly, athaturOFr,
ICINRINGENatora lbst:
Hir ydorand.
The cheedustocngurgothiserisdr ttortrcoryof sped s mswithary ithoknwe l
Borifry lyolo foou;
ICHUCUWA beenomz,
Tin ies


----

In [40]:

# Install the 'tiktoken' library
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting regex>=2022.1.18 (from tiktoken)
  Downloading regex-2023.12.25-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests>=2.26.0 (from tiktoken)
  Downloading requests-2.31.0-py3-none-any.whl.metadata (4.6 kB)
Collecting charset-normalizer<4,>=2 (from requests>=2.26.0->tiktoken)
  Downloading charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (33 kB)
Collecting idna<4,>=2.5 (from requests>=2.26.0->tiktoken)
  Downloading idna-3.6-py3-none-any.whl.metadata (9.9 kB)
Collecting urllib3<3,>=1.21.1 (from requests>=2.26.0->tiktoken)
  Downloading urllib3-2.2.1-py3-none-any.whl.metadata (6.4 kB)
Collecting certifi>=2017.4.17 (from requests>=2.26.0->tiktok

In [245]:

# Import the 'tiktoken' library
import tiktoken
# Get the encoding for a specific model
enc = tiktoken.get_encoding("r50k_base")

In [246]:
enc

<Encoding 'r50k_base'>

In [247]:

# Assert that encoding and decoding work correctly
assert enc.decode(enc.encode("hello world")) == "hello world"

In [248]:
enc.encode("hello world")

[31373, 995]

In [147]:
"""
# To get the tokeniser corresponding to a specific model in the OpenAI API:
enc = tiktoken.encoding_for_model("gpt-4")"""

In [148]:
"""
# Assert that encoding and decoding work correctly for the new model
assert enc.decode(enc.encode("hello world")) == "hello world" """

In [249]:

# Encode "hello world" using the tokeniser
enc.encode("hello world")

[31373, 995]

In [250]:
len(text)

1115394

In [251]:
text_sub = text[:5000]

In [252]:

# Count the number of tokens in the text
text_tokens = enc.encode(text_sub)
len(text_tokens)

1393

In [253]:
text_tokens

[5962,
 22307,
 25,
 198,
 8421,
 356,
 5120,
 597,
 2252,
 11,
 3285,
 502,
 2740,
 13,
 198,
 198,
 3237,
 25,
 198,
 5248,
 461,
 11,
 2740,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 1639,
 389,
 477,
 12939,
 2138,
 284,
 4656,
 621,
 284,
 1145,
 680,
 30,
 198,
 198,
 3237,
 25,
 198,
 4965,
 5634,
 13,
 12939,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 5962,
 11,
 345,
 760,
 327,
 1872,
 385,
 1526,
 28599,
 318,
 4039,
 4472,
 284,
 262,
 661,
 13,
 198,
 198,
 3237,
 25,
 198,
 1135,
 760,
 470,
 11,
 356,
 760,
 470,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 5756,
 514,
 1494,
 683,
 11,
 290,
 356,
 1183,
 423,
 11676,
 379,
 674,
 898,
 2756,
 13,
 198,
 3792,
 470,
 257,
 15593,
 30,
 198,
 198,
 3237,
 25,
 198,
 2949,
 517,
 3375,
 319,
 470,
 26,
 1309,
 340,
 307,
 1760,
 25,
 1497,
 11,
 1497,
 0,
 198,
 198,
 12211,
 22307,
 25,
 198,
 3198,
 1573,
 11,
 922,
 4290,
 13,
 198,
 198,
 5962,
 22307,
 25,
 198,
 1135,
 389,
 17830,
 3595,
 4290,
 11,
 262,
 1458,


In [254]:

# Create a sorted list of unique characters in the text
chars = sorted(list(set(text_tokens)))
vocab_size = len(chars)
vocab_size

522

In [255]:

# Decode the first token in the text
enc.decode([text_tokens[0]])

'First'

In [256]:

data = torch.tensor(text_tokens, dtype=torch.long)
data.shape

torch.Size([1393])

In [258]:
learning_rate = 3e-4

In [54]:
"""
chars = sorted(list(set(text.split(' '))))
vocab_size = len(chars)"""

In [55]:
"""vocab_size"""

42197

In [59]:
"""
chars[100]"""

"'banished'?\n\nFRIAR"

In [66]:
"""
# Create word-to-index and index-to-word mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}


# Functions to encode and decode words
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: " ".join([itos[x] for x in l])
# Encode words using the mappings
data = torch.tensor(encode(text.split(' ')), dtype = torch.long)
# Display the first 10 tokens in the data
data[:10]"""

tensor([ 1455,   957, 39874, 29614,  5949, 16628, 18572, 24432, 34050, 34057])

In [73]:
"""
# Decode the first 10 tokens in the data
decode(encode(text.split("\n")[:2]))"""

KeyError: 'First Citizen:'

In [213]:
len(data)

134353

In [259]:

# Split the data into training and validation sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [260]:
len(val_data)

140

In [261]:

# Set hyperparameters for the model
batch_size = 8
block_size = 16

In [262]:
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32
n_embed = 32
n_head = 4
n_layer = 4
dropout = 0.0

In [267]:
# Initialize the RetNet model
model = RetNet(block_size=block_size)
# Get a batch of training data
xb, yb = get_batch('train', batch_size=batch_size)
# Initialize the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [268]:
xb.shape

torch.Size([16, 32])

In [269]:

# Forward pass and loss calculation
logits, loss = model(xb, yb)
loss.shape

IndexError: index out of range in self

In [234]:
xb[0:3]

tensor([[ 3957,   956,   264, 36543,  1980,  2460,   512,  2822,   810,  7556,
           389,   956,    26,  1095,   433,   387],
        [  382,  2460,   512, 96945,    11,  6604,   382,  5451, 47317,   512,
          2675,   527,   682, 20250,  4856,   311],
        [ 9354,   311,   279,  1274,   382,  2460,   512,  1687,  1440,   956,
            11,   584,  1440,   956,   382,  5451]])

In [102]:
yb[0:3]

tensor([[  280,  3112,   358,   656,  3987,  1695,  2919,   323,  1317,   311,
          1518,   382, 58163,    44,  3895,   512,    46, 28146,    11,  1778,
           264,  2324,    11,   449,  1778,   264,  7555,    11,  1051, 15234,
          4999,  4071],
        [46811,    11, 24613,  2277,   757,   198,  1962,   311,  5622,   279,
         96923,   382, 16041, 52483,   261,   512, 18293,   279, 38736,   304,
         26236,  4059,    11,   323, 48839,  1461,   539,    25,   568,   198,
         41450,  1672],
        [ 4648,    11,   719,  2547,   596,  9120, 16409,   382,  3442,  6903,
           512, 34042,    11,  9120, 16409,     0,   387, 16888,  5092,    11,
          2019,   364, 63007, 99419,  2520, 61087, 52677,   810,  8818,   304,
           813,  1427]])

In [108]:

# Get a batch of training data
xb, yb = get_batch('train', batch_size=batch_size)


# Forward pass and loss calculation
logits, loss = model(xb, yb)

IndexError: index out of range in self

In [98]:

# Forward pass and loss calculation
logits, loss = model(xb, yb)

IndexError: index out of range in self

In [None]:

# Training loop
for iter in range(max_iters):
    # Every once in a while, evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


    # Sample a batch of data
    xb, yb = get_batch('train', batch_size=batch_size)


    # Forward pass, loss calculation, backpropagation, and optimization
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:

# Generate text with different initial contexts
context1 = torch.tensor([encode("thou art kneel before king".split(' '))], dtype=torch.long)
context2 = torch.tensor([encode("Hermione".split(' '))], dtype=torch.long)
context3 = torch.tensor([encode("come".split(' '))], dtype=torch.long)

In [None]:

# Print generated text using different contexts
print(decode(model.generate(context1, max_new_tokes=200)[0].tolist()))
print(decode(model.generate(context2, max_new_tokes=200)[0].tolist()))
print(decode(model.generate(context3, max_new_tokes=200)[0].tolist()))