In [3]:
import os 
import random 
import numpy as np 
import matplotlib.pyplot as plt 
import math 
from PIL import Image 

import numpy.random as rng
import import_ipynb
import  class_augmentator 

import copy

importing Jupyter notebook from class_augmentator.ipynb


In [6]:
class Lichensloader:
    """
    Class which loads and prepares the lichens dataset. It separates training, validation and evaluation
    data. It provides function for getting one-shot task batches 
    
    Attributes:
        dataset_path : path to lichen dataset 
        train_dictionary : dictionary of the files of the train set. This dictionary is used to load 
        the batch for training and validation 
        
        evaluation_dictionary : dictionary of evaluation set. 
        image_width
        image_height 
        batch_size
        use_augmentation : boolean that alloes us to select if data augmentation is used or not
        image augmentator: instanc of class ImageAugmentor that augments images with the affine transformations 
        referred in the paper     
    """
    
    
    def __init__(self, dataset_path, use_augmentation, batch_size):
        self.dataset_path = dataset_path 
        self.train_dictionary = {}
        self.evaluation_dictionary = {}
        self.image_width = 400
        self.image_height = 400
        self.channels = 3
        self.batch_size = batch_size
        self.use_augmentation = use_augmentation 
        self.train_data = []
        self.validation_data = []
        self.test_data = []

        self.current_train_lichen_index = 0
        self.current_validation_lichen_index = 0
        self.current_test_lichen_index = 0

        self.load_dataset()

        if(self.use_augmentation):
            self.image_augmentor = self.createAugmentor()
        else:
            self.use_augmentation = []

    
    
    def load_dataset(self):
        """
        load_dataset into dictionary
        """
        
        train_path = os.path.join(self.dataset_path,'train')
        test_path = os.path.join(self.dataset_path,'test')
        
        lichens = os.listdir(train_path)
        
        for lich in lichens:
            if(lich == ".DS_Store"):
                continue
            
            # train part
            lich_path = os.path.join(train_path,lich)
            spec_lichen_images = os.listdir(lich_path)
            res = []
            for t in spec_lichen_images:
                if(t==".DS_Store"):
                    continue
                else:
                    res.append(t)
            
            self.train_dictionary[lich] = res
            
            #test part 
            lich_test_path = os.path.join(test_path,lich)
            spec_lichen_images_test = os.listdir(lich_test_path)
            res_test = []
            for t in spec_lichen_images_test:
                if(t == ".DS_Store"):
                    continue
                else:
                    res_test.append(t)
            
            self.evaluation_dictionary[lich] = res_test 
            
            
    def createAugmentor(self):
        """
        Creates ImageAugmentor object with the parameters for image augmentation 
        
        rotation range was set in -15 to 15 degrees 
        Shear range was set in between -0.3 and 0.3 radians 
        Zoom_range = [0.5,2]
        shift_range = [5,5]
        
        """
        
        rotation_range = [-15,15]
        shear_range = [-0.3 * 180 / math.pi, 0.3 * 180 / math.pi]
        zoom_range = [0.5,2]
        shift_range = [5,5]
        
        return class_augmentator.ImageAugmentator(0.5,shear_range,rotation_range,shift_range,zoom_range)
    
    
    
    def split_train_datasets(self,tr = 0.8):
        """ Splits the train set in train and validation
        Divide the 30 train alphabets in train and validation with
        # a 80% - 20% split (24 vs 6 alphabets)
        """
        ts = 1 - tr
        available_lichens = list(self.train_dictionary.keys())
        number_of_lichens = len(available_lichens)

        train_indexes = random.sample(range(0, number_of_lichens - 1), int(tr * number_of_lichens))

        # If we sort the indexes in reverse order we can pop them from the list
        # and don't care because the indexes do not change
        train_indexes.sort(reverse=True)

        for index in train_indexes:
            self.train_data.append(available_lichens[index])
            available_lichens.pop(index)

        # The remaining alphabets are saved for validation
        self.validation_data = available_lichens
        self.test_data = list(self.evaluation_dictionary.keys())
            
    
    def convert_path_list_to_images_and_labels(self, path_list, is_one_shot_task):
        
        number_of_pairs = int(len(path_list) / 2)
        pairs_of_images = [np.zeros((number_of_pairs, self.image_width, self.image_height, 3)) for i in range(2)]
        labels = np.zeros((number_of_pairs, 1))

        for pair in range(number_of_pairs):
            image = Image.open(path_list[pair * 2])
            image = np.asarray(image).astype(np.float64)
            image = image / image.std() - image.mean()

            pairs_of_images[0][pair, :, :, 0] = image
            image = Image.open(path_list[pair * 2 + 1])
            image = np.asarray(image).astype(np.float64)
            image = image / image.std() - image.mean()

            pairs_of_images[1][pair, :, :, 0] = image
            if not is_one_shot_task:
                if (pair + 1) % 2 == 0:
                    labels[pair] = 0
                else:
                    labels[pair] = 1

            else:
                if pair == 0:
                    labels[pair] = 1
                else:
                    labels[pair] = 0

        if not is_one_shot_task:
            random_permutation = np.random.permutation(number_of_pairs)
            labels = labels[random_permutation]
            pairs_of_images[0][:, :, :,
                               :] = pairs_of_images[0][random_permutation, :, :, :]
            pairs_of_images[1][:, :, :,
                               :] = pairs_of_images[1][random_permutation, :, :, :]

        return pairs_of_images, labels
    
    
    
    def create_pairs_for_batch(self):
        """
        creare 32 coppie di immagini, 16 uguali e 16 diverse   
        """

        available_training_cat = self.train_data
        categories = rng.choice(available_training_cat,size=(self.batch_size,),replace=True)
        pairs=[np.zeros((self.batch_size, self.image_height, self.image_width,3),np.int32) for i in range(2)]
        targets=np.zeros((self.batch_size,))
        targets[self.batch_size//2:] = 1



        for i in range(self.batch_size):
            category = categories[i]
            # load category of lichen
            lichen_category = self.train_dictionary[category]

            idx_1 = rng.randint(0, len(lichen_category))
            #prima immagine
            direc = os.path.join("patches/train/" + category,lichen_category[idx_1] )

            image_1 = Image.open(direc)
            image_1 = np.asarray(image_1).astype(np.float64)
            image_1 = image_1 / image_1.std() - image_1.mean()
            pairs[0][i,:,:,:] = image_1

            if i >= self.batch_size // 2:
                idx_2 = (idx_1 + rng.randint(1,len(lichen_category))) % len(lichen_category)
                direc = os.path.join("patches/train/" + category,lichen_category[idx_2] )
                image_2 = Image.open(direc)
                image_2 = np.asarray(image_2).astype(np.float64)
                image_2 = image_2 / image_2.std() - image_2.mean()
                pairs[1][i,:,:,:] = image_2

            else:
                temp_cat = copy.deepcopy(available_training_cat)
                temp_cat.pop(temp_cat.index(category))
                category_2 =  rng.choice(temp_cat,size = 1, replace = False)[0]
                lichen_category_2 = self.train_dictionary[category_2]
                idx_2 = rng.randint(0, len(lichen_category_2))
                direc = os.path.join("patches/train/" + category_2,lichen_category_2[idx_2])
                image_2 = Image.open(direc)
                image_2 = np.asarray(image_2).astype(np.float64)
                image_2 = image_2 / image_2.std() - image_2.mean()
                
                pairs[1][i,:,:,:] = image_2
        return pairs, targets
    
    def get_one_shot_batch(self, support_set_size,is_validation):
        """
        Single image that will be compared with a support set of images. It returns the pair 
        if images to be compared by the model and it's label (FIRST PAIR ALWAYS 1) AND REMAINING 
        ONES ARE 0'S
        """
    
        if is_validation:
            lichens = self.validation_data 
            current_lichen_index = self.current_validation_lichen_index
            dictionary = self.train_dictionary
        else:
            lichens = self.test_data 
            current_lichen_index = self.current_test_lichen_index
            dictionary = self.evaluation_dictionary
            
        # prendo in considerazione uno specifico lichene
        current_lichen = lichens[current_lichen_index]
        #considero le immagini disponibili per questo specifico lichene
        available_lichen_images = list(dictionary[ current_lichen].keys()) 
        number_of_lichen_images  = len(available_lichen_images)
        
        batch_images_path = []
        
        #scelgo un indice casuale e predo un immagine di test
        text_character_index = random.sample(range(0,number_of_lichen_images),2) 
        #immagine di test
        test_L = available_lichen_images[text_character_index[0]]
        L = available_lichen_images[text_character_index[1]]
        
        batch_images_path.append(test_L) # appendo immagine da verificare
        batch_images_path.append(L) # appendo unica immagine uguale 
        
        #ora devo scegliere causalmente altre (support_set_size -1) immagini per il support set 
        
        categories = rng.choice(lichens.pop(lichens.index(current_lichen)),size = support_set_size -1, replace = False)
        
        for i in range(categories):
            cat = categories[i]
            available_lichen_images = list(dictionary[cat].keys())
            number_of_lichen_images  = len(available_lichen_images)
            text_character_index = random.sample(range(0,number_of_lichen_images),1)
            temp_L = available_lichen_images[text_character_index[0]]
            batch_images_path.append(test_L)
            batch_images_path.append(temp_L)
        
        images, labels = self._convert_path_list_to_images_and_labels(bacth_images_path, is_one_shot_task=True)
        
        return images, labels
    
    
    def one_shot_test(self, model, support_set_size, number_of_tasks_per_L,is_validation):
        if is_validation:
            lichens = self.validation_data
            print('\nMaking One Shot Task on validation alphabets:')
        else:
            lichens = self.test_data
            print('\nMaking One Shot Task on evaluation alphabets:')
        
        mean_global_accuracy = 0
        
        for lich in lichens:
            mean_alphabet_accuracy = 0
            for _ in range(number_of_tasks_per_L):
                images, _ = self.get_one_shot_batch(support_set_size, is_validation=is_validation)
                probabilities = model.predict_on_batch(images)

                # Added this condition because noticed that sometimes the outputs
                # of the classifier was almost the same in all images, meaning that
                # the argmax would be always by defenition 0.
                if np.argmax(probabilities) == 0 and probabilities.std()>0.01:
                    accuracy = 1.0
                else:
                    accuracy = 0.0

                mean_alphabet_accuracy += accuracy
                mean_global_accuracy += accuracy

            mean_alphabet_accuracy /= number_of_tasks_per_alphabet
            if is_validation:
                self.current_validation_lichen_inde\x += 1
            else:
                self.current_test_lichen_index += 1

        mean_global_accuracy /= (len(alphabets) *
                                 number_of_tasks_per_alphabet)

        print('\nMean global accuracy: ' + str(mean_global_accuracy))

        # reset counter
        if is_validation:
            self._current_validation_alphabet_index = 0
        else:
            self._current_evaluation_alphabet_index = 0

        return mean_global_accuracy

        
        
    
    
    
            
            
            
        
        
        
        
        
    
    
    


    
    

        

    
    