In [88]:
# built in
from argparse import Namespace
from functools import partial

# torch
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader,default_collate
from torch.utils.data.backward_compatibility import worker_init_fn
## data
from torchdata import datapipes as dp
## test
from torchtext import vocab

# manipulation
import numpy as np


In [87]:
args = Namespace(
    # data
    data_base_path = "../data/cbow/",
    datasets = ["train","val","test"],
    delimiter = "#",
    
    # vocab
    mask_tkn = "<MASK>",
    unk_tkn = "<UKN>",
    window_size = 2,
    
    # model
    
    # training
    batches = 32,
    
    # running options
    cuda = torch.cuda.is_available(),
    device = "cuda" if torch.cuda.is_available() else "cpu",
)

# Datapipes

In [9]:
opener_dp = dp.iter.FileOpener([args.data_base_path+"train.csv"])
parser_dp = opener_dp.parse_csv(delimiter=args.delimiter)

In [56]:
def convert_to_tuples(row):
    return (row[0].split(" "),[row[1]])

In [57]:
def open_parsed_dict(args):
    csv_pipe_dict = {}
    for fname in args.datasets:
        csv_pipe_dict[fname] = dp.iter\
            .FileOpener([args.data_base_path+f"{fname}.csv"])\
            .parse_csv(delimiter=args.delimiter)\
            .map(convert_to_tuples)
    return csv_pipe_dict            

In [58]:
csv_parse_dict = open_parsed_dict(args)

In [59]:
next(iter(csv_parse_dict["train"]))

(['start', 'of', 'project', 'gutenberg'], ['the'])

In [62]:
def join_context_target_fn(row):
    return row[0]+row[1]

In [63]:
next(iter(csv_parse_dict["train"].map(join_context_target_fn)))

['start', 'of', 'project', 'gutenberg', 'the']

In [64]:
def build_vocab(train_pipe,args=args):
    combined_dp = train_pipe.map(join_context_target_fn)
    cbow_vocab = vocab.build_vocab_from_iterator(combined_dp,
                                                 specials=[args.unk_tkn,args.mask_tkn])
    cbow_vocab.set_default_index(cbow_vocab[args.unk_tkn])
    return cbow_vocab

In [65]:
cbow_vocab = build_vocab(csv_parse_dict["train"])

In [67]:
cbow_vocab.lookup_indices(['start', 'of', 'project', 'gutenberg'])

[14053, 5, 5328, 5006]

In [72]:
cbow_vocab.lookup_indices(['the'])[-1]

2

# build dataset

In [69]:
def vectorize(context,vocab,args):
    indices = vocab.lookup_indices(context)
    vector = np.zeros(args.window_size*2,dtype=np.float32)
    vector[:len(indices)] = indices
    vector[len(indices):] = vocab[args.mask_tkn]
    return vector

In [74]:
def create_dataset_dict(vocab,args,row):
    context_vector = vectorize(row[0],vocab=vocab,args=args)
    target_index = vocab.lookup_indices(row[1])[-1]
    return {"x":context_vector,
            "y":target_index}

In [82]:
def build_dataset_dict(csv_parse_dict,vocab,args=args):
    dataset_dict = {}
    fn = partial(create_dataset_dict,vocab,args)
    for dataset_name,csv_parse_pipe in csv_parse_dict.items():
        if dataset_name == "train":
            csv_parse_pipe = csv_parse_pipe.shuffle()
        
        dataset_dict[dataset_name] = csv_parse_pipe.map(fn).batch(args.batches)
    
    return dataset_dict

In [83]:
dataset_dict = build_dataset_dict(csv_parse_dict,cbow_vocab)

In [84]:
dataset_dict

{'train': BatcherIterDataPipe,
 'val': BatcherIterDataPipe,
 'test': BatcherIterDataPipe}

In [89]:
def collate_fn(args,x):
    return {k:v.to(args.device)
            for x_ in default_collate(x)
            for k,v in x_.items()}

In [90]:
def generate_batches(dataset,args,shuffle):
    dataloader = DataLoader(dataset=dataset,batch_size=args.batches,
                            shuffle=shuffle,drop_last=True,
                            collate_fn=partial(collate_fn,args),
                            worker_init_fn=worker_init_fn)
    for batch in dataloader:
        yield batch