In [1]:
from wilds.common.data_loaders import get_eval_loader, get_train_loader
from wilds.common.grouper import CombinatorialGrouper
from models.initializer import get_dataset
import torchvision.transforms as transforms

In [19]:
model_transforms = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.Lambda(lambda image: image.convert('RGB')),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

full_dataset = get_dataset(dataset='waterbirds',
                           root_dir='/media/SSD2/Dataset',
                           download=True,
                           split_scheme='official',
                           seed=11111111)

train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=['generic-spurious', 'y']
    )

data = full_dataset.get_subset('test',frac=1.0,transform=model_transforms)
loader = get_train_loader(loader='standard', 
                          dataset=data, 
                          batch_size=10,
                          uniform_over_groups=True, 
                          grouper=train_grouper,
                          n_groups_per_batch=4)

print(len(data[0]))
                        
for batch in loader:
    print(batch[0].shape)
    print(batch[1]) 
    print(batch[2])
    print() 
    break                  

tensor([0.0016, 0.0016, 0.0016,  ..., 0.0004, 0.0004, 0.0004])
Using WeightedRandomSampler with 5794 groups
3
torch.Size([10, 3, 224, 224])
tensor([1, 0, 1, 1, 1, 1, 0, 0, 0, 1])
tensor([[1, 1, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 1, 0],
        [0, 1, 0],
        [1, 1, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 0, 0],
        [1, 1, 0]])



In [9]:
full_dataset.metadata_array

tensor([[1, 1, 0],
        [1, 1, 1],
        [0, 1, 0],
        ...,
        [0, 0, 0],
        [0, 0, 1],
        [1, 0, 0]])

In [8]:
loader = get_train_loader(loader='standard', 
                          dataset=data, 
                          batch_size=10,
                          uniform_over_groups=False, 
                          grouper=train_grouper,
                          n_groups_per_batch=4)

print(len(data[0]))
                        
for batch in loader:
    print(batch[0].shape)
    print(batch[1]) 
    print(batch[2])
    break      

3
torch.Size([10, 3, 224, 224])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([[1, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1]])


In [14]:
groups, group_counts = train_grouper.metadata_to_group(
                data.metadata_array,
                return_counts=True)
print(groups[:-10])
print(len(group_counts))

tensor([3, 3, 3,  ..., 0, 0, 1])
4


In [16]:
group_weights = 1 / group_counts
weights = group_weights[groups]
weights

tensor([0.0009, 0.0009, 0.0009,  ..., 0.0003, 0.0003, 0.0003])

In [17]:
len(groups),len(weights)

(4795, 4795)