In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import BertTokenizerFast

from src.data_utils import get_clean_text, TextDataset, collate_fn
from src.lstm_model import LSTMAutoCopleteText
from src.lstm_traint import train_model

  from .autonotebook import tqdm as notebook_tqdm


Сбор и подготовка данных

In [2]:

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
texts = get_clean_text("./data/raw_dataset.csv")
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_data, test_size=0.5, random_state=42)
train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)
test_dataset = TextDataset(test_data, tokenizer)

In [3]:
train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

Обучение модели

In [4]:
model = LSTMAutoCopleteText(vocab_size=tokenizer.vocab_size)
train_model(model, train_loader, val_loader)

cuda


100%|██████████| 323/323 [04:09<00:00,  1.29it/s]


Epoch 1 | Train Loss: 6.996 | Val Loss: 5.606 | Val Accuracy: 19.50%


100%|██████████| 323/323 [03:40<00:00,  1.46it/s]


Epoch 2 | Train Loss: 5.711 | Val Loss: 5.342 | Val Accuracy: 21.09%


100%|██████████| 323/323 [03:24<00:00,  1.58it/s]


Epoch 3 | Train Loss: 5.473 | Val Loss: 5.194 | Val Accuracy: 22.01%


100%|██████████| 323/323 [03:31<00:00,  1.53it/s]


Epoch 4 | Train Loss: 5.320 | Val Loss: 5.088 | Val Accuracy: 22.50%


100%|██████████| 323/323 [03:50<00:00,  1.40it/s]


Epoch 5 | Train Loss: 5.208 | Val Loss: 5.009 | Val Accuracy: 23.01%


100%|██████████| 323/323 [04:01<00:00,  1.34it/s]


Epoch 6 | Train Loss: 5.121 | Val Loss: 4.948 | Val Accuracy: 23.28%


In [8]:
text = model.generate(tokenizer, "i want eat")
print(text)

i want eat.. but i have to wake up so now i hate being a bad time and no twitter for the house for this hours
