In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from src.data_utils import get_clean_text, TextDataset, collate_fn
from src.lstm_model import LSTMAutoCopleteText
from src.lstm_traint import train_model

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Сбор и подготовка данных

In [3]:

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
texts = get_clean_text("./data/raw_dataset.csv")
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_data, test_size=0.5, random_state=42)
train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)
test_dataset = TextDataset(test_data, tokenizer)

In [4]:
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

Обучение модели

In [5]:
model = LSTMAutoCopleteText(vocab_size=tokenizer.vocab_size)
train_model(model, device, train_loader, val_loader, tokenizer)

100%|██████████| 10002/10002 [11:28<00:00, 14.53it/s]


Epoch 1 | Train Loss: 5.139 | Val Loss: 4.670 | Val Accuracy: 25.12% | ROUGE1: 0.0486 | ROUGE2: 0.0048


100%|██████████| 10002/10002 [11:33<00:00, 14.43it/s]


Epoch 2 | Train Loss: 4.729 | Val Loss: 4.536 | Val Accuracy: 26.05% | ROUGE1: 0.0493 | ROUGE2: 0.0051


100%|██████████| 10002/10002 [11:33<00:00, 14.43it/s]


Epoch 3 | Train Loss: 4.628 | Val Loss: 4.470 | Val Accuracy: 26.47% | ROUGE1: 0.0504 | ROUGE2: 0.0056


100%|██████████| 10002/10002 [11:33<00:00, 14.43it/s]


Epoch 4 | Train Loss: 4.570 | Val Loss: 4.427 | Val Accuracy: 26.78% | ROUGE1: 0.0494 | ROUGE2: 0.0052


100%|██████████| 10002/10002 [11:32<00:00, 14.43it/s]


Epoch 5 | Train Loss: 4.532 | Val Loss: 4.397 | Val Accuracy: 26.98% | ROUGE1: 0.0507 | ROUGE2: 0.0056


100%|██████████| 10002/10002 [11:33<00:00, 14.43it/s]


Epoch 6 | Train Loss: 4.504 | Val Loss: 4.375 | Val Accuracy: 27.11% | ROUGE1: 0.0504 | ROUGE2: 0.0052


In [6]:
torch.save(model.state_dict(), "weights.pt")

Проверка на тестовом наборе

In [7]:
from src.lstm_traint import evaluate, evaluate_rouge
test_loss, test_acc = evaluate(model, device, test_loader)
r1, r2 = evaluate_rouge(model, test_loader, tokenizer=tokenizer, device='cuda', max_batches=20)
print(f"Test Loss: {test_loss:.3f} | Test Accuracy: {test_acc:.2%} | ROUGE1: {r1:.4f} | ROUGE2: {r2:.4f}")

Test Loss: 4.373 | Test Accuracy: 27.19% | ROUGE1: 0.0505 | ROUGE2: 0.0052


Примеры генерации:

In [8]:
prompts = [
    "i know it sounds silly but",
    "this weekend we will",
    "i was about to text you but",
    "the best thing about today is",
    "if anyone needs me",
    "remember when we said we would",
    "today is going to be",
    "can someone explain why"
]

for i, p in enumerate(prompts, 1):
    gen = model.generate(tokenizer, p, max_new_tokens=40, temperature=0.9, top_p=0.9)
    print(f"\n[{i}] PROMPT:\n{p}\n---\nOUTPUT:\n{gen}\n" + "-"*60)


[1] PROMPT:
i know it sounds silly but
---
OUTPUT:
i know it sounds silly but now i'm losing i can go now but i have to cancel my ipod
------------------------------------------------------------

[2] PROMPT:
this weekend we will
---
OUTPUT:
this weekend we will be with them all alone
------------------------------------------------------------

[3] PROMPT:
i was about to text you but
---
OUTPUT:
i was about to text you but have a feeling!!
------------------------------------------------------------

[4] PROMPT:
the best thing about today is
---
OUTPUT:
the best thing about today is no rain
------------------------------------------------------------

[5] PROMPT:
if anyone needs me
---
OUTPUT:
if anyone needs me, thanks to user - so you'll be a. g. if you know he was on the high?
------------------------------------------------------------

[6] PROMPT:
remember when we said we would
---
OUTPUT:
remember when we said we would be excited to return the lot of his twitterr!!
------------

Замер ROUGE distilgpt2

In [None]:
from src.eval_transformer_pipline import evaluate_rouge_distilgpt2

r1, r2 = evaluate_rouge_distilgpt2(test_loader, lstm_tok=tokenizer)
print(f"distilgpt2_ROUGE1: {r1:.4f} | distilgpt2_ROUGE2: {r2:.4f}")

Device set to use cuda:0


Пример генерации текстов distilgpt2

In [None]:
from src.eval_transformer_pipline import transformer_generate

transformer_generate(prompts)