In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import StratifiedGroupKFold


class StratiSampler:
    def __init__(self, y, groups, folds, shuffle = True):
        if torch.is_tensor(y):
            y.cpu().numpy()
        assert len(y.shape) == 1, "y must be a 1D tensor"
        n_batches = int(len(y) / folds)
        self.batch_size = folds
        self.skf = StratifiedGroupKFold(n_splits = n_batches, shuffle = shuffle)
        self.y = y
        self.groups = groups
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            self.skf.random_state = np.random.randint(0, 1000)
        for train_index, test_index in self.skf.split(self.y, self.y, self.groups):
            yield train_index, test_index

    def __len__(self):
        return len(self.y)
    


