In [1]:
import sentencepiece as spm
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import re

In [2]:
with open("chateau.txt", "r", encoding="utf-8") as f:
    txt = f.read()

# remove page numbers
txt = re.sub(r"–\s*\d+\s*–\n", "", txt)

# fix split words
txt = re.sub(r"-\n", "", txt)

# remove line breaks
txt = re.sub(r"\n", " ", txt)

# use a single type of -
txt = re.sub(r"–", "-", txt)

with open("clean.txt", "w", encoding="utf-8") as f:
    f.write(txt)

print(len(set(txt.split())))

15494


In [3]:
with open("clean.txt", "r", encoding="utf-8") as f:
    tmp = f.read()

tmp = re.sub(r"\s-\s", "\n- ", txt)
tmp = re.sub(r"(?<!K)\.\s", ".\n", txt)

with open("tokenizable.txt", "w", encoding="utf-8") as f:
    f.write(tmp)

spm.SentencePieceTrainer.train(
    input = "tokenizable.txt",
    model_prefix = "tok",
    vocab_size = 512,
    model_type = "unigram",
    character_coverage = 1.0,
)

sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokenizable.txt
  input_format: 
  model_prefix: tok
  model_type: UNIGRAM
  vocab_size: 512
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy: 0
  differential_

In [4]:
sp = spm.SentencePieceProcessor(model_file="tok.model")
tokens = sp.encode(txt, out_type=int)

In [5]:
torch.manual_seed(42)

class GPT(nn.Module):
    def __init__(self):
        super().__init__()

        self.vocab_size = len(sp)
        self.attention_window = 64
        self.nheads = 6
        self.d_model = self.nheads * 8
        self.mlp_size = 4 * self.d_model
        self.n_attention_layers = 3

        self.embed = nn.Embedding(self.vocab_size, self.d_model)
        self.pos = nn.Parameter(torch.zeros(1, self.attention_window, self.d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nheads,
            dim_feedforward=self.mlp_size,
            activation="gelu",
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_attention_layers,
        )

        self.out = nn.Linear(self.d_model, self.vocab_size, bias=False)
        self.out.weight = self.embed.weight

    def forward(self, x):
        x = self.embed(x) + self.pos[:, :x.size(1)]
        mask = self._causal_mask(x.size(1), x.device)
        x = self.transformer(x, mask=mask)
        return self.out(x)

    def _causal_mask(self, size, device):
        mask = torch.triu(torch.ones(size, size, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

model = GPT()

In [6]:
model.eval()

for i in range(3):
    start = i * 64
    end = start + 64
    subset = tokens[start:end]
    x = torch.tensor([subset], dtype=torch.long)
    for _ in range(32):
        with torch.no_grad():
            logits = model(x)
        next_id = torch.argmax(logits[0, -1]).item()
        subset.append(next_id)
        x = torch.tensor([subset[-64:]], dtype=torch.long)

    print(f"\n--- window {i+1} ---")
    print(sp.decode(subset))


--- window 1 ---
Il était tard lorsque K. arriva. Une neige épaisse couvrait le village. La colline était cachée par la brume et par la nuit, nul rayon de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de

--- window 2 ---
lumière n’indiquait le grand Château. K. resta longtemps sur le pont de bois qui menait de la grand-route au village, les yeux levés vers ces hauteurs qui sembbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb

--- window 3 ---
laient vides. Puis il alla chercher un gîte ; les gens de l’auberge n’étaient pas encore au lit ; on n’avait pas de chambre à louer, mais, surpris et déconcerrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr


In [7]:
model.train()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)

seq_in, seq_out = 64, 32
stride = 32

inputs, targets = [], []
for i in range(0, len(tokens) - seq_in - seq_out, stride):
    chunk = tokens[i : i + seq_in + seq_out]
    inputs.append(chunk[:seq_in])
    targets.append(chunk[seq_in:])

X = torch.tensor(inputs, dtype=torch.long)
Y = torch.tensor(targets, dtype=torch.long)

dataset = TensorDataset(X, Y)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

criterion = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    total_loss = 0
    for Xb, Yb in loader:
        Xb, Yb = Xb.to(device), Yb.to(device)

        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(
            logits[:, -seq_out:, :].reshape(-1, model.vocab_size),
            Yb.reshape(-1)
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} | loss={total_loss/len(loader):.4f}")

torch.save(model.state_dict(), "gpt.pt")

cpu
Epoch 1 | loss=7.0911
Epoch 2 | loss=5.3308
Epoch 3 | loss=5.2791
Epoch 4 | loss=5.2607
Epoch 5 | loss=5.2532
Epoch 6 | loss=5.2480
Epoch 7 | loss=5.2453
Epoch 8 | loss=5.2420
Epoch 9 | loss=5.2392
Epoch 10 | loss=5.2380


In [8]:
model = GPT()
model.load_state_dict(torch.load("gpt.pt"))
model.eval()

for i in range(3):
    start = i * 64
    end = start + 64
    subset = tokens[start:end]
    x = torch.tensor([subset], dtype=torch.long)
    for _ in range(32):
        with torch.no_grad():
            logits = model(x)
        next_id = torch.argmax(logits[0, -1]).item()
        subset.append(next_id)
        x = torch.tensor([subset[-64:]], dtype=torch.long)

    print(f"\n--- window {i+1} ---")
    print(sp.decode(subset))


--- window 1 ---
Il était tard lorsque K. arriva. Une neige épaisse couvrait le village. La colline était cachée par la brume et par la nuit, nul rayon de                                

--- window 2 ---
lumière n’indiquait le grand Château. K. resta longtemps sur le pont de bois qui menait de la grand-route au village, les yeux levés vers ces hauteurs qui semb                       sssssssss

--- window 3 ---
laient vides. Puis il alla chercher un gîte ; les gens de l’auberge n’étaient pas encore au lit ; on n’avait pas de chambre à louer, mais, surpris et déconcerssssssssssssssssssssssssssssssss
