In [None]:
from wilds.datasets.civilcomments_dataset import CivilCommentsDataset

In [1]:
from civilcomments import DWCivilCommentsDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = 'data'
dataset = DWCivilCommentsDataset(root_dir=data_dir, download=True)

  self._identity_array = torch.LongTensor(self._identity_array)


In [3]:
import torch
from transformers import DistilBertTokenizerFast

class MyDistilBertTokenizer(DistilBertTokenizerFast):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        if 'padding' not in kwargs:
            kwargs['padding'] = 'max_length'
        if 'max_length' not in kwargs:
            kwargs['max_length'] = 300
        if 'truncation' not in kwargs:
            kwargs['truncation'] = True
        if 'return_tensors' not in kwargs:
            kwargs['return_tensors'] = 'pt'

        tokens = super().__call__(*args, **kwargs)

        x = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2)
        x = torch.squeeze(x, dim=0)

        return x

In [4]:
class GroupSampler:
    """
        Constructs batches by first sampling groups,
        then sampling data from those groups.
        It drops the last batch if it's incomplete.
    """
    def __init__(self, group_ids, batch_size, n_groups_per_batch,
                 uniform_over_groups, distinct_groups):

        if batch_size % n_groups_per_batch != 0:
            raise ValueError(f'batch_size ({batch_size}) must be evenly divisible by n_groups_per_batch ({n_groups_per_batch}).')
        if len(group_ids) < batch_size:
            raise ValueError(f'The dataset has only {len(group_ids)} examples but the batch size is {batch_size}. There must be enough examples to form at least one complete batch.')

        self.group_ids = group_ids
        self.unique_groups, self.group_indices, unique_counts = split_into_groups(group_ids)

        self.distinct_groups = distinct_groups
        self.n_groups_per_batch = n_groups_per_batch
        self.n_points_per_group = batch_size // n_groups_per_batch

        self.dataset_size = len(group_ids)
        self.num_batches = self.dataset_size // batch_size

        if uniform_over_groups: # Sample uniformly over groups
            self.group_prob = None
        else: # Sample a group proportionately to its size
            self.group_prob = unique_counts.numpy() / unique_counts.numpy().sum()

    def __iter__(self):
        for batch_id in range(self.num_batches):
            # Note that we are selecting group indices rather than groups
            groups_for_batch = np.random.choice(
                len(self.unique_groups),
                size=self.n_groups_per_batch,
                replace=(not self.distinct_groups),
                p=self.group_prob)
            sampled_ids = [
                np.random.choice(
                    self.group_indices[group],
                    size=self.n_points_per_group,
                    replace=len(self.group_indices[group]) <= self.n_points_per_group, # False if the group is larger than the sample size
                    p=None)
                for group in groups_for_batch]

            # Flatten
            sampled_ids = np.concatenate(sampled_ids)
            yield sampled_ids

    def __len__(self):
        return self.num_batches


In [5]:
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

NUM_WORKERS = 3


def get_train_loader(loader, dataset, batch_size,
        uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, **loader_kwargs):
    """
    Constructs and returns the data loader for training.
    Args:
        - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders,
                        which first samples groups and then samples a fixed number of examples belonging
                        to each group.
        - dataset (WILDSDataset or WILDSSubset): Data
        - batch_size (int): Batch size
        - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according
                                              to the natural data distribution.
                                              Setting to None applies the defaults for each type of loaders.
                                              For standard loaders, the default is False. For group loaders,
                                              the default is True.
        - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True
        - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders.
        - n_groups_per_batch (int): Number of groups to sample in each minibatch for group loaders.
        - loader_kwargs: kwargs passed into torch DataLoader initialization.
    Output:
        - data loader (DataLoader): Data loader.
    """
    if loader == 'standard':
        if uniform_over_groups is None or not uniform_over_groups:
            return DataLoader(
                dataset,
                num_workers=NUM_WORKERS,
                shuffle=True, # Shuffle training dataset
                sampler=None,
                collate_fn=dataset.collate,
                batch_size=batch_size,
                **loader_kwargs)
        else:
            assert grouper is not None
            groups, group_counts = grouper.metadata_to_group(
                dataset.metadata_array,
                return_counts=True)
            group_weights = 1 / group_counts
            weights = group_weights[groups]

            # Replacement needs to be set to True, otherwise we'll run out of minority samples
            sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
            return DataLoader(
                dataset,
                num_workers=NUM_WORKERS,
                shuffle=False, # The WeightedRandomSampler already shuffles
                sampler=sampler,
                collate_fn=dataset.collate,
                batch_size=batch_size,
                **loader_kwargs)

    elif loader == 'group':
        if uniform_over_groups is None:
            uniform_over_groups = True
        assert grouper is not None
        assert n_groups_per_batch is not None
        if n_groups_per_batch > grouper.n_groups:
            raise ValueError(f'n_groups_per_batch was set to {n_groups_per_batch} but there are only {grouper.n_groups} groups specified.')

        group_ids = grouper.metadata_to_group(dataset.metadata_array)
        batch_sampler = GroupSampler(
            group_ids=group_ids,
            batch_size=batch_size,
            n_groups_per_batch=n_groups_per_batch,
            uniform_over_groups=uniform_over_groups,
            distinct_groups=distinct_groups)

        return DataLoader(dataset,
              num_workers=NUM_WORKERS,
              shuffle=None,
              sampler=None,
              collate_fn=dataset.collate,
              batch_sampler=batch_sampler,
              drop_last=False,
              **loader_kwargs)


In [6]:
# Load the modified tokenizer
tokenizer = MyDistilBertTokenizer.from_pretrained('distilbert-base-uncased')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'MyDistilBertTokenizer'.


In [10]:
from wilds.common.utils import split_into_groups

In [7]:
train_dataset = dataset.get_subset('train', transform=tokenizer)

In [8]:
print(len(train_dataset))

309220


In [11]:
train_loader = get_train_loader('group', train_dataset, batch_size=16, uniform_over_groups=False, grouper=dataset.train_grouper, n_groups_per_batch=16)

In [1]:
import pandas as pd
import os

In [2]:
_data_dir = 'data/civilcomments_v1.0'
_metadata_df = pd.read_csv(
            os.path.join(_data_dir, 'all_data_with_identities.csv'),
            index_col=0)

In [3]:
df_cleaned = _metadata_df.dropna(subset=['comment_text'])

In [4]:
_split_array = _metadata_df['split'].values

In [5]:
print(type(_split_array))

<class 'numpy.ndarray'>
