In [None]:
import torch
import torch.nn as nn

from torchtext import data, datasets

import numpy as np
import random

import spacy
from spacy.tokenizer import Tokenizer

from string import punctuation


In [None]:
torch.cuda.is_available()

True

In [None]:
nlp = spacy.load("en_core_web_sm")


In [None]:
def tokenize(sentence):
    sentence = sentence.lower()
    return [tok.text for tok in nlp.tokenizer(sentence) if  
            tok.text not in punctuation]

In [None]:
TEXT = data.Field(tokenize=tokenize, include_lengths = True)
LABEL = data.LabelField()

train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

In [None]:
train_data, valid_data = train_data.split(split_ratio=0.8, random_state=random.seed(1024))

In [None]:
print('len_train_data: ', len(train_data))
print('len_test_data: ',  len(test_data))

len_train_data:  20000
len_test_data:  25000


In [None]:
TEXT.build_vocab(train_data, valid_data, 
                 max_size = 30000,
                 vectors = "glove.6B.300d",
                 unk_init = torch.Tensor.normal_)
LABEL.build_vocab(train_data)


In [None]:
TEXT.vocab.vectors

tensor([[-0.5858,  0.5646, -0.5422,  ...,  0.1476, -0.0430, -0.7319],
        [-2.0477, -1.3294, -0.3867,  ...,  1.1911, -0.0073,  0.3330],
        [ 0.0466,  0.2132, -0.0074,  ...,  0.0091, -0.2099,  0.0539],
        ...,
        [ 0.6308, -0.8578,  1.0551,  ...,  0.0186,  0.8295, -1.6352],
        [-0.1772,  0.4024, -0.3649,  ...,  0.1901,  0.6188,  0.0453],
        [-0.0640, -0.1976, -0.6130,  ...,  0.2417, -0.2630, -0.2261]])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_data, valid_data, test_data),
        batch_size = 50,
        sort_key = lambda x: len(x.text), sort_within_batch = True,
        device = device)

In [None]:
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence

class Bi_LSTM(nn.Module):
  def __init__(self, vocab_size, embedding_size, hidden_size, size2, num_classes, 
               count_layers, bidirec_, dropout, pad_index):
  
    super().__init__()

    self.embedding  = nn.Embedding(vocab_size, embedding_size, 
                                  padding_idx=pad_index)
    
    self.lstm = nn.LSTM(embedding_size, hidden_size,
                        num_layers=count_layers,
                        bidirectional = bidirec_, dropout = dropout)
    
    self.dropout = nn.Dropout(dropout)
    self.fc = nn.Linear(hidden_size * 2, size2)
    self.fc2 = nn.Linear(size2 , num_classes)

    self.relu = nn.ReLU()
    
  def forward(self, text, text_lengths):
    new_embedding = self.embedding(text)
    packed_embedded = pack_padded_sequence(new_embedding, text_lengths.to('cpu')) 
    packed_output, (hidden, cell) = self.lstm(packed_embedded)

    cat = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
    act_f1 = self.relu(cat)
    act_f2 = self.fc(act_f1)
    act_f3 = self.dropout(act_f2)
    act_f4 = self.fc2(act_f3)
   
    
    return act_f4

In [None]:
embedding_size = 300
vocab_size = len(TEXT.vocab)
hidden_size = 512
size2 = 1024
count_layers = 2
output_size = 1
dropout_keep_prob = 0.5
pad_index = TEXT.vocab.stoi[TEXT.pad_token]
bidirectional = True

In [None]:
 model_lstm = Bi_LSTM(vocab_size=vocab_size, embedding_size=embedding_size, 
                     hidden_size=hidden_size, size2 = size2, num_classes=output_size, 
                     count_layers=count_layers, 
                     bidirec_=bidirectional, dropout=dropout_keep_prob,
                     pad_index=pad_index)

In [None]:
import torch.optim as optim


In [None]:
optimizer = torch.optim.Adam(model_lstm.parameters())
criterion = nn.BCEWithLogitsLoss()

model_lstm = model_lstm.to(device)
criterion = criterion.to(device)

In [None]:
import time

In [None]:
def train(model, train_iterator, optimazer, criterion):
  epoch_loss = 0
  epoch_acc = 0
  model.train()
  start_time = time.time()
  for batch in train_iterator:
    optimizer.zero_grad()
    predictions = model(batch.text[0], batch.text[1]).squeeze(1)

    batch_label = batch.label.type_as(predictions)
    loss = criterion(predictions, batch_label)
    rounded_predictions = torch.round(torch.sigmoid(predictions))
    
    correct = (rounded_predictions == batch.label).float() 
    acc = correct.sum()/len(correct)
    

    loss.backward()
    optimizer.step()

    epoch_loss += loss.item()
    epoch_acc += acc.item()

  end_time = time.time()
  print((end_time - start_time)/60)  

  print('loss: ',epoch_loss / len(train_iterator),' accuracy: ' ,epoch_acc / len(train_iterator))
  return epoch_loss / len(train_iterator), epoch_acc / len(train_iterator)


In [None]:
def validation(model, valid_iterator, criterion):
  epoch_loss_eval = 0
  epoch_acc_eval = 0
  model.eval()
  start_time = time.time()
  with torch.no_grad():
      for batch in valid_iterator:
        predictions = model(batch.text[0], batch.text[1]).squeeze(1)
        batch_label = batch.label.type_as(predictions)
        loss = criterion(predictions, batch_label)
        rounded_predictions = torch.round(torch.sigmoid(predictions))

        correct = (rounded_predictions == batch.label).float() 
        acc = correct.sum()/len(correct)
        epoch_loss_eval += loss.item()
        epoch_acc_eval += acc.item()
  end_time = time.time()
  print('time val: ', (end_time - start_time)/60)  
  print('loss: ',epoch_loss_eval / len(valid_iterator),' accuracy: ' ,epoch_acc_eval / len(valid_iterator))
  return epoch_loss_eval / len(valid_iterator), epoch_acc_eval / len(valid_iterator)


In [None]:
best_valid_loss = float('inf')

for i in range(7):
    print('iteration: ', i+1)
  
    train_loss, train_acc = train(model_lstm, train_iterator, optimizer,criterion)
    valid_loss, valid_acc = validation(model_lstm, valid_iterator, criterion)
    if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model_lstm.state_dict(), 'best_model.pt')

best_param = model_lstm.load_state_dict(torch.load('best_model.pt'))
print('------------------------------')
best_result = validation(model_lstm, test_iterator, criterion)


iteration:  1
1.4593804995218913
loss:  0.5989452665299178  accuracy:  0.658299982920289
time val:  0.10149643421173096
loss:  0.40505763605237005  accuracy:  0.8153999769687652
iteration:  2
1.4596008698145548
loss:  0.30653180098161104  accuracy:  0.8706499746441841
time val:  0.10154728492101034
loss:  0.3358109851181507  accuracy:  0.8541999745368958
iteration:  3
1.455478568871816
loss:  0.17482712886296212  accuracy:  0.9321999773383141
time val:  0.10103125174840291
loss:  0.28466768652200697  accuracy:  0.8977999758720397
iteration:  4
1.4581735452016196
loss:  0.0863478634157218  accuracy:  0.9698999781906604
time val:  0.10168864727020263
loss:  0.3515908346325159  accuracy:  0.8965999794006347
iteration:  5
1.4620246569315591
loss:  0.041697045281471216  accuracy:  0.986299983561039
time val:  0.10099232196807861
loss:  0.43550104297697545  accuracy:  0.8939999747276306
iteration:  6
1.4592486262321471
loss:  0.029849791722081135  accuracy:  0.9894499859213829
time val:  0.1

In [None]:
model_lstm.state_dict()

OrderedDict([('embedding.weight',
              tensor([[ 2.5999,  0.3954, -1.6498,  ..., -1.2378,  0.2920,  0.3005],
                      [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 0.4090,  0.5958,  0.6855,  ..., -0.5204,  0.3533,  0.2377],
                      ...,
                      [-0.0890, -0.4436, -2.3261,  ..., -0.1895, -0.0114,  0.6668],
                      [ 1.0576,  1.1353,  0.3472,  ...,  1.8292,  1.7474,  0.7690],
                      [ 1.6179,  0.0929, -0.3723,  ..., -1.6484, -0.2116, -0.3492]],
                     device='cuda:0')),
             ('lstm.weight_ih_l0',
              tensor([[-0.0918,  0.0313,  0.0592,  ...,  0.0040, -0.0879, -0.0878],
                      [ 0.0695, -0.0316,  0.0077,  ..., -0.0801, -0.0030, -0.0942],
                      [-0.0128, -0.0061,  0.0222,  ...,  0.0451,  0.0513,  0.0483],
                      ...,
                      [ 0.0286, -0.0383, -0.0487,  ...,  0.0141, -0.1381, -0.028

Файл з state_dict можна завантажити з гугл диску, перейшовши за наступним посиланням:
https://drive.google.com/file/d/1odJqGOytlxDNlA360Sud2TZX4vQ_NW4Y/view?usp=sharing