In [1]:

# load defaulft config
import yaml
import os
import time
import pandas as pd
import random

config_path = './configs/default_cifar.yml'

with open(config_path) as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    config = yaml.load(file, Loader=yaml.FullLoader)
    

# create base dir and gr
if os.path.exists(config["PROJECT"]["project_dir"]) is False:
    os.mkdir(config["PROJECT"]["project_dir"])

if os.path.exists(config["PROJECT"]["group_dir"]) is False:
    os.mkdir(config["PROJECT"]["group_dir"])
    
    
# Get the data to annotate

#############################################################################################
# LOAD DATA
#############################################################################################
from data_utils import CIFAR10Data
# Load data
cifar10_data = CIFAR10Data()
num_classes = len(cifar10_data.classes)
x_train, y_train, x_test, y_test = cifar10_data.get_data(normalize_data=False)

indices = list(range(len(x_train)))
random.seed(101)
random.shuffle(indices)
labeled_set = indices
unlabeled_set =[]


# test with all the images
NUM_IMAGES_TEST = len(x_test)
# Initialize a labeled dataset by randomly sampling K=ADDENDUM=1,000 data points from the entire dataset.
test_set = list(range(NUM_IMAGES_TEST))

config["NETWORK"]["INPUT_SIZE"] =  x_train[0].shape[0]
config["NETWORK"]["CLASSES"] = cifar10_data.classes


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
config

{'ACTIVE_ALGO': {'LOSSLEARNING': 1.0},
 'DATASET': {'height_shift_range': 4,
  'horizontal_flip': True,
  'width_shift_range': 4},
 'NETWORK': {'CLASSES': ['plane',
   'car',
   'bird',
   'cat',
   'deer',
   'dog',
   'frog',
   'horse',
   'ship',
   'truck'],
  'INPUT_SIZE': 32,
  'MARGIN': 1.0,
  'embedding_size': 128},
 'PROJECT': {'Backbone': 'resnet18',
  'dataset_name': 'CIFAR',
  'group': 'Classif_all_data_0912',
  'group_dir': '/mnt/Ressources/Andres/Temp_active/runs/Classif_all_data_0912',
  'project': 'Active_Learning_CIFAR',
  'project_dir': '/mnt/Ressources/Andres/Temp_active/runs',
  'source': 'CIFAR'},
 'RUNS': {'ADDENDUM': 1000,
  'CYCLES': 1,
  'SUBSET': -1,
  'TRIALS': 1,
  'test_each': 1},
 'TEST': {'batch_size': 128},
 'TRAIN': {'Data_augementation': True,
  'EPOCH_SLIT': 20,
  'EPOCH_WARMUP': 2,
  'EPOCH_WHOLE': 40,
  'MILESTONES': [25, 35],
  'batch_size': 128,
  'gamma': 0.1,
  'lr': 0.01,
  'start_epoch': 0,
  'transfer_weight_path': False,
  'w_c_loss': 1.0,


In [3]:

class Active_Learning_train:
    def __init__(self,   config, 
                         labeled_set,
                         test_set, 
                         num_run,
                         resume_model_path,
                         resume = False):

        
        #############################################################################################
        # LIBRARIES
        #############################################################################################        
        import os
        """
        # dont work in notebook
        self.run_path = os.path.dirname(os.path.realpath(__file__))
        os.chdir(self.run_path)
        # stuff using ray
        core = local_module("core")
        backbones = local_module("backbones")
        self.user            = get_user()
        """
        import core
        import backbones
        
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow
        import numpy as np
        from tensorflow.keras import optimizers, losses, models, backend, layers, metrics
        

        #############################################################################################
        # SETUP TENSORFLOW SESSION
        #############################################################################################
        config_tf = tf.ConfigProto(allow_soft_placement=True) 
        config_tf.gpu_options.allow_growth = True 
        self.sess = tf.Session(config=config_tf)
        self.sess.graph.as_default()
        backend.set_session(self.sess)

    
        #############################################################################################
        # PARAMETERS RUN
        #############################################################################################
        
        self.config          = config
        self.num_run         = num_run
        self.group           = "Stage_"+str(num_run)
        self.name_run        = "Train_"+self.group 
        
        self.run_dir         = os.path.join(config["PROJECT"]["group_dir"],self.group)
        self.run_dir_check   = os.path.join(self.run_dir ,'checkpoints')
        self.checkpoints_path= os.path.join(self.run_dir_check,'checkpoint.{epoch:03d}.hdf5')
        self.test_run_id     = None
        self.stop_flag       = False
        self.training_thread = None
        self.resume_training = resume
        
        self.num_data_train  = len(labeled_set) 
        self.resume_model_path = resume_model_path
        self.transfer_weight_path = self.config['TRAIN']["transfer_weight_path"]
        self.num_class       = len(self.config["NETWORK"]["CLASSES"])
        self.input_shape     = [self.config["NETWORK"]["INPUT_SIZE"], self.config["NETWORK"]["INPUT_SIZE"], 3]
        
        
        self.pre ='\033[1;36m' + self.name_run + '\033[0;0m' #"____" #
        self.problem ='\033[1;31m' + self.name_run + '\033[0;0m'
        
        # Creating the train folde
        import shutil
        
        if os.path.exists(self.run_dir) and self.resume_model_path is False:
            if num_run==0:
                shutil.rmtree(config["PROJECT"]["group_dir"])
                os.mkdir(config["PROJECT"]["group_dir"])
            else:  
                shutil.rmtree(self.run_dir)
                
        if os.path.exists(self.run_dir) is False:
            os.mkdir(self.run_dir)
            
        if os.path.exists(self.run_dir_check) is False:
            os.mkdir(self.run_dir_check)

            
        #############################################################################################
        # SETUP WANDB
        #############################################################################################
        """
        import wandb
        
        self.wandb = wandb
        self.wandb.init(project  = config["PROJECT"]["project"], 
                        group    = config["PROJECT"]["group"], 
                        name     = "Train_"+str(num_run),
                        job_type = self.group ,
                        sync_tensorboard = True,
                        config = config)
        """

        #############################################################################################
        # GLOBAL PROGRESS
        #############################################################################################
        self.current_epoch = 0
        self.split_epoch   = self.config['TRAIN']["EPOCH_WHOLE"] 
        self.total_epochs  = self.config['TRAIN']["EPOCH_WHOLE"] + self.config['TRAIN']["EPOCH_SLIT"]
        self.progress = round(self.current_epoch / self.total_epochs * 100.0, 2)

        #############################################################################################
        # LOAD DATA
        #############################################################################################
        if self.config["PROJECT"]["source"]=='CIFAR':
            from data_utils import CIFAR10Data
            # Load data
            cifar10_data = CIFAR10Data()
            x_train, y_train, _, _ = cifar10_data.get_data(normalize_data=False)

            x_train = x_train[labeled_set]
            y_train = y_train[labeled_set]
            
            self.test_set = test_set
        else:
            raise NameError('This is not implemented yet')
        
        
        #############################################################################################
        # DATA GENERATOR
        #############################################################################################
        self.Data_Generator = core.Generator_cifar_train(x_train, y_train, config)


        #############################################################################################
        # GENERATE MODEL
        #############################################################################################

        """
        ResNet18
        ResNet50
        ResNet101
        ResNet152
        ResNet50V2
        ResNet101V2
        ResNet152V2
        ResNeXt50
        ResNeXt101
        """
        #############################################################################################
        # DEFINE CLASSIFIER
        #############################################################################################
        # set input
        img_input = tf.keras.Input(self.input_shape,name= 'input_image')

        include_top = True

        # Get the selected backbone
        self.backbone = getattr(backbones,"ResNet18_cifar")
        #
        c_pred_features = self.backbone(input_tensor=img_input, classes= self.num_class, include_top=include_top)
        self.c_pred_features= c_pred_features
        if include_top: # include top classifier
            # class predictions
            c_pred = c_pred_features[0]
        else:
            x = layers.GlobalAveragePooling2D(name='pool1')(c_pred_features[0])
            x = layers.Dense(self.num_class, name='fc1')(x)
            c_pred = layers.Activation('softmax', name='c_pred')(x)
            c_pred_features[0]=c_pred

        self.classifier = models.Model(inputs=[img_input], outputs=c_pred_features,name='Classifier') 

        #############################################################################################
        # DEFINE FULL MODEL
        #############################################################################################
        c_pred_features_1 = self.classifier(img_input)
        c_pred_1 = c_pred_features_1[0]

        # define lossnet
        loss_pred_embeddings = core.Lossnet(c_pred_features_1, self.config["NETWORK"]["embedding_size"])

        self.model = models.Model(inputs=img_input, outputs=[c_pred_1]+loss_pred_embeddings) #, embedding_s] )
        
        #############################################################################################
        # DEFINE LOSSES
        #############################################################################################
        # losses
        self.loss_dict = {}
        self.loss_dict['Classifier'] = losses.categorical_crossentropy
        self.loss_dict['l_pred_w']   = core.Loss_Lossnet
        self.loss_dict['l_pred_s']   = core.Loss_Lossnet
        # weights
        self.weight_w = backend.variable(1)
        self.weight_s = backend.variable(0)
        
        self.loss_w_dict = {}
        self.loss_w_dict['Classifier'] = 1
        self.loss_w_dict['l_pred_w']   = self.weight_w
        self.loss_w_dict['l_pred_s']   = self.weight_s
        self.loss_w_dict['Embedding']  = 0
        
        #############################################################################################
        # DEFINE METRICS
        #############################################################################################
        # metrics
        self.metrics_dict = {}
        self.metrics_dict['Classifier'] = metrics.categorical_accuracy
        self.metrics_dict['l_pred_w']   = core.MAE_Lossnet
        self.metrics_dict['l_pred_s']   = core.MAE_Lossnet
        
        #############################################################################################
        # DEFINE OPTIMIZER
        #############################################################################################
        self.opt = optimizers.Adam(lr=0.01)
        
        #############################################################################################
        # DEFINE CALLBACKS
        #############################################################################################
        # Checkpoint saver
        self.callbacks = []
        model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
                                                filepath=self.checkpoints_path,
                                                save_weights_only=True,
                                                period=self.config["RUNS"]["test_each"])
        
        
        self.callbacks.append(model_checkpoint_callback)
        
        # Callback to wandb
        #self.callbacks.append(self.wandb.keras.WandbCallback())
        
        # Callback Learning Rate
        def scheduler(epoch):
            lr = self.config['TRAIN']['lr']
            for i in self.config['TRAIN']['MILESTONES']:
                if epoch>i:
                    lr*=0.1
            return lr
        
        self.callbacks.append(tf.keras.callbacks.LearningRateScheduler(scheduler))
        
        # callbeck to change the weigths for the split training:
        self.callbacks.append(core.Change_loss_weights(self.weight_w, self.weight_s, 3))
        
        #############################################################################################
        # LOAD PREVIUS WEIGTHS
        #############################################################################################
        if self.resume_model_path:
            # check the epoch where is loaded
            try:
                loaded_epoch = int(self.resume_model_path.split('.')[-2])
                print(self.pre, "Loading weigths from: ",self.resume_model_path)
                print(self.pre, "The detected epoch is: ",loaded_epoch)
                # load weigths
                self.model.load_weights(self.resume_model_path)
            except:
                print( self.problem ,"=> Problem loading the weights from ",self.resume_model_path)
                print( self.problem ,'=> It will rain from scratch')
                
        if self.resume_training:
            self.current_epoch = loaded_epoch
            self.progress = round(self.current_epoch / self.total_epochs * 100.0, 2)
            
            if self.current_epoch > self.total_epochs:
                raise ValueError("The starting epoch is higher that the total epochs")
            else:
                print(self.pre, "Resuming the training from stage: ",self.num_run," at epoch ", self.current_epoch)

        #############################################################################################
        # COMPILE MODEL
        #############################################################################################        
        self.model.compile(loss = self.loss_dict, 
                           loss_weights = self.loss_w_dict, 
                           metrics = self.metrics_dict, 
                          optimizer = self.opt)
    
        #############################################################################################
        # INIT VARIABLES
        #############################################################################################
        self.sess.graph.as_default()
        backend.set_session(self.sess)
        self.sess.run(tf.local_variables_initializer())

        #############################################################################################
        # SETUP WATCHER
        #############################################################################################    
    """
        self.run_watcher = get_run_watcher()
        
        self.run_watcher.add_run.remote(name=self.name_run,
                                        user=self.user,
                                        progress=self.progress,
                                        wandb_url=self.wandb.run.get_url(),
                                        status="Idle")
        print(self.pre,'Init done')
        

    
    @ray.method(num_returns = 0)
    def start_training(self):
        import threading
        import os
        import numpy as np
        from copy import deepcopy
        from tensorflow.keras import backend
        
        def train():
            try:
                self.sess.graph.as_default()
                backend.set_session(self.sess)
                #self.sess.run(tf.local_variables_initializer())
                
                print( self.pre ,"Start training")
                self.run_watcher.update_run.remote(name=self.name_run, status="Training")
                
                ###############################################################################
                # TRAIN THE WHOLE NETWORK
                ###############################################################################
                if self.current_epoch <= self.split_epoch: 
                    
                    print( self.pre ,"Compile with new weights for the losses")
                    # change the weigth to the predictions of the whole network
                    #self.loss_w_dict['l_pred_w']   = 1
                    #self.loss_w_dict['l_pred_s']   = 0
                    

                    print( self.pre ,"End compiling")
                    
                    self.sess.run(tf.local_variables_initializer())
                    print( self.pre ,"Init local")
                    # run epoch by epoch to be able to have the stop flag
                    for epoch in range(self.current_epoch, self.split_epoch):
                        
                        print( self.pre ,"Training epoch", epoch)
                        
                        if self.stop_flag:
                            self.run_watcher.update_run.remote(name=self.name_run, status="Idle")
                            break
                        
                        history = self.model.fit_generator(self.Data_Generator,
                                                           epochs=epoch+1, 
                                                           callbacks = self.callbacks,
                                                           initial_epoch=epoch,
                                                           verbose=1)

   
                        self.current_epoch = epoch
                        self.progress = round(self.current_epoch / self.total_epochs * 100.0, 2)
                        self.run_watcher.update_run.remote(name=self.name_run, progress=self.progress)

                    
                if self.current_epoch <= self.total_epochs:
                    print( self.pre ,"Compile with new weights for the losses")
                    # change the weigth to the predictions of the whole network
                    self.loss_w_dict['l_pred_w']   = 0
                    self.loss_w_dict['l_pred_s']   = 1
                    
                    # compile the model
                    self.model.compile(loss = self.loss_dict, 
                           loss_weights = self.loss_w_dict, 
                           metrics = self.metrics_dict, 
                           optimizer = self.opt)
                    
                    self.sess.run(tf.local_variables_initializer())
                    # run epoch by epoch to be able to have the stop flag
                    for epoch in range(self.current_epoch, self.total_epochs):
                        print( self.pre ,"Training epoch", epoch)

                        if self.stop_flag:
                            self.run_watcher.update_run.remote(name=self.name_run, status="Idle")
                            break
                        
                        history = self.model.fit_generator(self.Data_Generator,
                                                           epochs=epoch+1, 
                                                           callbacks = self.callbacks,
                                                           initial_epoch=epoch ,verbose=1)

   
                        self.current_epoch = epoch
                        self.progress = round(self.current_epoch / self.total_epochs * 100.0, 2)
                        self.run_watcher.update_run.remote(name=self.name_run, progress=self.progress)
                if self.current_epoch > self.total_epochs:
                    print(self.problem, 'The starting epoch is higher that the total epochs')
                self.run_watcher.update_run.remote(name=self.name_run, status="Finished", progress=self.progress)
                
            except Exception as e:
                self.run_watcher.update_run.remote(name=self.name_run, status="Error")
                print( self.problem ,e)
            
        if self.training_thread is None or not self.training_thread.is_alive():
            self.stop_flag=False
            self.training_thread = threading.Thread(target=train, args=(), daemon=True)
            self.training_thread.start()
            
    @ray.method(num_returns = 1)
    def isTraining(self):
        return not (self.training_thread is None or not self.training_thread.is_alive())

    @ray.method(num_returns = 0)
    def stop_training(self):
        self.stop_flag=True

    @ray.method(num_returns = 1)
    def get_progress(self):
        return {"global_step" : self.global_step_val, "progress": self.progress, }
        
        """

In [4]:
resume_model_path = '/mnt/Ressources/Andres/Temp_active/runs/Classif_all_data_0912/Stage_5000/checkpoints/checkpoint.002.hdf5'
resume_model_path = False
fles = Active_Learning_train(config,labeled_set,[],5000,resume_model_path)

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.


In [5]:
from tensorflow.keras import optimizers, losses, models, backend, layers, metrics
import tensorflow as tf
import core

In [6]:
print( fles.pre ,"Start training")


history = fles.model.fit_generator(fles.Data_Generator,
                                   epochs=6, 
                                   callbacks = fles.callbacks,
                                   initial_epoch=0)




[1;36mTrain_Stage_5000[0;0m Start training
Instructions for updating:
Use tf.cast instead.
Epoch 1/6
out 1 0
Epoch 2/6
out 1 0
Epoch 3/6
out 1 0
Epoch 4/6
out 0 1
Epoch 5/6
out 0 1
Epoch 6/6
out 0 1
