In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import BertTokenizerFast

from src.data_utils import get_clean_text, TextDataset, collate_fn
from src.lstm_model import LSTMAutoCopleteText
from src.lstm_traint import train_model

Сбор и подготовка данных

In [2]:

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
texts = get_clean_text("./data/raw_dataset.csv")
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_data, test_size=0.5, random_state=42)
train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)
test_dataset = TextDataset(test_data, tokenizer)

In [3]:
train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

Обучение модели

In [None]:
model = LSTMAutoCopleteText(vocab_size=tokenizer.vocab_size)
train_model(model, train_loader, val_loader, tokenizer)

Проверка на тестовом наборе

In [None]:
from src.lstm_traint import evaluate, evaluate_rouge
text = model.generate(tokenizer, "today is going to be")
test_loss, test_acc = evaluate(model, test_loader)
r1, r2 = evaluate_rouge(model, test_loader, tokenizer=tokenizer, device='cuda', max_batches=20)
print(f"Test Loss: {test_loss:.3f} | Test Accuracy: {test_acc:.2%} | ROUGE1: {r1:.4f} | ROUGE2: {r2:.4f}")

Примеры генерации:

In [None]:
prompts = [
    "i know it sounds silly but",
    "this weekend we will",
    "i was about to text you but",
    "the best thing about today is",
    "if anyone needs me",
    "remember when we said we would",
    "today is going to be",
    "can someone explain why"
]

for i, p in enumerate(prompts, 1):
    gen = model.generate(tokenizer, p, max_new_tokens=40, temperature=0.9, top_p=0.9)
    print(f"\n[{i}] PROMPT:\n{p}\n---\nOUTPUT:\n{gen}\n" + "-"*60)

Пример генерации текстов distilgpt2

In [7]:
from src.eval_transformer_pipline import transformer_generate

transformer_generate(prompts)

Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



[1] PROMPT:
i know it sounds silly but
---
OUTPUT:
i know it sounds silly but if you’re in this situation, just put on the gloves and let them go. It
------------------------------------------------------------

[2] PROMPT:
this weekend we will
---
OUTPUT:
this weekend we will bring back an old friend, who had been on the team that he worked for. He was always
------------------------------------------------------------

[3] PROMPT:
i was about to text you but
---
OUTPUT:
i was about to text you but I could not get in on my shit. It's just a matter of time until that happens and
------------------------------------------------------------

[4] PROMPT:
the best thing about today is
---
OUTPUT:
the best thing about today is that you can look at what’s going on in the world, and it's different.
------------------------------------------------------------

[5] PROMPT:
if anyone needs me
---
OUTPUT:
if anyone needs me to get out of the way.”
------------------------------------------------