In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import itertools
import kagglehub
import os
import pandas as pd
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("Torch:", torch.__version__)

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

Device: cuda
Torch: 2.6.0+cu124


## Load dataset with Russian Jokes

In [2]:
path = kagglehub.dataset_download("vsevolodbogodist/data-jokes")
print("Path to dataset files:", path)
csv_path = os.path.join(path, "dataset.csv")
df = pd.read_csv(csv_path, sep=",", quotechar='"')
print(df.head())

texts = df["text"].astype(str).tolist()
texts = texts[:100000]
print(len(texts))

Path to dataset files: /kaggle/input/data-jokes
                                                text
0  - Зять, а ты знаешь, где найти того мужчину, к...
1  После проведения акции "К животным по-человече...
2  Штирлиц пришел домой и сразу завалился на боко...
3  Комету нашли русские, а захватила ее Европа. И...
4  - Мальчик, какой у тебя огромный рюкзачок, что...
100000


## Токенизация и словарь


In [3]:
def tokenize(s):
    s = re.sub(r"[^а-я ]+", "", s.lower())
    return s.split()

tokens_list = [tokenize(t) for t in texts]
tokens = list(itertools.chain.from_iterable(tokens_list))

vocab = list(set(tokens))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
V = len(vocab)
print("Размер словаря:", V)

Размер словаря: 175504


In [4]:
def generate_cbow_pairs(tokens, window_size=6):
    pairs = []
    for center_pos in range(window_size, len(tokens) - window_size):
        context = (
            tokens[center_pos - window_size:center_pos] +
            tokens[center_pos + 1:center_pos + window_size + 1]
        )
        target = tokens[center_pos]
        pairs.append((context, target))
    return pairs

pairs = generate_cbow_pairs(tokens, window_size=6)
training_data = [([word2idx[w] for w in context], word2idx[target]) for context, target in pairs]
print("Количество пар:", len(training_data))

Количество пар: 2356463


## параметры модели


In [5]:
class CBOWModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOWModel, self).__init__()
        # "input" и "output" эмбеддинги (как в word2vec)
        self.in_embed = nn.Embedding(vocab_size, embedding_dim)
        self.out_embed = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, context_idxs):
        """
        context_idxs: LongTensor [batch_size, context_len]
        """
        # 1. эмбеддинги всех слов контекста
        context_vecs = self.in_embed(context_idxs)        # [batch, context_len, D]
        v_context = context_vecs.mean(dim=1)              # [batch, D]

        # 2. логиты = скалярные произведения со всеми "output embeddings"
        #   out_embed.weight имеет shape [V, D]
        logits = torch.matmul(v_context, self.out_embed.weight.t())  # [batch, V]

        # 3. softmax будет применён внутри CrossEntropyLoss
        return logits

In [6]:
embedding_dim = 50
model = CBOWModel(V, embedding_dim).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

epochs = 10
batch_size = 256

## Обучение


In [7]:
def get_batch(data, batch_size, device):
    batch = random.sample(data, batch_size)
    contexts, targets = zip(*batch)
    contexts_tensor = torch.tensor(list(contexts), dtype=torch.long, device=device)
    targets_tensor = torch.tensor(list(targets), dtype=torch.long, device=device)
    return contexts_tensor, targets_tensor


for epoch in range(epochs):
    total_loss = 0
    steps = len(training_data) // batch_size
    for _ in range(steps):
        contexts, targets = get_batch(training_data, batch_size, device)

        logits = model(contexts)                # [batch, V]
        loss = criterion(logits, targets)       # CE: softmax+NLL внутри

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Эпоха {epoch+1}, loss = {total_loss:.4f}")

Эпоха 1, loss = 130540.3924
Эпоха 2, loss = 128897.1386
Эпоха 3, loss = 127759.2010
Эпоха 4, loss = 126768.4041
Эпоха 5, loss = 125720.7439
Эпоха 6, loss = 124525.1913
Эпоха 7, loss = 123137.3921
Эпоха 8, loss = 121722.9296
Эпоха 9, loss = 120561.6587
Эпоха 10, loss = 119529.0292


## Сохранение

In [8]:
torch.save({
    "model_state_dict": model.state_dict(),
    "word2idx": word2idx,
    "idx2word": idx2word,
    "embedding_dim": embedding_dim,
    "vocab_size": V
}, "word2vec_CBOW.pth")

print("model saved")

model saved
