In [7]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import BertTokenizerFast

from src.data_utils import get_clean_text, TextDataset, collate_fn
from src.lstm_model import LSTMAutoCopleteText
from src.lstm_traint import train_model

Сбор и подготовка данных

In [8]:

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
texts = get_clean_text("./data/raw_dataset.csv")
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_data, test_size=0.5, random_state=42)
train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)
test_dataset = TextDataset(test_data, tokenizer)

In [9]:
train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

Обучение модели

In [10]:
model = LSTMAutoCopleteText(vocab_size=tokenizer.vocab_size)
train_model(model, train_loader, val_loader, tokenizer)

cuda


100%|██████████| 5001/5001 [37:00<00:00,  2.25it/s]  


Epoch 1 | Train Loss: 5.304 | Val Loss: 4.753 | Val Accuracy: 24.60% | ROUGE1: 0.0486 | ROUGE2: 0.0047


100%|██████████| 5001/5001 [12:28<00:00,  6.68it/s]


Epoch 2 | Train Loss: 4.810 | Val Loss: 4.599 | Val Accuracy: 25.59% | ROUGE1: 0.0491 | ROUGE2: 0.0046


100%|██████████| 5001/5001 [18:32<00:00,  4.50it/s] 


Epoch 3 | Train Loss: 4.692 | Val Loss: 4.522 | Val Accuracy: 26.13% | ROUGE1: 0.0495 | ROUGE2: 0.0047


100%|██████████| 5001/5001 [56:04<00:00,  1.49it/s]  


Epoch 4 | Train Loss: 4.625 | Val Loss: 4.474 | Val Accuracy: 26.44% | ROUGE1: 0.0503 | ROUGE2: 0.0048


100%|██████████| 5001/5001 [55:24<00:00,  1.50it/s]  


Epoch 5 | Train Loss: 4.580 | Val Loss: 4.437 | Val Accuracy: 26.71% | ROUGE1: 0.0499 | ROUGE2: 0.0048


100%|██████████| 5001/5001 [56:20<00:00,  1.48it/s]  


Epoch 6 | Train Loss: 4.547 | Val Loss: 4.411 | Val Accuracy: 26.86% | ROUGE1: 0.0507 | ROUGE2: 0.0050


Проход по тестовому набору

Примеры генерации:

In [None]:
prompts = [
    "i know it sounds silly but",
    "this weekend we will",
    "i was about to text you but",
    "the best thing about today is",
    "if anyone needs me",
    "remember when we said we would",
    "today is going to be",
    "can someone explain why"
]

for i, p in enumerate(prompts, 1):
    gen = model.generate(tokenizer, p, max_new_tokens=40, temperature=0.9, top_p=0.9)
    print(f"\n[{i}] PROMPT:\n{p}\n---\nOUTPUT:\n{gen}\n" + "-"*60)