# Setup

In [None]:
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.data import get_tokenizer
from torchtext.data.functional import to_map_style_dataset
from torchtext.datasets import WikiText2

from utils.model_config import ModelConfig
from utils.project_config import ProjectConfig
from utils.dataloader import Word2VecDataLoader
from utils.model import ModelType
from utils.model import CBOWModel, SkipGramModel
from utils.trainer import Trainer
from utils.helper import (
    get_lr_scheduler,
    save_vocab,
)

In [None]:
# Global configs.
project_config = ProjectConfig()
project_config.use_cuda = True

project_config.model_type = ModelType.SKIPGRAM
project_config.criterion = nn.CrossEntropyLoss

project_config.learning_rate = 0.025
project_config.epochs = 8

project_config.data_dir = "data/"
project_config.model_dir = "models/"
project_config.train_batch_size = 16
project_config.val_batch_size = 16
project_config.shuffle = True
project_config.checkpoint_freq = 4

word2vec_config = ModelConfig()
word2vec_dataloader = Word2VecDataLoader(word2vec_config)
# Check current device.
if (project_config.use_cuda and torch.cuda.is_available()):
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Current Device: {device}")

# Construct DataLoader

In [None]:
# Get english tokenizer.
tokenizer = get_tokenizer("basic_english", language="en")
# Get WikiText2 dataset.
wiki_train, wiki_val, wiki_test = WikiText2()
wiki_train = to_map_style_dataset(wiki_train)
wiki_val = to_map_style_dataset(wiki_val)
wiki_test = to_map_style_dataset(wiki_test)

In [None]:
# Histogram of document length.
doc_len = [len(doc) for doc in wiki_train]
fig, ax = plt.subplots()
ax.hist(doc_len, bins=list(range(1, 1000, 100)))
fig.show()

In [None]:
# Build vocab from the dataset.
vocab = word2vec_dataloader.build_vocab(wiki_train, tokenizer)
vocab_size = len(vocab)
print(f"Vocab size: {vocab_size}")
# List first 10 vocabs.
print("First 10 tokens: ", vocab.get_itos()[:10])
# Encode an exist and non-existing token.
# The default index for non-existing token is 0.
print("Token id of [\"next\", \"fishball\"]: ", vocab(["next", "fishball"]))

In [None]:
# Build DataLoader
train_dataloader, vocab = word2vec_dataloader.get_dataloader_and_vocab(
    model=project_config.model_type,
    data_iter=wiki_train,
    batch_size=project_config.train_batch_size,
    shuffle=project_config.shuffle,
    vocab=vocab
)
val_dataloader, _ = word2vec_dataloader.get_dataloader_and_vocab(
    model=project_config.model_type,
    data_iter=wiki_val,
    batch_size=project_config.val_batch_size,
    shuffle=project_config.shuffle,
    vocab=vocab
)
print(f"Training data size: {len(train_dataloader)}")
print(f"Validation data size: {len(val_dataloader)}")

# Training

In [None]:
# Training setup.
if project_config.model_type == ModelType.CBOW:
    model = CBOWModel(vocab_size=vocab_size, config=word2vec_config)
else:
    model = SkipGramModel(vocab_size=vocab_size, config=word2vec_config)

criterion = project_config.criterion()
optimzier = optim.Adam(model.parameters(), lr=project_config.learning_rate)
lr_scheduler = get_lr_scheduler(optimzier, project_config.epochs, verbose=True)

trainer = Trainer(
    model=model,
    epochs=project_config.epochs,
    train_dataloader=train_dataloader,
    train_steps=project_config.train_steps,
    val_dataloader=val_dataloader,
    val_steps=project_config.val_steps,
    checkpoint_frequency=project_config.checkpoint_freq,
    criterion=criterion,
    optimizer=optimzier,
    lr_scheduler=lr_scheduler,
    device=device,
    model_dir=project_config.model_dir
)

In [None]:
# Train.
trainer.train()
print("Train finished.")

In [None]:
# Save artifacts.
trainer.save_model()
trainer.save_loss()
save_vocab(vocab, project_config.model_dir)
print(f"Model artifacts saved to folder: {project_config.model_dir}")

# Performance Visualization

In [None]:
train_loss = trainer.loss["train"]
val_loss = trainer.loss["val"]

fig, ax = plt.subplots()
ax.plot(train_loss, label="train_loss")
ax.plot(val_loss, label="val_loss")
ax.legend()
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
ax.set_title("Training and Validation Loss over Epochs")
fig.show()