In [65]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import os
import cv2
import multiprocessing as mp
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import import_ipynb
import distance

importing Jupyter notebook from distance.ipynb


In [19]:
def read_alphabets(alphabet_directory_path, alphabet_directory_name):
    """
    Reads all the characters from a given alphabet_directory
    """
    datax = []
    datay = []
    characters = os.listdir(alphabet_directory_path)
    for character in characters:
        images = os.listdir(alphabet_directory_path + character + '/')
        for img in images:
            image = cv2.resize(
                cv2.imread(alphabet_directory_path + character + '/' + img),
                (28,28)
                )
            #rotations of image
            rotated_90 = ndimage.rotate(image, 90)
            rotated_180 = ndimage.rotate(image, 180)
            rotated_270 = ndimage.rotate(image, 270)
            datax.extend((image, rotated_90, rotated_180, rotated_270))
            datay.extend((
                alphabet_directory_name + '_' + character + '_0',
                alphabet_directory_name + '_' + character + '_90',
                alphabet_directory_name + '_' + character + '_180',
                alphabet_directory_name + '_' + character + '_270'
            ))
    return np.array(datax), np.array(datay)

def read_images(base_directory):
    """
    Reads all the alphabets from the base_directory
    Uses multithreading to decrease the reading time drastically
    """
    datax = None
    datay = None
    results = []
    for directory in os.listdir(base_directory):
        results.append(read_alphabets(base_directory + '/' + directory + '/',directory))
    for result in results:
        if datax is None:
            datax = result[0]
            datay = result[1]
        else:
            datax = np.vstack([datax, result[0]])
            datay = np.concatenate([datay, result[1]])
    return datax, datay

In [62]:
def extract_sample(n_way, n_support, n_query, datax, datay):
    sample = []
    unique_y = np.unique(datay)
    K = np.random.choice(unique_y, n_way, replace = False)
    for cls in K:
        datax_cls = datax[datay==cls]
        perm = np.random.permutation(datax_cls)
        sample_cls = perm[:(n_support + n_query)]
        sample.append(sample_cls)
    # sample in the end will be a matrix of dimension  k X n_support + n_query
    
    sample = np.array(sample) #become a np array of array (matrix)
    sample = torch.from_numpy(sample).float() # become a tensor
    sample = sample.permute(0,1,4,2,3) 
    df = {'images': sample, 'n_way': n_way,'n_support': n_support,'n_query': n_query}
    return df


def display_sample(sample):
    """
    Displays sample in a grid
    Args:
    sample (torch.Tensor): sample of images to display
    """
    #need 4D tensor to create grid, currently 5D
    sample_4D = sample.view(sample.shape[0]*sample.shape[1],*sample.shape[2:])
    #make a grid
    out = torchvision.utils.make_grid(sample_4D, nrow=sample.shape[1])
    plt.figure(figsize = (16,7))
    plt.imshow(out.permute(1, 2, 0))