In [None]:
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

# Location of the organized scans 
#### Parent folder should contain a folder with subfolders for each class each 
#### containing images corresponding to the subfolder
data_dir = 'example_dir/Code/scans'


TEST_SIZE = 0.2
BATCH_SIZE = 4
SEED = 42


# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data = datasets.ImageFolder(data_dir)


# generate indices: instead of the actual data we pass in integers instead
train_indices, test_indices, _, _ = train_test_split(
    range(len(data)),
    data.targets,
    stratify=data.targets,
    test_size=TEST_SIZE,
    random_state=SEED
)

train_split = Subset(data, train_indices)
test_split = Subset(data, test_indices)


image_datasets = {"train":train_split, "val":train_split}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x].dataset, batch_size=BATCH_SIZE,
                                             shuffle=True, num_workers=2)
              for x in ['train', 'val']}

dataloaders['train'].dataset.transform = data_transforms["train"]
dataloaders['val'].dataset.transform = data_transforms["val"]

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = data.classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")