In [None]:
# %% Deep learning - Section 18.171
#    Image transforms

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [1]:
# %% Libraries and modules
import numpy                  as np
import matplotlib.pyplot      as plt
import torch
import torch.nn               as nn
import seaborn                as sns
import copy
import torch.nn.functional    as F
import pandas                 as pd
import scipy.stats            as stats
import sklearn.metrics        as skm
import time
import sys
import imageio.v2             as imageio
import torchvision
import torchvision.transforms as T

from torch.utils.data                 import DataLoader,TensorDataset
from sklearn.model_selection          import train_test_split
from google.colab                     import files
from torchsummary                     import summary
from scipy.stats                      import zscore
from sklearn.decomposition            import PCA
from scipy.signal                     import convolve2d
from IPython                          import display
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')
plt.style.use('default')


In [None]:
# %% New dataset !

cifar10_data = torchvision.datasets.CIFAR10(root='cifar10',download=True)
print(cifar10_data)


In [None]:
# %% Check dataset

# Shape of the dataset (num of images, [size, size], RGB channels)
print( cifar10_data.data.shape )

# Unique categories
print( cifar10_data.classes )

# .targets is a list of targets converted to integers
print( len(cifar10_data.targets) )


In [None]:
# %% Plotting

# Inspect a few random images
phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(5,5,figsize=(phi*9,9))

for ax in axs.flatten():

  randidx = np.random.choice(len(cifar10_data.targets))

  # extract image and label
  pic   = cifar10_data.data[randidx,:,:,:]
  label = cifar10_data.classes[cifar10_data.targets[randidx]]

  ax.imshow(pic)
  ax.text(16,0,label,ha='center',fontweight='bold',color='k',backgroundcolor='y')
  ax.axis('off')

plt.tight_layout()

plt.savefig('figure26_image_transforms.png')
plt.show()
files.download('figure26_image_transforms.png')


In [None]:
# %% Appy some transform

# Upsampling and greyscaling
Ts = T.Compose([ T.ToTensor(),
                 T.Resize(32*4),
                 T.Grayscale(num_output_channels=1)  ])

# Include the transform in the dataset
cifar10_data.transform = Ts

# You can also apply the transforms immediately when loading in the data
# cifar10_data = torchvision.datasets.CIFAR10(root='cifar10', download=True, transform=Ts)

# Nota Bene: adding a transform doesn't change the image data:
print(cifar10_data.data[123,:,:,:].shape)

# Option 1 to apply transform: apply "externally" to an image
img_1 = Ts( cifar10_data.data[123,:,:,:] )
print(img_1.shape)

# Option 2 to apply transform: use the embedded transform method
img_2 = cifar10_data.transform( cifar10_data.data[123,:,:,:] )
print(img_2.shape)


In [None]:
# %% Plotting

phi = (1 + np.sqrt(5)) / 2
fig,ax = plt.subplots(1,3,figsize=(phi*6,6))
ax[0].imshow(cifar10_data.data[123,:,:,:])
ax[0].set_title('Original')

ax[1].imshow(torch.squeeze(img_1),cmap='gray')
ax[1].set_title('Upsampled 1')

ax[2].imshow(torch.squeeze(img_2),cmap='gray')
ax[2].set_title('Upsampled 2')

plt.tight_layout()

plt.savefig('figure27_image_transforms.png')
plt.show()
files.download('figure27_image_transforms.png')


In [14]:
# %% Note about ToTensor() and normalization:

# Convert image (e.g. numpy array) into tensor, and also normalise images from a
# [1,256] range to a [0,1] range
??T.ToTensor


In [26]:
# %% Exercise 1
#    There are many other transforms available in torchvision: https://pytorch.org/vision/stable/transforms.html
#    Many transformations are useful for data preparation and augmentation. We'll cover some of them later in the course,
#    but for now, read about RandomCrop(), RandomHorizontalFlip(), and CenterCrop(). Then implement them to understand
#    what they do to images.
#    Tip: It's probably best to test these transforms separately, and on one test image, as we did above.

# RandomCrop(n); n = cropping window size
# RandomHorizontalFlip(p); p = probability of flipping
# CenterCrop(n); n =

# Transform
Ts = T.Compose([ T.ToTensor(),
                 T.CenterCrop(24),
                 T.Grayscale(num_output_channels=1)  ])

cifar10_data.transform = Ts
img_transformed        = cifar10_data.transform( cifar10_data.data[123,:,:,:] )


In [None]:
# %% Exercise 1
#    Continue ...

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig,ax = plt.subplots(1,2,figsize=(phi*6,6))
ax[0].imshow(cifar10_data.data[123,:,:,:])
ax[0].set_title('Original')

ax[1].imshow(torch.squeeze(img_transformed),cmap='gray')
ax[1].set_title('RandomHorizontalFlip()')

plt.tight_layout()

plt.savefig('figure28_image_transforms_extra1.png')
plt.show()
files.download('figure28_image_transforms_extra1.png')
