In [1]:
import pandas as pd
import numpy as np
import random

from tqdm import tqdm


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

In [3]:
df = pd.read_csv("./data/characters_metadata.csv")

In [4]:
df.columns

Index(['ID', 'Name', 'Alias', 'Gender', 'Hair Color', 'Love Rank', 'Hate Rank',
       'Eye color', 'Birthday', 'Blood Type', 'Tags', 'Love Count',
       'Hate Count', 'Description', 'url'],
      dtype='object')

In [5]:
df.head()

Unnamed: 0,ID,Name,Alias,Gender,Hair Color,Love Rank,Hate Rank,Eye color,Birthday,Blood Type,Tags,Love Count,Hate Count,Description,url
0,0,L,Ryuzaki,Male,Black,1.0,48.0,Black,"October 31, 1979",Unknown,"Analytical, Barefoot, Detectives, Eye Bags, Sw...",44.829,3.447,"Secretive, meticulous and cunning, L's desire ...",https://www.anime-planet.com/characters/l-deat...
1,1,Haru YOSHIDA,Unknown,Male,Black,346.0,4.172,Black,April 2,Unknown,"High School Students, Hot-Headed, Teenagers",4.669,124.0,Unknown,https://www.anime-planet.com/characters/haru-y...
2,2,Shinobu MAEHARA,Unknown,Female,Blue,2.942,9.11,Unknown,Unknown,Unknown,"Cooks, Crybabies, Middle School Students, Shy",823.0,53.0,Unknown,https://www.anime-planet.com/characters/shinob...
3,3,Chizuru OSHIMA,Unknown,Female,Black,3.877,1.801,Unknown,Unknown,Unknown,"Class Representatives, Glasses, High School St...",633.0,269.0,Unknown,https://www.anime-planet.com/characters/chizur...
4,4,Yuuzan YOSHIDA,Unknown,Male,Black,3.577,2.819,Unknown,Unknown,Unknown,Unknown,684.0,180.0,Unknown,https://www.anime-planet.com/characters/yuuzan...


In [6]:
def get_first_name(name):
    chars = [*name]
    len_chars = len(chars)
    idx = 0
    while idx < len_chars and chars[idx].isalpha() and chars[idx].isascii():
        idx += 1
    
    return "".join(chars[:idx])

In [7]:
female_names = list(set([*map(lambda full_name: get_first_name(full_name).lower(), df[df["Gender"] == "Female"]["Name"])]))
male_names = list(set([*map(lambda full_name: get_first_name(full_name).lower(), df[df["Gender"] == "Male"]["Name"])]))

In [8]:
female_names = [*filter(lambda name: len(name) != 0, female_names)]
male_names = [*filter(lambda name: len(name) != 0, male_names)]

In [9]:
print(f"max female name len: {max(map(lambda name: len(name), female_names))} | max male name len: {max(map(lambda name: len(name), male_names))}")
print(f"total female names: {len(female_names)} | total male names: {len(male_names)}")

max female name len: 20 | max male name len: 18
total female names: 17290 | total male names: 30758


In [10]:


# Special Tokens

END_TOKEN = "<end>"
PAD_TOKEN = "<pad>"

MALE_NAME_TOKEN = "<M>"
FEMALE_NAME_TOKEN = "<F>"

In [51]:
token2idx

{'a': 0,
 'b': 1,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12,
 'n': 13,
 'o': 14,
 'p': 15,
 'q': 16,
 'r': 17,
 's': 18,
 't': 19,
 'u': 20,
 'v': 21,
 'w': 22,
 'x': 23,
 'y': 24,
 'z': 25,
 '<M>': 26,
 '<F>': 27,
 '<end>': 28,
 '<pad>': 29}

In [11]:

token2idx = {token: idx for idx, token in enumerate([chr(i) for i in range(97,123)] + [MALE_NAME_TOKEN, FEMALE_NAME_TOKEN, END_TOKEN, PAD_TOKEN])}
idx2token = {v:k for k,v in token2idx.items()}

In [12]:
DEVICE      = "cuda:0"
VOCAB_SIZE  = len(token2idx)
MAX_LEN     = 24
EMBED_DIM   = 512
HIDDEN_DIM  = 1024

EPOCHS      = 8

In [13]:

def tokenizer_encode(name, gender, max_len=24):

    if gender == "Male":
        gender_token = MALE_NAME_TOKEN
    elif gender == "Female":
        gender_token = FEMALE_NAME_TOKEN
    else:
        raise RuntimeError("Invalid gender")
    
    name = [gender_token] + [*name[:max_len]]
    name.append(END_TOKEN)
    
    while len(name) < max_len:
        name.append(PAD_TOKEN)
        
    return [token2idx[c] for c in name]

In [14]:

class RNN(nn.Module):
    
    def __init__(self, vocab_size=VOCAB_SIZE, embd_dim=EMBED_DIM, hidden_dim=HIDDEN_DIM):
    
        super().__init__()
        self.W_hh = nn.Linear(hidden_dim, hidden_dim)
        self.W_xh = nn.Linear(embd_dim, hidden_dim)
        self.W_hy = nn.Linear(hidden_dim, vocab_size)
        # self.W_hg = nn.Linear(hidden_dim)
        
        self.h = nn.Parameter(torch.randn(hidden_dim))
        self.embeddings = nn.Embedding(vocab_size, embd_dim)
        
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
    def forward(self, x, device=DEVICE):

        # h = self.h
        x = self.embeddings(x)
        batch_size, seq_len, embd_dim = x.shape


        output = torch.zeros(batch_size, seq_len - 1, self.vocab_size).to(device)
        hiddens = torch.zeros(batch_size, self.hidden_dim).to(device)
        
        for i in range(batch_size):
            hiddens[i] = self.h
        
        for i in range(seq_len - 1):
            
            hiddens = F.tanh(self.W_hh(hiddens) + self.W_xh(x[:,i] + x[:,0]))
            y = self.W_hy(hiddens)
            output[:,i] = y
        
        return output

In [15]:
model = RNN().to(DEVICE)

In [16]:

def criterion(input_tokens, y_pred):
        
    y_true = input_tokens[:, 1:].clone()

    # The first token will be the gender token
    y_true[0] = -100
    y_true.masked_fill_(y_true == token2idx[PAD_TOKEN], -100)
    
    # print(y_pred.shape, y_true.shape)
    loss = F.cross_entropy(y_pred.reshape(-1,VOCAB_SIZE), y_true.reshape(-1))
    return loss

In [17]:
xs = []

for idx, name in tqdm(enumerate(female_names)):
    xs.append(tokenizer_encode(name, "Female"))

print(xs[-1])
for idx, name in tqdm(enumerate(male_names)):
    xs.append(tokenizer_encode(name, "Male"))

print(xs[-1])

17290it [00:00, 688812.11it/s]


[27, 1, 4, 13, 10, 0, 19, 4, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29]


30758it [00:00, 342730.09it/s]

[26, 2, 7, 14, 10, 8, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29]





In [18]:

random.shuffle(xs)
xs = torch.tensor(xs).to(DEVICE)

In [19]:
xs[:11]

tensor([[26,  5, 11,  8, 17, 19,  0, 19,  8, 14, 20, 18, 28, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [26, 14, 17,  8,  1,  4, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [26, 10, 20, 18,  0, 13, 14, 18, 20, 10,  4, 28, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [26,  4, 20, 13,  6, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [27, 13,  0, 13,  0,  1,  0, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [26, 19,  0,  0, 10,  8, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [27, 12,  4, 13,  2,  7,  8,  7, 20,  0,  7, 20,  0, 28, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [26, 13, 14, 25, 20, 12, 14, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
         29, 29, 29, 29, 29, 29],
        [26, 10,  8, 14, 20, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,

In [20]:

BATCH_SIZE = 64
optimizer = AdamW(model.parameters())

In [21]:

def inference(model, input_str, gender):
    
    while True:
        x = torch.tensor(tokenizer_encode(input_str, gender, max_len=len(input_str))).reshape(1,-1).to(DEVICE)
        last_logits = model(x, device=DEVICE).squeeze()[-1]
        new_char_idx = last_logits.softmax(dim=-1).argmax().item()
    
        if new_char_idx == token2idx[END_TOKEN]:
            # print(input_str)
            break
        
        input_str = input_str + idx2token[new_char_idx]
    return input_stri

In [30]:


for epoch in range(EPOCHS):
    print("="*40)
    print(f"EPOCH: {epoch}")
    print("="*40)
    training_loss = []
    step_count = 0
    inference_per_step = 50
    
    for batch_start_idx in (tbar := tqdm(range(0,len(xs),BATCH_SIZE))):
        
        optimizer.zero_grad()
        
        input_xs = xs[batch_start_idx: batch_start_idx + BATCH_SIZE]
        
        pred_ys = model(input_xs, device=DEVICE)
        loss = criterion(input_xs,pred_ys)
        
        loss.backward()
        optimizer.step()
        
        step_count += 1
        training_loss.append(loss.item())
        
        if step_count % inference_per_step == 0:
            model.eval()
            name = inference(model, "y", "Female")
            print(f"step count: {step_count} | name: {name} | gender: Female")
            name = inference(model, "y", "Male")
            print(f"step count: {step_count} | name: {name} | gender: Male")
            model.train()
        
        
        tbar.set_description(f"loss: {loss.item()} | training loss: {np.mean(training_loss)}")

EPOCH: 0


loss: 2.1484215259552 | training loss: 2.116554381031739:   9%|██████████▋                                                                                                       | 70/751 [00:00<00:05, 135.62it/s]

step count: 50 | name: yumiko | gender: Female
step count: 50 | name: yoshimaru | gender: Male


loss: 2.116305112838745 | training loss: 2.1232803337217314:  15%|████████████████▌                                                                                             | 113/751 [00:00<00:04, 135.91it/s]

step count: 100 | name: yuuka | gender: Female
step count: 100 | name: yashimaru | gender: Male


loss: 2.116968870162964 | training loss: 2.1205304820429194:  23%|████████████████████████▉                                                                                     | 170/751 [00:01<00:04, 135.11it/s]

step count: 150 | name: yurin | gender: Female
step count: 150 | name: yoshinosuke | gender: Male


loss: 2.0876643657684326 | training loss: 2.123368576234658:  30%|█████████████████████████████████                                                                             | 226/751 [00:01<00:03, 135.68it/s]

step count: 200 | name: yuu | gender: Female
step count: 200 | name: yoshinosuke | gender: Male


loss: 2.0936179161071777 | training loss: 2.1235292939072483:  36%|██████████████████████████████████████▉                                                                      | 268/751 [00:02<00:03, 135.83it/s]

step count: 250 | name: yuuka | gender: Female
step count: 250 | name: yoshin | gender: Male


loss: 2.1938581466674805 | training loss: 2.1232378803626477:  43%|███████████████████████████████████████████████                                                              | 324/751 [00:02<00:03, 135.86it/s]

step count: 300 | name: yukina | gender: Female
step count: 300 | name: yukita | gender: Male


loss: 2.172584056854248 | training loss: 2.1246895682590394:  49%|█████████████████████████████████████████████████████▌                                                        | 366/751 [00:02<00:02, 136.04it/s]

step count: 350 | name: yukino | gender: Female
step count: 350 | name: yukimaru | gender: Male


loss: 2.188952922821045 | training loss: 2.1250536749620395:  56%|█████████████████████████████████████████████████████████████▊                                                | 422/751 [00:03<00:02, 134.93it/s]

step count: 400 | name: yoshihara | gender: Female
step count: 400 | name: yoshitaka | gender: Male


loss: 2.134693145751953 | training loss: 2.126616526802047:  62%|████████████████████████████████████████████████████████████████████▌                                          | 464/751 [00:03<00:02, 134.27it/s]

step count: 450 | name: yukino | gender: Female
step count: 450 | name: yamanosuke | gender: Male


loss: 2.1601579189300537 | training loss: 2.1275345220511506:  69%|███████████████████████████████████████████████████████████████████████████▍                                 | 520/751 [00:03<00:01, 135.52it/s]

step count: 500 | name: yuuri | gender: Female
step count: 500 | name: yamaka | gender: Male


loss: 2.169656991958618 | training loss: 2.1268962877431954:  77%|████████████████████████████████████████████████████████████████████████████████████▌                         | 577/751 [00:04<00:01, 140.33it/s]

step count: 550 | name: yuuta | gender: Female
step count: 550 | name: yamagi | gender: Male


loss: 2.1856465339660645 | training loss: 2.1270670613223475:  83%|██████████████████████████████████████████████████████████████████████████████████████████▏                  | 621/751 [00:04<00:00, 138.06it/s]

step count: 600 | name: yuuta | gender: Female
step count: 600 | name: yasuke | gender: Male


loss: 2.0408565998077393 | training loss: 2.127186980473219:  88%|█████████████████████████████████████████████████████████████████████████████████████████████████             | 663/751 [00:04<00:00, 134.91it/s]

step count: 650 | name: yuushi | gender: Female
step count: 650 | name: yoshimaru | gender: Male


loss: 2.152297258377075 | training loss: 2.1270945114866104:  96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▎    | 719/751 [00:05<00:00, 133.99it/s]

step count: 700 | name: yukimi | gender: Female
step count: 700 | name: yamanosuke | gender: Male


loss: 2.02445387840271 | training loss: 2.1265811507457424: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 135.71it/s]


step count: 750 | name: yumin | gender: Female
step count: 750 | name: yoshika | gender: Male
EPOCH: 1


loss: 2.17546010017395 | training loss: 2.114105504828614:   9%|██████████▌                                                                                                      | 70/751 [00:00<00:05, 135.99it/s]

step count: 50 | name: yuuki | gender: Female
step count: 50 | name: yuuki | gender: Male


loss: 2.08793306350708 | training loss: 2.1176368983711784:  15%|████████████████▋                                                                                              | 113/751 [00:00<00:04, 135.28it/s]

step count: 100 | name: yana | gender: Female
step count: 100 | name: yashimaru | gender: Male


loss: 2.109851121902466 | training loss: 2.1144928153265607:  23%|████████████████████████▉                                                                                     | 170/751 [00:01<00:04, 134.48it/s]

step count: 150 | name: yuuichi | gender: Female
step count: 150 | name: yoshinosuke | gender: Male


loss: 2.1019809246063232 | training loss: 2.1173462039601487:  28%|███████████████████████████████                                                                              | 214/751 [00:01<00:03, 138.15it/s]

step count: 200 | name: yuushi | gender: Female
step count: 200 | name: yoshinosuke | gender: Male


loss: 2.0887606143951416 | training loss: 2.118050344153862:  36%|███████████████████████████████████████▌                                                                      | 270/751 [00:02<00:03, 136.36it/s]

step count: 250 | name: yuuki | gender: Female
step count: 250 | name: yashin | gender: Male


loss: 2.1987051963806152 | training loss: 2.117777778832555:  43%|███████████████████████████████████████████████▋                                                              | 326/751 [00:02<00:03, 135.66it/s]

step count: 300 | name: yukika | gender: Female
step count: 300 | name: yukita | gender: Male


loss: 2.1756725311279297 | training loss: 2.11922719718923:  49%|██████████████████████████████████████████████████████▍                                                        | 368/751 [00:02<00:02, 135.51it/s]

step count: 350 | name: yuuzaku | gender: Female
step count: 350 | name: yukito | gender: Male


loss: 2.145167112350464 | training loss: 2.1193398419120504:  56%|██████████████████████████████████████████████████████████████                                                | 424/751 [00:03<00:02, 134.24it/s]

step count: 400 | name: yoshika | gender: Female
step count: 400 | name: yoshitaka | gender: Male


loss: 2.12910795211792 | training loss: 2.1208032086616804:  62%|████████████████████████████████████████████████████████████████████▉                                          | 466/751 [00:03<00:02, 134.30it/s]

step count: 450 | name: yukiha | gender: Female
step count: 450 | name: yamanosuke | gender: Male


loss: 2.2206735610961914 | training loss: 2.1217532135234585:  70%|███████████████████████████████████████████████████████████████████████████▊                                 | 522/751 [00:03<00:01, 135.47it/s]

step count: 500 | name: yuu | gender: Female
step count: 500 | name: yamanosuke | gender: Male


loss: 2.0994958877563477 | training loss: 2.1211721073422165:  75%|█████████████████████████████████████████████████████████████████████████████████▊                           | 564/751 [00:04<00:01, 135.05it/s]

step count: 550 | name: yuuta | gender: Female
step count: 550 | name: yasuhito | gender: Male


loss: 2.1757969856262207 | training loss: 2.121512993671107:  83%|██████████████████████████████████████████████████████████████████████████████████████████▊                   | 620/751 [00:04<00:00, 135.32it/s]

step count: 600 | name: yuuta | gender: Female
step count: 600 | name: yuusuke | gender: Male


loss: 2.0060911178588867 | training loss: 2.1215274838300853:  90%|██████████████████████████████████████████████████████████████████████████████████████████████████           | 676/751 [00:04<00:00, 134.58it/s]

step count: 650 | name: yuusa | gender: Female
step count: 650 | name: yukiyoshi | gender: Male


loss: 2.0534989833831787 | training loss: 2.1213126546579018:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏    | 718/751 [00:05<00:00, 134.83it/s]

step count: 700 | name: yuriko | gender: Female
step count: 700 | name: yamagawa | gender: Male


loss: 2.018373966217041 | training loss: 2.1208934088680302: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 135.54it/s]


step count: 750 | name: yuuta | gender: Female
step count: 750 | name: yoshiki | gender: Male
EPOCH: 2


loss: 2.1561663150787354 | training loss: 2.1070746833627876:   9%|██████████▎                                                                                                   | 70/751 [00:00<00:05, 135.94it/s]

step count: 50 | name: yuuya | gender: Female
step count: 50 | name: yukito | gender: Male


loss: 2.082498550415039 | training loss: 2.1133690997371524:  17%|██████████████████▍                                                                                           | 126/751 [00:00<00:04, 135.73it/s]

step count: 100 | name: yuu | gender: Female
step count: 100 | name: yashima | gender: Male


loss: 2.100679397583008 | training loss: 2.1100676784461196:  22%|████████████████████████▌                                                                                     | 168/751 [00:01<00:04, 133.47it/s]

step count: 150 | name: yukimi | gender: Female
step count: 150 | name: yoshinosuke | gender: Male


loss: 2.1703174114227295 | training loss: 2.1126989046377793:  30%|████████████████████████████████▏                                                                            | 222/751 [00:01<00:04, 125.40it/s]

step count: 200 | name: yuusuke | gender: Female
step count: 200 | name: yoshinori | gender: Male


loss: 2.1465933322906494 | training loss: 2.1138828031338046:  36%|███████████████████████████████████████▊                                                                     | 274/751 [00:02<00:03, 123.57it/s]

step count: 250 | name: yuuri | gender: Female
step count: 250 | name: yoshinari | gender: Male


loss: 2.1485648155212402 | training loss: 2.1123594130262915:  42%|█████████████████████████████████████████████▍                                                               | 313/751 [00:02<00:03, 122.74it/s]

step count: 300 | name: yukiha | gender: Female
step count: 300 | name: yukitaka | gender: Male


loss: 2.111905574798584 | training loss: 2.1144050560533043:  49%|█████████████████████████████████████████████████████▍                                                        | 365/751 [00:02<00:03, 123.73it/s]

step count: 350 | name: yuu | gender: Female
step count: 350 | name: yokoko | gender: Male


loss: 2.1696739196777344 | training loss: 2.1145276941780775:  56%|████████████████████████████████████████████████████████████▌                                                | 417/751 [00:03<00:02, 122.57it/s]

step count: 400 | name: yoshimi | gender: Female
step count: 400 | name: yoshitaka | gender: Male


loss: 2.1166903972625732 | training loss: 2.1161161007257454:  62%|████████████████████████████████████████████████████████████████████                                         | 469/751 [00:03<00:02, 123.62it/s]

step count: 450 | name: yukihime | gender: Female
step count: 450 | name: yamanosuke | gender: Male


loss: 2.136939525604248 | training loss: 2.116612213225592:  70%|█████████████████████████████████████████████████████████████████████████████▏                                 | 522/751 [00:04<00:01, 125.06it/s]

step count: 500 | name: yuuta | gender: Female
step count: 500 | name: yamanosuke | gender: Male


loss: 2.2284016609191895 | training loss: 2.115952104226222:  76%|████████████████████████████████████████████████████████████████████████████████████                          | 574/751 [00:04<00:01, 124.39it/s]

step count: 550 | name: yuuta | gender: Female
step count: 550 | name: yamagaki | gender: Male


loss: 2.0827808380126953 | training loss: 2.1167697520592275:  82%|████████████████████████████████████████████████████████████████████████████████████████▉                    | 613/751 [00:04<00:01, 123.50it/s]

step count: 600 | name: yuuta | gender: Female
step count: 600 | name: yasuno | gender: Male


loss: 2.114816665649414 | training loss: 2.1168840835712572:  89%|█████████████████████████████████████████████████████████████████████████████████████████████████▍            | 665/751 [00:05<00:00, 125.15it/s]

step count: 650 | name: yuki | gender: Female
step count: 650 | name: yuusaku | gender: Male


loss: 2.1775259971618652 | training loss: 2.1170142831473515:  95%|████████████████████████████████████████████████████████████████████████████████████████████████████████     | 717/751 [00:05<00:00, 125.89it/s]

step count: 700 | name: yuriko | gender: Female
step count: 700 | name: yamane | gender: Male


loss: 1.9987648725509644 | training loss: 2.116729253935274: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 126.55it/s]


step count: 750 | name: yuu | gender: Female
step count: 750 | name: yoshika | gender: Male
EPOCH: 3


loss: 2.0878360271453857 | training loss: 2.10004502137502:   9%|██████████▎                                                                                                     | 69/751 [00:00<00:05, 128.51it/s]

step count: 50 | name: yukiko | gender: Female
step count: 50 | name: yasuke | gender: Male


loss: 2.0879437923431396 | training loss: 2.105988828659058:  16%|█████████████████▋                                                                                            | 121/751 [00:00<00:04, 128.06it/s]

step count: 100 | name: yuuna | gender: Female
step count: 100 | name: yashima | gender: Male


loss: 2.1767191886901855 | training loss: 2.1040634311948505:  23%|█████████████████████████▎                                                                                   | 174/751 [00:01<00:04, 127.58it/s]

step count: 150 | name: yukiko | gender: Female
step count: 150 | name: yoshimaru | gender: Male


loss: 2.1506731510162354 | training loss: 2.1080383825302125:  28%|██████████████████████████████▉                                                                              | 213/751 [00:01<00:04, 126.79it/s]

step count: 200 | name: yuu | gender: Female
step count: 200 | name: yuuza | gender: Male


loss: 2.08617901802063 | training loss: 2.108539922945741:  36%|███████████████████████████████████████▊                                                                        | 267/751 [00:02<00:03, 129.31it/s]

step count: 250 | name: yuuri | gender: Female
step count: 250 | name: yashin | gender: Male


loss: 2.072596549987793 | training loss: 2.1072490387696488:  43%|███████████████████████████████████████████████▏                                                              | 322/751 [00:02<00:03, 128.47it/s]

step count: 300 | name: yukiko | gender: Female
step count: 300 | name: yukita | gender: Male


loss: 2.1833693981170654 | training loss: 2.109751499493917:  50%|██████████████████████████████████████████████████████▉                                                       | 375/751 [00:02<00:02, 127.54it/s]

step count: 350 | name: yukiko | gender: Female
step count: 350 | name: yukitaka | gender: Male


loss: 2.131394863128662 | training loss: 2.109721932771071:  55%|█████████████████████████████████████████████████████████████▏                                                 | 414/751 [00:03<00:02, 123.93it/s]

step count: 400 | name: yoshika | gender: Female
step count: 400 | name: yoshitaka | gender: Male


loss: 2.1152641773223877 | training loss: 2.111282930595462:  62%|████████████████████████████████████████████████████████████████████▎                                         | 466/751 [00:03<00:02, 123.21it/s]

step count: 450 | name: yukino | gender: Female
step count: 450 | name: yamanosuke | gender: Male


loss: 2.1296753883361816 | training loss: 2.1118109305699666:  69%|███████████████████████████████████████████████████████████████████████████▏                                 | 518/751 [00:04<00:01, 124.19it/s]

step count: 500 | name: yuu | gender: Female
step count: 500 | name: yamaka | gender: Male


loss: 2.112790822982788 | training loss: 2.1112253996487094:  76%|███████████████████████████████████████████████████████████████████████████████████▋                          | 571/751 [00:04<00:01, 130.47it/s]

step count: 550 | name: yuuki | gender: Female
step count: 550 | name: yasuhito | gender: Male


loss: 2.161250114440918 | training loss: 2.111774773308725:  82%|██████████████████████████████████████████████████████████████████████████████████████████▌                    | 613/751 [00:04<00:01, 133.53it/s]

step count: 600 | name: yune | gender: Female
step count: 600 | name: yasuka | gender: Male


loss: 2.118776559829712 | training loss: 2.111823716050941:  89%|██████████████████████████████████████████████████████████████████████████████████████████████████▉            | 669/751 [00:05<00:00, 134.96it/s]

step count: 650 | name: yuuko | gender: Female
step count: 650 | name: yuusaku | gender: Male


loss: 2.053154706954956 | training loss: 2.112017230941636:  97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 725/751 [00:05<00:00, 135.25it/s]

step count: 700 | name: yuri | gender: Female
step count: 700 | name: yamagawa | gender: Male


loss: 2.0002238750457764 | training loss: 2.1116979530108435: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 129.01it/s]


step count: 750 | name: yoshino | gender: Female
step count: 750 | name: yoshiki | gender: Male
EPOCH: 4


loss: 2.1756601333618164 | training loss: 2.095976922419164:  10%|██████████▋                                                                                                    | 72/751 [00:00<00:04, 135.82it/s]

step count: 50 | name: yuuki | gender: Female
step count: 50 | name: yukito | gender: Male


loss: 2.2600653171539307 | training loss: 2.1021197419317943:  15%|████████████████▌                                                                                            | 114/751 [00:00<00:04, 134.01it/s]

step count: 100 | name: yashima | gender: Female
step count: 100 | name: yasuhito | gender: Male


loss: 2.096046209335327 | training loss: 2.0994137660258234:  23%|████████████████████████▉                                                                                     | 170/751 [00:01<00:04, 135.14it/s]

step count: 150 | name: yukimi | gender: Female
step count: 150 | name: yoshimaru | gender: Male


loss: 2.057142734527588 | training loss: 2.102840086962158:  29%|███████████████████████████████▊                                                                               | 215/751 [00:01<00:03, 139.41it/s]

step count: 200 | name: yuusuke | gender: Female
step count: 200 | name: yoshin | gender: Male


loss: 2.075011968612671 | training loss: 2.103451877294465:  36%|████████████████████████████████████████▏                                                                      | 272/751 [00:02<00:03, 138.31it/s]

step count: 250 | name: yuuhi | gender: Female
step count: 250 | name: yoshinaga | gender: Male


loss: 2.202199935913086 | training loss: 2.1031588641412418:  42%|██████████████████████████████████████████████▏                                                               | 315/751 [00:02<00:03, 136.25it/s]

step count: 300 | name: yung | gender: Female
step count: 300 | name: yukinosuke | gender: Male


loss: 2.1625657081604004 | training loss: 2.104865209809665:  49%|██████████████████████████████████████████████████████▎                                                       | 371/751 [00:02<00:02, 136.92it/s]

step count: 350 | name: yuki | gender: Female
step count: 350 | name: yukitaka | gender: Male


loss: 2.155316114425659 | training loss: 2.105250768136643:  55%|█████████████████████████████████████████████████████████████                                                  | 413/751 [00:03<00:02, 135.32it/s]

step count: 400 | name: yoshino | gender: Female
step count: 400 | name: yoshitaka | gender: Male


loss: 2.1389000415802 | training loss: 2.1063543220735945:  63%|██████████████████████████████████████████████████████████████████████                                          | 470/751 [00:03<00:02, 136.57it/s]

step count: 450 | name: yamano | gender: Female
step count: 450 | name: yamagawa | gender: Male


loss: 2.128142833709717 | training loss: 2.107217392387607:  68%|███████████████████████████████████████████████████████████████████████████▊                                   | 513/751 [00:03<00:01, 136.72it/s]

step count: 500 | name: yuuka | gender: Female
step count: 500 | name: yuuki | gender: Male


loss: 2.1050403118133545 | training loss: 2.1063368733668866:  76%|██████████████████████████████████████████████████████████████████████████████████▋                          | 570/751 [00:04<00:01, 136.68it/s]

step count: 550 | name: yuuji | gender: Female
step count: 550 | name: yamagi | gender: Male


loss: 2.006047248840332 | training loss: 2.106594613782919:  84%|████████████████████████████████████████████████████████████████████████████████████████████▊                  | 628/751 [00:04<00:00, 139.19it/s]

step count: 600 | name: yune | gender: Female
step count: 600 | name: yasuke | gender: Male


loss: 2.1075098514556885 | training loss: 2.1070284913878674:  89%|█████████████████████████████████████████████████████████████████████████████████████████████████▏           | 670/751 [00:04<00:00, 137.80it/s]

step count: 650 | name: yuushi | gender: Female
step count: 650 | name: yuuichi | gender: Male


loss: 2.066108465194702 | training loss: 2.1070990523085142:  97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▍   | 727/751 [00:05<00:00, 137.33it/s]

step count: 700 | name: yuri | gender: Female
step count: 700 | name: yamagawa | gender: Male


loss: 2.020493984222412 | training loss: 2.106670296144549: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 136.98it/s]


step count: 750 | name: yuune | gender: Female
step count: 750 | name: yoshika | gender: Male
EPOCH: 5


loss: 2.1389920711517334 | training loss: 2.0927371163117257:   9%|██████████▎                                                                                                   | 70/751 [00:00<00:05, 134.05it/s]

step count: 50 | name: yuuno | gender: Female
step count: 50 | name: yukinosuke | gender: Male


loss: 2.075221538543701 | training loss: 2.0991143889314547:  17%|██████████████████▍                                                                                           | 126/751 [00:00<00:04, 135.13it/s]

step count: 100 | name: yuuka | gender: Female
step count: 100 | name: yashimaru | gender: Male


loss: 2.0986809730529785 | training loss: 2.096466985615817:  22%|████████████████████████▌                                                                                     | 168/751 [00:01<00:04, 134.02it/s]

step count: 150 | name: yoshino | gender: Female
step count: 150 | name: yoshimaru | gender: Male


loss: 2.0465540885925293 | training loss: 2.0996047951576466:  30%|████████████████████████████████▌                                                                            | 224/751 [00:01<00:03, 134.90it/s]

step count: 200 | name: yuusuke | gender: Female
step count: 200 | name: yoshino | gender: Male


loss: 2.079559564590454 | training loss: 2.1006427211451615:  35%|██████████████████████████████████████▉                                                                       | 266/751 [00:02<00:03, 135.27it/s]

step count: 250 | name: yuuka | gender: Female
step count: 250 | name: yoshinari | gender: Male


loss: 2.1795318126678467 | training loss: 2.099911212192034:  43%|███████████████████████████████████████████████▍                                                              | 324/751 [00:02<00:03, 137.55it/s]

step count: 300 | name: yukita | gender: Female
step count: 300 | name: yukita | gender: Male


loss: 2.181553840637207 | training loss: 2.1012765071120123:  49%|█████████████████████████████████████████████████████▉                                                        | 368/751 [00:02<00:02, 138.74it/s]

step count: 350 | name: yuu | gender: Female
step count: 350 | name: yukihiko | gender: Male


loss: 2.1710729598999023 | training loss: 2.1016746245725932:  57%|█████████████████████████████████████████████████████████████▋                                               | 425/751 [00:03<00:02, 136.86it/s]

step count: 400 | name: yoshika | gender: Female
step count: 400 | name: yoshitoki | gender: Male


loss: 2.1333529949188232 | training loss: 2.1025897829287707:  62%|███████████████████████████████████████████████████████████████████▊                                         | 467/751 [00:03<00:02, 135.45it/s]

step count: 450 | name: yukihiko | gender: Female
step count: 450 | name: yamano | gender: Male


loss: 2.1493020057678223 | training loss: 2.1032113793678935:  70%|███████████████████████████████████████████████████████████████████████████▉                                 | 523/751 [00:03<00:01, 136.19it/s]

step count: 500 | name: yuuko | gender: Female
step count: 500 | name: yamakawa | gender: Male


loss: 2.1283252239227295 | training loss: 2.1026083533628:  75%|████████████████████████████████████████████████████████████████████████████████████▌                           | 567/751 [00:04<00:01, 139.82it/s]

step count: 550 | name: yuusuke | gender: Female
step count: 550 | name: yamakichi | gender: Male


loss: 1.9790138006210327 | training loss: 2.102943015516184:  83%|███████████████████████████████████████████████████████████████████████████████████████████▊                  | 627/751 [00:04<00:00, 142.56it/s]

step count: 600 | name: yung | gender: Female
step count: 600 | name: yuuki | gender: Male


loss: 2.083526372909546 | training loss: 2.1030930928676836:  89%|██████████████████████████████████████████████████████████████████████████████████████████████████▍           | 672/751 [00:04<00:00, 138.45it/s]

step count: 650 | name: yumiko | gender: Female
step count: 650 | name: yukihiko | gender: Male


loss: 2.086172103881836 | training loss: 2.103153886296562:  95%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▋     | 715/751 [00:05<00:00, 138.14it/s]

step count: 700 | name: yuri | gender: Female
step count: 700 | name: yamagawa | gender: Male


loss: 2.008774995803833 | training loss: 2.1025925923917645: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 137.04it/s]


step count: 750 | name: yumina | gender: Female
step count: 750 | name: yoshiki | gender: Male
EPOCH: 6


loss: 2.1515326499938965 | training loss: 2.087489405235687:  10%|███████████                                                                                                    | 75/751 [00:00<00:04, 139.85it/s]

step count: 50 | name: yumiko | gender: Female
step count: 50 | name: yukihito | gender: Male


loss: 2.079709768295288 | training loss: 2.092847849440387:  16%|█████████████████▎                                                                                             | 117/751 [00:00<00:04, 137.36it/s]

step count: 100 | name: yuuko | gender: Female
step count: 100 | name: yashimaru | gender: Male


loss: 2.098708391189575 | training loss: 2.0895525567275657:  23%|█████████████████████████▋                                                                                    | 175/751 [00:01<00:04, 137.49it/s]

step count: 150 | name: yukimi | gender: Female
step count: 150 | name: yoshinosuke | gender: Male


loss: 2.029306411743164 | training loss: 2.092065235066519:  29%|████████████████████████████████                                                                               | 217/751 [00:01<00:03, 136.07it/s]

step count: 200 | name: yuusuke | gender: Female
step count: 200 | name: yoshin | gender: Male


loss: 2.0460855960845947 | training loss: 2.093063059696652:  36%|████████████████████████████████████████▏                                                                     | 274/751 [00:02<00:03, 137.54it/s]

step count: 250 | name: yuuka | gender: Female
step count: 250 | name: yashin | gender: Male


loss: 2.0470962524414062 | training loss: 2.0924135754658626:  42%|█████████████████████████████████████████████▊                                                               | 316/751 [00:02<00:03, 132.87it/s]

step count: 300 | name: yukiko | gender: Female
step count: 300 | name: yukitaka | gender: Male


loss: 2.165719747543335 | training loss: 2.094781127611796:  50%|██████████████████████████████████████████████████████▉                                                        | 372/751 [00:02<00:02, 129.17it/s]

step count: 350 | name: yuuka | gender: Female
step count: 350 | name: yukichi | gender: Male


loss: 2.124938488006592 | training loss: 2.0950769071017996:  56%|██████████████████████████████████████████████████████████████                                                | 424/751 [00:03<00:02, 127.75it/s]

step count: 400 | name: yoshika | gender: Female
step count: 400 | name: yoshitaka | gender: Male


loss: 2.0380091667175293 | training loss: 2.0965992940099616:  62%|███████████████████████████████████████████████████████████████████▍                                         | 465/751 [00:03<00:02, 127.49it/s]

step count: 450 | name: yamano | gender: Female
step count: 450 | name: yamanosuke | gender: Male


loss: 2.1095540523529053 | training loss: 2.0974583217075895:  69%|███████████████████████████████████████████████████████████████████████████▍                                 | 520/751 [00:03<00:01, 129.96it/s]

step count: 500 | name: yuu | gender: Female
step count: 500 | name: yuuki | gender: Male


loss: 2.1929190158843994 | training loss: 2.096997666151627:  75%|██████████████████████████████████████████████████████████████████████████████████▎                           | 562/751 [00:04<00:01, 129.40it/s]

step count: 550 | name: yuuko | gender: Female
step count: 550 | name: yasuhito | gender: Male


loss: 2.1742348670959473 | training loss: 2.097893444824219:  82%|██████████████████████████████████████████████████████████████████████████████████████████▌                   | 618/751 [00:04<00:01, 128.83it/s]

step count: 600 | name: yuuta | gender: Female
step count: 600 | name: yuukichi | gender: Male


loss: 2.1019372940063477 | training loss: 2.0982369218049226:  90%|█████████████████████████████████████████████████████████████████████████████████████████████████▋           | 673/751 [00:05<00:00, 126.93it/s]

step count: 650 | name: yukimura | gender: Female
step count: 650 | name: yukihiko | gender: Male


loss: 2.098848819732666 | training loss: 2.09825405877212:  97%|████████████████████████████████████████████████████████████████████████████████████████████████████████████    | 725/751 [00:05<00:00, 125.80it/s]

step count: 700 | name: yumi | gender: Female
step count: 700 | name: yamagawa | gender: Male


loss: 2.0036697387695312 | training loss: 2.097848358230489: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 131.50it/s]


step count: 750 | name: yoshiko | gender: Female
step count: 750 | name: yoshika | gender: Male
EPOCH: 7


loss: 2.080983877182007 | training loss: 2.0826602093378703:   9%|█████████▌                                                                                                     | 65/751 [00:00<00:05, 124.89it/s]

step count: 50 | name: yuumi | gender: Female
step count: 50 | name: yuusuke | gender: Male


loss: 2.1254055500030518 | training loss: 2.088418445279521:  16%|█████████████████▏                                                                                            | 117/751 [00:00<00:05, 124.64it/s]

step count: 100 | name: yashima | gender: Female
step count: 100 | name: yashimaru | gender: Male


loss: 2.128410816192627 | training loss: 2.086179007121495:  23%|████████████████████████▉                                                                                      | 169/751 [00:01<00:04, 124.71it/s]

step count: 150 | name: yukimi | gender: Female
step count: 150 | name: younghoon | gender: Male


loss: 2.153684377670288 | training loss: 2.0899422828640257:  29%|████████████████████████████████▎                                                                             | 221/751 [00:01<00:04, 123.13it/s]

step count: 200 | name: yuusuke | gender: Female
step count: 200 | name: yoshinosuke | gender: Male


loss: 2.109699249267578 | training loss: 2.091477231387674:  36%|████████████████████████████████████████▎                                                                      | 273/751 [00:02<00:03, 123.26it/s]

step count: 250 | name: yuubei | gender: Female
step count: 250 | name: younghyung | gender: Male


loss: 2.0632846355438232 | training loss: 2.090511634166424:  43%|███████████████████████████████████████████████▌                                                              | 325/751 [00:02<00:03, 125.01it/s]

step count: 300 | name: yukina | gender: Female
step count: 300 | name: yukitaka | gender: Male


loss: 2.1593730449676514 | training loss: 2.093007394110176:  49%|█████████████████████████████████████████████████████▉                                                        | 368/751 [00:02<00:02, 132.85it/s]

step count: 350 | name: yukino | gender: Female
step count: 350 | name: yukitaka | gender: Male


loss: 2.1600818634033203 | training loss: 2.093095566126446:  57%|██████████████████████████████████████████████████████████████▍                                               | 426/751 [00:03<00:02, 136.28it/s]

step count: 400 | name: yoshima | gender: Female
step count: 400 | name: yoshitoki | gender: Male


loss: 2.0894384384155273 | training loss: 2.0942095018282876:  62%|████████████████████████████████████████████████████████████████████                                         | 469/751 [00:03<00:02, 136.40it/s]

step count: 450 | name: yukihime | gender: Female
step count: 450 | name: yamanosuke | gender: Male


loss: 2.1297008991241455 | training loss: 2.0949975019161795:  70%|████████████████████████████████████████████████████████████████████████████▎                                | 526/751 [00:04<00:01, 137.82it/s]

step count: 500 | name: yuuka | gender: Female
step count: 500 | name: yamaka | gender: Male


loss: 2.0849497318267822 | training loss: 2.0943936632245626:  76%|██████████████████████████████████████████████████████████████████████████████████▍                          | 568/751 [00:04<00:01, 136.78it/s]

step count: 550 | name: yuuka | gender: Female
step count: 550 | name: yamanosuke | gender: Male


loss: 2.1639304161071777 | training loss: 2.09489219553733:  83%|████████████████████████████████████████████████████████████████████████████████████████████▏                  | 624/751 [00:04<00:00, 137.75it/s]

step count: 600 | name: yuuta | gender: Female
step count: 600 | name: yuuta | gender: Male


loss: 2.0836479663848877 | training loss: 2.0951461684580566:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▋            | 666/751 [00:05<00:00, 136.62it/s]

step count: 650 | name: yuuko | gender: Female
step count: 650 | name: yukimaru | gender: Male


loss: 2.0737173557281494 | training loss: 2.0953957645896377:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▉    | 723/751 [00:05<00:00, 137.32it/s]

step count: 700 | name: yuri | gender: Female
step count: 700 | name: yamagawa | gender: Male


loss: 1.987791895866394 | training loss: 2.094895131419724: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:05<00:00, 131.40it/s]

step count: 750 | name: yukitsu | gender: Female
step count: 750 | name: yoshino | gender: Male





In [24]:
model.eval()

RNN(
  (W_hh): Linear(in_features=1024, out_features=1024, bias=True)
  (W_xh): Linear(in_features=512, out_features=1024, bias=True)
  (W_hy): Linear(in_features=1024, out_features=30, bias=True)
  (embeddings): Embedding(30, 512)
)

In [32]:
inference(model, "h", "Male")

'hanan'

In [34]:
torch.topk(,2)

torch.return_types.topk(
values=tensor([9, 8]),
indices=tensor([9, 8]))

In [37]:
random.choice(torch.arange(10).topk(2).indices.tolist())

9

In [45]:
torch.arange(10).sort(descending=True)

torch.return_types.sort(
values=tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
indices=tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0]))

In [43]:
torch.sort?

[0;31mDocstring:[0m
sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor)

Sorts the elements of the :attr:`input` tensor along a given dimension
in ascending order by value.

If :attr:`dim` is not given, the last dimension of the `input` is chosen.

If :attr:`descending` is ``True`` then the elements are sorted in descending
order by value.

If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving
the order of equivalent elements.

A namedtuple of (values, indices) is returned, where the `values` are the
sorted values and `indices` are the indices of the elements in the original
`input` tensor.

Args:
    input (Tensor): the input tensor.
    dim (int, optional): the dimension to sort along
    descending (bool, optional): controls the sorting order (ascending or descending)
    stable (bool, optional): makes the sorting routine stable, which guarantees that the order
       of equivalent elements is preserved.

Keyword arg

In [46]:
1

1

In [47]:
torch.save(model.state_dict(), "./weights/base_male_female_RNN.pt")

In [48]:
max(range(10))

9

In [49]:
z = [*range(10)]

In [50]:
max(i for i in z)

9

In [52]:
model.eval()

RNN(
  (W_hh): Linear(in_features=1024, out_features=1024, bias=True)
  (W_xh): Linear(in_features=512, out_features=1024, bias=True)
  (W_hy): Linear(in_features=1024, out_features=30, bias=True)
  (embeddings): Embedding(30, 512)
)

In [55]:
inference(model, "y", "Male")

'yoshino'