In [2]:
from PIL import Image
import glob

In [3]:
import torch
import torch.utils.data as data

# torch.utils.data.dataset is an abstract class representing a dataset
from torch.utils.data.dataset import Dataset
# https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html

In [4]:
import os
import torch
import numpy as np
import pandas as pd
import sys
import csv

### Notes
- Pytorch Dataloading documentation: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
- torch.utils.data.Dataset is an abstract class representing a dataset.
- CIFAR-10 python version: http://www.cs.toronto.edu/~kriz/cifar.html



## Downloading the CIFAR-10 Dataset
You can read more about the CIFAR-10 dataset here: https://www.kaggle.com/c/cifar-10

1. Go to this link https://www.cs.toronto.edu/~kriz/cifar.html
2. Right click on "CIFAR-10 python version" and click "Copy Link Address"
3. Go to your CLI and go into the data directory.
4. Run this cURL command to start downloading the dataset: curl -O <URL of the link that you copied>
4. To extract the data from the .tar file run: tar -xzvf <name of file> (type man tar in your CLI to see the different options for running the tar command). NOTE: Each file in the directory contains a batch of images in CIFAR-10 that have been serialized using python's pickle module. You will have to first unpickle the data before loading it into your model.

Loaded in this way, each of the batch files contains a dictionary with the following elements:

- data -- a 10000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image. The first 1024 entries contain the red channel values, the next 1024 the green, and the final 1024 the blue. The image is stored in row-major order, so that the first 32 entries of the array are the red channel values of the first row of the image.
- labels -- a list of 10000 numbers in the range 0-9. The number at index i indicates the label of the ith image in the array data.

The dataset contains another file, called batches.meta. It too contains a Python dictionary object. It has the following entries:
- label_names -- a 10-element list which gives meaningful names to the numeric labels in the labels array described above. For example, label_names[0] == "airplane", label_names[1] == "automobile", etc.


## Preprocessing
- Unpickle


In [5]:
CIFAR_DIR = './datasets/cifar-10-batches-py/'

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict




In [6]:
all_batches = glob.glob(CIFAR_DIR+'data_batch*')


In [7]:
batch_1 = unpickle(all_batches[0])

In [8]:
batch_1

{b'batch_label': b'training batch 1 of 5',
 b'labels': [6,
  9,
  9,
  4,
  1,
  1,
  2,
  7,
  8,
  3,
  4,
  7,
  7,
  2,
  9,
  9,
  9,
  3,
  2,
  6,
  4,
  3,
  6,
  6,
  2,
  6,
  3,
  5,
  4,
  0,
  0,
  9,
  1,
  3,
  4,
  0,
  3,
  7,
  3,
  3,
  5,
  2,
  2,
  7,
  1,
  1,
  1,
  2,
  2,
  0,
  9,
  5,
  7,
  9,
  2,
  2,
  5,
  2,
  4,
  3,
  1,
  1,
  8,
  2,
  1,
  1,
  4,
  9,
  7,
  8,
  5,
  9,
  6,
  7,
  3,
  1,
  9,
  0,
  3,
  1,
  3,
  5,
  4,
  5,
  7,
  7,
  4,
  7,
  9,
  4,
  2,
  3,
  8,
  0,
  1,
  6,
  1,
  1,
  4,
  1,
  8,
  3,
  9,
  6,
  6,
  1,
  8,
  5,
  2,
  9,
  9,
  8,
  1,
  7,
  7,
  0,
  0,
  6,
  9,
  1,
  2,
  2,
  9,
  2,
  6,
  6,
  1,
  9,
  5,
  0,
  4,
  7,
  6,
  7,
  1,
  8,
  1,
  1,
  2,
  8,
  1,
  3,
  3,
  6,
  2,
  4,
  9,
  9,
  5,
  4,
  3,
  6,
  7,
  4,
  6,
  8,
  5,
  5,
  4,
  3,
  1,
  8,
  4,
  7,
  6,
  0,
  9,
  5,
  1,
  3,
  8,
  2,
  7,
  5,
  3,
  4,
  1,
  5,
  7,
  0,
  4,
  7,
  5,
  5,
  1,
  0,
  9,
  6,
  9,
 

In [9]:
batch_1[b'data']

array([[ 59,  43,  50, ..., 140,  84,  72],
       [154, 126, 105, ..., 139, 142, 144],
       [255, 253, 253, ...,  83,  83,  84],
       ...,
       [ 71,  60,  74, ...,  68,  69,  68],
       [250, 254, 211, ..., 215, 255, 254],
       [ 62,  61,  60, ..., 130, 130, 131]], dtype=uint8)

In [13]:

im = batch_1[b'data'][0]          
im_red = im[0:1024].reshape(32,32)
im_green = im[1024:2048].reshape(32,32)
im_blue = im[2048:3072].reshape(32,32)
im = np.stack([im_red, im_green, im_blue],axis=2)
im = rotate_img(im, 270)
im = Image.fromarray(im)
im.save('./testing/your_file_reshaped.jpeg')

## Save all images

In [14]:
all_batches = glob.glob(CIFAR_DIR+'data_batch*')
name_counter = 0
for batch_name in all_batches:
    batch = unpickle(batch_name)
    data = batch[b'data']
    for i in range(len(data)):
        im = batch_1[b'data'][i]          
        im_red = im[0:1024].reshape(32,32)
        im_green = im[1024:2048].reshape(32,32)
        im_blue = im[2048:3072].reshape(32,32)
        im = np.stack([im_red, im_green, im_blue],axis=2)
        im = Image.fromarray(im)
        im.save('./images/'+ str(name_counter) + '.jpeg')
        name_counter = name_counter +1 

KeyboardInterrupt: 

## Rotation

In [15]:
"""
Takes in an image and a rotation. Returns the the image with the rotation applied.
"""
def rotate_img(img, rot):
    if rot == 0: # 0 degrees rotation
        return img
    elif rot == 90: # 90 degrees rotation
        return np.flipud(np.transpose(img, (1,0,2)))
    elif rot == 180: # 90 degrees rotation
        return np.fliplr(np.flipud(img))
    elif rot == 270: # 270 degrees rotation / or -90
        return np.transpose(np.flipud(img), (1,0,2))
    else:
        raise ValueError('rotation should be 0, 90, 180, or 270 degrees')

## Load the data

In [18]:
'''
Pytorch uses datasets and has a very handy way of creating dataloaders in your main.py
Make sure you read enough documentation.
'''
class CIFAR(Dataset):
    """
    CIFAR dataset
    Implements Dataset (torch.utils.data.dataset)
    """
    def __init__(self, data_dir):
        """
        Args:
            data_dir (string): Directory with all the images
        """
        #gets the data from the directory
        self.image_list = glob.glob(data_dir)
        #calculates the length of image_list
        self.data_len = len(self.image_list)

    def __getitem__(self, index):
        """
        Lazily get the item at the index.
        """
        # Get image name from the pandas df
        single_image_path = self.image_list[index]
        
        # Open image
        image = Image.open(single_image_path)
        
        # Convert to numpy, dim = 28x28
        image_np = np.asarray(image)/255 
        
        # Do some operations on image
        
        image_0 = rotate_img(image_np, 0).reshape(3,32,32)
        image_90 = rotate_img(image_np, 90).reshape(3,32,32)
        image_180 = rotate_img(image_np, 180).reshape(3,32,32)
        image_270 = rotate_img(image_np, 270).reshape(3,32,32)
        
        # print(image_270.shape)
        
        image_stack = np.stack((image_0, image_90, image_180, image_270))
        # print(image_stack.shape)
        
        # print(image_stack.shape)
        
        # One hot encoding for the label
        label_stack = np.stack((0,1,2,3))
        # label_stack = np.stack((np.array([1,0,0,0]), np.array([0,1,0,0]), np.array([0,0,1,0]), np.array([0,0,0,1])))
        
        # print(label_stack.shape)
        
        
        # Convert numpy to a tensor
        image_tensor = torch.from_numpy(image_stack).float()
        
        label_tensor = torch.from_numpy(label_stack).float()
        label_tensor = label_tensor.type(torch.LongTensor)

        
        # print(image_tensor.shape)
        # print(label_tensor.shape)
        return (image_tensor, label_tensor)

    def __len__(self):
        return self.data_len



In [19]:
IMAGE_DIR = './images/' + "*"


In [20]:
data_loader = CIFAR(IMAGE_DIR)

In [27]:
data_loader.image_list

['./images/29283.jpeg',
 './images/18927.jpeg',
 './images/2839.jpeg',
 './images/1716.jpeg',
 './images/30558.jpeg',
 './images/32565.jpeg',
 './images/19635.jpeg',
 './images/11205.jpeg',
 './images/9326.jpeg',
 './images/38168.jpeg',
 './images/13238.jpeg',
 './images/26176.jpeg',
 './images/41948.jpeg',
 './images/23374.jpeg',
 './images/21349.jpeg',
 './images/43975.jpeg',
 './images/14007.jpeg',
 './images/42667.jpeg',
 './images/723.jpeg',
 './images/38492.jpeg',
 './images/4514.jpeg',
 './images/6529.jpeg',
 './images/37767.jpeg',
 './images/29779.jpeg',
 './images/36875.jpeg',
 './images/34848.jpeg',
 './images/30108.jpeg',
 './images/1346.jpeg',
 './images/6483.jpeg',
 './images/19265.jpeg',
 './images/32135.jpeg',
 './images/8864.jpeg',
 './images/26526.jpeg',
 './images/10947.jpeg',
 './images/689.jpeg',
 './images/11655.jpeg',
 './images/13668.jpeg',
 './images/38538.jpeg',
 './images/9776.jpeg',
 './images/16190.jpeg',
 './images/373.jpeg',
 './images/42237.jpeg',
 './ima

In [43]:

image_stack = []
label_stack = []
for i in range(32):
    temp = next(iter(data_loader))
    image_stack.append(temp[0])
    label_stack.append(temp[1])
    
img = np.concatenate(image_stack, axis=0)
lb = np.concatenate(label_stack, axis=0)

(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)
(4, 3, 32, 32)


In [44]:
img.shape

(128, 3, 32, 32)

In [45]:
lb.shape

(128,)