### Note: We need preprocessed texts (tokinized and numericalized) in this notebook

> Notebook based on:
> 1. https://github.com/fastai/course-v3/blob/master/nbs/dl2/12_text.ipynb
> 2. https://github.com/fastai/course-v3/blob/master/nbs/dl2/12a_awd_lstm.ipynb
> 3. https://github.com/fastai/course-v3/blob/master/nbs/dl2/12b_lm_pretrain.ipynb
> 4. https://github.com/fastai/course-v3/blob/master/nbs/dl2/12c_ulmfit.ipynb
> 
> Video:
> - https://youtu.be/vnOpEwmtFJ8?t=4687 from 1:18:00 to 2:08:00 (50 mins)

# Imports

In [1]:
import numpy as np
import pathlib
from tqdm.notebook import tqdm
from collections import Counter, defaultdict

import torch
from torch.utils.data import Dataset, DataLoader, Sampler

# Data

In [2]:
!ls "../../Datasets/NLP/IMBd_prepro"

test  train  unsup  vocab.pkl


In [3]:
!ls "../../Datasets/NLP/IMBd_prepro/train"

neg  pos


# <center> Dataset & Dataloader for classification

- **Dataset**
  - **X**: Some unique text (a review in IMDb dataset)
  - **Y**: Some class label (pos or neg in IMDb dataset)
- **Dataloader**
  - Batch: We need **padding** for dealing with texts of different lenghts.
  - Sampler: To avoid mixing very long texts with very short ones, we will also use `Sampler` to sort (with a bit of randomness for the training set) our samples by length.

In [37]:
class FolderDataset(Dataset):
    
    def __init__(self, root_dir, file_extensions, x_tfms=None):
        root_dir = pathlib.Path(root_dir)

        def dir_conditions(folder: pathlib.PosixPath) -> bool:
            is_folder  = folder.is_dir()
            not_empty  = any(folder.iterdir()) if is_folder else False
            not_hidden = not folder.name.startswith('.')
            return is_folder and not_empty and not_hidden
        
        def file_conditions(file: pathlib.PosixPath) -> bool:
            is_file    = file.is_file()
            good_file  = file.name.endswith(file_extensions) # str or tuple of strings
            not_hidden = not file.name.startswith('.')
            return is_file and good_file and not_hidden

        #               __element__  ___________________________loop___________________________   
        self.classes = [folder.name  for folder in root_dir.iterdir() if dir_conditions(folder)]
        self.classes.sort()
        
        self.class2num = {cls_name: i for i,cls_name in enumerate(self.classes)}

        #                 ______________xy_element______________    
        self.samples = [ (str(file), self.class2num[folder.name]) 
        #                _________________________1st loop_________________________
                         for folder in root_dir.iterdir() if dir_conditions(folder)
        #                ______________________2nd loop______________________
                         for file in folder.iterdir() if file_conditions(file) ]
        
        self.x_tfms = x_tfms
            
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx: int):
        x, y = self.samples[idx]
        if self.x_tfms: x = self.x_tfms(x)
        
        return torch.tensor(x), torch.tensor(y)

In [31]:
ds_train = FolderDataset(root_dir        = '../../Datasets/NLP/IMBd_prepro/train',
                         file_extensions = (".npy"),
                         x_tfms          = lambda path: np.load(path).astype("int64") )

ds_valid = FolderDataset(root_dir        = '../../Datasets/NLP/IMBd_prepro/test',
                         file_extensions = (".npy"),
                         x_tfms          = lambda path: np.load(path).astype("int64") )

In [35]:
ds_train[0]

(tensor([    2,     7,    64,    12,  1405,    29,    15,    19,    18,    41,
            37,   690,    14,  1690,     9,    18,    58,    35,   140,   109,
           234,     8,   706,    13,     8,    29,    15, 31838,    16,   155,
            42, 50903,    30,  3837,    74,    43,   345,    78,   114,    19,
            29,    11,     8,   137,    47,     7,  8030,    15,    56,  1405,
             8,   348,   331,    13,     8,    29,   181,  2388,  1532,    11,
             8,    29,    61,    12,    83,    67,  2102,   291,    80,   181,
             8,    29,   429,     9,    32,   155,   126,    19,    29,   301,
            10,    18,   122,  1559,    88,    16,    22,    72,    14, 19687,
             7,  3837,    18,   133,    19,    29,    11,    16,    22,    37,
            56,   635,    11,   172,    30,    16,   103,   426,  2388,    12,
            83,    67,  9245,    11,    64,  2388,   155,    42,   422,    52,
            18,   122,    19,    15,    43,    13,  

In [36]:
ds_train[-1]

(tensor([    2,     7,   365,    10,   740,    10,    47,    23,     8,    23,
          1637,   201,  5532,    45,   242,    61,   206,    67,   326,   306,
            11,   983,    13,    36,  2952,    33,  4889,    11,   614,    14,
          4807,    36,  2952,    33,   502,   467,     9,     7,   676,    16,
            22,    59,  1342,  4954, 20939,    34,    12,   288,    13,   121,
            90,    58,    35,   479,    48, 15666,    59,    23,    63,    20,
            22,   147,   169,    10,  3085,    17,     9,    36,   194,   383,
           124,   263,    33,     3]),
 tensor(0))

# Dataloader with Custom sampler and collate_fn

### Samplers for putting texts of similar lenght together
- **For the validation set**: we will simply sort the samples by length, and we begin with the longest ones for memory reasons (it's better to always have the biggest tensors first).

In [None]:
# Sampler for validation
class SortSampler(Sampler):
    
    def __init__(self, data_source, key):
        self.data_source = data_source
        self.key         = key
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        return iter(sorted(list(range(len(self.data_source))), key=self.key, reverse=True))

# Sampler for training
class SortishSampler(Sampler):
    def __init__(self, data_source, key, bs):
        self.data_source,self.key,self.bs = data_source,key,bs

    def __len__(self) -> int: return len(self.data_source)

    def __iter__(self):
        idxs = torch.randperm(len(self.data_source))
        megabatches = [idxs[i:i+self.bs*50] for i in range(0, len(idxs), self.bs*50)]
        sorted_idx = torch.cat([tensor(sorted(s, key=self.key, reverse=True)) for s in megabatches])
        batches = [sorted_idx[i:i+self.bs] for i in range(0, len(sorted_idx), self.bs)]
        max_idx = torch.argmax(tensor([self.key(ck[0]) for ck in batches]))  # find the chunk with the largest key,
        batches[0],batches[max_idx] = batches[max_idx],batches[0]            # then make sure it goes first.
        batch_idxs = torch.randperm(len(batches)-2)
        sorted_idx = torch.cat([batches[i+1] for i in batch_idxs]) if len(batches) > 1 else LongTensor([])
        sorted_idx = torch.cat([batches[0], sorted_idx, batches[-1]])
        return iter(sorted_idx)

### Custom Collate for adding PADDING

In [47]:
x = list([torch.randint(0, 10, (x,)) for x in range(1, 11)])
y = list(range(10,20))
x, y

([tensor([0]),
  tensor([8, 0]),
  tensor([1, 1, 8]),
  tensor([2, 3, 6, 9]),
  tensor([2, 8, 5, 5, 1]),
  tensor([3, 4, 2, 1, 6, 1]),
  tensor([4, 3, 4, 0, 3, 3, 3]),
  tensor([9, 9, 5, 9, 1, 2, 7, 7]),
  tensor([7, 1, 5, 3, 0, 5, 9, 7, 1]),
  tensor([0, 6, 6, 5, 9, 9, 4, 8, 1, 2])],
 [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

In [48]:
dataset = list(zip(x,y))
dataset[5]

(tensor([3, 4, 2, 1, 6, 1]), 15)

In [57]:
from torch.nn.utils.rnn import pad_sequence

def pad_collate(batch):
    # batch looks like [(x0,y0), (x4,y4), (x2,y2)... ]
    batch_x, batch_y = zip(*batch) 
    
    batch_x = pad_sequence(batch_x, batch_first=True, padding_value=1)
    batch_y = torch.tensor(batch_y)
    
    return batch_x, batch_y

dl = torch.utils.data.DataLoader(dataset,
                                 batch_size=3,
                                 shuffle=True,
                                 collate_fn=pad_collate)
for xb, yb in dl:
    print(xb)

tensor([[0, 6, 6, 5, 9, 9, 4, 8, 1, 2],
        [8, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [9, 9, 5, 9, 1, 2, 7, 7, 0, 0]])
tensor([[3, 4, 2, 1, 6, 1],
        [0, 0, 0, 0, 0, 0],
        [1, 1, 8, 0, 0, 0]])
tensor([[7, 1, 5, 3, 0, 5, 9, 7, 1],
        [2, 8, 5, 5, 1, 0, 0, 0, 0],
        [4, 3, 4, 0, 3, 3, 3, 0, 0]])
tensor([[2, 3, 6, 9]])
