In [1]:
from torch.utils.data import Sampler

In [2]:
class SkipBatchSampler(Sampler):

    def __init__(self, sampler, batch_size, skip_batches=0):
        self.sampler = sampler
        self.batch_size = batch_size
        self.skip_batches = skip_batches

    def __iter__(self):
        batch = []
        skipped = 0
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                if skipped < self.skip_batches:
                    skipped += 1
                    batch = []
                    continue
                yield batch
                batch = []

        if len(batch) > 0 and skipped >= self.skip_batches:
            yield batch

    def __len__(self):
        total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
        return max(0, total_batches - self.skip_batches)

In [3]:
sampler = list(range(10))
skipper = SkipBatchSampler(sampler, batch_size=3, skip_batches=2)

for batch in skipper:
    print(batch)

[6, 7, 8]
[9]
