In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-12-27 12:00:23--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-12-27 12:00:24 (79.8 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [None]:
import torch
from torch.nn import functional as F

In [None]:
corpus_length = 10000
with open("input.txt", "r") as file:
  corpus = file.read()[:corpus_length]


In [None]:
EOS = "<eos>"

## data loading

In [None]:
lines = corpus.split('\n')
lines = list(filter(lambda x: len(x) > 0, lines))
examples = []
for line in lines:
  words = line.split()
  examples.extend([" ".join(words[:i] + [EOS]) for i in range(1, len(words) - 1)])

word_index = {word: i for i, word in enumerate(set(" ".join(examples).split()))}
index_word = {i: word for i, word in enumerate(set(" ".join(examples).split()))}
maximum_sequence_length = len(max(examples))
num_words = len(word_index)

## helper functions for conversions

In [None]:
def texts2seqs(texts, word_index, maximum_sequence_length):
  seqs = []
  for text in texts:
    words = text.split()
    seq = [word_index.get(word) for word in words]
    for i, s in enumerate(seq):
      if s is None:
        seq[i] = word_index[EOS]
      else:
        seq[i] = s
    len_seq = len(seq)
    seq = seq + [word_index[EOS]] * (maximum_sequence_length - len_seq)
    seqs.append(seq)
  return torch.tensor(seqs)

def seqs2texts(seqs, index_word):
  texts = []
  for seq in seqs:
    words = []
    for s in seq:
      word = index_word.get(s.item())
      if word is None:
        word = EOS
      if word == EOS:
        words.append(word)
        break
      words.append(word)
    text = ' '.join(words)
    texts.append(text)
  return texts

def seqs2onehot(seqs, num_classes):
  one_hots = []
  for seq in seqs:
    one_hots.append(
            F.one_hot(seq, num_classes=num_classes)
            .type(torch.float32)
            .numpy()
        )
  return torch.tensor(one_hots)

def dense2seqs(dense):
  seqs = []
  B, T, C = dense.shape
  for b in range(B):
    denses = dense[b]
    indices = torch.multinomial(denses, num_samples=1).squeeze(1)
    seqs.append(indices)
  return seqs

## dataset preparation

In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
class MyDataset(Dataset):
  def __init__(self, examples, word_index, maximum_sequence_length):
    self.examples = examples
    self.word_index = word_index
    self.maximum_sequence_length = maximum_sequence_length

  def __len__(self):
    return len(self.examples)

  def __getitem__(self, idx):
    example = self.examples[idx]
    seqs = texts2seqs([example], self.word_index, self.maximum_sequence_length)[0]
    return seqs

dataset = MyDataset(examples, word_index, maximum_sequence_length)
loader = DataLoader(dataset, batch_size=32, shuffle=True)


## model architecture

In [None]:
from torch import nn

In [None]:
class FFN(nn.Module):
  def __init__(self, embedding_dim):
    super().__init__()
    self.layer = nn.Sequential(
        nn.Linear(embedding_dim, embedding_dim),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.layer(x)

class MultiheadAttention(nn.Module):
  def __init__(self, embedding_dim, query_dim, num_heads):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_dim = query_dim
    self.num_heads = num_heads
    self.w_qs = nn.ParameterList([
            nn.Parameter(torch.randn(embedding_dim, query_dim)) for _ in range(num_heads)
        ])
    self.w_ks = nn.ParameterList([
            nn.Parameter(torch.randn(embedding_dim, query_dim)) for _ in range(num_heads)
        ])
    self.w_vs = nn.ParameterList([
            nn.Parameter(torch.randn(embedding_dim, query_dim)) for _ in range(num_heads)
        ])
    self.w = nn.Parameter(torch.randn(num_heads * query_dim, embedding_dim))

  def forward(self, x):
    zs = []
    for i, (w_q, w_k, w_v) in enumerate(zip(self.w_qs, self.w_ks, self.w_vs)):
      q = x @ w_q
      k = x @ w_k
      v = x @ w_v
      z = F.softmax(q@k.transpose(1, 2)/(self.embedding_dim**0.5), dim=1) @ v
      zs.append(z)
    zs = torch.cat(zs, dim=2)
    return zs @ self.w

class TrainablePositionalEncoding(nn.Module):
  def __init__(self, num_positions, embedding_dim):
    super().__init__()
    self.emb = nn.Embedding(num_positions, embedding_dim)

  def forward(self, x):
    return self.emb(x)


class Block(nn.Module):
  def __init__(self, embedding_dim, query_dim, num_heads):
    super().__init__()
    self.multihead_attention = MultiheadAttention(embedding_dim, query_dim, num_heads)
    self.ffn = FFN(embedding_dim)

  def forward(self, x):
    B, T, C = x.shape
    y = self.multihead_attention(x)
    x = F.layer_norm(x + y, normalized_shape=(B, T, C))

    y = self.ffn(x)
    return F.layer_norm(x + y, normalized_shape=(B, T, C))

class Model(nn.Module):
  def __init__(self, num_positions, num_words, embedding_dim, query_dim, num_heads, num_blocks=6):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_dim = query_dim
    self.num_heads = num_heads
    self.positional_encoding = TrainablePositionalEncoding(num_positions, embedding_dim)
    self.embedding = nn.Embedding(num_words, embedding_dim)
    self.blocks = nn.ParameterList([Block(embedding_dim, query_dim, num_heads) for _ in range(num_blocks)])
    self.last_layer = nn.Sequential(
        nn.Linear(embedding_dim, embedding_dim),
        nn.ReLU(),
        nn.Linear(embedding_dim, num_words),
    )

  def forward(self, x):
    B, T = x.shape
    x = self.embedding(x)
    x = x + self.positional_encoding(torch.arange(T, device=x.device))
    for block in self.blocks:
      x = block(x)
    x = self.last_layer(x)
    x = F.softmax(x, dim=2)
    return x

  def summary(self):
    # Count the number of parameters
    num_params = sum(p.numel() for p in self.parameters())
    print(f"Total number of parameters: {num_params}")

    # Count only trainable parameters
    num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_trainable_params}")


embedding_dim = 512
query_dim = 64
num_heads = 6
num_blocks=32

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def infer(model, start_word, num_words):
  model.eval()
  start_word = texts2seqs([start_word], word_index, maximum_sequence_length)[:, :1]
  start_word = start_word.to(device)
  with torch.no_grad():
    for _ in range(num_words):
      logits = model(start_word)
      logits = logits[:, -1, :].unsqueeze(1)
      seqs = dense2seqs(logits)[0]
      start_word = torch.cat([start_word, seqs.unsqueeze(0)], dim=1)
  texts = seqs2texts(start_word, index_word)
  return texts

In [None]:
model = Model(
      maximum_sequence_length,
      num_words, embedding_dim,
      query_dim, num_heads, num_blocks
    ).to(device)



In [None]:
model.train()
num_epochs = 100
lr = 3e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
  total_loss = 0
  for batch, x in enumerate(loader):
    x = x.to(device)
    y = model(x)
    y_target = seqs2onehot(x.cpu(), num_words).to(device)
    optimizer.zero_grad()
    loss = loss_fn(y, y_target)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    if batch % 9 == 0:
      print(f"Epoch: {epoch}, Batch: {batch}, Loss: {total_loss/10}")
      total_loss = 0


In [None]:
infer(model, 'almost', 10)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
torch.save(model.state_dict(), "model_state.pth")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
