# Notebook służący do wyprodukowania własnego modelu

In [1]:
%matplotlib widget
import numpy as np
import math
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import random

import matplotlib.pyplot as plt

import re

from tqdm import tqdm
from tqdm import tnrange, tqdm_notebook

from joblib import Parallel, delayed

import multiprocessing
from datetime import datetime

from torch.utils.data import DataLoader, Dataset
from IPython.display import clear_output

Wybór urządzenia na którym ma przebiegać uczenie. Nie polecamy CPU, a nawet odradzamy!

In [2]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0")
device

device(type='cuda', index=0)

Cele reprodukcyjne!

In [3]:
torch.manual_seed(1010101011)
random.seed(1010101011)

## Załadowania datasetów
Tutaj właściwie testset(walidacyjny) nie jest wykorzystywany, a limit co do słów został ustalony na długość pojedyńczego hasła poniżej 22 znaków.

In [4]:
ftrainset = open("../dataset/train_all.txt", encoding='utf8')
trainset = [slowo.replace("\n", "") for slowo in ftrainset.readlines() if len(slowo) < 22]
ftrainset.close()

ftestset = open("../dataset/test_all.txt", encoding='utf8')
testset = [slowo.replace("\n", "") for slowo in ftestset.readlines()]
ftestset.close()

In [6]:
print("Train: ", len(trainset))
print("Test: ", len(testset))

Train:  169728
Test:  8934


## Preprocess danych
Należy zbudować słownik unikalnych znaków występujących w korpusie haseł. Dodajemy tutaj jeszcze znaki specjalne jak *< EMPTY >* oraz *< START >. Pełnią one bardzo ważną rolę przy uczeniu sekwencji modelu. Każda sekwencja wchodząca do modelu zaczyna się od *< START >*, a puste miejsca (czy tam końcówki) uzupełniane są *< EMPTY >*. Dodatkowo w celu ujednolicenia danych wejściowych wyznacza się tzw. wielkość okna, która odpowiada długości najdłuższego hasła + 1 (znak *< START >*).

In [7]:
chartoidx = {}

cnt = 0

longestword = 0

chartoidx["<EMPTY>"] = 0
chartoidx["<START>"] = 1

for slowo in tqdm_notebook(trainset+testset):  
    for litera in slowo:
        if litera not in list(chartoidx.keys()):
            chartoidx[litera] = cnt+2
            cnt = cnt + 1
            
    if len(slowo) > longestword:
        longestword = len(slowo)
            
vocabsize = len(list(chartoidx.keys()))

HBox(children=(IntProgress(value=0, max=178662), HTML(value='')))




Dodajemy 1 dla startu 😎

In [8]:
longestword += 1; longestword

20

## Sprawdzenie danych wejściowych

In [9]:
chartoidx

{'<EMPTY>': 0,
 '<START>': 1,
 '1': 2,
 '9': 3,
 '5': 4,
 '3': 5,
 'l': 6,
 'e': 7,
 'm': 8,
 'd': 9,
 'y': 10,
 'r': 11,
 'h': 12,
 'a': 13,
 'n': 14,
 'R': 15,
 'u': 16,
 's': 17,
 'i': 18,
 '0': 19,
 'p': 20,
 'z': 21,
 'o': 22,
 '7': 23,
 '8': 24,
 '2': 25,
 'f': 26,
 'c': 27,
 'k': 28,
 'b': 29,
 't': 30,
 '6': 31,
 'g': 32,
 '4': 33,
 'x': 34,
 'w': 35,
 'j': 36,
 'B': 37,
 'v': 38,
 'W': 39,
 'M': 40,
 'S': 41,
 'O': 42,
 'A': 43,
 'K': 44,
 'L': 45,
 'P': 46,
 'I': 47,
 'E': 48,
 'H': 49,
 'q': 50,
 'F': 51,
 'Z': 52,
 'V': 53,
 'D': 54,
 'C': 55,
 'J': 56,
 'G': 57,
 'Q': 58,
 'Y': 59,
 'U': 60,
 'N': 61,
 '.': 62,
 'T': 63,
 ' ': 64,
 'X': 65,
 '!': 66,
 '-': 67,
 '#': 68,
 '@': 69,
 '%': 70,
 '&': 71,
 '?': 72,
 '$': 73,
 '_': 74,
 '^': 75,
 '*': 76,
 '/': 77,
 '~': 78,
 '`': 79,
 ';': 80,
 '=': 81,
 '+': 82,
 ',': 83,
 '(': 84,
 ')': 85,
 'Ĺ': 86,
 '‚': 87,
 '[': 88,
 ']': 89}

In [11]:
slowo = trainset[0]; slowo

'1953lem'

## Zbudowanie klasy Datasetu
Całym *clue* modelu jest odpowiednie podawanie danych. Weźmy za przykład hasło *kicia08*. Dataset dla wielkości okna = 10 wygeneruje następujące próbki (każda pozycja musi zostać zastąpiona ID z *chartoidx*:

[< START >, k, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >]

[< START >, k, i, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >]

[< START >, k, i, c, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >]

[< START >, k, i, c, i, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >]

[< START >, k, i, c, i, a, < EMPTY >, < EMPTY >, < EMPTY >, < EMPTY >]

[< START >, k, i, c, i, a, 0, < EMPTY >, < EMPTY >, < EMPTY >]

[< START >, k, i, c, i, a, 0, 8, < EMPTY >, < EMPTY >]


In [12]:
class MyDataset(Dataset):
    def __init__(self, slowa, chartoidx, longestword, padding):
        self.slowa = slowa
        self.chartoidx = chartoidx
        self.longestword = longestword
        self.padding = padding
        self.indexpos = []
        
        """ tutaj fajny triczek wyszedł """
        """ napotykałem sytuacje, w których """
        """ klasa datasetu spowalniała uczenie """
        """ wynikało to z konieczności iterowania przez cały DS """
        """ by wygenerować sekwencję wejściową z odpowiednią długością """
        """ niemaskowanych liter """
        """ *** """
        """ w tym celu najlepiej wyznaczyć macierz, która poprzez operator """
        """ pobierania indexu (macierz[index]) będzie w stanie od razu wskazać """
        """ pojedyńczą i konkretną próbkę, bez większych iteracji """
        """ *** """
        """ indexpos zawiera obiekty tuple opisujące ID slowa, którego dotyczy """
        """ index oraz numer litery, do której należy uzupełnić okno literami """      
        
        for idd, slowo in enumerate(slowa):
            for idd2, letter in enumerate(slowo):
                self.indexpos.append((idd, idd2))
            self.indexpos.append((idd, idd2+1))
        
    def __len__(self):
        return len(self.indexpos)
        
    def __getitem__(self, index):
        literyx = []
        literyy = []
        
        """ wybierz słowo """
        slowo = self.slowa[self.indexpos[index][0]]
        
        """ przygotuj wektor próbki - zaczynający od <START> i wypełniony "<END>" """
        literyx = [self.chartoidx["<START>"]] + [self.chartoidx["<EMPTY>"] for _ in range(self.longestword - 1)]
        
        """ uzupełnianie aż do ID z tuple """
        for i in range(self.indexpos[index][1]):
            literyx[i+1] = self.chartoidx[slowo[i]]
        
        """ wstawienie odpowiedniego znaku jako wartości wyjściowej """
        if self.indexpos[index][1] < len(slowo):
            literyy = [self.chartoidx[slowo[self.indexpos[index][1]]]]
        else:
            literyy = [self.chartoidx["<EMPTY>"]]
                
        return np.array(literyx, dtype="float32"), np.array(literyy, dtype="long")

In [13]:
DS_train = MyDataset(trainset, chartoidx, longestword, 0)
DS_test = MyDataset(testset, chartoidx, longestword, 0)

## Hiperparametry
Warto przestawić w zależności od sprzętu 😎

In [14]:
BS = 1500

lstms = 15
hiddensize = 40

epochs = 50

lrmin = 1e-8
lrmax = 1e-3

In [16]:
DL_train = DataLoader(dataset=DS_train, batch_size=BS, num_workers=0)
DL_test = DataLoader(dataset=DS_test, batch_size=BS, num_workers=0)

## Klasa modelu
![Przygotowany przez nas model](https://i.imgur.com/ytNEXZc.png)

In [17]:
class CharacterLSTM(nn.Module):
    def __init__(self, vocabsize, lstmlayers, hiddensize):
        super(CharacterLSTM, self).__init__()
        
        ## WARSTWY
        self.embd = nn.Embedding(vocabsize, vocabsize)
        self.LSTM1 = nn.GRU(vocabsize, hiddensize, lstmlayers, batch_first=True, bidirectional=True)
        self.linear_ins = nn.Linear(2*hiddensize, vocabsize)

        self.drop = nn.Dropout(p=0.1)
        
        ## OUTS
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x, hidden, NLL=True): 
        # WEJSCIE
        y0 = self.embd(x)
        
        # LSTM
        y, h1_ = self.LSTM1(y0, hidden)
        
        y = self.drop(y)
        
        # LINEAR OUT 1
        y = self.linear_ins(y)
        
        """ Ewentualnie, gdyby się kto chciał bawić w NLLLoss """
        if NLL:
            y = self.softmax(y[:,-1])
    
        return y, h1_

In [19]:
chlstm = CharacterLSTM(vocabsize, lstms, hiddensize).to(device)

In [20]:
criterionPretraining = nn.CrossEntropyLoss()

In [23]:
optimizerLSTM = optim.RMSprop(chlstm.parameters(), lr=lrmax)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizerLSTM, epochs, eta_min = lrmin)

## Pętla ucząca

In [24]:
t_epochs = tnrange(epochs)

losssess = []

for epoch in t_epochs:
    t_batch = tqdm_notebook(DL_train, leave=False)
    
    for batch in t_batch:
        d1, d2 = batch[0].shape
        xreals = batch[0].long().to(device)
        y = batch[1].long().to(device)
        
        hiddens1 = torch.zeros(2*lstms, d1, hiddensize).to(device)
        
        chlstm.train()
        optimizerLSTM.zero_grad()
        
        y_, _ = chlstm(xreals, hiddens1, NLL=False)
        loss = criterionPretraining(y_[:,-1], y.view(-1))
        
        loss.backward()
        optimizerLSTM.step()
        
        losss = loss.item()
        losssess.append(losss)
        
        t_batch.set_description("Loss: {:.8f}".format(losss))
        

    t_batch.close()
    t_epochs.set_description("Epoch {}/{}".format(epoch+1, epochs))
    
    print("Epoch {}/{}, Loss {:.8f}, LR {:.8f}".format(epoch+1, epochs, losss, scheduler.get_lr()[0]))
    scheduler.step()

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 1/50, Loss 2.86446047, LR 0.00100000


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 2/50, Loss 2.77597690, LR 0.00099803


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 3/50, Loss 2.72162485, LR 0.00099311


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 4/50, Loss 2.68221998, LR 0.00098625


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 5/50, Loss 2.64546180, LR 0.00097749


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 6/50, Loss 2.62575030, LR 0.00096684


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 7/50, Loss 2.59372830, LR 0.00095436


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 8/50, Loss 2.59944081, LR 0.00094010


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 9/50, Loss 2.57402825, LR 0.00092411


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 10/50, Loss 2.56117773, LR 0.00090645


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 11/50, Loss 2.55739880, LR 0.00088719


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 12/50, Loss 2.54003835, LR 0.00086642


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 13/50, Loss 2.53934717, LR 0.00084420


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 14/50, Loss 2.54102850, LR 0.00082064


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 15/50, Loss 2.53015208, LR 0.00079581


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 16/50, Loss 2.51514101, LR 0.00076983


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 17/50, Loss 2.52420425, LR 0.00074279


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 18/50, Loss 2.50713086, LR 0.00071480


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 19/50, Loss 2.50442982, LR 0.00068596


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 20/50, Loss 2.50744081, LR 0.00065640


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 21/50, Loss 2.49528503, LR 0.00062624


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 22/50, Loss 2.48372793, LR 0.00059558


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 23/50, Loss 2.47770572, LR 0.00056455


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 24/50, Loss 2.47945356, LR 0.00053327


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 25/50, Loss 2.46733999, LR 0.00050187


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 26/50, Loss 2.46504736, LR 0.00047046


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 27/50, Loss 2.46500683, LR 0.00043919


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 28/50, Loss 2.46680188, LR 0.00040815


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 29/50, Loss 2.46166515, LR 0.00037749


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 30/50, Loss 2.45627785, LR 0.00034732


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 31/50, Loss 2.46001816, LR 0.00031776


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 32/50, Loss 2.45061159, LR 0.00028892


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 33/50, Loss 2.45044708, LR 0.00026092


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 34/50, Loss 2.44100094, LR 0.00023387


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 35/50, Loss 2.45267797, LR 0.00020788


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 36/50, Loss 2.43763709, LR 0.00018304


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 37/50, Loss 2.43053865, LR 0.00015947


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 38/50, Loss 2.43240786, LR 0.00013724


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 39/50, Loss 2.42440939, LR 0.00011644


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 40/50, Loss 2.42890120, LR 0.00009716


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 41/50, Loss 2.42494488, LR 0.00007948


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 42/50, Loss 2.41430640, LR 0.00006345


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 43/50, Loss 2.42449689, LR 0.00004915


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 44/50, Loss 2.41393828, LR 0.00003662


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 45/50, Loss 2.42209053, LR 0.00002592


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 46/50, Loss 2.41481733, LR 0.00001707


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 47/50, Loss 2.41870141, LR 0.00001009


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 48/50, Loss 2.42077041, LR 0.00000500


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 49/50, Loss 2.41326308, LR 0.00000177


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))

Epoch 50/50, Loss 2.42311025, LR 0.00000026



## Zapisanie modelu i innych parametrów do ewaluacji

In [27]:
torch.save(chlstm, "../models/NEWDS_START_bezrelu_lstm_15_hidden_40_cosine1e-8_rmsprop1e-7_50epoch_loss_"+str(losss)+".pt")

  "type " + obj.__name__ + ". It won't be checked "


In [39]:
torch.save([chartoidx, longestword, lstms, hiddensize], "../models/zmienne_modelu.pth")