# Word2Vec - Skipgram

## 0. import

In [1]:
%load_ext lab_black

In [2]:
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from collections import Counter
from torch.utils.data import Dataset, DataLoader
from konlpy.tag import Mecab
from tqdm.notebook import tqdm

from typing import List

## 1. Preprocess

### text preprocess

In [3]:
def preprocess(
    data_path: str,
    word_index: dict = None,
    num_words: int = 10000,
):
    tokenizer = Mecab()

    # 0. data load
    with open(data_path, "rb") as f:
        data = pickle.load(f)

    # 1. bag-of-words
    vocab, docs = [], []
    for doc in tqdm(data):
        if doc:
            # nsmc 데이터에 nan값을 제외해주기 위함
            try:
                nouns = tokenizer.nouns(doc)
                vocab.extend(nouns)
                docs.append(nouns)
            except:
                continue

    # 2. build vocab
    if not word_index:
        vocab = Counter(vocab)
        vocab = vocab.most_common(num_words)

        # 3. add unknwon token
        word_index = {"<UNK>": 0}
        for idx, (word, _) in enumerate(vocab, 1):
            word_index[word] = idx

    index_word = {idx: word for word, idx in word_index.items()}

    # 4. create corpus
    corpus = []
    for doc in docs:
        if doc:
            corpus.append([word_index.get(word, 0) for word in doc])

    return corpus, word_index, index_word

In [4]:
train_path = "../data/nsmc/train_data.pkl"
test_path = "../data/nsmc/test_data.pkl"

train_corpus, word_index, index_word = preprocess(train_path)
test_corpus, _, _ = preprocess(test_path, word_index)

  0%|          | 0/150000 [00:00<?, ?it/s]

  0%|          | 0/50000 [00:00<?, ?it/s]

### skipgrams

In [5]:
# Reference: https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/sequence.py
def skipgrams(
    sequence: List[int],
    vocab_size: int,
    window_size: int = 4,
    negative_samples: int = 1.0,
):
    couples, labels = [], []
    for i, wi in enumerate(sequence):
        if not wi:  # <UNK> 토큰일 경우
            continue
        window_start = max(0, i - window_size)
        window_end = min(len(sequence), i + window_size + 1)
        #     print(window_start, window_end)

        for j in range(window_start, window_end):
            if j != i:
                wj = sequence[j]
                if not wj:  # <UNK> 토큰일 경우
                    continue

                couples.append([wi, wj])
                labels.append(1)

    if negative_samples > 0:
        num_negative_samples = int(len(labels) * negative_samples)
        words = [c[0] for c in couples]
        random.shuffle(words)

        couples += [
            [words[idx % len(words)], random.randint(1, vocab_size - 1)]
            for idx in range(num_negative_samples)
        ]

        labels += [0] * num_negative_samples

    return couples, labels

In [6]:
# train skipgrams
vocab_size = len(word_index)
train_pairs, train_labels = [], []
for sequence in tqdm(train_corpus):
    pairs, targets = skipgrams(sequence, vocab_size, negative_samples=0.5)
    train_pairs.extend(pairs)
    train_labels.extend(targets)

  0%|          | 0/141731 [00:00<?, ?it/s]

In [7]:
# test skipgrams
vocab_size = len(word_index)
test_pairs, test_labels = [], []
for sequence in tqdm(test_corpus):
    pairs, targets = skipgrams(sequence, vocab_size, negative_samples=0.5)
    test_pairs.extend(pairs)
    test_labels.extend(targets)

  0%|          | 0/47238 [00:00<?, ?it/s]

### DataSet

In [8]:
class W2VDataset(Dataset):
    def __init__(self, pairs: List[List[int]], labels: List[int]):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx], self.labels[idx]

In [9]:
trainset = W2VDataset(train_pairs, train_labels)
testset = W2VDataset(test_pairs, test_labels)

In [10]:
trainset[:2]

([[77, 12], [77, 319]], [1, 1])

### collate function

In [11]:
def collate_fn(batch):
    targets = [entry[0][0] for entry in batch]
    contexts = [entry[0][1] for entry in batch]
    labels = [entry[1] for entry in batch]

    targets = torch.LongTensor(targets)
    contexts = torch.LongTensor(contexts)
    labels = torch.FloatTensor(labels)

    return targets, contexts, labels

### dataloader

In [12]:
train_loader = DataLoader(
    dataset=trainset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    num_workers=8,
)


test_loader = DataLoader(
    dataset=testset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=8,
)

In [13]:
for batch in train_loader:
    sample = batch
    break

In [14]:
# sample[0]

## 2. Word2Vec Model

In [16]:
class Word2Vec(pl.LightningModule):
    def __init__(self, vocab_size: int, embed_dim: int = 100):
        super(Word2Vec, self).__init__()

        self.input_embed = nn.Embedding(vocab_size, embed_dim)
        self.output_embed = nn.Embedding(vocab_size, embed_dim)

    def forward(self, target, context):
        u = self.input_embed(target)  # [batch_size, embed_dim]
        v = self.output_embed(context)  # [batch_size, embed_dim]

        score = torch.sum(u * v, dim=1)  # [batch_size]
        return score

    def loss_fn(self, logits, labels):
        criterion = nn.BCEWithLogitsLoss()
        loss = criterion(logits, labels)
        return loss

    def accuracy(self, logits, labels):
        logits = torch.round(torch.sigmoid(logits))
        corrects = (logits == labels).float().sum()
        acc = corrects / labels.numel()
        return acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        targets, contexts, labels = train_batch
        logits = self.forward(targets, contexts)
        loss = self.loss_fn(logits, labels)
        acc = self.accuracy(logits, labels)
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        targets, contexts, labels = val_batch
        logits = self.forward(targets, contexts)
        loss = self.loss_fn(logits, labels)
        acc = self.accuracy(logits, labels)
        self.log("val_loss", loss)
        self.log("val_acc", acc)

## 3. Train

In [17]:
# model init
vocab_size = len(word_index)

model = Word2Vec(vocab_size)

In [18]:
trainer = pl.Trainer(
    gpus=2,
    max_epochs=10,
    val_check_interval=0.5,
    accelerator="dp",
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


In [19]:
trainer.fit(model, train_loader, test_loader)


  | Name         | Type      | Params
-------------------------------------------
0 | input_embed  | Embedding | 1.0 M 
1 | output_embed | Embedding | 1.0 M 
-------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



1

## 4. Check using gensim

In [20]:
import numpy as np

from gensim.models import KeyedVectors

### create pre-trained vectors file

In [22]:
embedding = model.input_embed.weight
embedding = embedding.detach().numpy()

In [23]:
embed_dim = 100
with open("./vectors.txt", "w", encoding="utf8") as f:
    f.write(f"{len(word_index)-1} {embed_dim}\n")
    for word, idx in word_index.items():
        if idx != 0:
            str_vec = " ".join(map(str, list(embedding[idx, :])))
            f.write(f"{word} {str_vec}\n")

In [24]:
w2v = KeyedVectors.load_word2vec_format("./vectors.txt", binary=False)

In [28]:
w2v.most_similar("배우")

[('영화', 0.8665787577629089),
 ('연기', 0.8557248115539551),
 ('스토리', 0.8543652296066284),
 ('최고', 0.8537499308586121),
 ('나', 0.8484581112861633),
 ('내용', 0.8468432426452637),
 ('생각', 0.8454784154891968),
 ('것', 0.8232203722000122),
 ('수', 0.8192157745361328),
 ('듯', 0.8190522193908691)]