In [3]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
from whiteGPT import word2vec
from whiteGPT import CBOW
from whiteGPT import Vocab

In [4]:
corpus_list = [
    "this animal is cat.",
    "the quick brown fox jumps over the lazy dog.",
    "dog and cat are animal.",
    "I love dogs and cats.",
    "the dog sat on the rug.",
    "cats are independent animals. ",
    "foxes are wild animals.",
    "The cat prowled through the moonlit garden.",
    "Dogs eagerly awaited their owner's return at the doorstep.",
    "A cat's purr filled the room with comfort.",
    "The dog wagged its tail in excitement.",
    "Cats gracefully leaped from rooftop to rooftop.",

    "Walking down the street, I spotted a stray dog searching for scraps.",
    "The cat stretched lazily in the warmth of the sunbeam.",
    "Dogs barked joyfully in the park.",
    "A sleek black cat slinked along the fence.",
    "The old dog snoozed contentedly by the fireplace.",
    "Cats darted through the alleyways, chasing shadows.",
    "A fluffy white cat napped peacefully on the windowsill.",

    "Dogs are furry friends who love to play fetch and cuddle with you.",
    "Cats are soft and independent pets that enjoy lounging in sunny spots.",
    "Dogs wag their tails when they're happy and bark to say hello.",
    "Cats purr when they're content and love to curl up in your lap.",

    "Fido loves going for car rides, he always sticks his head out the window with a goofy grin.",
    "the old dog hobbled over to greet us, his tail thumping gently against the floor.",
    "the cat perched regally on the windowsill, surveying its outdoor kingdom. ",
    "the cat stalked a dust bunny across the floor, pouncing with laser focus. ",
    "the neighbor's cat, notorious for its thievery, snuck into our yard and made off with a shiny red ball of yarn. ",
    "curled up with a good book, I felt a soft nudge – my cat, wanting some attention, was rubbing against my leg. ",
    "the therapy dog, with its gentle demeanor, brought a wave of calm to the anxious patients in the waiting room.",
    "despite their different personalities, the dog and cat often napped curled up together. ",
    "we need to buy more cat food, Whiskers seems to be inhaling everything in the bow. ",
    "despite being a scaredy cat, Luna the ginger bravely explored every corner of the new house. ",
    "Max the dog spent all afternoon digging a hole in the backyard, much to the gardener's dismay. ",
    "the smell of freshly baked cookies lured the cat out from its hiding spot. ",
    "during thunderstorms, Milo the dog would huddle under the bed, trembling uncontrollably. ",
    "we adopted a pair of playful kittens, and now our living room is a whirlwind of fur and feathery toys. ",
    "every morning, the rooster crows and the dog barks, creating a chaotic symphony to wake up the household. "
]

test_corpus_list = [
    "this animal is dog ",
    "the old cat hobbled ",
    "the dog perched regally ",
    "quick brown cat jumps ",
    "cat and dog often ",
    "the dog out from ",
    "Fido the cat would ",
    "Luna the cat spent ",
    "neighbor's dog ",
    "cat stalked a dust ",
    "buy more cat food",
    "cat and the dog barks",
]

### CBOWモデル

In [5]:
embedding_dim = 16
num_epoch = 300
learning_rate = 0.001
window_size = 3

In [6]:
## インスタンス化と前処理
train_corpus = word2vec.modify(corpus_list, window_size)

vocab = Vocab(train_corpus)
train_dataset = word2vec.TextDataset(vocab, train_corpus, window_size)
train_dataloader = word2vec.DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataset = word2vec.TextDataset(vocab, '', window_size)
test_dataset.test_corpus(test_corpus_list)

criterion = nn.CrossEntropyLoss()
model = word2vec.CBOW(vocab.vocab_size, embedding_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
print(vocab.vocab_size)

267


In [9]:
train_iter = iter(train_dataloader)
next(train_iter)

{'source': tensor([[  0,   0,   0],
         [  0,   0,   0],
         [203,  18, 225],
         [ 22,   8,   0]]),
 'target': tensor([ 0,  0, 82,  0])}

In [10]:
## 学習ループ
for epoch in range(num_epoch):
    total_loss = 0
    running_loss = 0.0
    for batch in train_dataloader:

        # 入力データと教師データの取得
        context_indices, target_index = batch['source'], batch['target']
        # 推論
        output = model(context_indices)
        # 損失計算
        loss = criterion(output, target_index)

        # 勾配の初期化
        optimizer.zero_grad()
        # 勾配の計算
        loss.backward()
        # 学習パラメータ（重み）の更新
        optimizer.step()
        # 損失の累計
        running_loss += loss.item()

    # 損失の平均
    total_loss = running_loss / len(train_dataloader)

    # n回に一度の処理
    if epoch % 10 == 0:
        # ログの出力
        print(f'Epoch {epoch}, Loss: {total_loss:.4f}')
        # テスト
        test_dataset.test(model)

Epoch 0, Loss: 5.0240
the dog perched : regally : eagerly
Epoch 10, Loss: 2.7619
neighbor ' s : dog : the
Epoch 20, Loss: 2.2051
the dog perched : regally : .
Epoch 30, Loss: 1.7555
luna the cat : spent : ,
Epoch 40, Loss: 1.3978
the old cat : hobbled : snoozed
Epoch 50, Loss: 1.1200
and the dog : barks : ,
Epoch 60, Loss: 0.9119
cat and dog : often : often
Epoch 70, Loss: 0.7527
luna the cat : spent : ,
Epoch 80, Loss: 0.6345
the dog perched : regally : regally
Epoch 90, Loss: 0.5477
quick brown cat : jumps : jumps
Epoch 100, Loss: 0.4828
the dog perched : regally : regally
Epoch 110, Loss: 0.4341
and the dog : barks : cat
Epoch 120, Loss: 0.4009
the old cat : hobbled : snoozed
Epoch 130, Loss: 0.3764
fido the cat : would : for
Epoch 140, Loss: 0.3580
the dog out : from : would
Epoch 150, Loss: 0.3461
quick brown cat : jumps : jumps
Epoch 160, Loss: 0.3359
neighbor ' s : dog : cat
Epoch 170, Loss: 0.3311
the dog perched : regally : and
Epoch 180, Loss: 0.3247
fido the cat : would : fo

### 予測

In [21]:
test_dataset.test(model) ## 次単語予測
model.embeddings.weight ## 重み

cat stalked a : dust : dust


Parameter containing:
tensor([[ 0.4541,  0.6885, -0.5250,  ...,  0.0195, -0.1688, -1.5878],
        [ 0.6066,  0.2463,  0.6090,  ...,  1.2772, -0.5622,  0.6622],
        [-0.5138,  0.1587, -0.9021,  ...,  0.2716, -0.5239,  0.7931],
        ...,
        [ 3.5551, -2.0100, -1.3904,  ..., -2.0147, -3.9390, -2.1982],
        [-1.1596,  1.7805, -4.9148,  ..., -3.6662, -3.8871,  1.0931],
        [-1.6684, -0.4187,  0.3246,  ...,  4.3551, -3.6439, -1.3612]],
       requires_grad=True)