# 不平衡数据集的赋予权重的采样

In [1]:
import torch
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data.dataloader import DataLoader

In [2]:
# Create dummy data with class imbalance 99 to 1
numDataPoints = 1000
data_dim = 5
bs = 100
data = torch.randn(numDataPoints, data_dim)

In [3]:
data # 样本

tensor([[-1.1185,  1.2206,  0.5193, -0.2541,  1.6835],
        [-2.2643, -0.1118, -0.5797,  0.8096, -1.0290],
        [ 0.9882,  0.6891,  0.2279, -0.2689, -0.2669],
        ...,
        [-0.4972, -0.8703, -0.6876, -1.3560, -0.3337],
        [ 0.3913,  0.2325, -1.4994, -1.0452,  0.0820],
        [ 0.0133, -0.5565,  0.1112,  2.1399, -0.6813]])

In [4]:
target = torch.cat((torch.zeros(int(numDataPoints * 0.99), dtype=torch.long),
                    torch.ones(int(numDataPoints * 0.01), dtype=torch.long)))

In [5]:
target # 标签 990个0 10个1

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [6]:
print('target train 0/1: {}/{}'.format(
    (target == 0).sum(), (target == 1).sum()))

target train 0/1: 990/10


In [7]:
# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
print(class_sample_count, weight) # 赋予不平衡的样本更多的权重

tensor([990,  10]) tensor([0.0010, 0.1000])


In [8]:
samples_weight = torch.tensor([weight[t] for t in target])

In [9]:
samples_weight # 权重和样本数量一致

tensor([0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 

In [10]:
# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

In [12]:
list(sampler)

[827,
 996,
 999,
 775,
 480,
 995,
 997,
 589,
 564,
 108,
 994,
 521,
 841,
 437,
 741,
 261,
 334,
 862,
 807,
 247,
 991,
 999,
 617,
 859,
 995,
 991,
 996,
 431,
 860,
 993,
 332,
 991,
 119,
 995,
 995,
 965,
 996,
 997,
 933,
 998,
 991,
 255,
 999,
 994,
 278,
 569,
 208,
 767,
 67,
 918,
 994,
 996,
 593,
 646,
 994,
 991,
 991,
 243,
 996,
 441,
 991,
 997,
 997,
 992,
 541,
 997,
 877,
 998,
 966,
 992,
 136,
 731,
 555,
 60,
 997,
 992,
 993,
 995,
 998,
 862,
 193,
 287,
 618,
 39,
 695,
 993,
 993,
 992,
 995,
 75,
 999,
 995,
 995,
 998,
 991,
 591,
 998,
 749,
 990,
 547,
 996,
 991,
 845,
 70,
 220,
 280,
 189,
 481,
 996,
 999,
 997,
 250,
 974,
 990,
 196,
 794,
 999,
 992,
 578,
 990,
 997,
 996,
 991,
 993,
 32,
 606,
 770,
 993,
 992,
 998,
 995,
 510,
 998,
 535,
 638,
 365,
 693,
 628,
 399,
 241,
 189,
 994,
 996,
 619,
 601,
 978,
 990,
 997,
 996,
 993,
 998,
 999,
 992,
 145,
 762,
 999,
 604,
 999,
 514,
 882,
 616,
 994,
 995,
 999,
 423,
 530,
 19,
 326,

In [13]:
train_dataset = torch.utils.data.TensorDataset(data, target)

In [19]:
train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1: {}/{}".format(
        i, (y == 0).sum(), (y == 1).sum()))

batch index 0, 0/1: 47/53
batch index 1, 0/1: 48/52
batch index 2, 0/1: 50/50
batch index 3, 0/1: 49/51
batch index 4, 0/1: 49/51
batch index 5, 0/1: 53/47
batch index 6, 0/1: 58/42
batch index 7, 0/1: 52/48
batch index 8, 0/1: 50/50
batch index 9, 0/1: 48/52
