In [13]:
import numpy
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
import torch.optim as optim
from torch.utils.data import Subset
from torchvision.datasets import ImageFolder

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomGrayscale(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomInvert(),
    transforms.RandomRotation(30),
])

data = datasets.ImageFolder('dataset', transform=transform)

# Split into train/test sets:
train_len = int(len(data)*0.8)
train_set, test_set = random_split(data, [train_len, len(data) - train_len])

# Extract classes:
train_classes = [train_set.dataset.targets[i] for i in train_set.indices]
# Calculate support:
class_count = Counter(train_classes)
# Calculate class weights:
class_weights = torch.DoubleTensor([len(train_classes)/c for c in pd.Series(class_count).sort_index().values]) 
# Sampler needs the respective class weight supplied for each image in the dataset:
sample_weights = [class_weights[train_set.dataset.targets[i]] for i in train_set.indices]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=int(len(train_set)*2), replacement=True)

batch_size=32

# Create torch dataloaders:
train_loader = DataLoader(train_set, batch_size=4, sampler=sampler, num_workers=12)
print("The number of images in a training set is:", len(train_loader)*batch_size)

test_loader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=12)
print("The number of images in a test set is:", len(test_loader)*batch_size)

The number of images in a training set is: 861856
The number of images in a test set is: 107744
