In [6]:
# Plot imports
import matplotlib.pyplot as plt
# numpy for process large numerical matrixes
import numpy as np
# PyTorch to train and process deep learning and AI models
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
# Import torchvision (part of Pytorch)
# Process images and manipulate them (crop, resize)
import torchvision
from torchvision import datasets, transforms, models
# Python Imaging Library (PIL) to visualize images
from PIL import Image

# libraries that ensure the plots are shown inline and in high resolution
%matplotlib inline
%config InLineBackend.figure_format = 'retina'

In [8]:
# Folder that contains the image data
# Unzip Data.zip 
data_dir = './data'

# Function to read the data; crop and resize the images; finally split it into test and train chunks (20-80)
def load_split_train_test(datadir, valid_size = .2):
    # Transform the images
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.Resize(224),
        transforms.ToTensor()
    ])

    test_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.Resize(224),
        transforms.ToTensor()
    ])

    train_data = datasets.ImageFolder(datadir, transform=train_transforms)
    test_data = datasets.ImageFolder(datadir, transform=test_transforms)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    np.random.shuffle(indices)
    from torch.utils.data.sampler import SubsetRandomSampler
    train_idx, test_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    train_loader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size = 16)
    test_loader = torch.utils.data.DataLoader(test_data, sampler=test_sampler, batch_size = 16)

    return train_loader, test_loader

# Using 20% of data for testing
train_loader, test_loader = load_split_train_test(data_dir, .2)
print(train_loader.dataset.classes)

['Basalt', 'Highland']
