In [1]:
import pandas as pd
from pathlib import Path
import regex as re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.lstm.functions import encode, f_score

# The end of preprocessing

In [2]:
data_root = Path('../../data/NER/processed/')

In [3]:
data = pd.read_csv(data_root / 'tokens_labels_lstm.csv', sep = ';')

In [4]:
data.head()

Unnamed: 0,token,label
0,заметной,0
1,добивался,0
2,регистраторши,0
3,подгоревшего,0
4,ленбах,1


# Training

In [5]:
X = list(data['token'])
y = list(data['label'])
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [6]:
from torch.utils.data import DataLoader
from models.lstm.token_dataset import TokensDataset

In [7]:
train_ds = TokensDataset(X_train, y_train)
test_ds = TokensDataset(X_test, y_test)

In [8]:
batch_size = 1000
vocab_size = 33
max_len = len(max(data['token'], key=lambda i: len(i)))
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn = TokensDataset.collate_fn)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True, collate_fn = TokensDataset.collate_fn)
emb_dim = 16

In [9]:
from tensorboardX import SummaryWriter
writer = SummaryWriter()

In [10]:
import wandb
wandb.login()

wandb: Currently logged in as: metpinc (use `wandb login --relogin` to force relogin)


True

In [11]:
wandb.init(project="ner_network", config={
    "learning_rate": 0.01,
})
config = wandb.config

wandb: wandb version 0.12.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [12]:
from tqdm import tqdm 
import sklearn
def train_model(model, epochs=10, lr=0.001):
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(parameters, lr=lr)
    for i in range(epochs):
        
        print('\n\n')
        print('------- EPOCH', i, '--------' )
        
        train_prebs_history = []
        train_labels_history = []
        train_accuracy_history = []
        
        test_preds_history = []
        test_labels_history = []

        for batch_idx, (data, label) in tqdm(enumerate(train_dl), total = len(train_dl)):

            optimizer.zero_grad()
            
            preds = model.forward(data)

            loss = nn.CrossEntropyLoss()

            output = loss(preds, label)
            
            output.backward()
            optimizer.step()

            train_prebs_history.append(preds)
            train_labels_history.append(label)
                
                
        print('--------TRAIN----------')
        wandb.log({"roc_auc/train": sklearn.metrics.roc_auc_score(label, preds.argmax(dim=1), average= 'weighted')})
        print(f_score(train_prebs_history, train_labels_history))
        train_accuracy = (preds.argmax(dim=1) == label).float().mean()
        writer.add_scalar('Accuracy/train', train_accuracy, i)
        writer.add_scalar('Loss/train', output, i)
        writer.add_scalar('F1_score/train', f_score(train_prebs_history, train_labels_history), i) 
        
        
        for batch_idx, (data, label) in tqdm(enumerate(test_dl), total = len(test_dl)):
            
            test_preds = model.forward(data.long())
            test_preds_history.append(test_preds)
            test_labels_history.append(label)
            
            test_loss = loss(test_preds, label)
        
        print('-----------TEST----------')
        wandb.log({"roc_auc/test": sklearn.metrics.roc_auc_score(label, test_preds.argmax(dim=1), average= 'weighted')})
        print(f_score(test_preds_history, test_labels_history))
        test_accuracy = (test_preds.argmax(dim=1) == label).float().mean()
        writer.add_scalar('Accuracy/test', test_accuracy, i)
        writer.add_scalar('Loss/test', test_loss, i)
        writer.add_scalar('F1_score/test', f_score(test_preds_history, test_labels_history), i)
            

            

In [13]:
from models.lstm.model import LSTMFixedLen

In [14]:
model_fixed = LSTMFixedLen(vocab_size, emb_dim, 128, max_len)

In [15]:
train_model(model_fixed, epochs=10, lr=0.01)

  0%|                                                                                           | 0/53 [00:00<?, ?it/s]




------- EPOCH 0 --------


  allow_unreachable=True)  # allow_unreachable flag
100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:15<00:00,  3.40it/s]


--------TRAIN----------
0.7736961100697597


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.27it/s]


-----------TEST----------
0.8645970937912815



------- EPOCH 1 --------


100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:14<00:00,  3.55it/s]
  0%|                                                                                           | 0/14 [00:00<?, ?it/s]

--------TRAIN----------
0.8731958762886598


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.27it/s]
  0%|                                                                                           | 0/53 [00:00<?, ?it/s]

-----------TEST----------
0.8864462809917356



------- EPOCH 2 --------


100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:14<00:00,  3.60it/s]


--------TRAIN----------
0.8912600160669043


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.49it/s]
  0%|                                                                                           | 0/53 [00:00<?, ?it/s]

-----------TEST----------
0.9003424943613734



------- EPOCH 3 --------


100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:14<00:00,  3.60it/s]
  0%|                                                                                           | 0/14 [00:00<?, ?it/s]

--------TRAIN----------
0.9059405940594059


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.10it/s]


-----------TEST----------
0.9017083436494182



------- EPOCH 4 --------


100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:14<00:00,  3.58it/s]


--------TRAIN----------
0.9148672639445109


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.45it/s]


-----------TEST----------
0.9085954917017587



------- EPOCH 5 --------


100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:14<00:00,  3.60it/s]


--------TRAIN----------
0.9233848502981656


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.46it/s]
  0%|                                                                                           | 0/53 [00:00<?, ?it/s]

-----------TEST----------
0.9063801432617026



------- EPOCH 6 --------


100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:14<00:00,  3.63it/s]
  0%|                                                                                           | 0/14 [00:00<?, ?it/s]

--------TRAIN----------
0.9298888431453273


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.42it/s]
  0%|                                                                                           | 0/53 [00:00<?, ?it/s]

-----------TEST----------
0.9084935107606374



------- EPOCH 7 --------


100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [00:14<00:00,  3.57it/s]


--------TRAIN----------
0.9400612633496152


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00,  8.23it/s]


-----------TEST----------
0.9077368638808441



------- EPOCH 8 --------


 25%|████████████████████                                                              | 13/53 [00:03<00:12,  3.31it/s]


KeyboardInterrupt: 

In [16]:
model_fixed.save('./weights/model_lstm_fixed.pt')

In [17]:
loaded_model = LSTMFixedLen().load('./weights/model_lstm_fixed.pt')

In [18]:
loaded_model.prediction('')

1

In [19]:
loaded_model.prediction('стол')

1

In [21]:
loaded_model.extract_names('привет меня я и ручка стул комод шкаф тумбочка рамка')

['привет',
 'меня',
 'я',
 'и',
 'ручка',
 'стул',
 'комод',
 'шкаф',
 'тумбочка',
 'рамка']

In [22]:
a = 'привет меня зовут майя'
type(a.split(' '))

list

In [23]:
import numpy
a = torch.tensor([1, 2]).detach().numpy()
a.argmax()

1

In [24]:
data.iloc[2].label

0

In [25]:
for i in range(10000):
    if loaded_model.prediction(data.iloc[i].token)!=data.iloc[i].label:
        print(data.iloc[i].token)
        print('prediction: ', loaded_model.prediction(data.iloc[i].token))
        print('label:', data.iloc[i].label)
        print('\n\n')

заметной
prediction:  1
label: 0



добивался
prediction:  1
label: 0



регистраторши
prediction:  1
label: 0



просившей
prediction:  1
label: 0



создавая
prediction:  1
label: 0



обеим
prediction:  1
label: 0



шенграбеном
prediction:  1
label: 0



посторонись
prediction:  1
label: 0



влиятельным
prediction:  1
label: 0



падали
prediction:  1
label: 0



отражающие
prediction:  1
label: 0



проспали
prediction:  1
label: 0



записка
prediction:  1
label: 0



оборотом
prediction:  1
label: 0



подойдт
prediction:  1
label: 0



входившему
prediction:  1
label: 0



крючковатый
prediction:  1
label: 0



листках
prediction:  1
label: 0



отрешнное
prediction:  1
label: 0



пуки
prediction:  1
label: 0



смены
prediction:  1
label: 0



скривя
prediction:  1
label: 0



копии
prediction:  1
label: 0



спутал
prediction:  1
label: 0



квадратик
prediction:  1
label: 0



бульварном
prediction:  1
label: 0



точка
prediction:  1
label: 0



плену
prediction:  1
label

KeyboardInterrupt: 

In [26]:
labels = data.label[:100]

In [29]:
string = ''
for i in a:
    string +=i
    string+=' '
print(string)

TypeError: must be str, not numpy.int64

In [30]:
loaded_model.extract_names(string)

['']