In [1]:
import torch
import torch.utils.data
import random
from collections import defaultdict

def inf_shuffle(xs):
    while xs:
        random.shuffle(xs)
        for x in xs:
            yield x

class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, dataset):
        transform = dataset.transform
        dataset.transform = None  # trick to avoid useless computations

        indices = defaultdict(list)
        for i in range(0, len(dataset)):
            _, label = dataset[i]
            indices[label].append(i)            
        self.indices = list(indices.values())
        
        self.n = max(len(ids) for ids in self.indices) * len(self.indices)

        dataset.transform = transform
        
    def __iter__(self):
        m = 0
        for xs in zip(*(inf_shuffle(xs) for xs in self.indices)):
            for i in xs:  # yield one index of each label
                yield i
                m += 1
                if m >= self.n:
                    return

    def __len__(self):
        return self.n

In [2]:
import dataset
from astropy.io import fits

def transform(images):
    images = [fits.open(file, memmap=False)[0].data for file in images]
    images = [torch.from_numpy(x.byteswap().newbyteorder()) for x in images]

    normalize = [3.5239e+10, 1.5327e+09, 1.8903e+09, 1.2963e+09]
    images = [x.mul(n) for x, n in zip(images, normalize)]

    vis, j, y, h = images
    upsample = torch.nn.Upsample(size=vis.shape, mode='bilinear', align_corners=True)
    jyh = upsample(torch.stack([j, y, h])[None]).squeeze(0)
    return torch.cat([vis[None], jyh], dim=0)

def target_transform(labels):
    return 1.0 if labels['n_sources'] else -1.0
    
dataset = dataset.GG2('~/.torchvision/datasets/GG2', transform=transform, target_transform=target_transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=10, sampler=BalancedBatchSampler(dataset))

In [3]:
for batch in loader:
    x, y = batch
    print(x.shape)
    print(y)
    break

torch.Size([10, 4, 200, 200])
tensor([ 1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.], dtype=torch.float64)
