In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define model hyperparameters
BATCH_SIZE = 32  # Parallel sequences processed
CONTEXT_WINDOW = 64  # Max context length for predictions
EPOCHS = 5000
CHECKPOINT_INTERVAL = 500
LR = 3e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 200
EMBEDDING_DIM = 384
HEADS = 6
LAYERS = 6
DROPOUT_RATE = 0.2

torch.manual_seed(457)

<torch._C.Generator at 0x781c1869aa30>

In [None]:
# Load dataset
with open('stoic.txt', 'r', encoding='utf-8') as file:
    corpus = file.read()

# Character encoding setup
char_list = sorted(set(corpus))
VOCAB_SIZE = len(char_list)
char_to_index = {ch: i for i, ch in enumerate(char_list)}
index_to_char = {i: ch for i, ch in enumerate(char_list)}

encode_text = lambda s: [char_to_index[c] for c in s]
decode_text = lambda l: ''.join([index_to_char[i] for i in l])

# Train-validation split
data_tensor = torch.tensor(encode_text(corpus), dtype=torch.long)
split_idx = int(0.9 * len(data_tensor))  # We'll be training with the first 90% of the data and do validation with the rest
train_data, val_data = data_tensor[:split_idx], data_tensor[split_idx:]

In [None]:
# Function to generate mini-batches
def get_batch(mode):
    dataset = train_data if mode == 'train' else val_data
    idxs = torch.randint(len(dataset) - CONTEXT_WINDOW, (BATCH_SIZE,))
    x_batch = torch.stack([dataset[i:i + CONTEXT_WINDOW] for i in idxs])
    y_batch = torch.stack([dataset[i + 1:i + CONTEXT_WINDOW + 1] for i in idxs])
    return x_batch.to(DEVICE), y_batch.to(DEVICE)

In [None]:
@torch.no_grad()
def compute_loss():
    losses = {}
    model.eval()
    for mode in ['train', 'val']:
        batch_losses = torch.zeros(EVAL_ITERS)
        for i in range(EVAL_ITERS):
            x, y = get_batch(mode)
            _, loss = model(x, y)
            batch_losses[i] = loss.item()
        losses[mode] = batch_losses.mean()
    model.train()
    return losses

In [None]:
class KVCache(nn.Module):
    def __init__(self):
        super().__init__()
        self.key_cache = None
        self.value_cache = None

    def forward(self, x):
        self.key_cache = x.clone()
        self.value_cache = x.clone()
        return x

In [None]:
class RoPE(nn.Module):
    def __init__(self, context_window, embedding_dim):
        super().__init__()
        self.position_embedding_table = nn.Embedding(context_window, embedding_dim)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
        return x + self.position_embedding_table(positions)

In [None]:
# Define attention heads with latent attention
class Head(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.key = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.query = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.value = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(CONTEXT_WINDOW, CONTEXT_WINDOW)))
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        B, T, C = x.shape
        k, q, v = self.key(x), self.query(x), self.value(x)

        attention_scores = (q @ k.transpose(-2, -1)) * k.shape[-1] ** -0.5
        attention_scores = attention_scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        return attention_probs @ v

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, EMBEDDING_DIM)
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [None]:
class Expert(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.ffwd = FeedForward(embedding_dim)

    def forward(self, x):
        return self.ffwd(x)

class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts, embedding_dim):
        super().__init__()
        self.experts = nn.ModuleList([Expert(embedding_dim) for _ in range(num_experts)])

    def forward(self, x):
        expert_outputs = [expert(x) for expert in self.experts]
        return torch.mean(torch.stack(expert_outputs), dim=0)

In [None]:
# Define feed-forward block
class FeedForward(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(DROPOUT_RATE),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# Define the main block incorporating all components
class Block(nn.Module):
    def __init__(self, embedding_dim, heads):
        super().__init__()
        self.mha = MultiHeadAttention(heads, embedding_dim // heads)
        self.ffwd = FeedForward(embedding_dim)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [None]:
# Define the language model
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.position_embedding = RoPE(CONTEXT_WINDOW, EMBEDDING_DIM)
        self.blocks = nn.Sequential(*[Block(EMBEDDING_DIM, HEADS) for _ in range(LAYERS)])
        self.ln_f = nn.LayerNorm(EMBEDDING_DIM)
        self.lm_head = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)
        self.kv_cache = KVCache()  # Add KV cache
        self.mo_experts = MixtureOfExperts(num_experts=2, embedding_dim=EMBEDDING_DIM)  # Add mixture of experts
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding(tok_emb)
        x = pos_emb + self.kv_cache(tok_emb)  # Apply KV cache
        x = self.blocks(x)
        x = self.ln_f(x)
        x = self.mo_experts(x)  # Apply mixture of experts
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -CONTEXT_WINDOW:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
model = LanguageModel()
m = model.to(DEVICE)

print(sum(p.numel() for p in m.parameters()) / 1e6, 'M parameters')

# Use AdamW optimizer for best results
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    if epoch % CHECKPOINT_INTERVAL == 0 or epoch == EPOCHS - 1:
        losses = compute_loss()
        print(f"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

13.098331 M parameters
step 0: train loss 4.5188, val loss 4.5200
step 500: train loss 1.9286, val loss 2.0548
step 1000: train loss 1.6357, val loss 1.8127
step 1500: train loss 1.5044, val loss 1.6995
step 2000: train loss 1.4311, val loss 1.6429
step 2500: train loss 1.3622, val loss 1.5822
step 3000: train loss 1.3384, val loss 1.5664
step 3500: train loss 1.3074, val loss 1.5485
step 4000: train loss 1.2804, val loss 1.5213
step 4500: train loss 1.2604, val loss 1.5143
step 4999: train loss 1.2448, val loss 1.4903


In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
print(decode_text(m.generate(context, max_new_tokens=500)[0].tolist()))


to his stud the world bad so. But in last nall
and she advangratual on the whole of them to imagine to a philosophers than so
in letter plaination, and seemed it set read into his
infitending with with our plying impactive it formed from Ferecultus, Regood Horant they when he don possessions, we
ought not such to any man do who lives, at fully? But because we be lot to happy thus, where is able dogs in from the tosphurden lament? why
did I become explee with paulan, we do not cause among himself


In [None]:
open('sasta2.0_v2_stoic.txt', 'w').write(decode_text(m.generate(context, max_new_tokens=10000)[0].tolist()))

10001