In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from nltk.tokenize import WordPunctTokenizer
import gensim
from gensim.models import KeyedVectors
from IPython.display import clear_output
from collections import Counter
from tqdm import tqdm
import seaborn as sns
import nltk

%matplotlib inline

In [2]:
from pymystem3 import Mystem
from string import punctuation

In [3]:
def read_queries_with_lemmatization(path):
    f = open(path)
    queries = []
    tags = []
    mystem = Mystem()
    tokenizer = WordPunctTokenizer()
    
    for line in f:
        tmp = []
        tmp_ = []
        text = tokenizer.tokenize(line.lower())
        text_tagged = nltk.pos_tag(text, lang='rus')
        
        for i, q in enumerate(text):
            if not np.all(np.any(np.array(list(q)).reshape(-1, 1) == np.array(list(punctuation)).reshape(1, -1), axis=1)):
                q_ = mystem.lemmatize(q)
                tmp.append("".join(q_).split()[0])
                tmp_.append(text_tagged[i][1])
        queries.append(tmp)
        tags.append(tmp_)
    f.close()
    return (queries, tags)

In [4]:
path = 'data/requests.uniq.train'
train_lem = read_queries_with_lemmatization(path)
path = 'data/requests.uniq.test'
test_lem = read_queries_with_lemmatization(path)
train_lem[0][:5], train_lem[1][:5], test_lem[0][:5], test_lem[1][:5], len(train_lem[0]), len(test_lem[0])

([['сибирский', 'сеть', 'личный', 'кабинет', 'бердск'],
  ['1', 'сантим', 'алжир', '1964'],
  ['река', 'колыма', 'на', 'карта', 'россия'],
  ['ноофен', 'для', 'какой', 'болезнь'],
  ['маус', 'хаус', 'спб']],
 [['A=pl', 'S', 'A=m', 'S', 'S'],
  ['NUM=ciph', 'V', 'S', 'NUM=ciph'],
  ['S', 'S', 'PR', 'S', 'S'],
  ['V', 'PR', 'A-PRO=pl', 'A=f'],
  ['NONLEX', 'NONLEX', 'NONLEX']],
 [['сбербанк', 'в', 'кунцево', 'плаза'],
  ['торт', 'дикий', 'вишня'],
  ['тася', 'кривун', 'танец', 'на', 'тнт'],
  ['рбт', 'ру'],
  ['toplü', 'vay', 'sexx']],
 [['V', 'PR', 'S', 'S'],
  ['S', 'A=f', 'S'],
  ['S', 'S', 'S', 'PR', 'S'],
  ['V', 'S'],
  ['NONLEX', 'NONLEX', 'NONLEX']],
 51353,
 21174)

In [5]:
count_words = Counter()

for d in [train_lem[0]]:
    for q in d:
        for word in q:
            count_words[word] += 1
        
freq, counts = np.unique(np.array(list(count_words.values())), return_counts=True) 
p = counts * freq 
p = p / p.sum()
p = np.cumsum(p)
freq[:10], p[:10]

(array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]),
 array([0.12185658, 0.16966756, 0.20326518, 0.23018925, 0.2515358 ,
        0.27062838, 0.28792218, 0.30358516, 0.31885922, 0.32964299]))

In [6]:
emb_2 = KeyedVectors.load_word2vec_format("ft_native_300_ru_wiki_lenta_lower_case.vec")

In [7]:
tmp = np.unique(np.hstack(train_lem[1]))
tmp1 = np.unique(np.hstack(test_lem[1]))
tmp = np.hstack([tmp, tmp1])
tmp = np.unique(tmp)
tmp, len(tmp)

(array(['A', 'A-PRO', 'A-PRO=f', 'A-PRO=m', 'A-PRO=n', 'A-PRO=pl',
        'A-PRO=sg', 'A=brev', 'A=comp', 'A=comp2', 'A=f', 'A=m', 'A=n',
        'A=pl', 'A=sg', 'ADV', 'ADV-PRO', 'ADV-PRO=abbr', 'ADV-PRO=comp',
        'ADV-PRO=distort', 'ADV=abbr', 'ADV=comp', 'ADV=comp2',
        'ANUM=ciph', 'ANUM=f', 'ANUM=m', 'ANUM=n', 'ANUM=pl', 'ANUM=sg',
        'CONJ', 'INIT=abbr', 'INTJ', 'INTJ=distort', 'NONLEX',
        'NONLEX=abbr', 'NUM', 'NUM=acc', 'NUM=ciph', 'NUM=comp', 'NUM=dat',
        'NUM=f', 'NUM=gen', 'NUM=ins', 'NUM=loc', 'NUM=m', 'NUM=n',
        'NUM=nom', 'PARENTH', 'PART', 'PR', 'PRAEDIC', 'PRAEDIC-PRO',
        'PRAEDIC=comp', 'S', 'S-PRO', 'S-PRO=acc', 'S-PRO=dat',
        'S-PRO=gen', 'S-PRO=ins', 'S-PRO=loc', 'S-PRO=n=sg', 'S-PRO=pl',
        'S=m', 'V'], dtype='<U32'), 64)

In [8]:
tags = tmp
tags_to_ind = {}
ind = 0
for t in tags:
    tags_to_ind[t] = ind
    ind += 1

# Сетки

* аддитивный attention

In [9]:
import torch, torch.nn as nn
import torch.nn.functional as F

In [10]:
def calculate_n_tokens(emb):
    n_tokens = 0
    for word in count_words.keys():
        if word in emb.vocab and count_words[word] >= 3:
            n_tokens += 1
    return n_tokens + 1

In [11]:
def transform_to_features(emb, emb_size, ind_to_word, batch_x, batch_x_tags):
    to_emb = np.zeros((len(batch_x), len(batch_x[0])+1, emb_size + len(tags)))
    for i in range(len(batch_x)):
        to_emb[i][0] = np.ones(emb_size + len(tags))
        for j in range(len(batch_x[i])):
            if batch_x[i][j] != pad_id:
                to_emb[i][j+1][:emb_size] = emb[ind_to_word[batch_x[i][j]]]
                if batch_x_tags[i][j] >= 0:
                    to_emb[i][j+1][emb_size + batch_x_tags[i][j]] = 1
    return to_emb

## Архитектуры сетей

In [12]:
class Net(nn.Module):
    def __init__(self, emb, ind_to_word, emb_size=300, lstm_units=256):
        super(self.__class__, self).__init__()
        n_tokens = calculate_n_tokens(emb)
        self.lstm = nn.LSTM(emb_size + len(tags), lstm_units, batch_first=True)
        self.logits = nn.Linear(lstm_units, n_tokens) 
        self.emb = emb
        self.emb_size = emb_size
        self.ind_to_word = ind_to_word
        
    def forward(self, batch_x, batch_x_tags):
        input_emb = transform_to_features(self.emb, self.emb_size, self.ind_to_word, batch_x, batch_x_tags)
        input_emb = torch.tensor(input_emb, dtype=torch.float32)
        lstm_out = self.lstm(input_emb)
        logits = self.logits(lstm_out[0])
        
        return logits

In [13]:
class NetWithAttention(nn.Module):
    def __init__(self, emb, ind_to_word, emb_size=300, lstm_units=256, hid_size=256):
        super(self.__class__, self).__init__()
        n_tokens = calculate_n_tokens(emb)
        self.lstm = nn.LSTM(emb_size + len(tags), lstm_units, batch_first=True)
        self.linear_lstm = nn.Linear(lstm_units, hid_size)
        self.linear_out_lstm = nn.Linear(lstm_units, hid_size)
        self.final_linear = nn.Linear(hid_size, 1)
        self.logits = nn.Linear(hid_size, n_tokens)
        self.emb = emb
        self.emb_size = emb_size
        self.ind_to_word = ind_to_word
        
    def forward(self, batch_x, batch_x_tags):
        input_emb = transform_to_features(self.emb, self.emb_size, self.ind_to_word, batch_x, batch_x_tags)
        input_emb = torch.tensor(input_emb, dtype=torch.float32)
        lstm_out = self.lstm(input_emb)
        #lstm_out = lstm_out[0][1:]
        
        
        lstm_out_linear = self.linear_out_lstm(lstm_out[0])
        lstm_linear = self.linear_lstm(lstm_out[0])
        pre_logits = torch.zeros_like(lstm_out[0])
        pre_logits[:, 0, :] = lstm_out[0][:, 0, :]
        pre_logits[:, 1, :] = lstm_out[0][:, 1, :]
        
        mask = np.array([batch_x == pad_id], dtype=int)[0]
        mask = torch.tensor(mask, dtype=torch.float32)
        #print(mask.shape)
        
        for i in range(2, pre_logits.shape[1]):
            to_add = lstm_out_linear[:, i, :]
            basic = lstm_linear[:, 1:i+1, :]
            to_add = to_add.reshape(pre_logits.shape[0], 1, pre_logits.shape[2])
            tmp = basic + to_add
            tmp = torch.tanh(tmp)
            tmp = self.final_linear(tmp)
            tmp = tmp.reshape(tmp.shape[0], tmp.shape[1])
            tmp = F.softmax(tmp, dim=1)
            #print(tmp.shape)
            #print(mask.shape)
            mask = np.array([batch_x == pad_id], dtype=int)[0][:, :i]
            mask[:, -1] = np.ones(mask.shape[0])
            mask = torch.tensor(mask, dtype=torch.float32)
            
            tmp = tmp * mask
            tmp = tmp / torch.sum(tmp, dim=1).reshape(-1, 1)
            tmp = tmp.reshape(tmp.shape[0], tmp.shape[1], 1)
            pre_logits[:, i, :] = torch.sum(tmp * lstm_out[0][:, 1:i+1, :], dim=1)
        
        logits = self.logits(pre_logits)
        return logits

In [14]:
pad = '#PAD#'
pad_id = 0

def construct_vocab(emb, count_words):
    word_to_ind = dict()
    word_to_ind['#PAD#'] = 0
    ind_to_word = ['#PAD#', ]
    
    count = 1
    for word in count_words.keys():
        if count_words[word] >= 3 and word in emb.vocab:
            ind_to_word.append(word)
            word_to_ind[word] = count
            count += 1
    return ind_to_word, word_to_ind


def as_matrix(sequences, tags, word_to_ind, max_len=None):
    """ Convert a list of tokens into a matrix with padding """
    max_len = max_len or max(map(len,sequences))
    
    matrix = np.zeros((2, len(sequences), max_len), dtype=int)
    for i, seq in enumerate(sequences):
        for j, word in enumerate(seq[:max_len]):
            if word in word_to_ind.keys():
                matrix[0][i][j] = word_to_ind[word]
                matrix[1][i][j] = tags_to_ind[tags[i][j]]
            else:
                matrix[0][i][j] = pad_id
                matrix[1][i][j] = -1
        for j in range(max_len, len(seq)):
            matrix[0][i][j] = pad_id
            matrix[1][i][j] = -1
    
    return matrix

In [15]:
ind_to_word, word_to_ind = construct_vocab(emb_2, count_words)
network = NetWithAttention(emb_2, ind_to_word)

In [16]:
def compute_loss(network, batch):
    """
    use scalar crossentropy loss (neg llh) loss 
    """
    batch_x = batch[0]
    batch_tags = batch[1]
    batch_x = np.array(batch_x)
    batch_tags = np.array(batch_tags)

    batch_x_inp = batch_x[:, :-1]
    batch_x_next = batch_x[:, 1:]
    batch_tags_inp = batch_tags[:, :-1]
    batch_tags_next = batch_tags[:, 1:]
    
    logits_for_next = network.forward(batch_x_inp, batch_tags_inp)
    logits_for_next = logits_for_next[:, 1:]
    
    answers = torch.argmax(logits_for_next, dim=-1).numpy()
    logits_for_next = logits_for_next.contiguous()
    logits_for_next = logits_for_next.view(-1, logits_for_next.shape[-1])
    
    accr = np.array([answers == batch_x_next]) * np.array([answers != pad_id])
    accr = accr.sum()
    to_div = np.sum(np.array([batch_x_next != pad_id]))
    batch_x_next = torch.tensor(batch_x_next, dtype=torch.int64)
    batch_x_next = batch_x_next.view(-1)
    
    loss = F.cross_entropy(logits_for_next, batch_x_next, ignore_index=pad_id, reduction='mean')
    
    
    return loss, accr, to_div

In [17]:
from random import choice

def generate_batch(train, batch_size, word_to_ind, max_len=None):
    random_x = np.random.randint(0, len(train[0]), size=batch_size)
    batch_x = []
    batch_tags = []
    for x in random_x:
        batch_x.append(train[0][x])
        batch_tags.append(train[1][x])
    return as_matrix(batch_x, batch_tags, word_to_ind, max_len)

In [18]:
batch_size = 64 
n_epochs = 20 
n_batches_per_epoch = 400  
n_validation_batches = 160

In [19]:
from tqdm import tqdm
from torch.optim import Adam

ind_to_word, word_to_ind = construct_vocab(emb_2, count_words)
network = NetWithAttention(emb_2, ind_to_word)
opt = Adam(network.parameters())

train_loss, val_loss, train_accr, val_accr = [], [], [], []

for epoch in range(n_epochs):
    train_loss_=0
    train_accr_=0
    to_div = 0
    network.train(True)
    for _ in tqdm(range(n_batches_per_epoch)):
        
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(train_lem, batch_size, word_to_ind))
        
        loss_t.backward()
        opt.step()
        opt.zero_grad()
        
        train_loss_ += loss_t.item()
        train_accr_ += accr_t.item()
        to_div += to_div_t
        
    train_loss_ /= n_batches_per_epoch
    #train_accr_ /= n_batches_per_epoch
    train_accr_ /= to_div
    
    val_loss_=0
    val_accr_=0
    to_div = 0
    network.train(False)
    for _ in range(n_validation_batches):
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(test_lem, batch_size, word_to_ind))
        
        val_loss_ += loss_t.item()
        val_accr_ += accr_t.item()
        to_div += to_div_t
        
    val_loss_ /= n_validation_batches
    #val_accr_ /= n_validation_batches
    val_accr_ /= to_div
    
    train_loss.append(train_loss_)
    val_loss.append(val_loss_)
    train_accr.append(train_accr_)
    val_accr.append(val_accr_)
    
    print('\nEpoch: {}, train loss: {}, val loss: {}'.format(epoch, train_loss_, val_loss_))
    print('\nEpoch: {}, train accr: {}, val accr: {}'.format(epoch, train_accr_, val_accr_))

print("Finished!")

100%|██████████| 400/400 [03:55<00:00,  1.88it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 0, train loss: 7.423060523271561, val loss: 7.038127017021179

Epoch: 0, train accr: 0.05640062682373284, val accr: 0.07788437710199313


100%|██████████| 400/400 [04:19<00:00,  1.87it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 1, train loss: 6.594546213150024, val loss: 6.531922671198845

Epoch: 1, train accr: 0.11072340539312023, val accr: 0.10850409119518847


100%|██████████| 400/400 [04:33<00:00,  1.60it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 2, train loss: 6.104903918504715, val loss: 6.355370584130287

Epoch: 2, train accr: 0.1335185135244465, val accr: 0.12147528680175834


100%|██████████| 400/400 [04:27<00:00,  1.94it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 3, train loss: 5.768276606798172, val loss: 6.197636517882347

Epoch: 3, train accr: 0.14809663349222854, val accr: 0.1302777388132978


100%|██████████| 400/400 [03:49<00:00,  1.93it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 4, train loss: 5.478484307527542, val loss: 6.077467900514603

Epoch: 4, train accr: 0.16635284376961393, val accr: 0.13939157842547012


100%|██████████| 400/400 [03:37<00:00,  2.28it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 5, train loss: 5.223425855636597, val loss: 6.061522600054741

Epoch: 5, train accr: 0.18128843850512064, val accr: 0.14998230088495576


100%|██████████| 400/400 [03:37<00:00,  1.94it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 6, train loss: 5.033553264141083, val loss: 6.016300585865975

Epoch: 6, train accr: 0.19589634858224078, val accr: 0.1527566842086674


100%|██████████| 400/400 [03:40<00:00,  1.33it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 7, train loss: 4.848545799255371, val loss: 6.030133962631226

Epoch: 7, train accr: 0.21071012805587894, val accr: 0.15863623555931247


100%|██████████| 400/400 [03:47<00:00,  1.62it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 8, train loss: 4.637233272194862, val loss: 5.978240105509758

Epoch: 8, train accr: 0.22795453016748504, val accr: 0.15952150710934423


100%|██████████| 400/400 [03:44<00:00,  1.42it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 9, train loss: 4.4731349092721935, val loss: 6.013766032457352

Epoch: 9, train accr: 0.24753330543820104, val accr: 0.15749294105343883


100%|██████████| 400/400 [03:47<00:00,  1.90it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 10, train loss: 4.340551814436912, val loss: 6.031679448485375

Epoch: 10, train accr: 0.2594030412121863, val accr: 0.15970196464344708


100%|██████████| 400/400 [03:34<00:00,  1.64it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 11, train loss: 4.178167462348938, val loss: 6.010131880640984

Epoch: 11, train accr: 0.2762568442010951, val accr: 0.15855325412680055


100%|██████████| 400/400 [03:52<00:00,  2.24it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 12, train loss: 4.044670078754425, val loss: 6.038313618302345

Epoch: 12, train accr: 0.2901517589017589, val accr: 0.16670189530050025


100%|██████████| 400/400 [03:29<00:00,  2.10it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 13, train loss: 3.9076805233955385, val loss: 6.103361284732818

Epoch: 13, train accr: 0.30927472289515356, val accr: 0.15643564356435644


100%|██████████| 400/400 [03:38<00:00,  1.57it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 14, train loss: 3.7983185827732084, val loss: 6.087379065155983

Epoch: 14, train accr: 0.3216994621093484, val accr: 0.16348541829276947


100%|██████████| 400/400 [03:51<00:00,  2.05it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 15, train loss: 3.6559036004543306, val loss: 6.134150323271752

Epoch: 15, train accr: 0.3420181913169281, val accr: 0.16439474618388356


100%|██████████| 400/400 [03:37<00:00,  1.94it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 16, train loss: 3.544307135939598, val loss: 6.165590533614159

Epoch: 16, train accr: 0.35751057303841427, val accr: 0.16379249166134413


100%|██████████| 400/400 [03:50<00:00,  1.82it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 17, train loss: 3.4641764444112777, val loss: 6.1288181185722355

Epoch: 17, train accr: 0.36971964131374246, val accr: 0.16460630750860433


100%|██████████| 400/400 [03:35<00:00,  2.33it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 18, train loss: 3.365718601942062, val loss: 6.159314242005348

Epoch: 18, train accr: 0.38398104265402844, val accr: 0.16759638469142776


100%|██████████| 400/400 [03:39<00:00,  2.27it/s]



Epoch: 19, train loss: 3.2709632605314254, val loss: 6.236696863174439

Epoch: 19, train accr: 0.40104025205062743, val accr: 0.16279069767441862
Finished!


In [20]:
torch.save(network.state_dict(), 'additive_attention.pwf')

In [21]:
def compute_accr(network, batch):
    batch_x = batch[0]
    batch_tag = batch[1]
    
    batch_x = np.array(batch_x)
    batch_tag = np.array(batch_tag)
    batch_x_inp = batch_x[:, :-1]
    batch_x_next = batch_x[:, 1:]
    batch_tag_inp = batch_tag[:, :-1]
    batch_tag_next = batch_tag[:, 1:]
    
    logits_for_next = network.forward(batch_x_inp, batch_tag_inp)
    logits_for_next = logits_for_next[:, 1:]
    
    answers = torch.argmax(logits_for_next, dim=-1).numpy()
    
    accr = np.array([answers == batch_x_next]) * np.array([answers != pad_id])
    accr = accr[0]
    accr = accr.sum(axis=0)
    to_divide = np.array([batch_x_next != pad_id])[0].sum(axis=0)
    
    return accr, to_divide

def get_batch(data, left, right, batch_size, word_to_ind, max_len=None):
    slice_x = np.arange(left, right, 1)
    batch_x = []
    batch_tag = []
    for x in slice_x:
        batch_x.append(data[0][x])
        batch_tag.append(data[1][x])
    return as_matrix(batch_x, batch_tag, word_to_ind, max_len)

def try_lengthes(data):
    accr = np.zeros(np.max(list(map(len, data[0]))))
    to_div = np.zeros(np.max(list(map(len, data[0]))))
    
    for _ in tqdm_notebook(range(0, len(data[0])-batch_size, batch_size)):
        accr_t, div_t = compute_accr(network, get_batch(data, _, _+batch_size, batch_size, word_to_ind))
        accr[:len(accr_t)] += accr_t
        to_div[:len(div_t)] += div_t
    eps = 1
    return accr / (to_div + eps), accr.sum() / to_div.sum()

In [22]:
def approximate_pad(data):
    to_pad = 0
    all_ = 0
    for x in data:
        for word in x:
            all_ += 1
            if word not in word_to_ind.keys():
                to_pad += 1
    return to_pad / all_

In [23]:
from tqdm import tqdm_notebook

on_train, all_accr = try_lengthes(train_lem)
on_train[:10], all_accr

HBox(children=(IntProgress(value=0, max=802), HTML(value='')))




(array([0.22727977, 0.34278654, 0.47388723, 0.52354379, 0.56286934,
        0.56820909, 0.57379249, 0.58333333, 0.54569362, 0.55749129]),
 0.40412957534357097)

In [24]:
from tqdm import tqdm_notebook

on_test, all_accr = try_lengthes(test_lem)
on_test[:10], all_accr

HBox(children=(IntProgress(value=0, max=330), HTML(value='')))




(array([0.14205298, 0.14598705, 0.17501215, 0.18585638, 0.19293589,
        0.19933222, 0.19914347, 0.19434307, 0.14467409, 0.10951009]),
 0.16324087691097508)

* attention с cosine similarity

In [27]:
class NetWithAttentionCosine(nn.Module):
    def __init__(self, emb, ind_to_word, emb_size=300, lstm_units=256, hid_size=256):
        super(self.__class__, self).__init__()
        n_tokens = calculate_n_tokens(emb)
        self.lstm = nn.LSTM(emb_size + len(tags), lstm_units, batch_first=True)
        self.logits = nn.Linear(hid_size, n_tokens)
        self.emb = emb
        self.emb_size = emb_size
        self.ind_to_word = ind_to_word
        
    def forward(self, batch_x, batch_x_tags):
        input_emb = transform_to_features(self.emb, self.emb_size, self.ind_to_word, batch_x, batch_x_tags)
        input_emb = torch.tensor(input_emb, dtype=torch.float32)
        lstm_out = self.lstm(input_emb)        
        
        pre_logits = torch.zeros_like(lstm_out[0])
        pre_logits[:, 0, :] = lstm_out[0][:, 0, :]
        pre_logits[:, 1, :] = lstm_out[0][:, 1, :]
        
        for i in range(2, pre_logits.shape[1]):
            current = lstm_out[0][:, i, :]
            previous = lstm_out[0][:, 1:i+1, :]
            current = current.reshape(pre_logits.shape[0], 1, pre_logits.shape[2])
            cosine = torch.sum(previous * current, dim=-1) / torch.sum(previous * previous, dim=-1)
            cosine /= torch.sum(current * current)
            cosine = F.softmax(cosine, dim=1)
            
            mask = np.array([batch_x == pad_id], dtype=int)[0][:, :i]
            mask[:, -1] = np.ones(mask.shape[0])
            mask = torch.tensor(mask, dtype=torch.float32)
            
            tmp = cosine * mask
            tmp = tmp / torch.sum(tmp, dim=1).reshape(-1, 1)
            tmp = tmp.reshape(tmp.shape[0], tmp.shape[1], 1)
            pre_logits[:, i, :] = torch.sum(tmp * lstm_out[0][:, 1:i+1, :], dim=1)
        
        logits = self.logits(pre_logits)
        return logits

In [28]:
from tqdm import tqdm
from torch.optim import Adam

ind_to_word, word_to_ind = construct_vocab(emb_2, count_words)
network = NetWithAttentionCosine(emb_2, ind_to_word)
opt = Adam(network.parameters())

train_loss, val_loss, train_accr, val_accr = [], [], [], []

for epoch in range(n_epochs):
    train_loss_=0
    train_accr_=0
    to_div = 0
    network.train(True)
    for _ in tqdm(range(n_batches_per_epoch)):
        
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(train_lem, batch_size, word_to_ind))
        
        loss_t.backward()
        opt.step()
        opt.zero_grad()
        
        train_loss_ += loss_t.item()
        train_accr_ += accr_t.item()
        to_div += to_div_t
        
    train_loss_ /= n_batches_per_epoch
    #train_accr_ /= n_batches_per_epoch
    train_accr_ /= to_div
    
    val_loss_=0
    val_accr_=0
    to_div = 0
    network.train(False)
    for _ in range(n_validation_batches):
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(test_lem, batch_size, word_to_ind))
        
        val_loss_ += loss_t.item()
        val_accr_ += accr_t.item()
        to_div += to_div_t
        
    val_loss_ /= n_validation_batches
    #val_accr_ /= n_validation_batches
    val_accr_ /= to_div
    
    train_loss.append(train_loss_)
    val_loss.append(val_loss_)
    train_accr.append(train_accr_)
    val_accr.append(val_accr_)
    
    print('\nEpoch: {}, train loss: {}, val loss: {}'.format(epoch, train_loss_, val_loss_))
    print('\nEpoch: {}, train accr: {}, val accr: {}'.format(epoch, train_accr_, val_accr_))

print("Finished!")

100%|██████████| 400/400 [03:39<00:00,  2.06it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 0, train loss: 7.44505086183548, val loss: 7.1137505799531935

Epoch: 0, train accr: 0.05662855687651624, val accr: 0.07278391390540924


100%|██████████| 400/400 [03:11<00:00,  2.30it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 1, train loss: 6.6961684632301335, val loss: 6.67335011959076

Epoch: 1, train accr: 0.1010148849797023, val accr: 0.10251135073779796


100%|██████████| 400/400 [03:12<00:00,  2.25it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 2, train loss: 6.236177433729171, val loss: 6.43335530757904

Epoch: 2, train accr: 0.12295665489126785, val accr: 0.11792368805443727


100%|██████████| 400/400 [03:10<00:00,  2.23it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 3, train loss: 5.887610920667648, val loss: 6.2970844358205795

Epoch: 3, train accr: 0.1419347043977877, val accr: 0.12405333333333333


100%|██████████| 400/400 [03:13<00:00,  2.35it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 4, train loss: 5.630113972425461, val loss: 6.220975038409233

Epoch: 4, train accr: 0.15478319327731094, val accr: 0.1336074714245888


100%|██████████| 400/400 [03:13<00:00,  2.07it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 5, train loss: 5.417030644416809, val loss: 6.17641467154026

Epoch: 5, train accr: 0.1676844852169702, val accr: 0.13669139568744368


100%|██████████| 400/400 [03:18<00:00,  2.15it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 6, train loss: 5.174362083673477, val loss: 6.178724017739296

Epoch: 6, train accr: 0.18613872068402812, val accr: 0.14180086026092212


100%|██████████| 400/400 [03:24<00:00,  1.96it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 7, train loss: 5.0053264272212985, val loss: 6.099540641903877

Epoch: 7, train accr: 0.20013222516055912, val accr: 0.1474485480687905


100%|██████████| 400/400 [03:09<00:00,  1.98it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 8, train loss: 4.83504567027092, val loss: 6.105298948287964

Epoch: 8, train accr: 0.2127192387974735, val accr: 0.1495208568207441


100%|██████████| 400/400 [03:10<00:00,  2.06it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 9, train loss: 4.647806782126427, val loss: 6.073650851845741

Epoch: 9, train accr: 0.23155523591114488, val accr: 0.15241282141599155


100%|██████████| 400/400 [03:27<00:00,  2.35it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 10, train loss: 4.488833258748055, val loss: 6.0926534056663515

Epoch: 10, train accr: 0.2483433048050868, val accr: 0.15284524635669675


100%|██████████| 400/400 [03:09<00:00,  2.08it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 11, train loss: 4.367524445056915, val loss: 6.085264527797699

Epoch: 11, train accr: 0.25973675693170784, val accr: 0.15710855286390282


100%|██████████| 400/400 [03:23<00:00,  2.19it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 12, train loss: 4.252106469869614, val loss: 6.086110788583755

Epoch: 12, train accr: 0.2699989262321486, val accr: 0.15500281056773468


100%|██████████| 400/400 [03:14<00:00,  1.52it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 13, train loss: 4.136816199421883, val loss: 6.092726942896843

Epoch: 13, train accr: 0.2853014261019879, val accr: 0.15868284471875438


100%|██████████| 400/400 [03:20<00:00,  1.65it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 14, train loss: 3.992578672170639, val loss: 6.15544844865799

Epoch: 14, train accr: 0.3023696682464455, val accr: 0.15538369177525868


100%|██████████| 400/400 [03:24<00:00,  2.09it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 15, train loss: 3.867003793120384, val loss: 6.12329578101635

Epoch: 15, train accr: 0.31621807770740207, val accr: 0.15921052631578947


100%|██████████| 400/400 [03:22<00:00,  1.14it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 16, train loss: 3.7838563936948777, val loss: 6.206732812523842

Epoch: 16, train accr: 0.3294041398306801, val accr: 0.15472139515099959


100%|██████████| 400/400 [03:29<00:00,  1.92it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 17, train loss: 3.6661137396097185, val loss: 6.205356431007385

Epoch: 17, train accr: 0.34507799892415275, val accr: 0.16134695740721233


100%|██████████| 400/400 [03:08<00:00,  2.27it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 18, train loss: 3.571841896176338, val loss: 6.259359449148178

Epoch: 18, train accr: 0.35601102076473357, val accr: 0.1579226660062971


100%|██████████| 400/400 [03:24<00:00,  2.12it/s]



Epoch: 19, train loss: 3.4841410970687865, val loss: 6.270016765594482

Epoch: 19, train accr: 0.3691182446160097, val accr: 0.16115981119352663
Finished!


In [29]:
torch.save(network.state_dict(), 'cosine_attention.pwf')

In [30]:
on_train, all_accr = try_lengthes(train_lem)
on_train[:10], all_accr

HBox(children=(IntProgress(value=0, max=802), HTML(value='')))




(array([0.22640145, 0.33578489, 0.43739364, 0.47390519, 0.49326986,
        0.48579811, 0.48703041, 0.48782344, 0.45627876, 0.43902439]),
 0.37205174407659414)

In [31]:
on_test, all_accr = try_lengthes(test_lem)
on_test[:10], all_accr

HBox(children=(IntProgress(value=0, max=330), HTML(value='')))




(array([0.14172185, 0.14365306, 0.17306757, 0.18476632, 0.18231931,
        0.18764608, 0.17558887, 0.16970803, 0.14626391, 0.09510086]),
 0.15946167789596088)

* attention с dot product similarity

In [34]:
class NetWithAttentionDotProduct(nn.Module):
    def __init__(self, emb, ind_to_word, emb_size=300, lstm_units=256, hid_size=256):
        super(self.__class__, self).__init__()
        n_tokens = calculate_n_tokens(emb)
        self.lstm = nn.LSTM(emb_size + len(tags), lstm_units, batch_first=True)
        self.logits = nn.Linear(hid_size, n_tokens)
        self.emb = emb
        self.emb_size = emb_size
        self.ind_to_word = ind_to_word
        
    def forward(self, batch_x, batch_x_tags):
        input_emb = transform_to_features(self.emb, self.emb_size, self.ind_to_word, batch_x, batch_x_tags)
        input_emb = torch.tensor(input_emb, dtype=torch.float32)
        lstm_out = self.lstm(input_emb)        
        
        pre_logits = torch.zeros_like(lstm_out[0])
        pre_logits[:, 0, :] = lstm_out[0][:, 0, :]
        pre_logits[:, 1, :] = lstm_out[0][:, 1, :]
        
        for i in range(2, pre_logits.shape[1]):
            current = lstm_out[0][:, i, :]
            previous = lstm_out[0][:, 1:i+1, :]
            current = current.reshape(pre_logits.shape[0], 1, pre_logits.shape[2])
            dot_product = torch.sum(previous * current, dim=-1)
            dot_product = F.softmax(dot_product, dim=1)
            
            mask = np.array([batch_x == pad_id], dtype=int)[0][:, :i]
            mask[:, -1] = np.ones(mask.shape[0])
            mask = torch.tensor(mask, dtype=torch.float32)
            
            tmp = dot_product * mask
            tmp = tmp / torch.sum(tmp, dim=1).reshape(-1, 1)
            tmp = tmp.reshape(tmp.shape[0], tmp.shape[1], 1)
            pre_logits[:, i, :] = torch.sum(tmp * lstm_out[0][:, 1:i+1, :], dim=1)
        
        logits = self.logits(pre_logits)
        return logits

In [35]:
from tqdm import tqdm
from torch.optim import Adam

ind_to_word, word_to_ind = construct_vocab(emb_2, count_words)
network = NetWithAttentionDotProduct(emb_2, ind_to_word)
opt = Adam(network.parameters())

train_loss, val_loss, train_accr, val_accr = [], [], [], []

for epoch in range(n_epochs):
    train_loss_=0
    train_accr_=0
    to_div = 0
    network.train(True)
    for _ in tqdm(range(n_batches_per_epoch)):
        
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(train_lem, batch_size, word_to_ind))
        
        loss_t.backward()
        opt.step()
        opt.zero_grad()
        
        train_loss_ += loss_t.item()
        train_accr_ += accr_t.item()
        to_div += to_div_t
        
    train_loss_ /= n_batches_per_epoch
    #train_accr_ /= n_batches_per_epoch
    train_accr_ /= to_div
    
    val_loss_=0
    val_accr_=0
    to_div = 0
    network.train(False)
    for _ in range(n_validation_batches):
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(test_lem, batch_size, word_to_ind))
        
        val_loss_ += loss_t.item()
        val_accr_ += accr_t.item()
        to_div += to_div_t
        
    val_loss_ /= n_validation_batches
    #val_accr_ /= n_validation_batches
    val_accr_ /= to_div
    
    train_loss.append(train_loss_)
    val_loss.append(val_loss_)
    train_accr.append(train_accr_)
    val_accr.append(val_accr_)
    
    print('\nEpoch: {}, train loss: {}, val loss: {}'.format(epoch, train_loss_, val_loss_))
    print('\nEpoch: {}, train accr: {}, val accr: {}'.format(epoch, train_accr_, val_accr_))

print("Finished!")

100%|██████████| 400/400 [03:29<00:00,  1.59s/it]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 0, train loss: 7.416083530187607, val loss: 7.038138619065284

Epoch: 0, train accr: 0.05802608883139284, val accr: 0.08346372688477952


100%|██████████| 400/400 [03:20<00:00,  2.27it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 1, train loss: 6.5898543739318844, val loss: 6.519152516126633

Epoch: 1, train accr: 0.11141763695619836, val accr: 0.11138875202365031


100%|██████████| 400/400 [03:15<00:00,  2.15it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 2, train loss: 6.10610211968422, val loss: 6.32130691409111

Epoch: 2, train accr: 0.1343413112753407, val accr: 0.12266789393563918


100%|██████████| 400/400 [03:18<00:00,  1.39it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 3, train loss: 5.750876413583756, val loss: 6.216240027546883

Epoch: 3, train accr: 0.1520407891583908, val accr: 0.13323290358744394


100%|██████████| 400/400 [03:30<00:00,  2.22it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 4, train loss: 5.495996385812759, val loss: 6.108094716072083

Epoch: 4, train accr: 0.1652547492783782, val accr: 0.14129738957251584


100%|██████████| 400/400 [03:27<00:00,  2.25it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 5, train loss: 5.2199262070655825, val loss: 6.037539073824883

Epoch: 5, train accr: 0.18134526977982096, val accr: 0.15114578254509994


100%|██████████| 400/400 [03:11<00:00,  2.19it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 6, train loss: 4.988677082061767, val loss: 6.045672821998596

Epoch: 6, train accr: 0.19864235789699627, val accr: 0.15292818466799657


100%|██████████| 400/400 [03:12<00:00,  2.10it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 7, train loss: 4.826058906316757, val loss: 5.987785187363625

Epoch: 7, train accr: 0.21086780210867803, val accr: 0.15512416928996153


100%|██████████| 400/400 [03:10<00:00,  2.18it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 8, train loss: 4.625342229604721, val loss: 6.016707813739776

Epoch: 8, train accr: 0.23316446145139993, val accr: 0.1602275920202304


100%|██████████| 400/400 [03:15<00:00,  1.99it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 9, train loss: 4.445227854847908, val loss: 6.0072618335485455

Epoch: 9, train accr: 0.24838079849188716, val accr: 0.15843671170392895


100%|██████████| 400/400 [03:15<00:00,  2.45it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 10, train loss: 4.286374972462654, val loss: 5.997842311859131

Epoch: 10, train accr: 0.26692625877470727, val accr: 0.16398950131233594


100%|██████████| 400/400 [03:11<00:00,  1.88it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 11, train loss: 4.163279265165329, val loss: 6.0592894673347475

Epoch: 11, train accr: 0.28128586279989737, val accr: 0.1627637944331487


100%|██████████| 400/400 [03:16<00:00,  2.04it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 12, train loss: 4.02643232345581, val loss: 6.039686298370361

Epoch: 12, train accr: 0.29388512423362373, val accr: 0.1644576212068364


100%|██████████| 400/400 [03:23<00:00,  2.18it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 13, train loss: 3.896793942451477, val loss: 6.106889122724533

Epoch: 13, train accr: 0.31116321938789343, val accr: 0.16126767978099885


100%|██████████| 400/400 [03:17<00:00,  1.83it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 14, train loss: 3.7756284737586974, val loss: 6.086659649014473

Epoch: 14, train accr: 0.3285609795741453, val accr: 0.16556914393226718


100%|██████████| 400/400 [03:29<00:00,  1.82it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 15, train loss: 3.651914670467377, val loss: 6.104388958215713

Epoch: 15, train accr: 0.3412018906144497, val accr: 0.16692519213142495


100%|██████████| 400/400 [03:24<00:00,  1.93it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 16, train loss: 3.541927777528763, val loss: 6.125961038470268

Epoch: 16, train accr: 0.3553207729080464, val accr: 0.16136747326955542


100%|██████████| 400/400 [03:14<00:00,  2.16it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 17, train loss: 3.4509879976511, val loss: 6.157786852121353

Epoch: 17, train accr: 0.37248050511319175, val accr: 0.1617590688922187


100%|██████████| 400/400 [03:17<00:00,  2.34it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 18, train loss: 3.3672933101654055, val loss: 6.181353515386581

Epoch: 18, train accr: 0.37959298659833923, val accr: 0.16628101527179523


100%|██████████| 400/400 [03:14<00:00,  2.33it/s]



Epoch: 19, train loss: 3.2424461460113525, val loss: 6.271638941764832

Epoch: 19, train accr: 0.40121446093215907, val accr: 0.16192755590483676
Finished!


In [36]:
torch.save(network.state_dict(), 'dotproduct_attention.pwf')

In [37]:
on_train, all_accr = try_lengthes(train_lem)
on_train[:10], all_accr

HBox(children=(IntProgress(value=0, max=802), HTML(value='')))




(array([0.22640145, 0.33978583, 0.47082404, 0.52184311, 0.5541694 ,
        0.55914122, 0.57088551, 0.57420091, 0.5687048 , 0.51103368]),
 0.40098969959390046)

In [38]:
on_test, all_accr = try_lengthes(test_lem)
on_test[:10], all_accr

HBox(children=(IntProgress(value=0, max=330), HTML(value='')))




(array([0.14172185, 0.1508056 , 0.17491492, 0.18190489, 0.19089424,
        0.20133556, 0.19379015, 0.18613139, 0.12877583, 0.10086455]),
 0.16308697287868942)

* location based attention

In [39]:
class NetWithAttentionLocbased(nn.Module):
    def __init__(self, emb, ind_to_word, emb_size=300, lstm_units=256, hid_size=256):
        super(self.__class__, self).__init__()
        n_tokens = calculate_n_tokens(emb)
        self.lstm = nn.LSTM(emb_size + len(tags), lstm_units, batch_first=True)
        self.logits = nn.Linear(hid_size, n_tokens)
        self.inner = nn.Linear(lstm_units, 1)
        self.emb = emb
        self.emb_size = emb_size
        self.ind_to_word = ind_to_word
        
    def forward(self, batch_x, batch_x_tags):
        input_emb = transform_to_features(self.emb, self.emb_size, self.ind_to_word, batch_x, batch_x_tags)
        input_emb = torch.tensor(input_emb, dtype=torch.float32)
        lstm_out = self.lstm(input_emb)        
        
        pre_logits = torch.zeros_like(lstm_out[0])
        pre_logits[:, 0, :] = lstm_out[0][:, 0, :]
        pre_logits[:, 1, :] = lstm_out[0][:, 1, :]
        
        for i in range(2, pre_logits.shape[1]):
            current = lstm_out[0][:, i, :]
            current = self.inner(current)
            current = F.softmax(current, dim=1)
            
            mask = np.array([batch_x == pad_id], dtype=int)[0][:, :i]
            mask[:, -1] = np.ones(mask.shape[0])
            mask = torch.tensor(mask, dtype=torch.float32)
            
            tmp = current * mask
            tmp = tmp / torch.sum(tmp, dim=1).reshape(-1, 1)
            tmp = tmp.reshape(tmp.shape[0], tmp.shape[1], 1)
            pre_logits[:, i, :] = torch.sum(tmp * lstm_out[0][:, 1:i+1, :], dim=1)
        
        logits = self.logits(pre_logits)
        return logits

In [40]:
from tqdm import tqdm
from torch.optim import Adam

ind_to_word, word_to_ind = construct_vocab(emb_2, count_words)
network = NetWithAttentionLocbased(emb_2, ind_to_word)
opt = Adam(network.parameters())

train_loss, val_loss, train_accr, val_accr = [], [], [], []

for epoch in range(n_epochs):
    train_loss_=0
    train_accr_=0
    to_div = 0
    network.train(True)
    for _ in tqdm(range(n_batches_per_epoch)):
        
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(train_lem, batch_size, word_to_ind))
        
        loss_t.backward()
        opt.step()
        opt.zero_grad()
        
        train_loss_ += loss_t.item()
        train_accr_ += accr_t.item()
        to_div += to_div_t
        
    train_loss_ /= n_batches_per_epoch
    #train_accr_ /= n_batches_per_epoch
    train_accr_ /= to_div
    
    val_loss_=0
    val_accr_=0
    to_div = 0
    network.train(False)
    for _ in range(n_validation_batches):
        loss_t, accr_t, to_div_t = compute_loss(network, generate_batch(test_lem, batch_size, word_to_ind))
        
        val_loss_ += loss_t.item()
        val_accr_ += accr_t.item()
        to_div += to_div_t
        
    val_loss_ /= n_validation_batches
    #val_accr_ /= n_validation_batches
    val_accr_ /= to_div
    
    train_loss.append(train_loss_)
    val_loss.append(val_loss_)
    train_accr.append(train_accr_)
    val_accr.append(val_accr_)
    
    print('\nEpoch: {}, train loss: {}, val loss: {}'.format(epoch, train_loss_, val_loss_))
    print('\nEpoch: {}, train accr: {}, val accr: {}'.format(epoch, train_accr_, val_accr_))

print("Finished!")

100%|██████████| 400/400 [03:14<00:00,  2.16it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 0, train loss: 7.447025760412216, val loss: 7.115234318375587

Epoch: 0, train accr: 0.055655706828822045, val accr: 0.068666057239294


100%|██████████| 400/400 [03:11<00:00,  2.21it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 1, train loss: 6.730785636901856, val loss: 6.675900080800057

Epoch: 1, train accr: 0.09917220789795088, val accr: 0.09547720669990316


100%|██████████| 400/400 [03:23<00:00,  1.74it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 2, train loss: 6.216401575803757, val loss: 6.456563133001327

Epoch: 2, train accr: 0.12674486564364645, val accr: 0.1163982683982684


100%|██████████| 400/400 [03:16<00:00,  2.31it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 3, train loss: 5.895000921487808, val loss: 6.307088363170624

Epoch: 3, train accr: 0.14418386491557222, val accr: 0.12609337316871244


100%|██████████| 400/400 [03:17<00:00,  2.15it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 4, train loss: 5.621777420043945, val loss: 6.220817524194717

Epoch: 4, train accr: 0.15472871112771322, val accr: 0.13327038746677858


100%|██████████| 400/400 [03:20<00:00,  2.03it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 5, train loss: 5.404830946922302, val loss: 6.151641166210174

Epoch: 5, train accr: 0.16867956537544937, val accr: 0.14225120061695937


100%|██████████| 400/400 [03:21<00:00,  1.66it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 6, train loss: 5.185947972536087, val loss: 6.151007956266403

Epoch: 6, train accr: 0.18437419894496837, val accr: 0.1418501893673727


100%|██████████| 400/400 [03:07<00:00,  2.31it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 7, train loss: 4.970528242588043, val loss: 6.1207690834999084

Epoch: 7, train accr: 0.20352290293710845, val accr: 0.14715449998251076


100%|██████████| 400/400 [03:21<00:00,  2.18it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 8, train loss: 4.812678690552712, val loss: 6.095470517873764

Epoch: 8, train accr: 0.21427996025885448, val accr: 0.14671612022813954


100%|██████████| 400/400 [03:16<00:00,  1.92it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 9, train loss: 4.673326399922371, val loss: 6.081770867109299

Epoch: 9, train accr: 0.2272511043153245, val accr: 0.15194751093522182


100%|██████████| 400/400 [03:16<00:00,  2.46it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 10, train loss: 4.514206776618957, val loss: 6.050610202550888

Epoch: 10, train accr: 0.24193199632412563, val accr: 0.15652850473900126


100%|██████████| 400/400 [03:18<00:00,  1.91it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 11, train loss: 4.365518608093262, val loss: 6.067493364214897

Epoch: 11, train accr: 0.25680583203314145, val accr: 0.1573266769793278


100%|██████████| 400/400 [03:30<00:00,  2.27it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 12, train loss: 4.239488666057587, val loss: 6.0544316411018375

Epoch: 12, train accr: 0.2719080299225077, val accr: 0.15548090523338048


100%|██████████| 400/400 [03:20<00:00,  2.04it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 13, train loss: 4.110196410417557, val loss: 6.112796288728714

Epoch: 13, train accr: 0.28842671423551985, val accr: 0.15614840989399292


100%|██████████| 400/400 [03:16<00:00,  2.13it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 14, train loss: 3.9873586916923522, val loss: 6.166152790188789

Epoch: 14, train accr: 0.30071155193790094, val accr: 0.15404192145912127


100%|██████████| 400/400 [03:14<00:00,  2.35it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 15, train loss: 3.8687827450037005, val loss: 6.131808218359947

Epoch: 15, train accr: 0.31547010386954416, val accr: 0.15707902820441216


100%|██████████| 400/400 [03:25<00:00,  2.11it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 16, train loss: 3.792925373911858, val loss: 6.204553681612015

Epoch: 16, train accr: 0.32246884945250553, val accr: 0.15393298059964727


100%|██████████| 400/400 [03:20<00:00,  2.08it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 17, train loss: 3.657875906229019, val loss: 6.196912091970444

Epoch: 17, train accr: 0.34223889596034035, val accr: 0.1590597659955729


100%|██████████| 400/400 [03:21<00:00,  2.31it/s]
  0%|          | 0/400 [00:00<?, ?it/s]


Epoch: 18, train loss: 3.6070441538095475, val loss: 6.2312178134918215

Epoch: 18, train accr: 0.3490930716622064, val accr: 0.1564366998577525


100%|██████████| 400/400 [03:15<00:00,  1.66it/s]



Epoch: 19, train loss: 3.502924472093582, val loss: 6.266480302810669

Epoch: 19, train accr: 0.36231185387540843, val accr: 0.156059377313917
Finished!


In [41]:
torch.save(network.state_dict(), 'locationbased_attention.pwf')

In [42]:
on_train, all_accr = try_lengthes(train_lem)
on_train[:10], all_accr

HBox(children=(IntProgress(value=0, max=802), HTML(value='')))




(array([0.22913976, 0.34031537, 0.44196952, 0.46821854, 0.48711425,
        0.4844646 , 0.49686941, 0.48363775, 0.45693623, 0.42973287]),
 0.3735040206546002)

In [43]:
on_test, all_accr = try_lengthes(test_lem)
on_test[:10], all_accr

HBox(children=(IntProgress(value=0, max=330), HTML(value='')))




(array([0.14099338, 0.14523415, 0.16752552, 0.17917972, 0.17741935,
        0.18297162, 0.18254818, 0.17791971, 0.14626391, 0.08933718]),
 0.15739252368412052)