In [None]:
# default_exp dataloaders
# default_cls_lvl 3

In [None]:
#hide
%matplotlib widget
from fastai2.callback.progress import *
from fastai2.callback.tracker import *
from fastai2.callback.schedule import *

In [None]:
#export
from seqdata.core import *
from seqdata.model import *
from seqdata.learner import *
from fastai2.basics import *

import math

## Custom Dataloaders
> Pytorch Modules for Training Models for sequential data

# Truncated Backpropagation Through Time

The tbptt dataloader needs to split the minibatches that are created in several smaller minibatches that will be returned sequentially before the next minibatch may be created.

In [None]:
#export
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)
@delegates()
class TbpttDl(TfmdDL):

    def __init__(self, dataset, sub_seq_len=None,max_batches=None, seq_len = None ,shuffle=True,num_workers=0, **kwargs):
#         assert sub_seq_len is not None
        store_attr(self,'sub_seq_len,max_batches,seq_len')
        super().__init__(dataset=dataset, shuffle=shuffle, num_workers=num_workers, **kwargs)
        self.rnn_reset = False
    @property
    def n_sub_seq(self):
        if self.sub_seq_len is None: return 1
        if self.seq_len is None: self.seq_len = self.do_item(0)[0].shape[0]
        return math.ceil(self.seq_len / self.sub_seq_len)
        
    def __len__(self):
        l = super().__len__() * self.n_sub_seq
        if self.max_batches is not None: l = min(l,self.max_batches)
        return l
    
    def _next_worker(self,w_id):
        w_id += 1
        if w_id > self.fake_l.num_workers-1: w_id = 0
        return w_id
    
    def __iter__(self):
        '''iterator that handles multiprocessing by caching samples that are generated out of order'''
        self.randomize()
        self.before_iter()
        n_buffer = self.fake_l.num_workers*self.n_sub_seq
        queue = {n:[] for n in range(self.fake_l.num_workers)} 
        current_worker = None
        idx = 0
        for loaded_b,w_id in _loaders[self.fake_l.num_workers==0](self.fake_l):
            if self.max_batches is not None and idx >= self.max_batches: break #check if batch limit has been reached
#             import pdb; pdb.set_trace()
            if w_id is None:
                self.rnn_reset=True
                b= loaded_b
                self.rnn_reset = (idx % self.n_sub_seq) == 0
                yield self.after_batch(b if self.device is None else to_device(b, self.device))
                idx += 1 #idx increments after every yield, not every loop
            else:
                if current_worker is None:
                    current_worker = w_id
                
                #retrieve queued elements from worker
                while len(queue[current_worker]) > 0:
                    b = queue[current_worker].pop(0)
                    self.rnn_reset = (idx % self.n_sub_seq) == 0
                    yield self.after_batch(b if self.device is None else to_device(b, self.device))
                    idx += 1
                    if (idx % self.n_sub_seq) == 0:
                        current_worker = self._next_worker(current_worker) #next worker, stay in loop for the queue
                        
                
                #retrieve fresh elements from worker
                if w_id != current_worker: #not active worker
                    queue[w_id] += [loaded_b]
                    continue
                else:#active worker
                    b = loaded_b
                    self.rnn_reset = (idx % self.n_sub_seq) == 0
                    yield self.after_batch(b if self.device is None else to_device(b, self.device))
                    idx += 1 #idx increments after every yield, not every loop
                    if (idx % self.n_sub_seq) == 0:
                        current_worker = self._next_worker(current_worker)
                
        self.after_iter()
        if hasattr(self, 'it'): delattr(self, 'it')
    
    def create_batches(self, samps):
        yield from self._tbptt_generator(super().create_batches(samps))
        
    def _tbptt_generator(self,batch_iter):
        '''generator function that splits batches in smaller windows and truncates batch count if max_batches is set, yields mini_batch and worker id'''
        for idx,b in enumerate(batch_iter):
            for i in range(self.n_sub_seq):
                #it is importan to retain the tuple type, or future transforms may now work
                if self.sub_seq_len is None:
                    trunc_b = b
                else:
                    trunc_b = tuple([retain_type(x[:,i*self.sub_seq_len:(i+1)*self.sub_seq_len],x) for x in b])
                yield trunc_b, (None if torch.utils.data.get_worker_info() is None else torch.utils.data.get_worker_info().id)
                    

In [None]:
tfm_lst = [DfHDFCreateWindows(win_sz=1000+1,stp_sz=1000,clm='current')]
seq = DataBlock(blocks=(SequenceBlock.from_hdf(['current','voltage'],TensorSequencesInput,clm_shift=[-1,-1]),
                        SequenceBlock.from_hdf(['voltage'],TensorSequencesOutput,clm_shift=[1])),
                 get_items=CreateDict(tfm_lst),
                 splitter=ApplyToDict(ParentSplitter()))
db = seq.dataloaders(get_hdf_files('test_data/'),dl_type=TbpttDl,sub_seq_len=100,max_batches=1000,num_workers=5)

In [None]:
l = [array(x[-1][0,:,0].cpu()) for x in db.train]


In [None]:
plt.figure()
plt.plot(np.concatenate(l))

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7f27f08f8210>]

num_workers has to be 0. If there are parallel workers, the order of minibatches will be corrupted

## TBPTT_Reset_Callback
The stateful model needs to reset its hidden state, when a new sequence begins. The callback reads the reset flag and acts accordingly.

In [None]:
#export
class TbpttResetCB(Callback):
    "`Callback` resets the rnn model with every new sequence for tbptt"
        
    def begin_batch(self):
        dl = self.learn.dls.train if self.training else self.learn.dls.valid
#         if not self.training: import pdb; pdb.set_trace()
        if hasattr(dl,'rnn_reset')and dl.rnn_reset and hasattr(self.model,'reset'):
            self.model.reset()
        
    def after_fit(self): 
        if hasattr(self.model,'reset'): self.model.reset()

## Example

In [None]:
lrn = RNNLearner(db,num_layers=1,rnn_type='gru',stateful=False,metrics=[SkipNLoss(fun_rmse,100)])
lrn.add_cb(TbpttResetCB())

<fastai2.learner.Learner at 0x7f27f083a210>

In [None]:
lrn.fit_one_cycle(1,lr_max=3e-2)

epoch,train_loss,valid_loss,fun_rmse,time
0,1.400621,0.016195,0.09802,00:02


In [None]:
db.train.max_batches = 100

In [None]:
db.train.sub_seq_len = 10

In [None]:
lrn.fit_one_cycle(1,lr_max=3e-2)

epoch,train_loss,valid_loss,fun_rmse,time
0,0.027535,0.002094,0.016635,00:02


# Weighted Sampling Dataloader

A weighted sampling dataloader for nonuniforly distributed data. A factory method receives the base Dataloader class and returns the inherited weighted sampling dataloader class

In [None]:
#export
def WeightedDL_Factory(cls):
    '''
    Weighted Dataloader that provides control over sampling probabilities.
    wgts: probability array with probability for every item
            gets extracted from the pandas 'p_sample' column if given. 
            Otherwise uniform sampling will be enabled
        
    '''
    assert issubclass(cls, TfmdDL)
    
    class WeightedDL(cls):
        def __init__(self, dataset, wgts=None, **kwargs):
#             import pdb;pdb.set_trace()
            self.wgts = None
            #self.items need to be assigned, but super.init needs wgts allready assigned
            super().__init__(dataset=dataset, **kwargs) 
            if wgts is None:
                if  (type(self.items) is list and
                    type(self.items[0]) is dict and 
                    'p_sample' in self.items[0].keys()):
                    self.wgts = np.array([x['p_sample'] for x in self.items])
                    self.wgts = self.wgts/self.wgts.sum()
                else:
                    print('No wgts provided for WeightedDL. Was that intentional?')
            else:
                self.wgts = wgts/np.sum(wgts)

        def get_idxs(self):
            if self.n==0: return []
            if not self.shuffle or self.wgts is None: return super().get_idxs()
            return list(np.random.choice(self.n, self.n, p=self.wgts))
    return WeightedDL

In [None]:
dl = WeightedDL_Factory(TfmdDL)([1,2]*5,bs=10,wgts=[2,1]*5)

In [None]:
dl.wgts

In [None]:
dl.one_batch()

## ItemLst Transform for weight calculation

In [None]:
#export
def uniform_p_of_category(cat_name):  
    '''Scales sampling weights for an even distribution between every category'''
    def _inner(df):
        counts = df[cat_name].value_counts()
        sample_prob =  1/counts
        sample_prob.name = 'p_sample'
        return df.merge(sample_prob,left_on=cat_name,right_index=True)
    
    return _inner

In [None]:
def train_valid(df):   
    ''' test function that extracts valid and train from the path string'''
    df['train'] = df.path.astype(str).str.contains('train',regex=False)
    return df

In [None]:
tfm_lst = [train_valid, DfHDFCreateWindows(win_sz=1000+1,stp_sz=1000,clm='current') ,uniform_p_of_category('train')]
apply_df_tfms(get_hdf_files('test_data/'),tfm_lst) 

In [None]:
seq = DataBlock(blocks=(SequenceBlock.from_hdf(['current','voltage'],TensorSequencesInput,clm_shift=[-1,-1]),
                        SequenceBlock.from_hdf(['voltage'],TensorSequencesOutput,clm_shift=[1])),
                 get_items=CreateDict(tfm_lst),
                 splitter=ApplyToDict(ParentSplitter()))
db = seq.dataloaders(get_hdf_files('test_data/'),dl_type=WeightedDL_Factory(TbpttDl),sub_seq_len=10,max_batches=1)

In [None]:
db.train.wgts[:5],db.valid.wgts[:5]

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_model.ipynb.
Converted 02_learner.ipynb.
Converted 03_dataloaders.ipynb.
Converted 11_dualrnn.ipynb.
Converted 12_TensorQuaternions.ipynb.
Converted 13_HPOpt.ipynb.
Converted index.ipynb.
