In [None]:
pip install PyPDF2

Collecting PyPDF2
  Downloading pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Downloading pypdf2-3.0.1-py3-none-any.whl (232 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/232.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.7/232.6 kB[0m [31m2.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: PyPDF2
Successfully installed PyPDF2-3.0.1


In [None]:

# %%
import os
import re
import math
import torch
import random
import pathlib
import itertools
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
import PyPDF2  # for PDF text extraction


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [None]:
def extract_text_from_pdf(pdf_path):
    reader = PyPDF2.PdfReader(pdf_path)
    pages = [page.extract_text() for page in reader.pages]
    return "\n".join(pages)



In [None]:
def clean_text(text):
    # Strip Gutenberg markers
    text = re.sub(r"\*\*\* START.*?\*\*\*", "", text, flags=re.S)
    text = re.sub(r"\*\*\* END.*", "", text, flags=re.S)
    # Keep letters, digits, basic punctuation
    text = re.sub(r"[^\w\s\.,;:'\"\?!-]", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.lower().strip()

In [None]:
def get_stats(vocab):
    pairs = {}
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pair = (symbols[i], symbols[i+1])
            pairs[pair] = pairs.get(pair, 0) + freq
    return pairs


In [None]:

def merge_vocab(pair, vocab):
    v_out = {}
    bigram = ' '.join(pair)
    replacement = ''.join(pair)
    for word, freq in vocab.items():
        w_out = re.sub(rf"(?<=\b){re.escape(bigram)}(?=\b)", replacement, word)
        v_out[w_out] = freq
    return v_out

In [None]:
def learn_bpe(text, num_merges=5000):
    words = text.split()
    vocab = { ' '.join(list(w)) + ' </w>': words.count(w) for w in set(words) }
    merges = []
    for _ in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
        merges.append(best)
    return merges

In [None]:
class BPEEncoder:
    def __init__(self, merges):
        self.merges = merges
        tokens = sorted({c for pair in merges for c in pair}) + ['</w>', '<unk>', '<pad>']
        self.token2id = {tok: i for i, tok in enumerate(tokens)}
        self.id2token = {i: tok for tok, i in self.token2id.items()}

    def encode_word(self, word):
        symbols = list(word) + ['</w>']
        for a, b in self.merges:
            i = 0
            while i < len(symbols)-1:
                if (symbols[i], symbols[i+1]) == (a, b):
                    symbols[i] = a + b
                    del symbols[i+1]
                else:
                    i += 1
        return [self.token2id.get(s, self.token2id['<unk>']) for s in symbols]


In [None]:
class TextDataset(Dataset):
    def __init__(self, ids, seq_len=256, overlap=50):
        self.seq_len, self.overlap = seq_len, overlap
        self.ids = ids
        self.seqs = [torch.tensor(ids[i:i+seq_len], dtype=torch.long)
                     for i in range(0, len(ids)-seq_len, seq_len-overlap)]

    def __len__(self): return len(self.seqs)
    def __getitem__(self, i):
        seq = self.seqs[i]
        return seq[:-1], seq[1:]

In [None]:
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, d_ff=512,
                 num_layers=3, max_len=512):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.enc_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, n_heads, d_ff, activation='relu')
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        b, t = x.shape
        pos = torch.arange(t, device=x.device).unsqueeze(0)
        h = self.tok_emb(x) + self.pos_emb(pos)
        out = h.transpose(0,1)
        for layer in self.enc_layers:
            out = layer(out)
        out = self.norm(out.transpose(0,1))
        return self.head(out)

In [None]:
def train(model, loader, epochs=3, lr=1e-3):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    for ep in range(1, epochs+1):
        model.train()
        total = 0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"Epoch {ep}/{epochs}, Loss: {total/len(loader):.4f}")

In [None]:
from collections import Counter
def retrieve_context(query, text, window=3):
    sents = re.split(r'(?<=[\.\?\!]) ', text)
    q_tokens = set(query.lower().split())
    scores = [len(q_tokens & set(sent.split())) for sent in sents]
    idx = scores.index(max(scores))
    start, end = max(0, idx-window), min(len(sents), idx+window)
    return " ".join(sents[start:end])

In [None]:
def generate_answer(model, encoder, text, question, max_len=50):
    ctx = retrieve_context(question, text)
    prompt = question + ' ' + ctx
    tok_ids = list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in prompt.split()))
    inp = torch.tensor(tok_ids[-256:], device=device).unsqueeze(0)
    model.eval()
    with torch.no_grad(): out = model(inp)
    ids = out.argmax(-1).squeeze().tolist()[:max_len]
    return ''.join(encoder.id2token[i].replace('</w>',' ') for i in ids)

In [None]:
raw_pdf = extract_text_from_pdf('/content/The_Art_Of_War.pdf')
text = clean_text(raw_pdf)


In [None]:
text

"g53 g75 g6e g20 g54 g7a g75 g20 g6f g6e g20 g74 g68 g65 g20 g41 g72 g74 g20 g6f g66 g20 g57 g61 g72 the oldest military treatise in the world g41 g6c g6c g61 g6e g64 g61 g6c g65 g4f g6e g6c g69 g6e g65 g50 g75 g62 g6c g69 g73 g68 g69 g6e g67 g43 g6c g61 g73 g73 g69 g63 g20 g45 g74 g65 g78 g74 g73 g20 g53 g65 g72 g69 g65 g73 g53 g75 g6e g20 g54 g7a g75 g20 g6f g6e g20 g74 g68 g65 g20 g41 g72 g74 g20 g6f g66 g20 g57 g61 g72 the oldest mili tary treatise in the world translated from the chinese by lionel giles, m.a. 1910 published by allandale online publishing 2 park house, 21 st leonards rd, leicester le2 1ws, england published 2000 the classic etexts series is made up of manuscripts available in the public domain, all of which are out of copyright. the series was origi- nated to aid the distribution of these etexts in multiple formats and to highlight the benefits of pdf ebooks. they are freely distributable with no payment required. allandale online publishing ebooks are distributed 

In [None]:

      # 2. Learn BPE merges and initialize encoder
merges = learn_bpe(text)
encoder = BPEEncoder(merges)


In [None]:
token_ids = list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in text.split()))

In [None]:

      # 4. Create DataLoader
dataset = TextDataset(token_ids)
loader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:

      # 5. Initialize and train model
model = MiniTransformer(vocab_size=len(encoder.token2id))
train(model, loader)


Epoch 1/3, Loss: 0.6586
Epoch 2/3, Loss: 0.6040
Epoch 3/3, Loss: 0.6029


In [None]:

      # 6. Generate sample answer
question = "What is the key principle of warfare?"
answer = generate_answer(model, encoder, text, question)
print("Q:", question)
print("A:", answer)

Q: What is the key principle of warfare?
A: <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>


In [None]:
# ## 1. Setup and Imports

# %%
import os
import re
import math
import string
import torch
import random
import itertools
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
import PyPDF2  # for PDF text extraction

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



def extract_text_from_pdf(pdf_path):
    reader = PyPDF2.PdfReader(pdf_path)
    pages = [page.extract_text() or '' for page in reader.pages]
    return "\n".join(pages)

def clean_text(text):
    text = re.sub(r"\*\*\* START.*?\*\*\*", "", text, flags=re.S)
    text = re.sub(r"\*\*\* END.*", "", text, flags=re.S)
    text = re.sub(r"[^\w\s\.,;:'\"\?!-]", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.lower().strip()


# %%
def get_stats(vocab):
    pairs = {}
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pair = (symbols[i], symbols[i+1])
            pairs[pair] = pairs.get(pair, 0) + freq
    return pairs


def merge_vocab(pair, vocab):
    v_out = {}
    bigram = ' '.join(pair)
    replacement = ''.join(pair)
    for word, freq in vocab.items():
        w_out = re.sub(rf"(?<=\b){re.escape(bigram)}(?=\b)", replacement, word)
        v_out[w_out] = freq
    return v_out


def learn_bpe(text, num_merges=5000):
    words = text.split()
    vocab = { ' '.join(list(w)) + ' </w>': words.count(w) for w in set(words) }
    merges = []
    for _ in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
        merges.append(best)
    return merges


class BPEEncoder:
    def __init__(self, merges, base_text):
        self.merges = merges
        # Determine base symbols from text
        chars = sorted(set(base_text))
        tokens = chars + ['</w>', '<unk>', '<pad>']
        self.token2id = {tok: i for i, tok in enumerate(tokens)}
        self.id2token = {i: tok for tok, i in self.token2id.items()}

    def encode_word(self, word):
        symbols = list(word) + ['</w>']
        for a, b in self.merges:
            i = 0
            while i < len(symbols)-1:
                if (symbols[i], symbols[i+1]) == (a, b):
                    symbols[i] = a + b
                    del symbols[i+1]
                else:
                    i += 1
        return [self.token2id.get(s, self.token2id['<unk>']) for s in symbols]


# %%
class TextDataset(Dataset):
    def __init__(self, ids, seq_len=256, overlap=50):
        self.seq_len, self.overlap = seq_len, overlap
        self.ids = ids
        self.seqs = [torch.tensor(ids[i:i+seq_len], dtype=torch.long)
                     for i in range(0, len(ids)-seq_len, seq_len-overlap)]

    def __len__(self): return len(self.seqs)
    def __getitem__(self, i):
        seq = self.seqs[i]
        return seq[:-1], seq[1:]



# %%
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, d_ff=512,
                 num_layers=3, max_len=512):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.enc_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, n_heads, d_ff, activation='relu')
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        b, t = x.shape
        pos = torch.arange(t, device=x.device).unsqueeze(0)
        h = self.tok_emb(x) + self.pos_emb(pos)
        out = h.transpose(0,1)
        for layer in self.enc_layers:
            out = layer(out)
        out = self.norm(out.transpose(0,1))
        return self.head(out)



# %%
def train(model, loader, epochs=3, lr=1e-3):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    for ep in range(1, epochs+1):
        model.train()
        total = 0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"Epoch {ep}/{epochs}, Loss: {total/len(loader):.4f}")



# %%
def retrieve_context(query, text, window=3):
    sents = re.split(r'(?<=[\.\?\!]) ', text)
    q_tokens = set(query.lower().split())
    scores = [len(q_tokens & set(sent.split())) for sent in sents]
    idx = scores.index(max(scores))
    start, end = max(0, idx-window), min(len(sents), idx+window)
    return " ".join(sents[start:end])



# %%
def generate_answer(model, encoder, text, question, max_len=50):
    ctx = retrieve_context(question, text)
    prompt = question + ' ' + ctx
    tok_ids = list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in prompt.split()))
    inp = torch.tensor(tok_ids[-256:], device=device).unsqueeze(0)
    model.eval()
    with torch.no_grad(): out = model(inp)
    ids = out.argmax(-1).squeeze().tolist()[:max_len]
    return ''.join(encoder.id2token[i].replace('</w>',' ') for i in ids)



# %%
if __name__ == '__main__':
    # 1. Load and clean PDF text
    raw_pdf = extract_text_from_pdf('/content/The_Art_Of_War.pdf')
    text = clean_text(raw_pdf)

    # 2. Learn BPE merges and initialize encoder (pass base text)
    merges = learn_bpe(text)
    encoder = BPEEncoder(merges, base_text=text)

    # 3. Encode entire text
    token_ids = list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in text.split()))

    # 4. Create DataLoader
    dataset = TextDataset(token_ids)
    loader = DataLoader(dataset, batch_size=8, shuffle=True)

    # 5. Initialize and train model
    model = MiniTransformer(vocab_size=len(encoder.token2id))
    train(model, loader)

    # 6. Generate sample answer
    question = "What is the key principle of warfare?"
    answer = generate_answer(model, encoder, text, question)
    print("Q:", question)
    print("A:", answer)

# End of Notebook


Using device: cpu
Epoch 1/3, Loss: 2.8080
Epoch 2/3, Loss: 2.4267
Epoch 3/3, Loss: 2.3611
Q: What is the key principle of warfare?
A: d tyn hnen tn th<unk>nhts hnl n <unk>n th<unk>oor hnng  tu ten


In [None]:
# Notebook: Building a Custom Small LLM from Scratch on "The Art of War"
# ==============================================================================
# This notebook implements a mini-Transformer encoder-only model trained on "The Art of War".
# Now includes improved decoding of BPE tokens, sampling-based generation, and training tips.


# %%
import os
import re
import math
import torch
import random
import itertools
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
import PyPDF2

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



# %%

def extract_text_from_pdf(pdf_path):
    reader = PyPDF2.PdfReader(pdf_path)
    pages = [page.extract_text() or '' for page in reader.pages]
    return "\n".join(pages)


def clean_text(text):
    text = re.sub(r"\*\*\* START.*?\*\*\*", "", text, flags=re.S)
    text = re.sub(r"\*\*\* END.*", "", text, flags=re.S)
    text = re.sub(r"[^\w\s\.,;:'\"\?!-]", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.lower().strip()


# %%
def get_stats(vocab):
    pairs = {}
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pair = (symbols[i], symbols[i+1])
            pairs[pair] = pairs.get(pair, 0) + freq
    return pairs


def merge_vocab(pair, vocab):
    v_out = {}
    bigram = ' '.join(pair)
    replacement = ''.join(pair)
    for word, freq in vocab.items():
        w_out = re.sub(rf"(?<=\b){re.escape(bigram)}(?=\b)", replacement, word)
        v_out[w_out] = freq
    return v_out


def learn_bpe(text, num_merges=8000):
    words = text.split()
    vocab = { ' '.join(list(w)) + ' </w>': words.count(w) for w in set(words) }
    merges = []
    for _ in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
        merges.append(best)
    return merges


class BPEEncoder:
    def __init__(self, merges, base_text):
        self.merges = merges
        # create tokens from all unique chars
        chars = sorted(set(base_text))
        tokens = chars + ['</w>', '<unk>', '<pad>']
        self.token2id = {tok: i for i, tok in enumerate(tokens)}
        self.id2token = {i: tok for tok, i in self.token2id.items()}

    def encode_word(self, word):
        symbols = list(word) + ['</w>']
        for a, b in self.merges:
            i = 0
            while i < len(symbols)-1:
                if (symbols[i], symbols[i+1]) == (a, b):
                    symbols[i] = a + b
                    del symbols[i+1]
                else:
                    i += 1
        return [self.token2id.get(s, self.token2id['<unk>']) for s in symbols]

    def decode_ids(self, ids):
        words = []
        cur = ''
        for i in ids:
            tok = self.id2token[i]
            if tok.endswith('</w>'):
                cur += tok.replace('</w>', '')
                words.append(cur)
                cur = ''
            else:
                cur += tok
        return ' '.join(words)


# %%
class TextDataset(Dataset):
    def __init__(self, ids, seq_len=256, overlap=50):
        self.seq_len, self.overlap = seq_len, overlap
        self.seqs = [torch.tensor(ids[i:i+seq_len], dtype=torch.long)
                     for i in range(0, len(ids)-seq_len, seq_len-overlap)]

    def __len__(self): return len(self.seqs)
    def __getitem__(self, i):
        seq = self.seqs[i]
        return seq[:-1], seq[1:]



# %%
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, d_ff=512,
                 num_layers=4, max_len=512):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, n_heads, d_ff, activation='relu')
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        b, t = x.shape
        pos = torch.arange(t, device=x.device).unsqueeze(0)
        h = self.tok_emb(x) + self.pos_emb(pos)
        out = h.transpose(0,1)
        for layer in self.layers:
            out = layer(out)
        out = self.norm(out.transpose(0,1))
        return self.head(out)



# %%
def train(model, loader, epochs=10, lr=1e-3):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    for ep in range(1, epochs+1):
        model.train()
        total = 0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"Epoch {ep}/{epochs}, Loss: {total/len(loader):.4f}")



# %%
def retrieve_context(query, text, window=3):
    sents = re.split(r'(?<=[\.!?]) ', text)
    q_tokens = set(query.lower().split())
    scores = [len(q_tokens & set(sent.split())) for sent in sents]
    idx = scores.index(max(scores))
    start, end = max(0, idx-window), min(len(sents), idx+window)
    return " ".join(sents[start:end])



# %%
def generate_answer(model, encoder, text, question, max_len=50, top_k=5, temperature=1.0):
    model.eval()
    ctx = retrieve_context(question, text)
    prompt = question + ' ' + ctx
    ids = list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in prompt.split()))
    inp = torch.tensor(ids[-256:], device=device).unsqueeze(0)

    generated = []
    with torch.no_grad():
        for _ in range(max_len):
            logits = model(inp)[0, -1, :]
            logits = logits / temperature
            top = torch.topk(logits, top_k)
            probs = torch.softmax(top.values, dim=-1)
            idx = torch.multinomial(probs, 1).item()
            next_id = top.indices[idx].item()
            generated.append(next_id)
            inp = torch.cat([inp, torch.tensor([[next_id]], device=device)], dim=1)

    return encoder.decode_ids(generated)



# %%
if __name__ == '__main__':
    # 1. Load & clean PDF
    raw = extract_text_from_pdf('/content/The_Art_Of_War.pdf')
    text = clean_text(raw)

    # 2. BPE merges & encoder
    merges = learn_bpe(text)
    encoder = BPEEncoder(merges, base_text=text)

    # 3. Encode all text
    ids = list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in text.split()))

    # 4. DataLoader
    ds = TextDataset(ids)
    dl = DataLoader(ds, batch_size=8, shuffle=True)

    # 5. Model & train
    model = MiniTransformer(vocab_size=len(encoder.token2id))
    train(model, dl)

    # 6. Generate with sampling
    question = "What is the key principle of warfare?"
    ans = generate_answer(model, encoder, text, question)
    print("Q:", question)
    print("A:", ans)

# End of Notebook


Using device: cpu
Epoch 1/10, Loss: 2.8302
Epoch 2/10, Loss: 2.4178
Epoch 3/10, Loss: 2.3547
Epoch 4/10, Loss: 2.3284
Epoch 5/10, Loss: 2.3027
Epoch 6/10, Loss: 2.2814
Epoch 7/10, Loss: 2.2611
Epoch 8/10, Loss: 2.2360
Epoch 9/10, Loss: 2.2088
Epoch 10/10, Loss: 2.1750
Q: What is the key principle of warfare?
A: o tto atth<unk>p whalith<unk>tho when c<unk>my wh<unk>was h<unk>m<unk>s


In [None]:
# Notebook: Building a Custom Small LLM from Scratch on "The Art of War"
# ==============================================================================
# Enhanced mini-Transformer encoder-only model with training best practices.



# %%
import os
import re
import torch
import random
import itertools
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import PyPDF2
from transformers import get_linear_schedule_with_warmup

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



# %%
def extract_text_from_pdf(pdf_path):
    reader = PyPDF2.PdfReader(pdf_path)
    pages = [page.extract_text() or '' for page in reader.pages]
    return "\n".join(pages)

def clean_text(text):
    text = re.sub(r"\*\*\* START.*?\*\*\*", "", text, flags=re.S)
    text = re.sub(r"\*\*\* END.*", "", text, flags=re.S)
    text = re.sub(r"[^\w\s\.,;:'\"\?!-]", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.lower().strip()



# %%
def learn_bpe(text, num_merges=8000):
    words = text.split()
    vocab = { ' '.join(list(w)) + ' </w>': words.count(w) for w in set(words) }
    merges = []
    for _ in range(num_merges):
        pairs = {}
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                pair = (symbols[i], symbols[i+1])
                pairs[pair] = pairs.get(pair, 0) + freq
        if not pairs: break
        best = max(pairs, key=pairs.get)
        merges.append(best)
        bigram = ' '.join(best)
        replacement = ''.join(best)
        vocab = {re.sub(rf"(?<=\b){re.escape(bigram)}(?=\b)", replacement, w): f for w,f in vocab.items()}
    return merges

class BPEEncoder:
    def __init__(self, merges, text):
        self.merges = merges
        chars = sorted(set(text))
        tokens = chars + ['</w>', '<unk>', '<pad>']
        self.token2id = {tok:i for i,tok in enumerate(tokens)}
        self.id2token = {i:tok for tok,i in self.token2id.items()}

    def encode_word(self, word):
        symbols = list(word) + ['</w>']
        for a,b in self.merges:
            i=0
            while i < len(symbols)-1:
                if (symbols[i],symbols[i+1])==(a,b): symbols[i]=a+b; del symbols[i+1]
                else: i+=1
        return [self.token2id.get(s,self.token2id['<unk>']) for s in symbols]

    def decode_ids(self, ids):
        words=[]; cur=''
        for idx in ids:
            tok=self.id2token[idx]
            if tok.endswith('</w>'):
                cur+=tok[:-4]; words.append(cur); cur=''
            else: cur+=tok
        return ' '.join(words)



# %%
class TextDataset(Dataset):
    def __init__(self, ids, seq_len=256, overlap=50):
        self.seq_len, self.overlap = seq_len, overlap
        self.ids = ids
        self.samples = []
        step = seq_len - overlap
        for i in range(0, len(ids)-seq_len, step):
            chunk = torch.tensor(ids[i:i+seq_len], dtype=torch.long)
            self.samples.append((chunk[:-1], chunk[1:]))
    def __len__(self): return len(self.samples)
    def __getitem__(self,i): return self.samples[i]



# %%
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, d_ff=1024,
                 num_layers=6, max_len=512, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff,
                                                   dropout=dropout, activation='gelu')
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
        # weight tie
        self.head.weight = self.tok_emb.weight

    def forward(self, x):
        b,t = x.shape
        pos = torch.arange(t, device=x.device).unsqueeze(0)
        h = self.tok_emb(x) + self.pos_emb(pos)
        out = self.encoder(h.transpose(0,1)).transpose(0,1)
        out = self.norm(out)
        return self.head(out)



# %%
def train(model, train_loader, val_loader=None, epochs=20, lr=5e-4,
          warmup_steps=1000, max_grad_norm=1.0):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    total_steps = epochs * len(train_loader)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                    num_warmup_steps=warmup_steps,
                    num_training_steps=total_steps)
    loss_fn = nn.CrossEntropyLoss()
    for ep in range(1, epochs+1):
        model.train(); train_loss=0
        for xb,yb in train_loader:
            xb,yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = loss_fn(logits.view(-1,logits.size(-1)), yb.view(-1))
            optimizer.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step(); scheduler.step()
            train_loss+=loss.item()
        avg_train = train_loss/len(train_loader)
        print(f"Epoch {ep}: Train Loss {avg_train:.4f}")
        if val_loader:
            model.eval(); val_loss=0
            with torch.no_grad():
                for xb,yb in val_loader:
                    xb,yb=xb.to(device), yb.to(device)
                    val_loss+= loss_fn(model(xb).view(-1,model(xb).size(-1)), yb.view(-1)).item()
            print(f"       Val   Loss {val_loss/len(val_loader):.4f}")



# %%
def retrieve_context(query, text, window=5):
    sents=re.split(r'(?<=[\.!?]) ', text)
    q_tokens=set(query.lower().split())
    scores=[len(q_tokens & set(s.split())) for s in sents]
    idx=max(range(len(scores)), key=lambda i: scores[i])
    start, end = max(0, idx-window), min(len(sents), idx+window)
    return ' '.join(sents[start:end])



# %%
def generate_answer(model, encoder, text, question, max_len=100,
                    top_k=10, temperature=0.7):
    model.eval()
    ctx=retrieve_context(question,text)
    prompt=question + ' ' + ctx
    idxs=list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in prompt.split()))
    inp=torch.tensor(idxs[-512:],device=device).unsqueeze(0)
    generated=[]
    with torch.no_grad():
        for _ in range(max_len):
            logits=model(inp)[0,-1]/temperature
            top=logits.topk(top_k)
            probs=torch.softmax(top.values,dim=-1)
            choice=torch.multinomial(probs,1).item()
            nxt=top.indices[choice].item(); generated.append(nxt)
            inp=torch.cat([inp, torch.tensor([[nxt]],device=device)],1)
    return encoder.decode_ids(generated)



# %%
if __name__=='__main__':
    raw=extract_text_from_pdf('/content/The_Art_Of_War.pdf')
    text=clean_text(raw)
    merges=learn_bpe(text);
    encoder=BPEEncoder(merges,text)
    all_ids=list(itertools.chain.from_iterable(
        encoder.encode_word(w) for w in text.split()))
    ds=TextDataset(all_ids)
    train_sz=int(0.9*len(ds)); val_sz=len(ds)-train_sz
    train_ds, val_ds=random_split(ds, [train_sz,val_sz])
    train_dl=DataLoader(train_ds,batch_size=16,shuffle=True)
    val_dl=DataLoader(val_ds,batch_size=16)
    model=MiniTransformer(vocab_size=len(encoder.token2id))
    train(model, train_dl, val_loader=val_dl)
    q="What is the key principle of warfare?"
    print(generate_answer(model,encoder,text,q))

# End of Notebook


Using device: cpu




Epoch 1: Train Loss 77.8847
       Val   Loss 69.2789
Epoch 2: Train Loss 50.7393
       Val   Loss 32.3044
Epoch 3: Train Loss 19.6537
       Val   Loss 8.8584
Epoch 4: Train Loss 7.3926
       Val   Loss 4.0256
Epoch 5: Train Loss 6.0670
       Val   Loss 3.4120
Epoch 6: Train Loss 5.5494
       Val   Loss 3.3500
Epoch 7: Train Loss 5.0814
       Val   Loss 3.2206
Epoch 8: Train Loss 4.6374
       Val   Loss 3.0999
Epoch 9: Train Loss 4.2815
       Val   Loss 3.1774
Epoch 10: Train Loss 4.0213
       Val   Loss 2.9618
Epoch 11: Train Loss 3.8307
       Val   Loss 2.7932
Epoch 12: Train Loss 3.6891
       Val   Loss 2.7211
Epoch 13: Train Loss 3.5652
       Val   Loss 2.7336
Epoch 14: Train Loss 3.4466
       Val   Loss 2.7476
Epoch 15: Train Loss 3.3568
       Val   Loss 2.6655
Epoch 16: Train Loss 3.2705
       Val   Loss 2.7005
Epoch 17: Train Loss 3.2236
       Val   Loss 2.6757
Epoch 18: Train Loss 3.1312
       Val   Loss 2.5926
Epoch 19: Train Loss 3.0712
       Val   Loss 2.65

IndexError: index out of range in self