In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import BertTokenizerFast

from src.data_utils import get_clean_text, TextDataset, collate_fn
from src.lstm_model import LSTMAutoCopleteText
from src.lstm_traint import train_model

  from .autonotebook import tqdm as notebook_tqdm


Сбор и подготовка данных

In [2]:

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
texts = get_clean_text("./data/raw_dataset.csv")
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_data, test_size=0.5, random_state=42)

train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)
test_dataset = TextDataset(test_data, tokenizer)

In [3]:
train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

Обучение модели

In [5]:
model = LSTMAutoCopleteText(vocab_size=tokenizer.vocab_size)
train_model(model, train_loader, val_loader)

  1%|          | 45/4995 [00:47<1:27:21,  1.06s/it]


KeyboardInterrupt: 

In [None]:
import os, torch

os.makedirs("checkpoints", exist_ok=True)

ckpt = {
    "model_state": model.state_dict(),
    "epoch": 7,
    "model_config": {
        "vocab_size": model.embedding.num_embeddings,
        "emb_dim": model.embedding.embedding_dim,
        "hidden_dim": model.rnn.hidden_size,
        "num_layers": model.rnn.num_layers,
        "pad_id": getattr(model, "pad_id", 0),
        "tie_weights": bool(model.fc.weight is model.embedding.weight),
        "dropout": (model.dropout.p if hasattr(model.dropout, "p") else 0.0),
    },
    "tokenizer_name": getattr(tokenizer, "name_or_path", None),
}
torch.save(ckpt, "checkpoints/lstm_checkpoint.pt")
print("Saved to checkpoints/lstm_checkpoint.pt")

Saved to checkpoints/lstm_checkpoint.pt
