In [3]:
# coding: utf-8

"""Visualization of the transforms in data_loading.py"""
# **Author**: `Francisco Belchí <frbegu@gmail.com>, <https://github.com/KikoBelchi/2d_to_3d>`_

###
### Imports
###
from __future__ import print_function, division
import itertools
import os
import torch
import pandas as pd
from skimage import io, transform # package 'scikit-image'
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

import data_loading
import functions_data_processing
import functions_plot

# Imports for plotting
import matplotlib.pyplot as plt # Do not use when running on the server
from mpl_toolkits.mplot3d import axes3d # Do not use when running on the server

# Allow the interactive rotation of 3D scatter plots in jupyter notebook
import sys    
import os    
file_name =  os.path.basename(sys.argv[0])
#print(file_name == 'ipykernel_launcher.py') # This basicaly asks whether this file is a jupyter notebook?
if __name__ == "__main__":
    if file_name == 'ipykernel_launcher.py': # Run only in .ipynb, not in exported .py scripts
        get_ipython().run_line_magic('matplotlib', 'notebook') # Equivalent to ''%matplotlib notebook', but it is also understood by .py scripts


In [4]:
# This script worked before changing the list of arguments of data_loading.vertices_Dataset() by the args produced by the parser.
# Now al lthis should be adapted and run from a .py script in order to work.

# In particular, we will need to add the following and then modify the args as appropriate in each cell
import argparse
import functions_data_processing
args = functions_data_processing.parser()


usage: ipykernel_launcher.py [-h] [--sequence-name SEQUENCE_NAME]
                             [--dataset-number N] [--reordered-dataset N]
                             [--num-selected-vertices N]
                             [--submesh-num-vertices-vertical N]
                             [--submesh-num-vertices-horizontal N]
                             [--batch-size N] [--num-epochs N] [--lr LR]
                             [--momentum M] [--gamma N] [--no-cuda] [--seed S]
                             [--log-interval N] [--log-epochs N]
                             [--num-workers N] [--resnet-version N]
                             [--frozen-resnet N] [--hyperpar-option N]
                             [--crop-centre-or-ROI N] [--camera-coordinates N]
                             [--random-seed-to-choose-video-sequences N]
                             [--random-seed-to-shuffle-training-frames N]
                             [--random-seed-to-shuffle-validation-frames N]
             

SystemExit: 2

In [2]:
# Horizontal Flip (applied only to the 2D image)
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()
    tsfrm = transforms.RandomHorizontalFlip()
    sample = vertex_dataset[22]
    # RGBA_to_RGB(sample['image'])

    from PIL import Image
    im = Image.open(sample['img_name'])
    # im.show()
    plt.figure()
    plt.imshow(im)
    plt.title('Original image')
    plt.show()

    transformed_image = tsfrm(im)
    # transformed_image.show()
    plt.figure()
    plt.imshow(transformed_image)
    plt.title('Transformed image')
    plt.show()

TypeError: __init__() missing 1 required positional argument: 'args'

In [None]:
# Random Resized Crop (applied only to the 2D image)
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()
    tsfrm = transforms.RandomResizedCrop(224)
    sample = vertex_dataset[22]
    # RGBA_to_RGB(sample['image'])

    im = Image.open(sample['img_name'])
    # im.show()
    plt.figure()
    plt.imshow(im)
    plt.title('Original image')
    plt.show()

    transformed_image = tsfrm(im)
    # transformed_image.show()
    plt.figure()
    plt.imshow(transformed_image)
    plt.title('Transformed image')
    plt.show()

    # The following does not work, so I had to use PIL images
    # transformed_image = tsfrm(Image.open(io.BytesIO(sample['img_name'])))
    # transformed_image = tsfrm(RGBA_to_RGB(sample['image']))
    # fig = plt.figure()
    # ax = plt.subplot

In [None]:
# Resize (without cropping) (applied only to the 2D image)
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()
    tsfrm = transforms.Resize(10)
    sample = vertex_dataset[22]
    # RGBA_to_RGB(sample['image'])

    im = Image.open(sample['img_name'])
    # im.show()
    plt.figure()
    plt.imshow(im)
    plt.title('Original image')
    plt.show()

    transformed_image = tsfrm(im)
    # transformed_image.show()
    plt.figure()
    plt.imshow(transformed_image)
    plt.title('Transformed image')
    plt.show()

    # The following does not work, so I had to use PIL images
    # transformed_image = tsfrm(Image.open(io.BytesIO(sample['img_name'])))
    # transformed_image = tsfrm(RGBA_to_RGB(sample['image']))
    # fig = plt.figure()
    # ax = plt.subplot

In [None]:
# Resize (without cropping) (applied only to the 2D image)
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()
    tsfrm = transforms.Resize((10, 20))
    sample = vertex_dataset[22]
    # RGBA_to_RGB(sample['image'])

    im = Image.open(sample['img_name'])
    # im.show()
    plt.figure()
    plt.imshow(im)
    plt.title('Original image')
    plt.show()

    transformed_image = tsfrm(im)
    # transformed_image.show()
    plt.figure()
    plt.imshow(transformed_image)
    plt.title('Transformed image')
    plt.show()

    # The following does not work, so I had to use PIL images
    # transformed_image = tsfrm(Image.open(io.BytesIO(sample['img_name'])))
    # transformed_image = tsfrm(RGBA_to_RGB(sample['image']))
    # fig = plt.figure()
    # ax = plt.subplot

In [None]:
# Resize (without cropping) (applied only to the 2D image)
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()
    tsfrm = transforms.Resize((20, 20))
    sample = vertex_dataset[22]
    # RGBA_to_RGB(sample['image'])

    im = Image.open(sample['img_name'])
    # im.show()
    plt.figure()
    plt.imshow(im)
    plt.title('Original image')
    plt.show()

    transformed_image = tsfrm(im)
    # transformed_image.show()
    plt.figure()
    plt.imshow(transformed_image)
    plt.title('Transformed image')
    plt.show()

    # The following does not work, so I had to use PIL images
    # transformed_image = tsfrm(Image.open(io.BytesIO(sample['img_name'])))
    # transformed_image = tsfrm(RGBA_to_RGB(sample['image']))
    # fig = plt.figure()
    # ax = plt.subplot

In [None]:
# Resize to ResNet needed size (without cropping) (applied only to the 2D image)
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()
    tsfrm = transforms.Resize((224, 224))
    sample = vertex_dataset[22]
    # RGBA_to_RGB(sample['image'])

    im = Image.open(sample['img_name'])
    # im.show()
    plt.figure()
    plt.imshow(im)
    plt.title('Original image')
    plt.show()

    transformed_image = tsfrm(im)
    # transformed_image.show()
    plt.figure()
    plt.imshow(transformed_image)
    plt.title('Transformed image')
    plt.show()

    # The following does not work, so I had to use PIL images
    # transformed_image = tsfrm(Image.open(io.BytesIO(sample['img_name'])))
    # transformed_image = tsfrm(RGBA_to_RGB(sample['image']))
    # fig = plt.figure()
    # ax = plt.subplot

In [None]:
###
### CAVEAT: Since our images have high resolution and many background pixels, 
### some random crops don't even show the towel
###
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()
    sample = vertex_dataset[22]
    plt.figure()
    image = sample['image']
    plt.imshow(image)
    plt.title('Original image')
    plt.show()

    # Lucky example in which most of the towel shows:
    plt.figure()
    image = io.imread('resizedRandomCrop_goodLuckExample')
    plt.imshow(image)
    plt.title('Lucky example in which most of the towel shows')
    plt.show()

    # Example in which only some of the towel shows:
    plt.figure()
    image = io.imread('resizedRandomCrop_averageLuckExample')
    plt.imshow(image)
    plt.title('Example in which only some of the towel shows')
    plt.show()

    # Bad luck example in which none of the towel shows:
    plt.figure()
    image = io.imread('resizedRandomCrop_badLuckExample')
    plt.imshow(image)
    plt.title('Example in which none of the towel shows')
    plt.show()

# CAVEAT: Permutation of axes by 'torchvision.transforms.ToTensor' 
'torchvision.transforms.ToTensor' <br>
converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a 
torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].

To convert it back to a tensor of shape (H x W x C) in the range [0, 255], and to remove the extra alpha channel we are carrying with out RGBA pictures, we create the following functions:
data_loading.unnormalize_RGB_of_HWC_tensor
data_loading.unnormalize_RGB_of_CHW_tensor
data_loading.tensor_to_plot

# CAVEAT: Permutation of axes by 'torchvision.transforms.ToTensor'. Part II 
Notice, though, that 'torchvision.utils.make_grid' 
and some functions for CNNs take as input torch.FloatTensor of shape (C x H x W).


In [None]:
###
### Composition of off-the-shelf torchvision.transforms
###
# Plot using 'data_loading.tensor_to_plot'
if __name__ == '__main__':
    vertex_dataset = data_loading.vertices_Dataset()

    mean_for_normalization = [0.485, 0.456, 0.406]
    std_for_normalization = [0.229, 0.224, 0.225]

    tsfrm = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean_for_normalization, std_for_normalization)
        ])

    sample = vertex_dataset[22]
    im = Image.open(sample['img_name'])
    # print(im.size)
    plt.figure()
    plt.imshow(im)
    plt.title('Original image')
    plt.show()

    transformed_image = tsfrm(im)
    # print(type(transformed_image))
    # print(transformed_image.shape)

    transformed_image_toPlot = data_loading.tensor_to_plot(transformed_image, mean_for_normalization, std_for_normalization)

    plt.figure()
    plt.imshow(transformed_image_toPlot) # Plot the R channel
    plt.title('Transformed image')
    plt.show()

In [None]:
###
### All transforms together in the instanciation of the dataset class
###
# 
# Let's put this all together to create a dataset with composed
# transforms.
# To summarize, every time this dataset is sampled:
# 
# -  An image is read from the file on the fly
# -  Transforms are applied on the read image
# -  Since some of the transforms are random, data is augmentated on
#    sampling
# 
# ##  Iterating through the dataset
# 
# We can iterate over the created dataset with a loop as before,
# using the function functions_plot.plot_a_fixed_list_images_randomly_transformed(transformed_dataset).

if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.RandomResizedCrop(224),
                                               transforms.RandomHorizontalFlip(),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]))

    functions_plot.plot_a_fixed_list_images_randomly_transformed(transformed_dataset)


In [None]:
# Rectangular box containing the towel in the 2D image rescaled to 224x224
if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((224,224)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406],
                                                     [0.229, 0.224, 0.225])
                               ])  
    transformed_dataset = data_loading.vertices_Dataset(transform=transform,
                                                   crop_centre_or_ROI=2)
    functions_plot.plot_a_fixed_list_images_randomly_transformed(transformed_dataset)

In [None]:
###
### Only CenterCrop transform
###
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.CenterCrop(224),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]))

    functions_plot.plot_a_fixed_list_images_randomly_transformed(transformed_dataset)

In [None]:
###
### Using torch.utils.data.DataLoader
###
# 
# However, we are losing a lot of features by using a simple loop to
# iterate over the data. In particular, we are missing out on:
# 
# -  Batching the data
# -  Shuffling the data
# -  Load the data in parallel using ``multiprocessing`` workers.
# 
# ``torch.utils.data.DataLoader`` is an iterator which provides all these
# features. Parameters used below should be clear. One parameter of
# interest is ``collate_fn``. You can specify how exactly the samples need
# to be batched using ``collate_fn``. However, default collate should work
# fine for most use cases.

# Piece of code from when the transformed sample['image'] had the shape HxWxC instead of CxHxW:
#     for i in range(batch_size):
#         plt.figure()
#         plt.imshow(images_batch[i, :, :, :]) 
# #         group_number, animation_frame = data_loading.group_and_frame_from_idx(i)
# #         plt.title('Transformed image. Group no.: ' + str(group_number) + '. Frame = ' + str(animation_frame))
#         plt.show()

# Show a batch from DataLoader - only CenterCrop
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.CenterCrop(224),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]))

    dataloader = DataLoader(transformed_dataset, batch_size=16,
                            shuffle=True, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['Vertex_coordinates'].size())

        # observe 16th batch and stop.
        if i_batch == 15:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A random batch of unnormalized centre cropped images')
            plt.ioff()
            plt.show()
            break

# Show a batch from DataLoader
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.RandomResizedCrop(224),
                                               transforms.RandomHorizontalFlip(),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]))

    dataloader = DataLoader(transformed_dataset, batch_size=5,
                            shuffle=True, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['Vertex_coordinates'].size())

        # observe 5th batch and stop.
        if i_batch == 4:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A random batch of unnormalized transformed images')
            plt.ioff()
            plt.show()
            break


In [None]:
### Resize instead of crop
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.Resize(35),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]))

    dataloader = DataLoader(transformed_dataset, batch_size=2,
                            shuffle=False, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['Vertex_coordinates'].size())

        # observe 5th batch and stop.
        if i_batch == 4:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A batch of unnormalized Resized images')
            plt.ioff()
            plt.show()
            break
            
### No resize nor crop
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]))

    dataloader = DataLoader(transformed_dataset, batch_size=2,
                            shuffle=False, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['Vertex_coordinates'].size())

        # observe 5th batch and stop.
        if i_batch == 4:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A batch of unnormalized images')
            plt.ioff()
            plt.show()
            break

In [None]:
### Resize ROI to 35 (very small to see the difference)
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.Resize(35),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]), crop_centre_or_ROI=1
                                                       )

    dataloader = DataLoader(transformed_dataset, batch_size=1,
                            shuffle=False, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['Vertex_coordinates'].size())

        # observe 5th batch and stop.
        if i_batch == 4:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A batch of unnormalized Resized images')
            plt.ioff()
            plt.show()
            break
            
### No resize of the ROI
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]), crop_centre_or_ROI=1
                                                       )

    dataloader = DataLoader(transformed_dataset, batch_size=1,
                            shuffle=False, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['Vertex_coordinates'].size())

        # observe 5th batch and stop.
        if i_batch == 4:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A batch of unnormalized images')
            plt.ioff()
            plt.show()
            break

In [None]:
### Resize ROI to 224 (the amount needed for ResNet)
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.Resize(224),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]), crop_centre_or_ROI=1
                                                       )

    batch_size = 1
    dataloader = DataLoader(transformed_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print('Member of the batch number ' + str(i_batch) + '. Image size: ' + str(sample_batched['image'].size()))

        # observe last member of the batch and stop.
        if i_batch == batch_size-1:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A batch of unnormalized Resized images')
            plt.ioff()
            plt.show()
            break
            
### No resize of the ROI
if __name__ == '__main__':
    transformed_dataset = data_loading.vertices_Dataset(transform = 
                                           transforms.Compose([
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                           ]), crop_centre_or_ROI=1
                                                       )

    batch_size = 1
    dataloader = DataLoader(transformed_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=4)


    for i_batch, sample_batched in enumerate(dataloader):
        print('Member of the batch number ' + str(i_batch) + '. Image size: ' + str(sample_batched['image'].size()))

        # observe last member of the batch and stop.
        if i_batch == batch_size-1:
            plt.figure()
            functions_plot.show_image_batch(sample_batched)
            plt.axis('off')
            plt.title('A batch of unnormalized images')
            plt.ioff()
            plt.show()
            break