# Stratified Splitting in PyTorch

Importing Libraries:

In [1]:
import sys
sys.path.append('.')
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random


Importing the Data:

In [2]:
train_data = datasets.MNIST("MNIST-data", train=True, download=True, transform=transforms.ToTensor())

0it [00:00, ?it/s]Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST-data\MNIST\raw\train-images-idx3-ubyte.gz
 93%|█████████▎| 9191424/9912422 [00:02<00:00, 6005216.90it/s]Extracting MNIST-data\MNIST\raw\train-images-idx3-ubyte.gz to MNIST-data\MNIST\raw

0it [00:00, ?it/s][ADownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST-data\MNIST\raw\train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s][A

0it [00:00, ?it/s][A[AExtracting MNIST-data\MNIST\raw\train-labels-idx1-ubyte.gz to MNIST-data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST-data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s][A[A

  3%|▎         | 49152/1648877 [00:00<00:05, 307300.58it/s][A[A

 13%|█▎        | 212992/1648877 [00:00<00:03, 392720.45it/s][A[A

 43%|████▎     | 712704/1648877 [00:00<00:01, 542619.58it/s][A[A

 85%|████████▍ | 1400832/1648877 [00:0

Creating a list of different class:

In [3]:
my_data = [[] for idx in range(10)]
my_train_data = [[] for idx in range(10)]
my_val_data = [[] for idx in range(10)]

for data in tqdm(train_data):
    my_data[int(data[1])].append((data[0],data[1]))

for idx in range(10):
    my_train_data[idx] = my_data[idx][:][:]

9920512it [00:05, 1756471.63it/s]                             
100%|██████████| 60000/60000 [00:08<00:00, 7226.34it/s]


Splitting each class to training and validation data(with the desire rartio):

In [4]:
split_ratio = .8

for idx in trange(10):
    my_train_data[idx], my_val_data[idx] = torch.utils.data.random_split(my_train_data[idx], [int(len(my_train_data[idx])*split_ratio), len(my_train_data[idx])-int(len(my_train_data[idx])*split_ratio)])


100%|██████████| 10/10 [00:00<00:00, 771.18it/s]


Flattening the list of lists into a single list:

In [5]:
train_data = []
val_data = []

for idx in trange(10):
    for item in my_train_data[idx]:
        train_data.append(item)
    for item in my_val_data[idx]:
        val_data.append(item)

100%|██████████| 10/10 [00:00<00:00, 217.98it/s]


Just a simple shuffle to ensure the rows in the same class are not all near each other:

In [6]:
random.shuffle(train_data)
random.shuffle(val_data)

In [7]:
print("Training rows:", len(train_data), "Validation rows:", len(val_data))

47995 12005


Checking the percentage of each class in our training set:

In [8]:
total = 0
counter_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}
for data in tqdm(train_data):
    x, y = data
    counter_dict[int(y)]+=1
    total+=1
for i in counter_dict:
    print(f"{i}: {round(counter_dict[i]/total*100,2)}%")

100%|██████████| 47995/47995 [00:00<00:00, 601488.65it/s]0: 9.87%
1: 11.24%
2: 9.93%
3: 10.22%
4: 9.74%
5: 9.03%
6: 9.86%
7: 10.44%
8: 9.75%
9: 9.92%



Ensuring, we get the same proportion for the validation set:

In [9]:
total = 0
counter_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}
for data in tqdm(val_data):
    x, y = data
    counter_dict[int(y)]+=1
    total+=1
for i in counter_dict:
    print(f"{i}: {round(counter_dict[i]/total*100,2)}%")

100%|██████████| 12005/12005 [00:00<00:00, 401264.05it/s]0: 9.87%
1: 11.24%
2: 9.93%
3: 10.22%
4: 9.74%
5: 9.04%
6: 9.86%
7: 10.44%
8: 9.75%
9: 9.91%

