In [2]:
import sentencepiece as spm
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset, load_from_disk
import re

# Model

In [5]:
torch.manual_seed(42)

class GPT(nn.Module):
    def __init__(self):
        super().__init__()

        self.vocab_size = 2048
        self.attention_window = 256
        self.nheads = 8
        self.d_model = self.nheads * 16
        self.mlp_size = 4 * self.d_model
        self.n_attention_layers = 5
        self.dropout = 0.1
        self.activation = "gelu"

        self.embed = nn.Embedding(self.vocab_size, self.d_model)
        self.pos = nn.Parameter(torch.zeros(1, self.attention_window, self.d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nheads,
            dim_feedforward=self.mlp_size,
            activation= self.activation,
            dropout=self.dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_attention_layers,
            norm=nn.LayerNorm(self.d_model),
        )

        self.out = nn.Linear(self.d_model, self.vocab_size, bias=False)
        self.out.weight = self.embed.weight
        
    def forward(self, x):
        x = self.embed(x) + self.pos[:, :x.size(1)]
        mask = self._causal_mask(x.size(1), x.device)
        x = self.transformer(x, mask=mask)
        return self.out(x)

    def _causal_mask(self, size, device):
        mask = torch.triu(torch.ones(size, size, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

model = GPT()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters : {total_params}")

cpu
Total number of parameters : 1286528


# Heavy tuning

## Load the dataset

In [4]:
gutenberg = load_dataset("manu/project_gutenberg", split="fr", streaming=False)
gutenberg.save_to_disk("gutenberg")

Downloading data: 100%|██████████| 52/52 [28:41<00:00, 33.11s/files]
Generating de split: 100%|██████████| 3131/3131 [00:08<00:00, 380.52 examples/s]
Generating en split: 100%|██████████| 61340/61340 [01:57<00:00, 523.31 examples/s]
Generating es split: 100%|██████████| 1202/1202 [00:02<00:00, 451.13 examples/s]
Generating fr split: 100%|██████████| 5493/5493 [00:10<00:00, 526.31 examples/s]
Generating it split: 100%|██████████| 1008/1008 [00:01<00:00, 640.86 examples/s]
Generating nl split: 100%|██████████| 1420/1420 [00:15<00:00, 93.87 examples/s] 
Generating pl split: 100%|██████████| 34/34 [00:00<00:00, 2181.33 examples/s]
Generating pt split: 100%|██████████| 1111/1111 [00:00<00:00, 1521.69 examples/s]
Generating ru split: 100%|██████████| 6/6 [00:00<00:00, 1316.62 examples/s]
Generating sv split: 100%|██████████| 388/388 [00:00<00:00, 921.52 examples/s]
Generating zh split: 100%|██████████| 437/437 [00:00<00:00, 470.27 examples/s]
Saving the dataset (5/5 shards): 100%|██████████|

## tokenize

In [7]:
ds = load_from_disk("gutenberg")
with open("gutenberg.txt", "w", encoding="utf-8") as f:
    for ex in ds:
        txt = ex.get("text") or ex.get("content") or ""
        f.write(txt + "\n")

In [32]:
spm.SentencePieceTrainer.train(
    input="gutenberg.txt",
    model_prefix="tok",
    vocab_size=model.vocab_size,
    model_type="unigram",
    character_coverage=0.999,
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
    input_sentence_size=2_000_000,
    train_extremely_large_corpus=True,
    shuffle_input_sentence = True,
)

sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: gutenberg.txt
  input_format: 
  model_prefix: tok
  model_type: UNIGRAM
  vocab_size: 2048
  self_test_sample_size: 0
  character_coverage: 0.999
  input_sentence_size: 2000000
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 1
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy: 0
  diffe

## train

In [6]:
sp = spm.SentencePieceProcessor(model_file="tok.model")

In [None]:
seq_in, seq_out = model.attention_window, model.attention_window // 2
stride = seq_out
criterion = torch.nn.CrossEntropyLoss()

optimizer = AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=0,
)

log_path = "gutenberg.log"
with open(log_path, "w", encoding="utf-8") as log:
    log.write("epoch,step,loss\n")

def stream_batches(path, sp, seq_in, seq_out, batch_size=16):
    buf = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            ids = sp.encode(line.strip(), out_type=int)
            buf.extend(ids + [sp.eos_id()])
            while len(buf) >= seq_in + seq_out:
                chunk = buf[: seq_in + seq_out]
                buf = buf[seq_out:]
                X = torch.tensor(chunk[:seq_in], dtype=torch.long)
                Y = torch.tensor(chunk[seq_in:], dtype=torch.long)
                yield X, Y

for epoch in range(1):
    model.train()
    total_loss = 0
    step = 0
    batch_X, batch_Y = [], []

    for X, Y in stream_batches("gutenberg.txt", sp, seq_in, seq_out):
        batch_X.append(X)
        batch_Y.append(Y)
        if len(batch_X) == 16:
            Xb = torch.stack(batch_X).to(device)
            Yb = torch.stack(batch_Y).to(device)
            batch_X.clear()
            batch_Y.clear()
            step += 1

            optimizer.zero_grad()
            logits = model(Xb)
            loss = criterion(
                logits[:, -seq_out:, :].reshape(-1, model.vocab_size),
                Yb.reshape(-1)
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if step % 100 == 0:
                avg_loss = total_loss / step
                with open(log_path, "a", encoding="utf-8") as log:
                    log.write(f"{epoch+1},{step},{avg_loss:.10f}\n")
                    
            if step >= 10_000 and optimizer.param_groups[0]["lr"] > 1e-4:
                for g in optimizer.param_groups:
                    g["lr"] = 1e-4
                
torch.save(model.state_dict(), f"gpt_gutenberg.pt")



In [None]:
model.eval()
temperature = 0.8

starters = []
for X, _ in stream_batches("gutenberg.txt", sp, seq_in, seq_out):
    starters.append(X.tolist())
    if len(starters) >= 3:
        break

for i, subset in enumerate(starters):
    subset = subset.copy()
    x = torch.tensor([subset], dtype=torch.long, device=device)
    for _ in range(200):
        x = torch.tensor([subset[-model.attention_window:]], dtype=torch.long, device=device)
        with torch.no_grad():
            logits = model(x)
        probs = torch.softmax(logits[0, -1] / temperature, dim=-1)
        next_id = torch.multinomial(probs, 1).item()
        subset.append(next_id)
    print(f"\n[SAMPLE {i+1}]\n", sp.decode(subset), "\n")


# Fine tuning

## Prepare the dataset

In [None]:
with open("chateau.txt", "r", encoding="utf-8") as f:
    txt = f.read()

# remove page numbers
txt = re.sub(r"–\s*\d+\s*–\n", "", txt)

# fix split words
txt = re.sub(r"-\n", "", txt)

# remove line breaks
txt = re.sub(r"\n", " ", txt)

# use a single type of -
txt = re.sub(r"–", "-", txt)

with open("clean.txt", "w", encoding="utf-8") as f:
    f.write(txt)

print(len(set(txt.split())))

In [None]:
sp = spm.SentencePieceProcessor(model_file="tok.model")
tokens = sp.encode(txt, out_type=int)

In [None]:
seq_in, seq_out = model.attention_window, model.attention_window // 2
stride = seq_out

inputs, targets = [], []
for i in range(0, len(tokens) - seq_in - seq_out, stride):
    chunk = tokens[i : i + seq_in + seq_out]
    inputs.append(chunk[:seq_in])
    targets.append(chunk[seq_in:])

X = torch.tensor(inputs, dtype=torch.long)
Y = torch.tensor(targets, dtype=torch.long)

dataset = TensorDataset(X, Y)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

criterion = torch.nn.CrossEntropyLoss()
optimizer = AdamW(
    model.parameters(),
    lr=1e-5,
    weight_decay=0,
)

for epoch in range(10):
    model.train()
    total_loss = 0
    for Xb, Yb in loader:
        Xb, Yb = Xb.to(device), Yb.to(device)

        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(
            logits[:, -seq_out:, :].reshape(-1, model.vocab_size),
            Yb.reshape(-1)
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if(epoch + 1 < 10):
        print(f"Epoch {epoch+1} | loss={total_loss/len(loader):.5f}")
        
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1} | loss={total_loss/len(loader):.5f}")
        model.eval()
        subset = tokens[:model.attention_window]
        x = torch.tensor([subset], dtype=torch.long)
        for _ in range(200):
            x = torch.tensor([subset[-model.attention_window:]], dtype=torch.long, device=device)
            with torch.no_grad():
                logits = model(x)
            probs = torch.softmax(logits[0, -1] / temperature, dim=-1)
            next_id = torch.multinomial(probs, 1).item()
            subset.append(next_id)

    if (epoch + 1) % 100 == 0:
        print("\n[SAMPLE]", sp.decode(subset), "\n")
        torch.save(model.state_dict(), f"gpt_epoch_{epoch+1}.pt")

In [None]:
model.eval()
temperature = 0.8

for i in range(3):
    start = i * model.attention_window
    end = start + model.attention_window
    subset = tokens[start:end]
    x = torch.tensor([subset], dtype=torch.long)
    for _ in range(200):
        x = torch.tensor([subset[-model.attention_window:]], dtype=torch.long, device=device)
        with torch.no_grad():
            logits = model(x)
        probs = torch.softmax(logits[0, -1] / temperature, dim=-1)
        next_id = torch.multinomial(probs, 1).item()
        subset.append(next_id)
    print("\n[SAMPLE]", sp.decode(subset), "\n") 