In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from src.data_utils import get_clean_text, TextDataset, collate_fn
from src.lstm_model import LSTMAutoCopleteText
from src.lstm_traint import train_model

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Сбор и подготовка данных

In [3]:

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
texts = get_clean_text("./data/raw_dataset.csv")
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_data, test_size=0.5, random_state=42)
train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)
test_dataset = TextDataset(test_data, tokenizer)

In [4]:
train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    pin_memory_device="cuda",
    persistent_workers=True
)

Обучение модели

In [5]:
model = LSTMAutoCopleteText(vocab_size=tokenizer.vocab_size)
train_model(model, device, train_loader, val_loader, tokenizer)

  0%|          | 0/63 [00:07<?, ?it/s]

Epoch 1 | Train Loss: 8.996 | Val Loss: 7.828 | Val Accuracy: 13.94% | ROUGE1: 0.0381 | ROUGE2: 0.0028


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 2 | Train Loss: 7.403 | Val Loss: 6.526 | Val Accuracy: 16.18% | ROUGE1: 0.0401 | ROUGE2: 0.0032


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 3 | Train Loss: 6.500 | Val Loss: 5.960 | Val Accuracy: 17.54% | ROUGE1: 0.0421 | ROUGE2: 0.0035


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4 | Train Loss: 6.100 | Val Loss: 5.736 | Val Accuracy: 18.69% | ROUGE1: 0.0433 | ROUGE2: 0.0038


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 5 | Train Loss: 5.906 | Val Loss: 5.626 | Val Accuracy: 19.13% | ROUGE1: 0.0437 | ROUGE2: 0.0040


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 6 | Train Loss: 5.783 | Val Loss: 5.556 | Val Accuracy: 19.70% | ROUGE1: 0.0449 | ROUGE2: 0.0038


In [11]:
torch.save(model.state_dict(), "weights.pt")

Проверка на тестовом наборе

In [7]:
from src.lstm_traint import evaluate, evaluate_rouge
test_loss, test_acc = evaluate(model, device, test_loader)
r1, r2 = evaluate_rouge(model, test_loader, tokenizer=tokenizer, device='cuda', max_batches=20)
print(f"Test Loss: {test_loss:.3f} | Test Accuracy: {test_acc:.2%} | ROUGE1: {r1:.4f} | ROUGE2: {r2:.4f}")

Test Loss: 5.510 | Test Accuracy: 20.34% | ROUGE1: 0.0429 | ROUGE2: 0.0038


Примеры генерации:

In [8]:
prompts = [
    "i know it sounds silly but",
    "this weekend we will",
    "i was about to text you but",
    "the best thing about today is",
    "if anyone needs me",
    "remember when we said we would",
    "today is going to be",
    "can someone explain why"
]

for i, p in enumerate(prompts, 1):
    gen = model.generate(tokenizer, p, max_new_tokens=40, temperature=0.9, top_p=0.9)
    print(f"\n[{i}] PROMPT:\n{p}\n---\nOUTPUT:\n{gen}\n" + "-"*60)


[1] PROMPT:
i know it sounds silly but
---
OUTPUT:
i know it sounds silly but it think nothings... after get the store i can't tonight but in nice job
------------------------------------------------------------

[2] PROMPT:
this weekend we will
---
OUTPUT:
this weekend we will be there's!!
------------------------------------------------------------

[3] PROMPT:
i was about to text you but
---
OUTPUT:
i was about to text you but sick i'm are til a she...
------------------------------------------------------------

[4] PROMPT:
the best thing about today is
---
OUTPUT:
the best thing about today is going to go to the
------------------------------------------------------------

[5] PROMPT:
if anyone needs me
---
OUTPUT:
if anyone needs me up of the day i guess i have to really sleep
------------------------------------------------------------

[6] PROMPT:
remember when we said we would
---
OUTPUT:
remember when we said we would be but i were.
------------------------------------------

Замер ROUGE distilgpt2

In [9]:
from src.eval_transformer_pipline import evaluate_rouge_distilgpt2

r1, r2 = evaluate_rouge_distilgpt2(test_loader, lstm_tok=tokenizer)
print(f"distilgpt2_ROUGE1: {r1:.4f} | distilgpt2_ROUGE2: {r2:.4f}")

Device set to use cuda:0


distilgpt2_ROUGE1: 0.0550 | distilgpt2_ROUGE2: 0.0057


Пример генерации текстов distilgpt2

In [10]:
from src.eval_transformer_pipline import transformer_generate

transformer_generate(prompts)


[1] PROMPT:
i know it sounds silly but
---
OUTPUT:
i know it sounds silly but I will explain what they are doing here.
------------------------------------------------------------

[2] PROMPT:
this weekend we will
---
OUTPUT:
this weekend we will be taking a look at all the issues that have been raised and it should take some time to get
------------------------------------------------------------

[3] PROMPT:
i was about to text you but
---
OUTPUT:
i was about to text you but they didn't realize she had a bad attitude. So, he took her in and made him think
------------------------------------------------------------

[4] PROMPT:
the best thing about today is
---
OUTPUT:
the best thing about today is that we‹re all better than ever.
------------------------------------------------------------

[5] PROMPT:
if anyone needs me
---
OUTPUT:
if anyone needs me.
The only way you can do this is to just let go of your fears and anxiety,
---------------------------------------------------------