In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import BertTokenizerFast

from src.data_utils import get_clean_text, TextDataset, collate_fn
from src.lstm_model import LSTMAutoCopleteText
from src.lstm_traint import train_model

Сбор и подготовка данных

In [2]:

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
texts = get_clean_text("./data/raw_dataset.csv")
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_data, test_size=0.5, random_state=42)
train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)
test_dataset = TextDataset(test_data, tokenizer)

In [3]:
train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

Обучение модели

In [4]:
model = LSTMAutoCopleteText(vocab_size=tokenizer.vocab_size)
train_model(model, train_loader, val_loader, tokenizer)

cuda


100%|██████████| 63/63 [01:21<00:00,  1.29s/it]


Epoch 1 | Train Loss: 8.947 | Val Loss: 7.680 | Val Accuracy: 13.97% | ROUGE1: 0.0323 | ROUGE2: 0.0026


100%|██████████| 63/63 [01:13<00:00,  1.16s/it]


Epoch 2 | Train Loss: 7.322 | Val Loss: 6.421 | Val Accuracy: 17.07% | ROUGE1: 0.0373 | ROUGE2: 0.0031


100%|██████████| 63/63 [01:18<00:00,  1.24s/it]


Epoch 3 | Train Loss: 6.482 | Val Loss: 5.918 | Val Accuracy: 18.12% | ROUGE1: 0.0372 | ROUGE2: 0.0031


100%|██████████| 63/63 [01:12<00:00,  1.15s/it]


Epoch 4 | Train Loss: 6.112 | Val Loss: 5.700 | Val Accuracy: 19.25% | ROUGE1: 0.0379 | ROUGE2: 0.0033


100%|██████████| 63/63 [01:11<00:00,  1.14s/it]


Epoch 5 | Train Loss: 5.916 | Val Loss: 5.582 | Val Accuracy: 19.93% | ROUGE1: 0.0392 | ROUGE2: 0.0029


100%|██████████| 63/63 [01:10<00:00,  1.12s/it]


Epoch 6 | Train Loss: 5.793 | Val Loss: 5.505 | Val Accuracy: 20.16% | ROUGE1: 0.0390 | ROUGE2: 0.0031


Проход по тестовому набору

In [None]:
from src.lstm_traint import evaluate, evaluate_rouge
text = model.generate(tokenizer, "today is going to be")
test_loss, test_acc = evaluate(model, test_loader)
r1, r2 = evaluate_rouge(model, test_loader, tokenizer=tokenizer, device='cuda', max_batches=20)
print(f"Test Loss: {test_loss:.3f} | Test Accuracy: {test_acc:.2%} | ROUGE1: {r1:.4f} | ROUGE2: {r2:.4f}")

Test Loss: 5.508 | Test Accuracy: 20.47% | ROUGE1: 0.0362 | ROUGE2: 0.0026


Примеры генерации:

In [21]:
prompts = [
    "i know it sounds silly but",
    "this weekend we will",
    "i was about to text you but",
    "the best thing about today is",
    "if anyone needs me",
    "remember when we said we would",
    "today is going to be",
    "can someone explain why"
]

for i, p in enumerate(prompts, 1):
    gen = model.generate(tokenizer, p, max_new_tokens=40, temperature=0.9, top_p=0.9)
    print(f"\n[{i}] PROMPT:\n{p}\n---\nOUTPUT:\n{gen}\n" + "-"*60)


[1] PROMPT:
i know it sounds silly but
---
OUTPUT:
i know it sounds silly but i'm r to om ; it's brown de... wow it's not as in horse work w i miss
------------------------------------------------------------

[2] PROMPT:
this weekend we will
---
OUTPUT:
this weekend we will be in myw.
------------------------------------------------------------

[3] PROMPT:
i was about to text you but
---
OUTPUT:
i was about to text you but is not some have to gonna be seen the amp ; i'm so amp ; time!
------------------------------------------------------------

[4] PROMPT:
the best thing about today is
---
OUTPUT:
the best thing about today isover!!
------------------------------------------------------------

[5] PROMPT:
if anyone needs me
---
OUTPUT:
if anyone needs me that they are not today...
------------------------------------------------------------

[6] PROMPT:
remember when we said we would
---
OUTPUT:
remember when we said we would be today...
--------------------------------------------