In [1]:
from collections import Counter
import json
import os
import string
import random

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset, DataLoader

In [2]:
namedf = pd.read_csv("surnames_with_splits.csv")

In [3]:
class LMDataset(Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, dataframe, ngram_len=3):
        
        #self.X will be an n-tuple of characters
        #self.Y will be the next character after the n-tuple
        
        self.X = []
        self.Y = []
        for i, row in dataframe.iterrows():
            name = f"#{row['surname'].lower()}$" #add a # at the beginning and end of the name to indicate the start and end of the name
            if (len(name) + 1) < (ngram_len + 1):
                continue
            
            for i in range(len(name)-ngram_len):
                self.X.append(name[i:i+ngram_len])
                self.Y.append(name[i+ngram_len])
        
    def vectorize_name(self, trigram):
        pass
        #return ...
    
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.X)

    def __getitem__(self, index):
        'Generates one sample of data'
       
        x = self.X[index]
        y = self.Y[index]
        
        return x,y

def get_trigram_loader(dataframe, batch_size=16, shuffle=True):
    'returns a DataLoader instance with the given parameters'
    dataset = LMDataset(dataframe, ngram_len = 3)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


In [4]:
train_loader = get_trigram_loader(namedf[namedf.split == 'train'])

In [5]:
class TrigramLMOneHot(nn.Module):

    def __init__(self, hidden_dim, vocab_dim):
       
        super(TrigramLMOneHot, self).__init__()
        
        self.hidden_layer = nn.Linear(vocab_dim*3,hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, vocab_dim)
        self.sm = nn.Softmax(dim=1) # softmax über output nodes/spalten deswegen dim=1    
        
        

    def forward(self, x_in):
        x = self.hidden_layer(x_in)
        x = self.output_layer(x)
        
        return self.sm(x) 
    
 # https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html   

In [30]:
def build_vocab(start: chr, end: chr, dataframe: pd.DataFrame):
    names = dataframe.surname.str.lower()
    
    vocab = {start: 0, end: 1}
    for name in names:
        for c in name:
            if c not in vocab:
                vocab[c] = len(vocab)
                
    return vocab

In [40]:
valid_chars = list(string.ascii_lowercase) + ['ä', 'ö', 'ü', 'ß']

vocab = build_vocab('#', '$', namedf)
trilm = TrigramLMOneHot(128, len(vocab))

In [32]:
print(vocab)

{'#': 0, '$': 1, 't': 2, 'o': 3, 'a': 4, 'h': 5, 'b': 6, 'u': 7, 'd': 8, 'f': 9, 'k': 10, 'r': 11, 'y': 12, 's': 13, 'e': 14, 'g': 15, 'c': 16, 'm': 17, 'i': 18, 'n': 19, 'w': 20, 'l': 21, 'z': 22, 'q': 23, 'j': 24, '-': 25, 'p': 26, 'x': 27, ':': 28, 'v': 29, '1': 30, '/': 31, 'é': 32, "'": 33, 'ç': 34, 'ê': 35, 'ß': 36, 'ö': 37, 'ä': 38, 'ü': 39, 'ú': 40, 'à': 41, 'ò': 42, 'è': 43, 'ó': 44, 'ù': 45, 'ì': 46, 'ś': 47, 'ą': 48, 'ń': 49, 'á': 50, 'ż': 51, 'ł': 52, 'õ': 53, 'ã': 54, 'í': 55, 'ñ': 56}


In [62]:
lr = 0.3
optimizer = optim.Adam(params=trilm.parameters(), lr=lr)
ce_loss = nn.NLLLoss()
epoch = 0  # reset epoch counter
n_epochs = 10

In [34]:
def vectorize_trigrams(trigrams: list[str], vocab: dict[chr, int]):
    vecs = []
    len_vocab = len(vocab)
    for trigram in trigrams:
        vec = torch.zeros(3*len_vocab)
        for i, char in enumerate(trigram):
            vec[i*len_vocab + vocab[char]] = 1
        vecs.append(vec)
    return torch.stack(vecs)

def target_to_indicies(chars: list[chr], vocab: dict[chr, int]):
    vec = []
    for char in chars:
        vec.append(vocab[char])
    return torch.tensor(vec)

In [70]:
#For each epoch
for epoch in range(n_epochs):
    print('Epoch : ', epoch)
    sum_loss=0
    
    #For each batch
    for batch_index, (x,y) in enumerate(train_loader):
        #print(x, x.type(), x.size())
    
        #reset gradient
        optimizer.zero_grad()
        
        #prepare the x as tensor
        x = vectorize_trigrams(x, vocab)
        y = target_to_indicies(y, vocab)
        
        #run the model, get the prediction for x
        y_pred = trilm(x)
        
        #print the dimensions of x, y, y_pred
        #print(x.size(), y.size(), y_pred.size())
        
        #compare it with label, calculate loss, add it to epoch loss
        loss=ce_loss(y_pred, y)
        sum_loss+=loss.item()
        
        #calculate gradients
        loss.backward()
        
        #call optimizer to update the weights backwards 
        optimizer.step()
        
    print('Loss : ', sum_loss)

Epoch :  0
Loss :  -217.9375
Epoch :  1
Loss :  -217.9375
Epoch :  2
Loss :  -217.9375
Epoch :  3
Loss :  -217.98611111193895
Epoch :  4
Loss :  -218.0347222238779
Epoch :  5
Loss :  -217.9375
Epoch :  6
Loss :  -217.98611111193895
Epoch :  7
Loss :  -217.9375
Epoch :  8
Loss :  -217.9375
Epoch :  9
Loss :  -217.9375


In [67]:
def generate_name(model: nn.Module, prefix: str, vocab: dict[chr, int]):
    model.eval()
    name = prefix
    
    #limit the length of the generated text to 15 characters
    while len(name) < 15:
        #vectorize the last 3 characters of the generated text
        x = vectorize_trigrams([name[-3:]], vocab)
        
        #run the model, get the prediction for x
        y_pred = model(x)
        
        #if sample is True, sample the next character from the predicted probabilities
        y_pred = torch.argmax(y_pred, dim=1).detach().numpy()[0]
        
        #get the character corresponding to the index
        char = list(vocab.keys())[list(vocab.values()).index(y_pred)]
        
        if char == '$':
            break
        elif char == '#':
            print("Error: Start character found in the middle of the name")
            break
        
        #add the character to the generated text
        name += char
    
    if(len(name) == 15):
        print("Error: Maximum length reached")
    print(f"final name: {name}")
        
        
        

Error: Maximum length reached
final name: aanaaaaaaaaaaaa
