In [1]:
from src.data_utils import load_and_preprocess, split_data
from src.next_token_dataset import NextTokenDataset
from src.lstm_model import LSTMAutocomplete
from src.lstm_train import LSTMTrainer

from transformers import BertTokenizerFast

import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
RANDOM_STATE = 42

SEQ_LEN = 20

Выполним загрузку и обработку текстовых данных, просмотрим первые строки:

In [3]:
df = load_and_preprocess("data/raw_dataset.txt")

df.head()

0    awww that s a bummer you shoulda got david car...
1    is upset that he can t update his facebook by ...
2    i dived many times for the ball managed to sav...
3       my whole body feels itchy and like its on fire
4    no it s not behaving at all i m mad why am i h...
Name: cleaned_text, dtype: object

Выполним разбиение на обучающую, тренировочную и валидационную выборки и выведем их размер:

In [4]:
train_texts, val_texts, test_texts = split_data(df, RANDOM_STATE)

display(len(df))
display(len(train_texts))
display(len(test_texts))
display(len(val_texts))

1601127

1280901

160113

160113

In [5]:
# загружаем токенизатор
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# тренировочный, тестовый и валидационный датасеты
train_dataset = NextTokenDataset(train_texts, tokenizer, seq_len=SEQ_LEN)
test_dataset = NextTokenDataset(test_texts, tokenizer, seq_len=SEQ_LEN)
val_dataset = NextTokenDataset(val_texts, tokenizer, seq_len=SEQ_LEN)

# даталоадеры
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64) 
test_loader = DataLoader(test_dataset, batch_size=64) 

In [None]:
vocab_size = tokenizer.vocab_size

# создаем и обучаем модель
model = LSTMAutocomplete(vocab_size, hidden_dim=128)
trainer = LSTMTrainer(model, vocab_size)

trainer.fit(train_loader, val_loader, n_epochs=3)

Device: cuda


100%|██████████| 29576/29576 [04:06<00:00, 120.13it/s]


Epoch 1 | Train Loss: 5.266 | Val Loss: 5.220
------------------------------


100%|██████████| 29576/29576 [04:02<00:00, 121.89it/s]


Epoch 2 | Train Loss: 4.886 | Val Loss: 5.223
------------------------------


100%|██████████| 29576/29576 [04:03<00:00, 121.33it/s]


Epoch 3 | Train Loss: 4.776 | Val Loss: 5.259
------------------------------


100%|██████████| 29576/29576 [04:03<00:00, 121.52it/s]


Epoch 4 | Train Loss: 4.714 | Val Loss: 5.299
------------------------------


100%|██████████| 29576/29576 [04:00<00:00, 122.85it/s]


Epoch 5 | Train Loss: 4.672 | Val Loss: 5.336
------------------------------
