In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [2]:
#export
from exp.nb_11a import *

In [3]:
#export
class SplitData():
    def __init__(self, train, valid):
        self.train, self.valid = train, valid

    def __getattr__(self, k):
        return getattr(self.train, k)
    
    #This is needed if we want to pickle SplitData and be able to load it back without recursion errors
    def __setstate__(self,data:Any):
        self.__dict__.update(data)

    @classmethod
    def split_by_func(cls, il, f):
        lists = map(il.new, split_by_func(il, f)) # il or il.items (il.items in github nb_08.py)
        return cls(*lists)

    def __repr__(self):
        return f'{self.__class__.__name__}\nTrain: {self.train}\nValid: {self.valid}\n'

In [4]:
#export
def _label_by_func(ds, f, cls=ItemList):
    return cls([f(o) for o in ds.items], path=ds.path)

class LabeledData():
    def process(self, il, proc): return il.new(compose(il.items, proc))

    def __init__(self, x, y, proc_x=None, proc_y=None):
        self.x,self.y = self.process(x, proc_x),self.process(y, proc_y)
        self.proc_x,self.proc_y = proc_x,proc_y

    def __repr__(self): return f'{self.__class__.__name__}\nx: {self.x}\ny: {self.y}\n'
    def __getitem__(self,idx): return self.x[idx],self.y[idx]
    def __len__(self): return len(self.x)

    def x_obj(self, idx): return self.obj(self.x, idx, self.proc_x)
    def y_obj(self, idx): return self.obj(self.y, idx, self.proc_y)

    def obj(self, items, idx, procs):
        isint = isinstance(idx, int) or (isinstance(idx,torch.LongTensor) and not idx.ndim)
        item = items[idx]
        for proc in reversed(listify(procs)):
            item = proc.deproc1(item) if isint else proc.deprocess(item)
        return item

    @classmethod
    def label_by_func(cls, il, f, proc_x=None, proc_y=None):
        return cls(il, _label_by_func(il, f), proc_x=proc_x, proc_y=proc_y)
    
def label_by_func(sd, f, proc_x=None, proc_y=None):
    train = LabeledData.label_by_func(sd.train, f, proc_x=proc_x, proc_y=proc_y)
    valid = LabeledData.label_by_func(sd.valid, f, proc_x=proc_x, proc_y=proc_y)
    return SplitData(train,valid)

# Data

In [5]:
#export
imdb_path = 'https://s3.amazonaws.com/fast-ai-nlp/imdb'
path = untar_data(imdb_path)

In [6]:
path.ls()

[PosixPath('/home/ubuntu/learnai/dl/data/imdb/README'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/ll_clas.pkl'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/imdb.vocab'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/test'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/train'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/vocab_lm.pkl'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/ll_lm.pkl'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/tmp_lm'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/tmp_clas'),
 PosixPath('/home/ubuntu/learnai/dl/data/imdb/ld.pkl')]

In [7]:
#export
def read_file(fn):
    with open(fn, 'r', encoding='utf8') as f:
        return f.read()
    
class TextList(ItemList):
    @classmethod
    def from_files(cls, path, extensions='.txt', recurse=True, include=None, **kwargs):
        return cls(get_files(path, extensions, recurse=recurse, include=include), path, **kwargs)
    
    def get(self, i):
        if isinstance(i, Path):
            return read_file(i)
        return i

In [8]:
il = TextList.from_files(path, include=['train', 'test', 'unsup'])

In [9]:
len(il.items)

100000

In [10]:
il

TextList(100000 items)
[PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/10274_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/5668_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/8620_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/19150_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/37632_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/21939_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/20835_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/14011_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/20552_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/33059_0.txt')...]
Path: /home/ubuntu/learnai/dl/data/imdb

In [11]:
txt = il[0]
txt

"Well, this a very good TV series from our neighbor country Turkey. Its a series about romance and love.Very good indeed.Here in Greece we had a similar series but is was a dramatic one (cant understand why Greek writers are stuck on this kind...drama is not always good).To be exact I prefer the Turkish one cause combines lots of good things (characters, script etc) and cause as a curious Greek I'd love to see the Turkish point of view...with this one u can see the common cultural background of modern Greece and Turkey.This TV series was really a good surprise for me.I am addicted neighbors!!!Not forget to mention than Nehri Erdogan is the most beautiful creature on Earth!!I would be happy to see her coming here to Greece!Congratulations!!!Merhaba!"

In [12]:
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

In [13]:
sd

SplitData
Train: TextList(89907 items)
[PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/10274_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/5668_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/8620_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/19150_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/37632_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/21939_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/14011_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/20552_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/33059_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/25229_0.txt')...]
Path: /home/ubuntu/learnai/dl/data/imdb
Valid: TextList(10093 items)
[PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/20835_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/13943_0.txt'), PosixPath('/home/ubuntu/learnai/dl/data/imdb/unsup/25108_0.txt'), PosixPath('/home/ubuntu/learnai

# Tokenizing

In [14]:
#export
import spacy, html

In [15]:
#export
#special tokens

UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ = 'xxunk xxpad xxbos xxeos xxrep xxwrep xxup xxmaj'.split()

def sub_br(t):
    re_br = re.compile(r'<\s*br\s*/?>', re.IGNORECASE)
    return re_br.sub('\n', t)

def spec_add_spaces(t):
    "Add spaces around / and #"
    return re.sub(r'([/#])', r' \1', t)

def rm_useless_spaces(t):
    "Remove multiple spaces"
    return re.sub(' {2,}', ' ', t)

def replace_rep(t):
    def _replace_rep(m):
        c, cc = m.groups()
        return f' {TK_REP} {len(cc)+1} {c} '
    re_rep = re.compile(r'(\S)(\1{3,})')
    return re_rep.sub(_replace_rep, t)

def replace_wrep(t):
    "Replace word repetitions: word word word -> TK_WREP 3 word"
    def _replace_wrep(m):
        c, cc = m.groups()
        return f' {TK_WREP} {len(cc.split())+1} {c} '
    re_wrep = re.compile(r'(\b\w+\W+)(\1{3,})')
    return re_wrep.sub(_replace_wrep, t)

def fixup_text(x):
    "Remove messy things"
    re1 = re.compile(r'  +')
    x = x.replace('#39;',"'").replace('amp;', '&').replace('#146;', "'").replace(
        'nbsp;', ' ').replace('#36;', '$').replace('\\n',"\n").replace('quot;', "'").replace(
        '<br />', '\n').replace('\\"', '"').replace('<unk>', UNK).replace(' @.@ ', '.').replace(
        ' @-@ ', '-').replace('\\', ' \\ ')
    return re1.sub(' ', html.unescape(x))

default_pre_rules = [fixup_text, replace_rep, replace_wrep, spec_add_spaces, rm_useless_spaces, sub_br]
default_spec_tok = [UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ]

In [16]:
replace_rep('cccc')

' xxrep 4 c '

In [17]:
replace_wrep('word word word word word ')

' xxwrep 5 word  '

In [18]:
#export
def replace_all_caps(x):
    "Replace tokens in ALL CAPS by their lower version and add `TK_UP` before."
    res = []
    for t in x:
        if t.isupper() and len(t)>1:
            res.append(TK_UP)
            res.append(t.lower())
        else:
            res.append(t)
    return res

def deal_caps(x):
    "Replace all Capitalized tokens by their lower version and add `TK_MAJ` before."
    res = []
    for t in x:
        if t == '':
            continue
        if t[0].isupper() and len(t)>1 and t[1:].islower():
            res.append(TK_MAJ)
        res.append(t.lower())
    return res

def add_eos_bos(x):
    return [BOS] + x + [EOS]

default_post_rules = [deal_caps, replace_all_caps, add_eos_bos]

In [19]:
replace_all_caps('I AM SHOUTING'.split())

['I', 'xxup', 'am', 'xxup', 'shouting']

In [20]:
deal_caps('My name is Akhil'.split())

['xxmaj', 'my', 'name', 'is', 'xxmaj', 'akhil']

In [21]:
#export
from spacy.symbols import ORTH
from concurrent.futures import ProcessPoolExecutor

def parallel(func, arr, max_workers=4):
    if max_workers<2:
        results = list(progress_bar(map(func, enumerate(arr)), total=len(arr)))
    else:
        with ProcessPoolExecutor(max_workers=max_workers) as ex:
            return list(progress_bar(ex.map(func, enumerate(arr)), total=len(arr)))
    if any([o is not None for o in results]):
        return results

In [22]:
parallel(len, ['hello', 'world'], 8)

[2, 2]

In [23]:
??compose

In [24]:
#export
class TokenizeProcessor(Processor):
    def __init__(self, lang='en', chunksize=2000, pre_rules=None, post_rules=None, max_workers=4):
        self.chunksize, self.max_workers = chunksize, max_workers
        self.tokenizer = spacy.blank(lang).tokenizer
        for w in default_spec_tok:
            self.tokenizer.add_special_case(w, [{ORTH: w}])
        if pre_rules is None:
            self.pre_rules = default_pre_rules
        else:
            self.pre_rules = pre_rules
        if post_rules is None:
            self.post_rules = default_post_rules
        else:
            self.post_rules = post_rules
    
    def proc_chunk(self, args):
        i, chunk = args
        chunk = [compose(t, self.pre_rules) for t in chunk]
        docs = [[d.text for d in doc] for doc in self.tokenizer.pipe(chunk)]
        docs = [compose(t, self.post_rules) for t in docs]
        return docs
    
    def __call__(self, items):
        toks = []
        if isinstance(items[0], Path):
            items = [read_file(i) for i in items]
        chunks = [items[i: i+self.chunksize] for i in (range(0, len(items), self.chunksize))]
        toks = parallel(self.proc_chunk, chunks, max_workers=self.max_workers)
        return sum(toks, [])
    
    def proc1(self, item):
        return self.proc_chunk([toks])[0]
    
    def deprocess(self, toks):
        return [self.deproc1(tok) for tok in toks]
    
    def deproc1(self, tok):
        return ' '.join(tok)

In [25]:
tp = TokenizeProcessor()

In [26]:
txt[:250]

'Well, this a very good TV series from our neighbor country Turkey. Its a series about romance and love.Very good indeed.Here in Greece we had a similar series but is was a dramatic one (cant understand why Greek writers are stuck on this kind...drama'

In [27]:
' • '.join(tp(il[:100])[0])[:500]

'xxbos • xxmaj • well • , • this • a • very • good • tv • series • from • our • neighbor • country • xxmaj • turkey • . • xxmaj • its • a • series • about • romance • and • love • . • xxmaj • very • good • indeed • . • xxmaj • here • in • xxmaj • greece • we • had • a • similar • series • but • is • was • a • dramatic • one • ( • ca • nt • understand • why • xxmaj • greek • writers • are • stuck • on • this • kind • ... • drama • is • not • always • good).to • be • exact • i • prefer • the • xxma'

In [28]:
tp(['hello world ....'])

[['xxbos', 'hello', 'world', 'xxrep', '4', '.', 'xxeos']]

# Numericalizing

In [29]:
#export
import collections

class NumericalizeProcessor(Processor):
    def __init__(self, vocab=None, max_vocab=60000, min_freq=2):
        self.vocab, self.max_vocab, self.min_freq = vocab, max_vocab, min_freq
        
    def __call__(self, items):
        # vocab gets defined on first use
        if self.vocab is None:
            freq = Counter(p for o in items for p in o)
            self.vocab = [o for o,c in freq.most_common(self.max_vocab) if c>=self.min_freq]
            for o in reversed(default_spec_tok):
                if o in self.vocab:
                    self.vocab.remove(o)
                self.vocab.insert(0, o)
        if getattr(self, 'otoi', None) is None:
            self.otoi = collections.defaultdict(int, {v:k for k,v in enumerate(self.vocab)})
        return [self.proc1(o) for o in items]
    
    def proc1(self, item):
        return [self.otoi[o] for o in item]
    
    def deprocess(self, idxs):
        assert self.vocab is not None
        return [self.deproc1(idx) for idx in idxs]
    
    def deproc1(self, idx):
        return [self.vocab[i] for i in idx]

In [30]:
proc_tok, proc_num = TokenizeProcessor(max_workers=8), NumericalizeProcessor()

In [31]:
%time ll = label_by_func(sd, lambda x:0, proc_x = [proc_tok, proc_num])

CPU times: user 22 s, sys: 2.62 s, total: 24.6 s
Wall time: 2min 13s


In [32]:
ll.train.x_obj(0)

"xxbos xxmaj well , this a very good tv series from our neighbor country xxmaj turkey . xxmaj its a series about romance and love . xxmaj very good indeed . xxmaj here in xxmaj greece we had a similar series but is was a dramatic one ( ca nt understand why xxmaj greek writers are stuck on this kind ... drama is not always xxunk be exact i prefer the xxmaj turkish one cause combines lots of good things ( characters , script etc ) and cause as a curious xxmaj greek i 'd love to see the xxmaj turkish point of view ... with this one u can see the common cultural background of modern xxmaj greece and xxmaj turkey . xxmaj this tv series was really a good surprise for me . i am addicted xxunk forget to mention than xxmaj xxunk xxmaj xxunk is the most beautiful creature on xxunk would be happy to see her coming here to xxunk ! xxeos"

In [33]:
pickle.dump(ll, open(path/'ld.pkl', 'wb'))

In [34]:
ll = pickle.load(open(path/'ld.pkl', 'rb'))

# Batching

In [35]:
stream = """
In this notebook, we will go back over the example of classifying movie reviews we studied in part 1 and dig deeper under the surface. 
First we will look at the processing steps necessary to convert text into numbers and how to customize it. By doing this, we'll have another example of the Processor used in the data block API.
Then we will study how we build a language model and train it.\n
"""

In [36]:
tokens = np.array(tp([stream])[0])

In [37]:
import pandas as pd
from IPython.display import display, HTML

In [38]:
bs = 6

In [39]:
len(tokens)//bs

15

In [40]:
bs, seq_len = 6, len(tokens)//bs
d_tokens = np.array([tokens[i*seq_len:(i+1)*seq_len] for i in range(bs)])
df = pd.DataFrame(d_tokens)
display(HTML(df.to_html(index=False, header=None)))

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
xxbos,\n,xxmaj,in,this,notebook,",",we,will,go,back,over,the,example,of
classifying,movie,reviews,we,studied,in,part,1,and,dig,deeper,under,the,surface,.
\n,xxmaj,first,we,will,look,at,the,processing,steps,necessary,to,convert,text,into
numbers,and,how,to,customize,it,.,xxmaj,by,doing,this,",",we,'ll,have
another,example,of,the,xxmaj,processor,used,in,the,data,block,api,.,\n,xxmaj
then,we,will,study,how,we,build,a,language,model,and,train,it,.,\n\n


In [41]:
bs, bptt = 6,5
for k in range(len(tokens)//(bs*bptt)):
    d_tokens = np.array([tokens[i*seq_len + k*bptt:i*seq_len + (k+1)*bptt] for i in range(bs)])
    df = pd.DataFrame(d_tokens)
    display(HTML(df.to_html(index=False, header=None)))

0,1,2,3,4
xxbos,\n,xxmaj,in,this
classifying,movie,reviews,we,studied
\n,xxmaj,first,we,will
numbers,and,how,to,customize
another,example,of,the,xxmaj
then,we,will,study,how


0,1,2,3,4
notebook,",",we,will,go
in,part,1,and,dig
look,at,the,processing,steps
it,.,xxmaj,by,doing
processor,used,in,the,data
we,build,a,language,model


0,1,2,3,4
back,over,the,example,of
deeper,under,the,surface,.
necessary,to,convert,text,into
this,",",we,'ll,have
block,api,.,\n,xxmaj
and,train,it,.,\n\n


In [42]:
#export
class LM_Dataset():
    def __init__(self, data, bs=64, bptt=70, shuffle=False):
        self.data, self.bs, self.bptt, self.shuffle = data, bs, bptt, shuffle
        total_len = sum([len(t) for t in data.x])
        self.n_batch = total_len // bs
        self.batchify()
    
    def __len__(self):
        return ((self.n_batch-1) // self.bptt) * self.bs
    
    def __getitem__(self, idx):
        source = self.batched_data[idx % self.bs]
        seq_idx = (idx // self.bs) * self.bptt
        return source[seq_idx:seq_idx+self.bptt], source[seq_idx+1:seq_idx+self.bptt+1]
    
    def batchify(self):
        texts = self.data.x
        if self.shuffle:
            texts = texts[torch.randperm(len(texts))]
        stream = torch.cat([tensor(t) for t in texts])
        self.batched_data = stream[:self.n_batch * self.bs].view(self.bs, self.n_batch)

In [43]:
dl = DataLoader(LM_Dataset(ll.valid, shuffle=True), batch_size=64)

In [44]:
iter_dl = iter(dl)
x1, y1 = next(iter_dl)
x2, y2 = next(iter_dl)

In [45]:
x1.size(), y1.size()

(torch.Size([64, 70]), torch.Size([64, 70]))

In [46]:
x1

tensor([[    2,    16,    25,  ...,    45,   272,   101],
        [ 2772,     9,    24,  ...,   133,    42,   173],
        [10223,    17,  2425,  ...,   532,  4508,    17],
        ...,
        [   21,     7,     8,  ...,    10,  8359,    27],
        [  153,  6241,    43,  ...,     8, 23599,   142],
        [   30,   264,    11,  ...,     7,    68,    71]])

In [47]:
y1

tensor([[   16,    25,    19,  ...,   272,   101,    11],
        [    9,    24,    18,  ...,    42,   173,     9],
        [   17,  2425,    75,  ...,  4508,    17,     8],
        ...,
        [    7,     8,     7,  ...,  8359,    27,  6812],
        [ 6241,    43,  5855,  ..., 23599,   142,     7],
        [  264,    11,  3107,  ...,    68,    71,  2511]])

In [48]:
x2

tensor([[   11,   323,    10,  ...,    11,  5298, 21672],
        [    9,     7,     8,  ...,    10,  4258,    14],
        [    8,   337, 52433,  ...,    19,    29,    25],
        ...,
        [ 6812,   121,    11,  ...,   164,    13,   200],
        [    7,  1736,     7,  ..., 11449,    13,     7],
        [ 2511,    57,  1359,  ...,    60,    13,   402]])

In [49]:
vocab = proc_num.vocab

In [50]:
' '.join(vocab[o] for o in x1[0])

'xxbos it was this movie , not the sixth sense , that turned me into the huge m. night shyamalan fan i am today . i saw this movie back before i even heard of the sixth sense . i saw it because i saw a trailer for it before another movie i had rented . after viewing just the trailer , my sister and i looked at each other'

In [51]:
' '.join(vocab[o] for o in y1[0])

'it was this movie , not the sixth sense , that turned me into the huge m. night shyamalan fan i am today . i saw this movie back before i even heard of the sixth sense . i saw it because i saw a trailer for it before another movie i had rented . after viewing just the trailer , my sister and i looked at each other and'

In [52]:
' '.join(vocab[o] for o in x2[0])

'and said , " we have to see that ... " so , when i found it at blockbuster a few months later , i jumped to rent it . now i only wish that i had gotten a chance to see this limited release movie in theaters . the story of young xxmaj joshua questioning life while everyone , including his best friend , parents , and teachers scoff'

In [53]:
#export
def get_lm_dls(train_ds, valid_ds, bs, bptt, **kwargs):
    return (DataLoader(LM_Dataset(train_ds, bs, bptt, shuffle=True), batch_size=bs, **kwargs),
           DataLoader(LM_Dataset(valid_ds, bs, bptt, shuffle=False), batch_size=2*bs, **kwargs))

def lm_databunchify(sd, bs, bptt, **kwargs):
    return DataBunch(*get_lm_dls(sd.train, sd.valid, bs, bptt, **kwargs))

In [54]:
bs, bptt = 64, 70
data = lm_databunchify(ll, bs, bptt)

# Batching for classification

In [55]:
proc_cat = CategoryProcessor()

In [56]:
il = TextList.from_files(path, include=['train', 'test'])
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='test'))
ll = label_by_func(sd, parent_labeler, proc_x = [proc_tok, proc_num], proc_y=proc_cat)

In [57]:
pickle.dump(ll, open(path/'ll_clas.pkl', 'wb'))

In [58]:
ll = pickle.load(open(path/'ll_clas.pkl', 'rb'))

In [59]:
[(ll.train.x_obj(i), ll.train.y_obj(i)) for i in [1, 12552]]

[("xxbos xxmaj deanna xxmaj durbin , then 14 and just under contract to mgm , made a short feature in 1936 which paired her with xxmaj judy xxmaj garland , a year younger , in the first film for both of them . xxmaj louis b. xxmaj mayer then decided he did n't need two competing young singers , placed his bet on xxmaj garland and let xxmaj durbin go . xxmaj universal immediately signed xxmaj durbin , rushed her into xxmaj three xxmaj smart xxmaj girls and rewrote the screenplay to pump up her part . xxmaj she 's billed last , but with the xxunk equivalent of neon lights around her name . xxmaj universal was convinced xxmaj durbin would be a smash , and they were right . xxmaj three xxmaj smart xxmaj girls is less a musical and more a screwball comedy , and xxmaj durbin , 15 when the movie was released , carries it with aplomb . xxmaj she 's xxmaj penny xxmaj craig , and she and her older sisters , xxmaj joan and xxmaj kay , are determined to save their father , who had divorced their m

In [60]:
#export
from torch.utils.data import Sampler

class SortSampler(Sampler):
    def __init__(self, data_source, key):
        self.data_source, self.key = data_source, key
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        return iter(sorted(list(range(len(self.data_source))), key=self.key, reverse=True))

In [61]:
#export
class SortishSampler(Sampler):
    def __init__(self, data_source, key, bs):
        self.data_source, self.key, self.bs = data_source, key, bs
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        idxs = torch.randperm(len(self.data_source))
        megabatches = [idxs[i:i+self.bs*50] for i in range(0, len(idxs), self.bs*50)]
        sorted_idx = torch.cat([tensor(sorted(s, key=self.key, reverse=True)) for s in megabatches])
        batches = [sorted_idx[i:i+self.bs] for i in range(0, len(sorted_idx), self.bs)]
        max_idx = torch.argmax(tensor([self.key(ck[0]) for ck in batches]))
        batches[0], batches[max_idx] = batches[max_idx], batches[0]
        batch_idxs = torch.randperm(len(batches) - 2)
        sorted_idx = torch.cat([batches[i+1] for i in batch_idxs]) if len(batches) > 1 else LongTensor([])
        sorted_idx = torch.cat([batches[0], sorted_idx, batches[-1]])
        return iter(sorted_idx)

In [62]:
#export
def pad_collate(samples, pad_idx=1, pad_first=False):
    max_len = max([len(s[0]) for s in samples])
    res = torch.zeros(len(samples), max_len).long() + pad_idx
    for i,s in enumerate(samples):
        if pad_first:
            res[i, -len(s[0]):] = LongTensor(s[0])
        else:
            res[i, :len(s[0])] = LongTensor(s[0])
    return res, tensor([s[1] for s in samples])

In [63]:
pad_collate([[[23, 24], [45, 46, 47]]])

(tensor([[23, 24]]), tensor([[45, 46, 47]]))

In [64]:
bs = 64
train_sampler = SortishSampler(ll.train.x, key = lambda t:len(ll.train[int(t)][0]), bs=bs)
train_dl = DataLoader(ll.train, batch_size=bs, sampler = train_sampler, collate_fn=pad_collate)

In [65]:
iter_dl = iter(train_dl)
x, y = next(iter_dl)

In [66]:
x.size()

torch.Size([64, 3311])

In [67]:
lengths = []
for i in range(x.size(0)):
    lengths.append(x.size(1) - (x[i]==1).sum().item())
lengths[:5], lengths[-1]

([3311, 1576, 1486, 1481, 1425], 1022)

In [68]:
x, y = next(iter_dl)
lengths = []
for i in range(x.size(0)):
    lengths.append(x.size(1) - (x[i]==1).sum().item())
lengths[:5], lengths[-1]

([415, 415, 415, 415, 415], 393)

In [69]:
x

tensor([[    2,     7,  1383,  ...,   154,     9,     3],
        [    2,    18, 10054,  ...,    16,     9,     3],
        [    2,     7,    68,  ...,    72,     9,     3],
        ...,
        [    2,    21,     7,  ...,     1,     1,     1],
        [    2,     7,    19,  ...,     1,     1,     1],
        [    2,    18,  1738,  ...,     1,     1,     1]])

In [70]:
#export
def get_clas_dls(train_ds, valid_ds, bs, **kwargs):
    train_sampler = SortishSampler(train_ds.x, key=lambda t:len(train_ds.x[t]), bs=bs)
    valid_sampler = SortSampler(valid_ds.x, key=lambda t:len(valid_ds.x[t]))
    return (DataLoader(train_ds, batch_size=bs, sampler=train_sampler, collate_fn=pad_collate, **kwargs),
           DataLoader(valid_ds, batch_size=bs*2, sampler=valid_sampler, collate_fn=pad_collate, **kwargs))

def clas_databunchify(sd, bs, **kwargs):
    return DataBunch(*get_clas_dls(sd.train, sd.valid, bs, **kwargs))

In [71]:
bs, bptt = 64, 70
data = clas_databunchify(ll, bs)

# Export

In [72]:
!python notebook2script.py 12_text.ipynb

converted 12_text.ipynb to nb_12.py
