https://towardsdatascience.com/use-torchtext-to-load-nlp-datasets-part-i-5da6f1c89d84?fbclid=IwAR1zjXlM5w93z47QalvvWeX7OQkIRGL8KK8dAkHepITnk81XFJt_g_FKdVE

The code is heavily inspired from the above blogpost :-) 

In [15]:
import re
import logging

import numpy as np
import pandas as pd
import spacy
import torch
import csv
from torchtext import data
import urllib.request

In [16]:
#The User data 
User_data = pd.read_csv('data/user-info.csv', usecols = ['user.id','doc.id','rating'])
User_data = User_data.rename(columns={'user.id': 'user_id','doc.id': 'doc_id'})


In [17]:
# The article dataset
article_data = pd.read_csv('data/raw-data.csv', usecols = ['doc.id','title','citeulike.id', 'raw.abstract'],encoding = "ISO-8859-1")
article_data = article_data.rename(columns={'raw.abstract': 'abstract','doc.id': 'doc_id','citeulike.id': 'citeulike_id'})


In [18]:
CiteULike_data=pd.merge(User_data,article_data,on="doc_id")

In [19]:
le = max(len(x) for x in CiteULike_data.abstract)
print(le)

122938


In [20]:
NLP = spacy.load('en')
MAX_CHARS = 1229381
VAL_RATIO = 0.2
LOGGER = logging.getLogger("CiteULike_data")

In [21]:
def tokenizer(abs_text):
    abs_text = re.sub(
        r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
        str(abs_text))
    abs_text = re.sub(r"[ ]+", " ", str(abs_text))
    abs_text = re.sub(r"\!+", "!", str(abs_text))
    abs_text = re.sub(r"\,+", ",", str(abs_text))
    abs_text = re.sub(r"\?+", "?", str(abs_text))
    if (len(abs_text) > MAX_CHARS):
        abs_text = abs_text[:MAX_CHARS]
    return [
        x.text for x in NLP.tokenizer(abs_text) if x.text != " "]

In [22]:
VAL_RATIO = 0.2

def prepare_csv(seed=999):
    df_train = CiteULike_data
    df_train["abstract"] = \
        df_train.abstract.str.replace("\n", " ")
    idx = np.arange(df_train.shape[0])
    np.random.seed(seed)
    np.random.shuffle(idx)
    val_size = int(len(idx) * VAL_RATIO)
    df_train.iloc[idx[val_size:], :].to_csv(
        "cache/dataset_train.csv", index=False)
    df_train.iloc[idx[:val_size], :].to_csv(
        "cache/dataset_val.csv", index=False)

In [23]:
def get_dataset(fix_length=100, lower=False, vectors=None):
    if vectors is not None:
        # pretrain vectors only supports all lower cases
        lower = True
    LOGGER.debug("Preparing CSV files...")
    prepare_csv()
    abs_text = data.Field(
        sequential=True,
        fix_length=fix_length,
        tokenize=tokenizer,
        pad_first=True,
        tensor_type=torch.LongTensor,
        lower=lower
    )
    LOGGER.debug("Reading train csv file...")
    train, val = data.TabularDataset.splits(
        path='cache/', format='csv', skip_header=True,
        train='dataset_train.csv', validation='dataset_val.csv',
        fields=[
            ('abstract', abs_text),
            ('doc_id', data.Field(
                use_vocab=False, sequential=False, tensor_type=torch.ByteTensor)),
            ('rating', data.Field(
                use_vocab=False, sequential=False, tensor_type=torch.ByteTensor)),
            ('title', data.Field(
                use_vocab=False, sequential=False, tensor_type=torch.ByteTensor)),
            ('citeulike_id', data.Field(
                use_vocab=False, sequential=False, tensor_type=torch.ByteTensor)),
            ('user_id', data.Field(
                use_vocab=False, sequential=False, tensor_type=torch.ByteTensor)),
        ])
    LOGGER.debug("Building vocabulary...")
    abs_text.build_vocab(
        train, val,
        max_size=1229381,
        min_freq=20,
        vectors=vectors
    )
    LOGGER.debug("Done preparing the datasets")
    return train, val


In [32]:
def get_iterator(dataset, batch_size, train=True, shuffle=True, repeat=False, sort=None,sort_key=None):
    dataset_iter = data.Iterator(
        dataset, batch_size=batch_size, device=0,
        train=train, shuffle=shuffle, repeat=repeat,
        sort=sort,sort_key=sort_key
    )
    return dataset_iter

In [34]:
x=get_iterator(
            CiteULike_data, 100,sort=None ,train=True,
            shuffle=True, repeat=False,sort_key=True
        )

In [35]:
len(x)

2050

In [36]:
len(CiteULike_data)

204986