<a href="https://colab.research.google.com/github/Arindam-18/BTP/blob/main/Word2Vec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install portalocker

In [None]:
import os
import shutil
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import WikiText103
from torchtext.vocab import build_vocab_from_iterator
from tqdm import tqdm

In [None]:
EMBEDDING_DIMENSION = 150
EMBEDDING_MAX_NORM = 1
MINIMUM_FREQUENCY = 75
MINIMUM_LENGTH = 4
MAXIMUM_LENGTH = 256

In [None]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size: int):
        super(Word2Vec, self).__init__()
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBEDDING_DIMENSION,
            max_norm=EMBEDDING_MAX_NORM,
        )
        self.linear = nn.Linear(
            in_features=EMBEDDING_DIMENSION, out_features=vocab_size
        )

    def forward(self, inp):
        x = self.embeddings(inp)
        x = x.mean(axis=1)
        x = self.linear(x)
        return x

In [None]:
class Trainer:
    def __init__(
        self,
        model,
        epochs,
        train_data_loader,
        train_steps,
        val_data_loader,
        val_steps,
        checkpoint_frequency,
        loss_fn,
        optimizer,
        model_dir,
    ):
        self.model = model
        self.epochs = epochs
        self.train_data_loader = train_data_loader
        self.train_steps = train_steps
        self.val_data_loader = val_data_loader
        self.val_steps = val_steps
        self.checkpoint_frequency = checkpoint_frequency
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.lr_scheduler = LambdaLR(
            optimizer, lr_lambda=lambda epoch: (epochs - epoch) / epochs, verbose=True
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_dir = model_dir

        self.loss = {"train": [], "val": []}
        self.model.to(self.device)

    def train(self):
        for epoch in range(self.epochs):
            print("Training...")
            self.train_epoch()
            print("Validating...")
            self.val_epoch()
            print(
                "Epoch: {} of {}\nTrain Loss = {:.5f}\nValidation Loss = {:.5f}\n".format(
                    epoch + 1, self.epochs, self.loss["train"][-1], self.loss["val"][-1]
                )
            )
            self.lr_scheduler.step()
            if self.checkpoint_frequency:
                self.save_checkpoint(epoch)

    def train_epoch(self):
        self.model.train()
        running_loss = []
        for idx, batch_data in enumerate(tqdm(self.train_data_loader), 1):
            inputs = batch_data[0].to(self.device)
            labels = batch_data[1].to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss.append(loss.item())

            if idx == self.train_steps:
                break

        epoch_loss = np.mean(running_loss)
        self.loss["train"].append(epoch_loss)

    def val_epoch(self):
        self.model.eval()
        running_loss = []
        with torch.no_grad():
            for idx, batch_data in enumerate(tqdm(self.val_data_loader), 1):
                inputs = batch_data[0].to(self.device)
                labels = batch_data[1].to(self.device)

                outputs = self.model(inputs)
                loss = self.loss_fn(outputs, labels)

                running_loss.append(loss.item())

                if idx == self.val_steps:
                    break

        epoch_loss = np.mean(running_loss)
        self.loss["val"].append(epoch_loss)

    def save_checkpoint(self, epoch):
        epoch_num = epoch + 1
        if epoch_num % self.checkpoint_frequency == 0:
            model_path = "checkpoint_{}.pt".format(str(epoch_num).zfill(3))
            model_path = os.path.join(self.model_dir, model_path)
            torch.save(self.model, model_path)

    def save_model(self):
        model_path = os.path.join(self.model_dir, "model.pt")
        torch.save(self.model, model_path)

In [None]:
def build_vocab(data_iter, tokenizer):
    vocab = build_vocab_from_iterator(
        map(tokenizer, data_iter), specials=["<unk>"], min_freq=MINIMUM_FREQUENCY
    )
    vocab.set_default_index(vocab["<unk>"])
    return vocab


def get_data_iterator(data_type):
    data_iter = WikiText103(split=data_type)
    data_iter = to_map_style_dataset(data_iter)
    return data_iter


def collate(batch, text_pipeline):
    batch_input, batch_output = [], []
    for text in batch:
        token_ids = text_pipeline(text)

        if len(token_ids) < MINIMUM_LENGTH * 2 + 1:
            continue
        if MAXIMUM_LENGTH:
            token_ids = token_ids[:MAXIMUM_LENGTH]

        for idx in range(len(token_ids) - MINIMUM_LENGTH * 2):
            token_sequence = token_ids[idx : (idx + 1 + MINIMUM_LENGTH * 2)]
            out = token_sequence.pop(MINIMUM_LENGTH)
            inp = token_sequence
            batch_input.append(inp)
            batch_output.append(out)

    batch_input = torch.tensor(batch_input, dtype=torch.long)
    batch_output = torch.tensor(batch_output, dtype=torch.long)
    return batch_input, batch_output


def get_data_loader_and_vocab(data_type, batch_size, shuffle, vocab=None):
    data_iter = get_data_iterator(data_type)
    tokenizer = get_tokenizer("basic_english", language="en")
    if vocab is None:
        vocab = build_vocab(data_iter, tokenizer)
    text_pipeline = lambda x: vocab(tokenizer(x))
    data_loader = DataLoader(
        data_iter,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=partial(collate, text_pipeline=text_pipeline),
    )
    return data_loader, vocab


def save_vocab(vocab, model_dir):
    vocab_path = os.path.join(model_dir, "vocab.pt")
    torch.save(vocab, vocab_path)


def save_config(config: dict, model_dir: str):
    config_path = os.path.join(model_dir, "config.yaml")
    with open(config_path, "w") as stream:
        yaml.dump(config, stream)


def train(config):
    if os.path.isdir(config["model_dir"]):
        shutil.rmtree(config["model_dir"])
    os.makedirs(config["model_dir"])

    train_data_loader, vocab = get_data_loader_and_vocab(
        data_type="train",
        batch_size=config["train_batch_size"],
        shuffle=config["shuffle"],
    )

    val_data_loader, _ = get_data_loader_and_vocab(
        data_type="valid",
        batch_size=config["val_batch_size"],
        shuffle=config["shuffle"],
        vocab=vocab,
    )

    vocab_size = len(vocab.get_stoi())
    print(f"Vocabulary Size: {vocab_size}\n")

    model = Word2Vec(vocab_size)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])

    trainer = Trainer(
        model=model,
        epochs=config["epochs"],
        train_data_loader=train_data_loader,
        train_steps=config["train_steps"],
        val_data_loader=val_data_loader,
        val_steps=config["val_steps"],
        checkpoint_frequency=config["checkpoint_frequency"],
        loss_fn=loss_fn,
        optimizer=optimizer,
        model_dir=config["model_dir"],
    )

    trainer.train()
    print("Training Finished.")
    trainer.save_model()

    save_vocab(vocab, config["model_dir"])
    save_config(config, config["model_dir"])
    print("Model artifacts saved to folder:", config["model_dir"])

In [None]:
config = {
    "train_batch_size": 64,
    "val_batch_size": 64,
    "shuffle": True,
    "learning_rate": 0.05,
    "epochs": 10,
    "train_steps": None,
    "val_steps": None,
    "checkpoint_frequency": 2,
    "model_dir": "/content/drive/MyDrive/Word2Vec",
}

train(config)