In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as Optim
import random

In [5]:
# import text & shuffle set
with open("./assets/names.txt", mode="r", encoding="utf-8") as file:
    names = file.readlines()
# shuffle it
random.seed(42)
random.shuffle(names)
# dataset size
print(len(names))
print(names[:10])

62262
['Rieder Berg\n', 'Alttiefenweg\n', 'Goßmannsdorf\n', 'Gemeindebühl\n', 'Mader\n', 'Kroissenhof\n', 'Schlappenreuth\n', 'Obermitterdorf\n', 'Ullading\n', 'Großköllnbach\n']


In [6]:
# setup vocabulary
all_chars = list(sorted(set([("".join(char)) for name in names for char in name])))
print(all_chars)
vocab_size = len(all_chars)
print(vocab_size)

['\n', ' ', '-', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'Ä', 'Ö', 'Ü', 'ß', 'ä', 'ö', 'ü']
61


In [18]:
# hyperparameters
context_len = 64
n_embd = 256
n_head = 8
n_layer = 8
batch_size = 64
learning_rate = 3e-4
train_iter = 8000
eval_iter = 150
eval_interval = 1000
dropout = 0.2
# configured to port data to mac gpu if available
device = "mps" if torch.backends.mps.is_available() else "cpu"
#weight_decay = 0.001

In [8]:
# vocabulary mapping dicts
itos = {i:s for i, s in enumerate(all_chars)}
stoi = {s:i for i, s in itos.items()}
#print(itos)
#print(stoi)
# voc encoding / decoding functions
encode = lambda input: [stoi[i] for i in input]
decode = lambda input: "".join([itos[i] for i in input])
#print(encode(names[0]))
#print(decode(encode(names[0])))

In [9]:
# convert names list to data: concat text, encode it, tensor it
data = torch.tensor(encode("".join(names)), dtype=torch.long)
# split data into train / dev / test with 0.8 / 0.1 / 0.1
border_1 = int(0.8 * len(data))
border_2 = int(0.9 * len(data))
train_split = data[:border_1]
dev_split = data[border_1:border_2]
test_split = data[border_2:]
print(len(train_split), len(dev_split), len(test_split))

540104 67513 67513


In [10]:
# data loading: deliver batches of X, Y tensors for chosen split
torch.manual_seed(42)
def get_batch(split):
    """ delivers a batch of X, Y tensors for specified split"""
    # get random numbers (in amount of "batch_size") within split boundaries to grab data for the batch samples
    batch_borders = torch.randint(0, len(split)-context_len, (batch_size,))
    x = torch.stack([split[t : t+context_len] for t in batch_borders])
    y = torch.stack([split[t+1 : t+context_len+1] for t in batch_borders])
    return x, y
x, y = get_batch(train_split)
print(x.shape, y.shape)

torch.Size([64, 64]) torch.Size([64, 64])


In [12]:
# validate loss function outsite backprop; called from training function after defined training steps
@torch.no_grad()
def check_loss():
    m.eval()
    out = {}
    # calc train & dev loss as averages after defined eval steps
    for split in [train_split, dev_split]:
        losses = torch.zeros(eval_iter)
        # calc loss for every batch and save result into tensor
        for i in range(eval_iter):
            x, y = get_batch(split)
            x, y = x.to(device), y.to(device)
            _, loss = m(x, y)
            losses[i] = loss.item()
        out[split] = losses.mean() 
    m.train()
    return out

In [13]:
# single self-attention head; called from multi-head-attention class
class Head(nn.Module):
    
    def __init__(self, h_size):
        super().__init__()
        self.query = nn.Linear(in_features=n_embd, out_features=h_size, bias=False)
        self.key = nn.Linear(in_features=n_embd, out_features=h_size, bias=False)
        self.value = nn.Linear(in_features=n_embd, out_features=h_size, bias=False)
        # helper matrix for triangular masking; pre-registered as full-size buffer for performance; all zero values above the diagonal
        self.register_buffer("tril", torch.tril(torch.ones(context_len, context_len)))
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        # B, T, H
        q = self.query(x)
        k = self.key(x)
        # B, T, T
        wei = q @ torch.transpose(k, dim0=-1, dim1=-2) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        wei = self.drop(wei)
        # B, T, H
        v = self.value(x)
        out = wei @ v
        return out

In [14]:
# multiple heads of self-attention in parallel
class MultiHeadAttention(nn.Module):
    
    def __init__(self, n_head, head_size):
        super().__init__()
        self.heads = nn.ModuleList( Head(head_size) for _ in range(n_head))
        # linear projection layer to blend all cat head outputs
        self.proj = nn.Linear(n_embd, n_embd)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        # cat / stack each head's out_features along last dim to total of n_embd out_features
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.drop(self.proj(out))
        return out

In [15]:
# mlp layer with relu; widened first linear layer
class Ffw(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4, bias=None),
            nn.ReLU(),
            nn.Linear(n_embd * 4, n_embd, bias=None),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.layers(x)

In [16]:
# transformer block: communication in multi-head-attention, then computation in ffw layers
class TransformerBlock(nn.Module):
    
    def __init__(self):
        super().__init__()
        head_size = n_embd // n_head
        self.multi_head_sa = MultiHeadAttention(n_head, head_size)
        self.ffw = Ffw()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
    
    def forward(self, x):
        x = x + self.multi_head_sa((self.ln1(x)))
        x = x + self.ffw(self.ln2(x))
        return x

In [17]:
# Core GPT logic setting up NN
class GPT(nn.Module):
    
    def __init__(self):
        super().__init__()
        # embeddings
        self.tok_embeddings = nn.Embedding(vocab_size, n_embd)
        self.pos_embeddings = nn.Embedding(context_len, n_embd)
        # transformer blocks of amount n_layer
        self.t_blocks = nn.Sequential(*[TransformerBlock() for _ in range(n_layer)])
        # output layer
        self.lm_head = nn.Linear(n_embd, vocab_size)
 
    def forward(self, idx, targets=None):
        B, T = idx.shape
        # x comes as B, T; token embeddings; B,T,C
        tok_emb = self.tok_embeddings(idx)
        # creates 1D-tensor with values from 0 - context_len; T
        pos_idx = torch.arange(0, T, device=device)
        # position embeddings; T, C
        pos_emb = self.pos_embeddings(pos_idx)
        # combined emds for token + pos; B, T, C
        emb = tok_emb + pos_emb
        # hidden layers & logits
        h = self.t_blocks(emb)
        logits = self.lm_head(h)
        # calc loss if targets are available; otherwise set loss to None for sampling
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # flatten logits into B*T, C
            logits = logits.view(B*T, C)
            # flatten targets into B*T
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    # generate names of tbd amount; name ends at first line break char
    def generate(self, amount_names):
        out = []
        for _ in range(amount_names):
            name = []
            # start always with 0 context for linebreak as first char; forward pass expects shape of (1, 1) to work
            context = torch.zeros((1, 1), dtype=torch.long)
            context = context.to(device)
            while True:
                # context must not be greater than context_len, otherwise mat mul in forward pass does not work; cut max latest context
                context_cut = context[:, -context_len:]
                logits, _ = self(context_cut)
                # grab logits at last timestep
                logits = logits[:, -1, :]
                logits = F.softmax(logits, dim=-1)
                idx = torch.multinomial(logits, num_samples=1, replacement=True).item()
                name.append(itos[idx])
                # end name gen when first linebreak is sampled
                if idx == 0:
                    break
                else:
                    # as long as no linebreak is hit, add last idx to context and sample next char for name
                    context = torch.cat((context, torch.tensor([[idx]], dtype=torch.long, device=device)), dim=1)
            out.append("".join(name))
        return out

In [124]:
# init model, port to gpu, init optimizer, print model params
model = GPT()
m = model.to(device)
optimizer = Optim.Adam(m.parameters(), lr=learning_rate)
parameters = m.parameters()
print(sum(p.nelement() for p in parameters))

6349373


In [125]:
# train model over defined train steps
def train_model():

    for i in range(train_iter):
    
        # eval loss & print after certain amount of train steps
        if i % eval_interval == 0:
            losses = check_loss()
            print(f"loss after {i} iterations: train_loss {losses[train_split]}; eval_loss {losses[dev_split]}")
        
        # forward pass
        Xtr, Ytr = get_batch(train_split)
        Xtr, Ytr = Xtr.to(device), Ytr.to(device)
        _, loss = m(Xtr, Ytr)

        # backward pass
        optimizer.zero_grad()
        loss.backward()

        # update params
        optimizer.step()

train_model()

loss after 0 iterations: train_loss 4.603909492492676; eval_loss 4.60637903213501
loss after 1000 iterations: train_loss 1.7214868068695068; eval_loss 1.741114854812622
loss after 2000 iterations: train_loss 1.587981104850769; eval_loss 1.6238888502120972
loss after 3000 iterations: train_loss 1.5157275199890137; eval_loss 1.577518105506897
loss after 4000 iterations: train_loss 1.4594924449920654; eval_loss 1.5488260984420776
loss after 5000 iterations: train_loss 1.4075320959091187; eval_loss 1.5284029245376587
loss after 6000 iterations: train_loss 1.3645418882369995; eval_loss 1.5264160633087158
loss after 7000 iterations: train_loss 1.3213069438934326; eval_loss 1.5272397994995117


In [126]:
# sample from model with amount names
m.generate(50)

['Lindgraben\n',
 'Frankendorf\n',
 'Rißtaubieren\n',
 'Felkendorf\n',
 'Oberstaufen\n',
 'Scharloch\n',
 'Stumpfental\n',
 'Schwalbmühle\n',
 'Wiesenhausen\n',
 'Rothenbuch\n',
 'Atter\n',
 'Westenerhof\n',
 'Einzienla\n',
 'Pötzling\n',
 'Gstrieß\n',
 'Hummelsried\n',
 'Reifling\n',
 'Obernöringen\n',
 'Branntännl\n',
 'Krösel\n',
 'Bärenlingen\n',
 'Breitel\n',
 'Siegskofen\n',
 'Mainauf\n',
 'Bihelhaeden\n',
 'Buxach\n',
 'Wineder\n',
 'Thierau\n',
 'Hofdorf\n',
 'Ruhmannsdorf\n',
 'Scheckenbach\n',
 'Hollberg\n',
 'Reicher Spiedler\n',
 'Wartenhofen\n',
 'Krietzendorf\n',
 'Langenprechting\n',
 'Gstorfing\n',
 'Bugendorf\n',
 'Grimmsried\n',
 'Horrheim\n',
 'Nindicherles\n',
 'Ablom\n',
 'Alte Leonitz\n',
 'Unternore Wald\n',
 'Pliener\n',
 'Lippertsgrün\n',
 'Krippenstaller\n',
 'Querenhofen\n',
 'Kötzwieser\n',
 'Vogelschlag\n']

# RUN 1: 1.2M params; 7,3 min; train_loss 1.5999797582626343; eval_loss 1.6328703165054321
context_len = 64
n_embd = 128
n_head = 4
n_layer = 6
batch_size = 64
learning_rate = 3e-4
train_iter = 5000
eval_iter = 150
eval_interval = 500
dropout = 0.2
"loss after 0 iterations: train_loss 4.494983673095703; eval_loss 4.498855113983154
loss after 500 iterations: train_loss 2.1576623916625977; eval_loss 2.1638336181640625
loss after 1000 iterations: train_loss 1.8927803039550781; eval_loss 1.8983594179153442
loss after 1500 iterations: train_loss 1.794490933418274; eval_loss 1.806968331336975
loss after 2000 iterations: train_loss 1.7333722114562988; eval_loss 1.7506181001663208
loss after 2500 iterations: train_loss 1.6985934972763062; eval_loss 1.7193843126296997
loss after 3000 iterations: train_loss 1.6637684106826782; eval_loss 1.6839677095413208
loss after 3500 iterations: train_loss 1.6381772756576538; eval_loss 1.6658086776733398
loss after 4000 iterations: train_loss 1.6181026697158813; eval_loss 1.6488206386566162
loss after 4500 iterations: train_loss 1.5999797582626343; eval_loss 1.6328703165054321"
['Oberkirchen\n',
 'Riedenog\n',
 'Schömatsbach\n',
 'Löhendorf\n',
 'Zauchberg\n',
 'Hahnath\n',
 'Haid\n',
 'Hintsterheim\n',
 'Kreuzwinden\n',
 'Degrüblach\n']


---
# RUN 2: 6.34M params; canceled after 31 min due to overfitting
context_len = 64
n_embd = 256
n_head = 8
n_layer = 8
batch_size = 64
learning_rate = 3e-4
train_iter = 20000
eval_iter = 150
eval_interval = 1000
dropout = 0.2

loss after 0 iterations: train_loss 4.703804969787598; eval_loss 4.707516193389893
loss after 1000 iterations: train_loss 1.7233854532241821; eval_loss 1.7396659851074219
loss after 2000 iterations: train_loss 1.5916658639907837; eval_loss 1.6276907920837402
loss after 3000 iterations: train_loss 1.5190118551254272; eval_loss 1.5804953575134277
loss after 4000 iterations: train_loss 1.4579596519470215; eval_loss 1.5506863594055176
loss after 5000 iterations: train_loss 1.4088348150253296; eval_loss 1.540738582611084
loss after 6000 iterations: train_loss 1.359342336654663; eval_loss 1.5292009115219116
loss after 7000 iterations: train_loss 1.3172597885131836; eval_loss 1.527799129486084
loss after 8000 iterations: train_loss 1.2786493301391602; eval_loss 1.5266221761703491
loss after 9000 iterations: train_loss 1.2427423000335693; eval_loss 1.546021580696106
loss after 10000 iterations: train_loss 1.2063122987747192; eval_loss 1.563255786895752

---
# RUN 3: 14.2M params; canceled after 33 min due to overfitting
context_len = 64
n_embd = 384
n_head = 6
n_layer = 8
batch_size = 64
learning_rate = 3e-4
train_iter = 10000
eval_iter = 150
eval_interval = 1000
dropout = 0.2

loss after 0 iterations: train_loss 4.475872039794922; eval_loss 4.478583335876465
loss after 1000 iterations: train_loss 1.6488356590270996; eval_loss 1.6782140731811523
loss after 2000 iterations: train_loss 1.5181366205215454; eval_loss 1.580893874168396
loss after 3000 iterations: train_loss 1.4300034046173096; eval_loss 1.5446563959121704
loss after 4000 iterations: train_loss 1.358554482460022; eval_loss 1.53664231300354
loss after 5000 iterations: train_loss 1.2867943048477173; eval_loss 1.5338422060012817
loss after 6000 iterations: train_loss 1.2154250144958496; eval_loss 1.5534075498580933
loss after 7000 iterations: train_loss 1.1497846841812134; eval_loss 1.5951995849609375

---
# RUN 4: 2.7m; canceled after 10 min due to learning rate stalling
context_len = 32
n_embd = 192
n_head = 6
n_layer = 6
batch_size = 32
learning_rate = 1e-4
train_iter = 8000
eval_iter = 150
eval_interval = 1000
dropout = 0.3
weight_decay = 0.01

loss after 0 iterations: train_loss 4.531078338623047; eval_loss 4.526917457580566
loss after 1000 iterations: train_loss 2.3952038288116455; eval_loss 2.3950064182281494
loss after 2000 iterations: train_loss 2.3764455318450928; eval_loss 2.3757312297821045
loss after 3000 iterations: train_loss 2.367504596710205; eval_loss 2.3657174110412598
loss after 4000 iterations: train_loss 2.365919351577759; eval_loss 2.3662781715393066
loss after 5000 iterations: train_loss 2.369229555130005; eval_loss 2.365706205368042

# RUN 5: 4.7m; canceled after 10 min due to learning rate stalling
context_len = 64
n_embd = 256
n_head = 8
n_layer = 6
batch_size = 64
learning_rate = 3e-4
train_iter = 10000
eval_iter = 150
eval_interval = 1000
dropout = 0.2
weight_decay = 0.01

loss after 0 iterations: train_loss 4.481040954589844; eval_loss 4.479975700378418
loss after 1000 iterations: train_loss 2.363853693008423; eval_loss 2.3665316104888916
loss after 2000 iterations: train_loss 2.3708105087280273; eval_loss 2.3742127418518066
loss after 3000 iterations: train_loss 2.3888113498687744; eval_loss 2.3926215171813965

# RUN 6: 6.34M params; canceled after 8 min due to significantly worse performance than Run 2
-> same as Run 2 but with small L2 weight decay
context_len = 64
n_embd = 256
n_head = 8
n_layer = 8
batch_size = 64
learning_rate = 3e-4
train_iter = 10000
eval_iter = 150
eval_interval = 1000
dropout = 0.2
weight_decay = 0.001

loss after 0 iterations: train_loss 4.602679252624512; eval_loss 4.598903656005859
loss after 1000 iterations: train_loss 2.1808369159698486; eval_loss 2.190643787384033
loss after 2000 iterations: train_loss 2.142751455307007; eval_loss 2.1480987071990967

# RUN 7: 6.34M params; 23 min -> 1.52 nlll

context_len = 64
n_embd = 256
n_head = 8
n_layer = 8
batch_size = 64
learning_rate = 3e-4
train_iter = 8000
eval_iter = 150
eval_interval = 1000
dropout = 0.2

loss after 0 iterations: train_loss 4.603909492492676; eval_loss 4.60637903213501
loss after 1000 iterations: train_loss 1.7214868068695068; eval_loss 1.741114854812622
loss after 2000 iterations: train_loss 1.587981104850769; eval_loss 1.6238888502120972
loss after 3000 iterations: train_loss 1.5157275199890137; eval_loss 1.577518105506897
loss after 4000 iterations: train_loss 1.4594924449920654; eval_loss 1.5488260984420776
loss after 5000 iterations: train_loss 1.4075320959091187; eval_loss 1.5284029245376587
loss after 6000 iterations: train_loss 1.3645418882369995; eval_loss 1.5264160633087158
loss after 7000 iterations: train_loss 1.3213069438934326; eval_loss 1.5272397994995117

'Lindgraben\n',
 'Frankendorf\n',
 'Rißtaubieren\n',
 'Felkendorf\n',
 'Oberstaufen\n',
 'Scharloch\n',
 'Stumpfental\n',
 'Schwalbmühle\n',
 'Wiesenhausen\n',
 'Rothenbuch\n',
 'Atter\n',
 'Westenerhof\n',
 'Einzienla\n',
 'Pötzling\n',
 'Gstrieß\n',
 'Hummelsried\n',
 'Reifling\n',
 'Obernöringen\n',
 'Branntännl\n',
 'Krösel\n',
 'Bärenlingen\n',
 'Breitel\n',
 'Siegskofen\n',
 'Mainauf\n',
 'Bihelhaeden\n',
 'Buxach\n',
 'Wineder\n',
 'Thierau\n',
 'Hofdorf\n',
 'Ruhmannsdorf\n',
 'Scheckenbach\n',
 'Hollberg\n',
 'Reicher Spiedler\n',
 'Wartenhofen\n',
 'Krietzendorf\n',
 'Langenprechting\n',
 'Gstorfing\n',
 'Bugendorf\n',
 'Grimmsried\n',
 'Horrheim\n',
 'Nindicherles\n',
 'Ablom\n',
 'Alte Leonitz\n',
 'Unternore Wald\n',
 'Pliener\n',
 'Lippertsgrün\n',
 'Krippenstaller\n',
 'Querenhofen\n',
 'Kötzwieser\n',
 'Vogelschlag\n']