In [88]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import tqdm
from sklearn.metrics import classification_report
import pandas as pd

In [54]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 15
batch_size = 128
learning_rate = 0.001

In [48]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # mean and std of MNIST
])

train_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=transform)

In [27]:
np.unique(train_dataset.targets, return_counts=True)

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]))

As we see, initially, the dataset is more or less balanced. We will take 10% of each classes to the new balanced dataset. The unbalanced dataset will consist of 100% samples from type 0, and 10% of samples from other types

In [79]:
indices = [np.where(np.array(train_dataset.targets) == 0)[0][:5500:20]]
indices_balanced = [np.where(np.array(train_dataset.targets) == 0)[0][:5000:100]]
for i in range(1, 10):
    indices.append(np.where(np.array(train_dataset.targets) == i)[0][:5000:200])
    indices_balanced.append(np.where(np.array(train_dataset.targets) == i)[0][:5000:100])
indices = np.concatenate(indices)
indices_balanced = np.concatenate(indices_balanced)
train_dataset_balanced = Subset(train_dataset, indices_balanced)
train_dataset_imbalanced = Subset(train_dataset, indices)
len(train_dataset_balanced), len(train_dataset_imbalanced)

  indices = [np.where(np.array(train_dataset.targets) == 0)[0][:5500:20]]
  indices_balanced = [np.where(np.array(train_dataset.targets) == 0)[0][:5000:100]]
  indices.append(np.where(np.array(train_dataset.targets) == i)[0][:5000:200])
  indices_balanced.append(np.where(np.array(train_dataset.targets) == i)[0][:5000:100])


(500, 500)

In [87]:
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # flatten
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_net():
    return NeuralNet().to(device)


def train(net: NeuralNet, train_loader, epochs):
        # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)

    # Training loop
    for epoch in tqdm.tqdm(range(epochs)):
        net.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = net(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

    return net


def eval(net, test_loader):
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return classification_report(all_labels, all_predictions, output_dict=True)


def avg_report(reports):
    res = {}
    for entry in zip(*[r.items() for r in reports]):
        if isinstance(entry[0][1], float):
            res[entry[0][0]] = sum(t[1] for t in entry) / len(reports)
        else:
            res[entry[0][0]] = avg_report([t[1] for t in entry])
    return res

In [89]:
EXPERIMENT_REPEATS = 20

train_loader_balanced = DataLoader(train_dataset_balanced, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

stats = []
for _ in range(EXPERIMENT_REPEATS):
    net = get_net()
    train(net, train_loader_balanced, epochs=num_epochs)
    stats.append(eval(net, test_loader))

100%|██████████| 15/15 [00:00<00:00, 15.49it/s]
100%|██████████| 15/15 [00:00<00:00, 15.89it/s]
100%|██████████| 15/15 [00:00<00:00, 15.62it/s]
100%|██████████| 15/15 [00:00<00:00, 15.42it/s]
100%|██████████| 15/15 [00:01<00:00, 14.05it/s]
100%|██████████| 15/15 [00:01<00:00, 14.93it/s]
100%|██████████| 15/15 [00:01<00:00, 14.93it/s]
100%|██████████| 15/15 [00:01<00:00, 13.37it/s]
100%|██████████| 15/15 [00:00<00:00, 15.04it/s]
100%|██████████| 15/15 [00:01<00:00, 14.84it/s]
100%|██████████| 15/15 [00:01<00:00, 11.26it/s]
100%|██████████| 15/15 [00:01<00:00, 12.87it/s]
100%|██████████| 15/15 [00:01<00:00, 14.81it/s]
100%|██████████| 15/15 [00:01<00:00, 13.98it/s]
100%|██████████| 15/15 [00:01<00:00, 14.33it/s]
100%|██████████| 15/15 [00:00<00:00, 16.05it/s]
100%|██████████| 15/15 [00:00<00:00, 15.98it/s]
100%|██████████| 15/15 [00:01<00:00, 14.68it/s]
100%|██████████| 15/15 [00:00<00:00, 15.10it/s]
100%|██████████| 15/15 [00:01<00:00, 13.88it/s]


In [90]:
stats = avg_report(stats)

In [92]:
pd.DataFrame(stats)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,accuracy,macro avg,weighted avg
precision,0.936353,0.929458,0.860811,0.788146,0.864094,0.819354,0.91048,0.905956,0.773553,0.828178,0.8603,0.861638,0.862898
recall,0.92199,0.942159,0.789099,0.842129,0.88666,0.75824,0.893633,0.853648,0.852823,0.846234,0.8603,0.858661,0.8603
f1-score,0.928969,0.935721,0.822889,0.813544,0.874834,0.786744,0.901871,0.878772,0.810106,0.836368,0.8603,0.858982,0.860451
support,980.0,1135.0,1032.0,1010.0,982.0,892.0,958.0,1028.0,974.0,1009.0,0.8603,10000.0,10000.0


In [93]:
train_loader_unbalanced = DataLoader(train_dataset_imbalanced, batch_size=batch_size, shuffle=True)

stats = []
for _ in range(EXPERIMENT_REPEATS):
    net = get_net()
    train(net, train_loader_unbalanced, epochs=num_epochs)
    stats.append(eval(net, test_loader))
stats = avg_report(stats)
pd.DataFrame(stats)

100%|██████████| 15/15 [00:00<00:00, 15.82it/s]
100%|██████████| 15/15 [00:00<00:00, 15.97it/s]
100%|██████████| 15/15 [00:00<00:00, 15.90it/s]
100%|██████████| 15/15 [00:00<00:00, 15.28it/s]
100%|██████████| 15/15 [00:00<00:00, 16.02it/s]
100%|██████████| 15/15 [00:00<00:00, 15.92it/s]
100%|██████████| 15/15 [00:00<00:00, 15.49it/s]
100%|██████████| 15/15 [00:00<00:00, 15.12it/s]
100%|██████████| 15/15 [00:00<00:00, 15.15it/s]
100%|██████████| 15/15 [00:00<00:00, 15.88it/s]
100%|██████████| 15/15 [00:00<00:00, 15.87it/s]
100%|██████████| 15/15 [00:00<00:00, 15.69it/s]
100%|██████████| 15/15 [00:01<00:00, 14.66it/s]
100%|██████████| 15/15 [00:00<00:00, 16.23it/s]
100%|██████████| 15/15 [00:01<00:00, 13.75it/s]
100%|██████████| 15/15 [00:01<00:00, 14.90it/s]
100%|██████████| 15/15 [00:00<00:00, 16.11it/s]
100%|██████████| 15/15 [00:00<00:00, 15.38it/s]
100%|██████████| 15/15 [00:00<00:00, 16.42it/s]
100%|██████████| 15/15 [00:00<00:00, 16.14it/s]


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,accuracy,macro avg,weighted avg
precision,0.76791,0.889459,0.866157,0.737081,0.792943,0.81903,0.876982,0.89793,0.759974,0.767212,0.8142,0.817468,0.818721
recall,0.991276,0.958414,0.781056,0.795594,0.770927,0.650448,0.875418,0.803405,0.73886,0.744995,0.8142,0.811039,0.8142
f1-score,0.865232,0.92252,0.820755,0.763951,0.779731,0.723541,0.876042,0.847299,0.747866,0.752968,0.8142,0.80999,0.812313
support,980.0,1135.0,1032.0,1010.0,982.0,892.0,958.0,1028.0,974.0,1009.0,0.8142,10000.0,10000.0


As we see, in specific scenarios class imbalance leads to performance downgrade when comparing with dataset of the same size but equal class distribution. However, the recall in case of overrepresented class is lower in balanced scenario.

In [96]:
indices_ext = [np.where(np.array(train_dataset.targets) == 0)[0][:4100:10]]
for i in range(1, 10):
    indices_ext.append(np.where(np.array(train_dataset.targets) == i)[0][:5000:500])
indices_ext = np.concatenate(indices_ext)
train_dataset_unbalanced_ext = Subset(train_dataset, indices_ext)
train_loader_unbalanced_ext = DataLoader(train_dataset_unbalanced_ext, batch_size=batch_size, shuffle=True)
len(train_dataset_unbalanced_ext)

  indices_ext = [np.where(np.array(train_dataset.targets) == 0)[0][:4100:10]]
  indices_ext.append(np.where(np.array(train_dataset.targets) == i)[0][:5000:500])


500

In [97]:
stats = []
for _ in range(EXPERIMENT_REPEATS):
    net = get_net()
    train(net, train_loader_unbalanced_ext, epochs=num_epochs)
    stats.append(eval(net, test_loader))
stats = avg_report(stats)
pd.DataFrame(stats)

100%|██████████| 15/15 [00:00<00:00, 15.28it/s]
100%|██████████| 15/15 [00:00<00:00, 16.04it/s]
100%|██████████| 15/15 [00:00<00:00, 15.37it/s]
100%|██████████| 15/15 [00:00<00:00, 16.26it/s]
100%|██████████| 15/15 [00:00<00:00, 15.93it/s]
100%|██████████| 15/15 [00:00<00:00, 15.88it/s]
100%|██████████| 15/15 [00:00<00:00, 15.49it/s]
100%|██████████| 15/15 [00:00<00:00, 15.90it/s]
100%|██████████| 15/15 [00:00<00:00, 15.73it/s]
100%|██████████| 15/15 [00:00<00:00, 15.94it/s]
100%|██████████| 15/15 [00:00<00:00, 15.57it/s]
100%|██████████| 15/15 [00:00<00:00, 15.55it/s]
100%|██████████| 15/15 [00:00<00:00, 15.57it/s]
100%|██████████| 15/15 [00:00<00:00, 15.69it/s]
100%|██████████| 15/15 [00:00<00:00, 15.10it/s]
100%|██████████| 15/15 [00:01<00:00, 11.14it/s]
100%|██████████| 15/15 [00:01<00:00, 12.04it/s]
100%|██████████| 15/15 [00:01<00:00, 13.03it/s]
100%|██████████| 15/15 [00:01<00:00, 12.64it/s]
100%|██████████| 15/15 [00:00<00:00, 15.14it/s]


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,accuracy,macro avg,weighted avg
precision,0.434724,0.768834,0.828723,0.744925,0.611503,0.862751,0.877502,0.808639,0.699019,0.4851,0.650885,0.712172,0.711858
recall,0.997959,0.958943,0.557122,0.614356,0.603157,0.361267,0.580219,0.644747,0.627618,0.498018,0.650885,0.644341,0.650885
f1-score,0.604535,0.851609,0.659446,0.670712,0.600033,0.496335,0.695098,0.71571,0.653211,0.486029,0.650885,0.643272,0.647724
support,980.0,1135.0,1032.0,1010.0,982.0,892.0,958.0,1028.0,974.0,1009.0,0.650885,10000.0,10000.0
