In [180]:
import os
from functools import partial
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
from torch.utils.data import random_split

# Utils
from utils import _add_docstring_header, _create_dataset_directory, _wrap_split_argument
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [181]:
# Loading IMDB Pipeline From https://github.com/pytorch/data/blob/main/examples/text/imdb.py
URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

MD5 = "7c2ac02c03563afcf9b574c7e56c153a"

NUM_LINES = {
    "train": 25000,
    "test": 25000,
}

_PATH = "aclImdb_v1.tar.gz"

DATASET_NAME = "IMDB"

def _path_fn(root, path):
    return os.path.join(root, os.path.basename(path))


def _filter_fn(split, t):
    return Path(t[0]).parts[-3] == split and Path(t[0]).parts[-2] in ["pos", "neg"]


def _file_to_sample(t):
    return Path(t[0]).parts[-2], t[1].read().decode("utf-8")

In [182]:
@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def IMDB(root, split):
    """Demonstrates complex use case where each sample is stored in separate file and compressed in tar file
    Here we show some fancy filtering and mapping operations.
    Filtering is needed to know which files belong to train/test and neg/pos label
    Mapping is needed to yield proper data samples by extracting label from file name
        and reading data from file
    """

    url_dp = IterableWrapper([URL])
    # cache data on-disk
    cache_dp = url_dp.on_disk_cache(
        filepath_fn=partial(_path_fn, root),
        hash_dict={_path_fn(root, URL): MD5},
        hash_type="md5",
    )
    cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)

    cache_dp = FileOpener(cache_dp, mode="b")

    # stack TAR extractor on top of load files data pipe
    extracted_files = cache_dp.load_from_tar()

    # filter the files as applicable to create dataset for given split (train or test)
    filter_files = extracted_files.filter(partial(_filter_fn, split))

    # map the file to yield proper data samples
    sample = filter_files.map(_file_to_sample)
    
    def convlabel(x):
        r = None
        if x[0] == 'pos':
            r = 1
        elif x[0] == 'neg':
            r = 0
        else:
            r = -1
            raise(ValueError(f'Error: {x[0]} is not proper value'))
        
        return r, x[1]
    
    # sample = sample.map(lambda x: (1 if x[0]=='pos' else 0, x[1]))
    sample = sample.map(convlabel)
    # sample = sample.shuffle().set_shuffle(False).sharding_filter()
    sample = sample.shuffle().sharding_filter()
    return sample

In [183]:
SPLIT_SIZE = 0.2

train_iter = IMDB(split='train')

# Tokenize
tokenizer = get_tokenizer('spacy', 'en_core_web_md')

print(next(iter(train_iter)))
count = 0
for label, text in train_iter:
    if label == 1:
        count += 1
        
print(count)

(0, "This movie had a good story, but was brought down because it didn't have enough horror film elements and violence. It was like watching a live action cartoon. It would of been better if this story is what they planned from the start of the first movie so they could of played seeds for where the series was going.")
