In [1]:
# built in
from argparse import Namespace
from functools import partial
from pathlib import Path

# torch
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader,default_collate
from torch.utils.data.backward_compatibility import worker_init_fn
## others
from torchdata import datapipes as dp
from torchtext import vocab
from torchmetrics import Accuracy

# manipulation
import numpy as np

# others
from tqdm import tqdm

In [2]:
args = Namespace(
    # data
    data_base_path = "../data/cbow/",
    datasets = ["train","val","test"],
    delimiter = "#",
    
    # vocab
    mask_tkn = "<MASK>",
    unk_tkn = "<UKN>",
    window_size = 2,
    
    # model
    embedding_dim = 50,
    model_base_path = "../models/cbow/",
    model_filename = "model.pth",
    
    
    # training
    batches = 32,
    learning_rate = 0.001,
    num_epochs = 100,
    early_stopping_criteria = 5,
    
    # running options
    cuda = torch.cuda.is_available(),
    device = "cuda" if torch.cuda.is_available() else "cpu",
    seed = 1432
)

for k,v in args._get_kwargs():
    if "base" in k:
        Path(v).mkdir(parents=True,exist_ok=True)

# Datapipes

## open and parse 

In [3]:
def convert_to_tuples(row):
    return (row[0].split(" "),[row[1]])

In [4]:
def open_parsed_dict(args):
    csv_pipe_dict = {}
    for fname in args.datasets:
        csv_pipe_dict[fname] = dp.iter\
            .FileOpener([args.data_base_path+f"{fname}.csv"])\
            .parse_csv(delimiter=args.delimiter)\
            .map(convert_to_tuples)
    return csv_pipe_dict            

In [5]:
csv_parse_dict = open_parsed_dict(args)

### testing

In [69]:
next(iter(csv_parse_dict["train"]))

(['start', 'of', 'project', 'gutenberg'], ['the'])

## build vocab

In [6]:
def join_context_target_fn(row):
    return row[0]+row[1]


In [7]:
def build_vocab(train_pipe,args=args):
    combined_dp = train_pipe.map(join_context_target_fn)
    cbow_vocab = vocab.build_vocab_from_iterator(combined_dp,
                                                 specials=[args.unk_tkn,args.mask_tkn])
    cbow_vocab.set_default_index(cbow_vocab[args.unk_tkn])
    return cbow_vocab

In [8]:
cbow_vocab = build_vocab(csv_parse_dict["train"])

### testing

In [23]:
cbow_vocab.lookup_indices(['start', 'of', 'project', 'gutenberg'])

[6938, 5, 5181, 4859]

In [24]:
cbow_vocab.lookup_indices(['the'])[-1]

3

# build dataset

In [9]:
def vectorize(context,vocab,args):
    indices = vocab.lookup_indices(context)
    vector = np.zeros(args.window_size*2,dtype=np.float32)
    vector[:len(indices)] = indices
    vector[len(indices):] = vocab[args.mask_tkn]
    return vector

In [10]:
def create_dataset_dict(vocab,args,row):
    context_vector = vectorize(row[0],vocab=vocab,args=args)
    target_index = vocab.lookup_indices(row[1])[-1]
    return {"x":context_vector,
            "y":target_index}

In [11]:
def build_dataset_dict(csv_parse_dict,vocab,args=args):
    dataset_dict = {}
    fn = partial(create_dataset_dict,vocab,args)
    for dataset_name,csv_parse_pipe in csv_parse_dict.items():
        if dataset_name == "train":
            csv_parse_pipe = csv_parse_pipe.shuffle()
        
        dataset_dict[dataset_name] = csv_parse_pipe.map(fn).batch(args.batches)
    
    return dataset_dict

In [12]:
dataset_dict = build_dataset_dict(csv_parse_dict,cbow_vocab)
dataset_dict

{'train': BatcherIterDataPipe,
 'val': BatcherIterDataPipe,
 'test': BatcherIterDataPipe}

## generate batches

In [13]:
def collate_fn(args,x):
    return {k:v.to(args.device).long()
            for x_ in default_collate(x)
            for k,v in x_.items()}

In [14]:
def generate_batches(dataset,args,shuffle):
    dataloader = DataLoader(dataset=dataset,batch_size=args.batches,
                            shuffle=shuffle,drop_last=True,
                            collate_fn=partial(collate_fn,args),
                            worker_init_fn=worker_init_fn)
    for batch in dataloader:
        yield batch

## Model

In [15]:
class CBOWClassifier(nn.Module):
    def __init__(self,vocabulary_size,embedding_size,padding_idx=0) -> None:
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocabulary_size,
                                      embedding_dim=embedding_size,
                                      padding_idx=padding_idx)
        self.fc = nn.Linear(in_features=embedding_size,
                            out_features=vocabulary_size)
        
    
    def forward(self,input,apply_softmax=False):
        embed = self.embedding(input)
        embed_sum = F.dropout(embed.sum(dim=1),0.3)
        out = self.fc(embed_sum)
        if apply_softmax:
            out = torch.softmax(out,dim=1)
        return out

### test

In [48]:
classifier = CBOWClassifier(vocabulary_size=len(cbow_vocab),
                            embedding_size=args.embedding_dim)

In [49]:
sample = next(iter(generate_batches(dataset_dict["train"],
                                    args=args,shuffle=True)))

In [50]:
sample

{'x': tensor([[6175,    5,  492,    5]]), 'y': tensor([3])}

In [52]:
classifier(sample["x"]).shape

torch.Size([1, 6956])

# Training 

## Helper Function

In [16]:
def make_train_state(args=args):
    return {"stop_early":False,
            "early_stopping_step":0,
            "early_stopping_val":1e8,
            "learning_rate":args.learning_rate,
            "epoch_index":0,
            "model_filename":args.model_base_path+args.model_filename,
            "train_loss":[],
            "train_acc":[],
            "val_loss":[],
            "val_acc":[],
            "test_loss":-1,
            "test_acc":-1}
    
def update_train_state(train_state,model,args=args):
    if train_state["epoch_index"] == 0:
        torch.save(model.state_dict(),train_state["model_filename"])
        train_state["stop_early"] = False
        
    elif train_state["epoch_index"] >= 1:
        loss_tm1,loss_t = train_state["val_loss"][-2:]
        if loss_t >=  train_state["early_stopping_val"]:
            train_state["early_stopping_step"] += 1
        else:
            torch.save(model.state_dict(),train_state["model_filename"])
            train_state["early_stopping_step"] = 0 
            
        train_state["stop_early"] = train_state["early_stopping_step"] >= args.early_stopping_criteria
    return train_state

In [17]:
def set_seed_everywhere(seed,cuda):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)
        
set_seed_everywhere(args.seed,args.cuda)

## initialize

In [18]:
classifier = CBOWClassifier(vocabulary_size=len(cbow_vocab),
                            embedding_size=args.embedding_dim).to(args.device)
loss_fn = nn.CrossEntropyLoss()
acc_fn = Accuracy(task="multiclass",num_classes=len(cbow_vocab))
optimizer = optim.Adam(params=classifier.parameters(),
                       lr=args.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 mode="min",factor=0.5,
                                                 patience=1)
train_state = make_train_state(args)

In [19]:
for epoch_index in tqdm(range(args.num_epochs)):
    train_state["epoch_index"] = epoch_index
    
    # init the running variable
    running_loss = 0.0
    running_acc = 0.0
    
    # get the train dataloader
    batch_generator = generate_batches(dataset=dataset_dict["train"],
                                       args=args,shuffle=True)
    
    # put the model in training mode
    classifier.train()
    
    for batch_idx,batch_dict in enumerate(batch_generator):
        optimizer.zero_grad()
        logits = classifier(batch_dict["x"])
        # compute the loss per batch
        loss = loss_fn(logits,batch_dict["y"])
        loss_t = loss.item()
        running_loss += (loss_t - running_loss) / (batch_idx+1)
        
        # compute the acc per batch
        acc = acc_fn(logits,batch_dict["y"])
        acc_t = acc.item()
        running_acc += (acc_t - running_acc) / (batch_idx+1)
        
        loss.backward()
        
        optimizer.step()
        
    train_state["train_loss"].append(running_loss)
    train_state["train_acc"].append(running_acc)
    
    # iterate over the val dataset
    
    # init the running variable
    running_loss = 0.0
    running_acc = 0.0
    
    # get the val dataloader
    batch_generator  = generate_batches(dataset=dataset_dict["val"],
                                        args=args,shuffle=False)
    
    # put the model in eval mode
    classifier.eval()
    
    for batch_idx,batch_dict in enumerate(batch_generator):
        with torch.inference_mode():
            logits = classifier(batch_dict["x"])
            
            # compute the loss
            loss = loss_fn(logits,batch_dict["y"])
            loss_t = loss.item()
            running_loss += (loss_t-running_loss) /(batch_idx+1)
            
            # compute the acc
            acc = acc_fn(logits,batch_dict["y"])
            acc_t = acc.item()
            running_acc += (acc_t-running_acc) /(batch_idx+1)
            
    train_state["val_loss"].append(running_loss)
    train_state["val_acc"].append(running_acc)
    
    
    train_state = update_train_state(train_state=train_state,
                                     model=classifier,
                                     args=args)
    
    scheduler.step(train_state["val_loss"][-1])
    if train_state["stop_early"]:
        break          

100%|██████████| 100/100 [10:50<00:00,  6.50s/it]


# trained embedding

In [29]:
def get_closest(target_word,vocab,embedding,args,n=5):
    word_embedding = embedding[vocab.lookup_indices([target_word.lower()])[-1]]
    distance = []
    for word,index in vocab.get_stoi().items():
        if word in [args.mask_tkn,args.unk_tkn,target_word.lower()]:
            continue
        distance.append((word,torch.dist(word_embedding,embedding[index])))
    
    results = sorted(distance,key=lambda x:x[1])[1:n+2]
    return results

In [31]:
embeddings = classifier.embedding.weight.data
get_closest("science",cbow_vocab,embeddings,args)

[('simpler', tensor(7.6247)),
 ('tale', tensor(7.6839)),
 ('wished', tensor(7.7530)),
 ('malignity', tensor(7.7582)),
 ('timid', tensor(7.8019)),
 ('decides', tensor(7.8025))]