In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from tqdm import tqdm

In [None]:
# Шаг 1: Подготовка данных
class TextDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=128):
        self.data = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Чтение данных из CSV
        import csv
        with open(data_path, "r", encoding="utf-8") as file:
            reader = csv.DictReader(file)
            for row in reader:
                if row["comment"]:  # Используем поле "comment"
                    self.data.append(row["comment"])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return tokens.input_ids.squeeze(0), tokens.attention_mask.squeeze(0)

In [None]:
# Шаг 2: Определение модели
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, max_length):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_length, embed_size))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_size, nhead=num_heads, dim_feedforward=hidden_dim
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, input_ids, attention_mask):
        embeddings = self.embedding(input_ids) + self.positional_encoding[:, :input_ids.size(1), :]
        mask = ~attention_mask.bool()  # Создаем маску для трансформера
        transformer_output = self.transformer(embeddings.transpose(0, 1), src_key_padding_mask=mask)
        logits = self.fc(transformer_output.transpose(0, 1))
        return logits

In [None]:
# Шаг 3: Обучение модели
def train_model(data_path, epochs=5, batch_size=16, lr=5e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Инициализация токенизатора
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    vocab_size = tokenizer.vocab_size

    # Подготовка данных
    dataset = TextDataset(data_path, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Инициализация модели
    model = TransformerModel(
        vocab_size=vocab_size,
        embed_size=256,
        num_heads=8,
        hidden_dim=512,
        num_layers=4,
        max_length=128
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Цикл обучения
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")
        for input_ids, attention_mask in progress_bar:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

            # Создаем сдвиг для задачи языкового моделирования
            labels = input_ids[:, 1:].contiguous()
            input_ids = input_ids[:, :-1].contiguous()

            optimizer.zero_grad()
            logits = model(input_ids, attention_mask[:, :-1])

            loss = criterion(logits.view(-1, vocab_size), labels.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=total_loss / len(dataloader))

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")

    # Сохранение модели
    torch.save(model.state_dict(), "llm_model.pth")
    print("Модель сохранена как llm_model.pth")

In [None]:
# Шаг 4: Запуск обучения
if __name__ == "__main__":
    data_path = "parsers/data/2ch_data_b.csv"  
    train_model(data_path)