In [153]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

In [2]:
!pip install unidecode

Collecting unidecode
  Downloading Unidecode-1.3.8-py3-none-any.whl.metadata (13 kB)
Downloading Unidecode-1.3.8-py3-none-any.whl (235 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/235.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m225.3/235.5 kB[0m [31m9.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.5/235.5 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: unidecode
Successfully installed unidecode-1.3.8


In [118]:
%load_ext autoreload
%autoreload 2

# Dataset kurimas
Pašalinamos visos nereikalingos raidės išskyrus mažasias.

In [140]:
import requests
from bs4 import BeautifulSoup
import numpy as np
import pandas as pd
import unidecode

def preprocess_name(name):
    return unidecode.unidecode(name.lower())

man_names = []
woman_names = []

for key in ['a', 'b', 'c', 'c-2', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
            'm', 'n', 'o', 'p', 'r', 's', 's-2', 't', 'u', 'v', 'z', 'z-2']:
    url = f'https://vardai.vlkk.lt/sarasas/{key}/'
    response = requests.get(url)
    soup = BeautifulSoup(response.text, 'html.parser')

    man_links = soup.find_all('a', class_='names_list__links names_list__links--man')
    man_names += [preprocess_name(name.text) for name in man_links]

    woman_links = soup.find_all('a', class_='names_list__links names_list__links--woman')
    woman_names += [preprocess_name(name.text) for name in woman_links]

pd.DataFrame(man_names, columns=['name']).to_csv('m_names.csv', index=False)
pd.DataFrame(woman_names, columns=['name']).to_csv('w_names.csv', index=False)

In [141]:
class NameDataset(Dataset):
    def __init__(self, csv_file):
        self.names = pd.read_csv(csv_file)['name'].values
        self.chars = sorted(list(set(''.join(self.names) + ' ')))
        self.char_to_int = {c: i for i, c in enumerate(self.chars)}
        self.int_to_char = {i: c for c, i in self.char_to_int.items()}
        self.vocab_size = len(self.chars)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx] + ' '
        encoded_name = [self.char_to_int[char] for char in name]
        return torch.tensor(encoded_name)

In [142]:
m_dataset = NameDataset('m_names.csv')
w_dataset = NameDataset('w_names.csv')

decoded_m_name = ''.join([m_dataset.int_to_char[idx.item()] for idx in m_dataset[0]])
decoded_w_name = ''.join([w_dataset.int_to_char[idx.item()] for idx in w_dataset[0]])

print(f'Man name: "{decoded_m_name}"')
print(f'Woman name: "{decoded_w_name}"')

Man name: "abas "
Woman name: "abe "


# Datos paruošimas

In [144]:
def pad_collate(batch):
    padded_seqs = pad_sequence(batch, batch_first=True, padding_value=0)
    input_seq = padded_seqs[:, :-1]
    target_seq = padded_seqs[:, 1:]
    return input_seq, target_seq

m_dataloader = DataLoader(m_dataset, batch_size=32, shuffle=True, collate_fn=pad_collate)
w_dataloader = DataLoader(w_dataset, batch_size=32, shuffle=True, collate_fn=pad_collate)


# Treniravimas bei pats modelio sukurimas
Modelis yra įkeltas į arch.py (MinimalTransformer). Šis modelis skirtas, kurti vardus ir moterims ir vyrams. Yra du skirtingi encoderiai m_encoder - vyriškiems vardams, o w_encoder - moteriškiems, bet architektūra išlieka tapati. Kad nuspręsti dėl encoderio naudojamas parametras gender, kuris gali būti arba m, arba w.

In [154]:
import torch
from torch.utils.data import DataLoader
from arch import MinimalTransformer
from utilsforjupyter import train, sample

model = MinimalTransformer(vocab_size=m_dataset.vocab_size, embed_size=128, num_heads=8, forward_expansion=4)

print("Training on man names dataset:")
train(model, m_dataloader, gender='m', epochs=15)

print("\nTraining on woman names dataset:")
train(model, w_dataloader, gender='w', epochs=15)




Training on man names dataset:
Epoch 1/15 (m), Loss: 1.2893228924964084
Epoch 2/15 (m), Loss: 1.1716887546964914
Epoch 3/15 (m), Loss: 1.1577868993617286
Epoch 4/15 (m), Loss: 1.1501042980793095
Epoch 5/15 (m), Loss: 1.1354533980700596
Epoch 6/15 (m), Loss: 1.1337486707474576
Epoch 7/15 (m), Loss: 1.1285725245791034
Epoch 8/15 (m), Loss: 1.1258878274397417
Epoch 9/15 (m), Loss: 1.125634959906586
Epoch 10/15 (m), Loss: 1.1218698940986445
Epoch 11/15 (m), Loss: 1.1129084633401602
Epoch 12/15 (m), Loss: 1.1126101204186432
Epoch 13/15 (m), Loss: 1.1088492481176517
Epoch 14/15 (m), Loss: 1.1037366961644701
Epoch 15/15 (m), Loss: 1.110900941466497

Training on woman names dataset:
Epoch 1/15 (w), Loss: 1.3816265746166831
Epoch 2/15 (w), Loss: 1.1929793779115032
Epoch 3/15 (w), Loss: 1.1792138142693311
Epoch 4/15 (w), Loss: 1.1663570888060377
Epoch 5/15 (w), Loss: 1.161831662170869
Epoch 6/15 (w), Loss: 1.1620637052937557
Epoch 7/15 (w), Loss: 1.154253229610902
Epoch 8/15 (w), Loss: 1.1643975

# Vardų generavimas
Naudojantims sample funkcija iš utilsforjupyter.py. Ši funkcija pasinaudoja temperature logika, kuri parodo ant kiek AI yra įsitikines vardo tikslumu. Modelis naudojamas atspeti po raidę kiekviename varde, priklausnat nuo tikimybės.

In [162]:
sample(
    model,
    m_dataset,
    w_dataset,
    start_str='a',
    max_length=20,
    num_names=5
)

Confidence           Man Names                      Woman Names                   
-------------------------------------------------------------------------------------
Higher Confidence   
                     airinas                        arime                         
                     alivilas                       ailija                        
                     adetanas                       aulija                        
                     ailejus                        alinone                       
                     alekonas                       augija                        

More Creative       
                     aitanrabas                     atiuedijogna                  
                     agmmaktas                      almja                         
                     aurntaneodvavigugane           auzkinyza                     
                     amatylgonileridagari           antinana                      
                     aldiamyus           

# Išsaugojam modelį

In [151]:
torch.save(model, 'namesformer_model.pt')