In [72]:
import gc
import glob
import random
import torch

#from others.logging import logger

In [62]:
import sys
import argparse
import torch

In [None]:
class Batch(object):
    def _pad(self, data, pad_id, width=-1):
        if (width == -1):
            width = max(len(d) for d in data)
        rtn_data = [d + [pad_id] * (width - len(d)) for d in data]
        return rtn_data

    def __init__(self, data=None, device=None,  is_test=False):
        """Create a Batch from a list of examples."""
        if data is not None:
            self.batch_size = len(data)
            pre_src = [x[0] for x in data]
            pre_labels = [x[1] for x in data]
            pre_segs = [x[2] for x in data]
            pre_clss = [x[3] for x in data]

            src = torch.tensor(self._pad(pre_src, 0))

            labels = torch.tensor(self._pad(pre_labels, 0))
            segs = torch.tensor(self._pad(pre_segs, 0))
            mask = 1 - (src == 0)

            clss = torch.tensor(self._pad(pre_clss, -1))
            mask_cls = 1 - (clss == -1)
            clss[clss == -1] = 0

            setattr(self, 'clss', clss.to(device))
            setattr(self, 'mask_cls', mask_cls.to(device))
            setattr(self, 'src', src.to(device))
            setattr(self, 'labels', labels.to(device))
            setattr(self, 'segs', segs.to(device))
            setattr(self, 'mask', mask.to(device))

            if (is_test):
                src_str = [x[-2] for x in data]
                setattr(self, 'src_str', src_str)
                tgt_str = [x[-1] for x in data]
                setattr(self, 'tgt_str', tgt_str)

    def __len__(self):
        return self.batch_size

In [None]:
def batch(data, batch_size):
    """Yield elements from data in chunks of batch_size."""
    minibatch, size_so_far = [], 0
    for ex in data:
        minibatch.append(ex)
        size_so_far = simple_batch_size_fn(ex, len(minibatch))
        if size_so_far == batch_size:
            yield minibatch
            minibatch, size_so_far = [], 0
        elif size_so_far > batch_size:
            yield minibatch[:-1]
            minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1)
    if minibatch:
        yield minibatch

In [None]:
def load_dataset(args, corpus_type, shuffle):
    assert corpus_type in ["train", "valid", "test"]

    def _lazy_dataset_loader(pt_file, corpus_type):
        dataset = torch.load(pt_file)
        logger.info('Loading %s dataset from %s, number of examples: %d' %
                    (corpus_type, pt_file, len(dataset)))
        return dataset

    # Sort the glob output by file name (by increasing indexes).
    pts = sorted(glob.glob(args.bert_data_path + '.' + corpus_type + '.[0-9]*.pt'))
    if pts:
        if (shuffle):
            random.shuffle(pts)

        for pt in pts:
            yield _lazy_dataset_loader(pt, corpus_type)
    else:
        # Only one inputters.*Dataset, simple!
        pt = args.bert_data_path + '.' + corpus_type + '.pt'
        yield _lazy_dataset_loader(pt, corpus_type)

In [None]:
def simple_batch_size_fn(new, count):
    src, labels = new[0], new[1]
    global max_n_sents, max_n_tokens, max_size
    if count == 1:
        max_size = 0
        max_n_sents=0
        max_n_tokens=0
    max_n_sents = max(max_n_sents, len(src))
    max_size = max(max_size, max_n_sents)
    src_elements = count * max_size
    return src_elements

In [None]:
class Dataloader(object):
    def __init__(self, args, datasets,  batch_size,
                 device, shuffle, is_test):
        self.args = args
        self.datasets = datasets
        self.batch_size = batch_size
        self.device = device
        self.shuffle = shuffle
        self.is_test = is_test
        self.cur_iter = self._next_dataset_iterator(datasets)

        assert self.cur_iter is not None

    def __iter__(self):
        dataset_iter = (d for d in self.datasets)
        while self.cur_iter is not None:
            for batch in self.cur_iter:
                yield batch
            self.cur_iter = self._next_dataset_iterator(dataset_iter)


    def _next_dataset_iterator(self, dataset_iter):
        try:
            # Drop the current dataset for decreasing memory
            if hasattr(self, "cur_dataset"):
                self.cur_dataset = None
                gc.collect()
                del self.cur_dataset
                gc.collect()

            self.cur_dataset = next(dataset_iter)
        except StopIteration:
            return None

        return DataIterator(args = self.args,
            dataset=self.cur_dataset,  batch_size=self.batch_size,
            device=self.device, shuffle=self.shuffle, is_test=self.is_test)


class DataIterator(object):
    def __init__(self, args, dataset,  batch_size,  device=None, is_test=False,
                 shuffle=True):
        self.args = args
        self.batch_size, self.is_test, self.dataset = batch_size, is_test, dataset
        self.iterations = 0
        self.device = device
        self.shuffle = shuffle

        self.sort_key = lambda x: len(x[1])

        self._iterations_this_epoch = 0

    def data(self):
        if self.shuffle:
            random.shuffle(self.dataset)
        xs = self.dataset
        return xs


    def preprocess(self, ex, is_test):
        src = ex['src']
        if('labels' in ex):
            labels = ex['labels']
        else:
            labels = ex['src_sent_labels']

        segs = ex['segs']
        if(not self.args.use_interval):
            segs=[0]*len(segs)
        clss = ex['clss']
        src_txt = ex['src_txt']
        tgt_txt = ex['tgt_txt']

        if(is_test):
            return src,labels,segs, clss, src_txt, tgt_txt
        else:
            return src,labels,segs, clss

    def batch_buffer(self, data, batch_size):
        minibatch, size_so_far = [], 0
        for ex in data:
            if(len(ex['src'])==0):
                continue
            ex = self.preprocess(ex, self.is_test)
            if(ex is None):
                continue
            minibatch.append(ex)
            size_so_far = simple_batch_size_fn(ex, len(minibatch))
            if size_so_far == batch_size:
                yield minibatch
                minibatch, size_so_far = [], 0
            elif size_so_far > batch_size:
                yield minibatch[:-1]
                minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1)
        if minibatch:
            yield minibatch

    def create_batches(self):
        """ Create batches """
        data = self.data()
        for buffer in self.batch_buffer(data, self.batch_size * 50):

            p_batch = sorted(buffer, key=lambda x: l￼en(x[3]))
            p_batch = batch(p_batch, self.batch_size)

            p_batch = list(p_batch)
            if (self.shuffle):
                random.shuffle(p_batch)
            for b in p_batch:
                yield b

    def __iter__(self):
        while True:
            self.batches = self.create_batches()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                batch = Batch(minibatch, self.device, self.is_test)

                yield batch
            return

In [None]:
class Batch(object):
    def _pad(self, data, pad_id, width=-1):
        return rtn_data

    def __init__(self, data=None, device=None,  is_test=False):

    def __len__(self):
        return self.batch_size

In [None]:
def batch(data, batch_size):
    yield minibatch    

def load_dataset(args, corpus_type, shuffle):
    yield _lazy_dataset_loader(pt, corpus_type)   

def simple_batch_size_fn(new, count):
    return src_elements        

In [None]:
class Dataloader(object):
    def __init__(self, args, datasets,  batch_size, device, shuffle, is_test):
        assert self.cur_iter is not None

    def __iter__(self):
        yield batch

    def _next_dataset_iterator(self, dataset_iter):
        return DataIterator(args = self.args,  dataset=self.cur_dataset,  batch_size=self.batch_size,
            device=self.device, shuffle=self.shuffle, is_test=self.is_test)
    
class DataIterator(object):
    def __init__(self, args, dataset,  batch_size,  device=None, is_test=False, shuffle=True):

    def data(self):
        return xs

    def preprocess(self, ex, is_test):
        return src,labels,segs, clss

    def batch_buffer(self, data, batch_size):
        yield minibatch

    def create_batches(self):
        yield b

    def __iter__(self):
        batch = Batch(minibatch, self.device, self.is_test)
        yield batch

In [68]:
corpus_type = 'train'
args.bert_data_path = '/home/alvin/workspace/BertSum/bert_data/cnndm'
pts = sorted(glob.glob(args.bert_data_path + '.' + corpus_type + '.[0-9]*.pt'))
args.bert_data_path
pts

['/home/alvin/workspace/BertSum/bert_data/cnndm.train.0.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.1.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.10.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.100.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.101.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.102.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.103.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.104.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.105.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.106.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.107.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.108.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.109.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.11.bert.pt',
 '/home/alvin/workspace/BertSum/bert_data/cnndm.train.

In [74]:
random.shuffle(pts)
print(type(pts))
print(pts[0])


<class 'list'>
/home/alvin/workspace/BertSum/bert_data/cnndm.train.132.bert.pt


In [75]:
dataset = torch.load(pts[0])

In [101]:
print('type of dataset = ' + str(type(dataset)))
print('type of dataset[0] = ' + str(type(dataset[0])))
print('keys of dataset[0] = ' + str(dataset[0].keys()))

type of dataset = <class 'list'>
type of dataset[0] = <class 'dict'>
keys of dataset[0] = dict_keys(['src', 'labels', 'segs', 'clss', 'src_txt', 'tgt_txt'])


In [90]:
len(dataset[0]['src_txt'])

28

In [97]:
print("---------------------------------- src\n "+ str(dataset[0]['src']))
print("---------------------------------- label\n "+ str(dataset[0]['labels']))
print("---------------------------------- seg\n " + str(dataset[0]['segs']))
print("---------------------------------- clss\n " + str(dataset[0]['clss']))
print("---------------------------------- src_txt\n " + str(dataset[0]['src_txt']))
print("---------------------------------- tgt_txt\n " + str(dataset[0]['tgt_txt']))

---------------------------------- src
 [101, 1996, 1057, 1012, 1055, 1012, 3187, 1997, 2740, 2758, 2045, 2024, 3497, 1036, 2062, 3572, 1005, 1997, 1041, 24290, 2006, 2149, 5800, 1998, 2008, 3199, 11326, 2013, 6712, 3032, 2003, 6827, 1999, 10723, 1996, 3659, 1012, 102, 101, 1036, 2057, 2018, 2028, 2553, 1998, 1045, 2228, 2045, 2089, 2022, 2060, 3572, 1010, 1998, 1045, 2228, 2057, 2031, 2000, 6807, 2008, 2004, 1037, 3842, 1010, 1005, 20934, 2099, 4381, 2056, 2012, 1037, 2865, 6350, 2651, 1012, 102, 101, 20934, 2099, 4381, 4208, 2006, 3199, 11326, 1998, 2056, 2348, 2009, 2001, 2025, 2531, 3867, 4621, 2016, 2018, 7023, 1999, 1996, 3653, 3540, 13700, 1012, 102, 101, 2062, 3497, 2062, 3572, 1024, 2740, 1998, 2529, 2578, 3187, 13378, 20934, 2099, 4381, 2038, 2056, 2008, 3572, 1997, 1041, 24290, 2089, 2525, 2022, 1999, 1996, 2142, 2163, 1998, 2008, 29174, 1997, 1996, 4295, 2003, 6827, 2005, 9740, 102, 101, 26629, 2472, 1024, 7458, 3604, 2000, 1998, 2013, 2225, 3088, 2035, 2362, 2097, 4652, 29