In [1]:
import numpy as np
import os
import shutil

from keras import Input
from keras.applications.efficientnet import EfficientNetB0
from keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint, ReduceLROnPlateau
from keras.layers import GlobalAveragePooling2D, BatchNormalization, Dropout, Dense
from keras.models import load_model, Model
from tensorflow.keras.optimizers import Adam
from keras.preprocessing import image
from keras_preprocessing.image import ImageDataGenerator
from PIL import ImageFile
from PIL import Image

import sys
sys.modules['Image'] = Image 
ImageFile.LOAD_TRUNCATED_IMAGES = True


In [8]:

class NoisyStudent:
    def __init__(self, labelled_dataset: str, unlabelled_dataset: str, labels: list, number_of_train: int = 50,
                 teacher_model_path: str = "model.h5",
                 teacher_model_size: tuple = (224, 224),
                 student_model_path: str = None,
                 confidence: float = 0.90, batch_size: int = 70, epochs: int = 4, learning_rate: float = 0.001,
                 loss_function: str = "categorical_crossentropy"):
        self.confidence = confidence
        self.labelled_dataset = labelled_dataset
        self.unlabelled_dataset = unlabelled_dataset
        self.labels = labels
        self.teacher_model_path = teacher_model_path
        self.teacher_model = None
        self.teacher_model_size = teacher_model_size
        self.student_model_path = student_model_path
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.loss_function = loss_function
        self.num_classes = len(self.labels)
        self.number_of_train = number_of_train
        self.__check_require_data()
        pass

    def __check_require_data(self):
        if not os.path.exists(self.labelled_dataset):
            print( "Labelled dataset not found")
            exit(0)
        if not os.path.exists(self.unlabelled_dataset):
            print( "Unlabelled dataset not found")
            exit(0)

        if len(self.labels) == 0:
            print( "not labels set")
            exit(0)
        for obj_class in self.labels:
            if obj_class not in os.listdir(self.labelled_dataset):
                print( "label not exist in labelled dataset")
                exit(0)
        if self.teacher_model_path != "":
            if not os.path.exists(self.teacher_model_path):
                print( "teacher model not found")
                exit(0)
            self.teacher_model = self.__load_model(self.teacher_model_path)
        else:
            self.teacher_model_path = "model.h5"
        return


    def __load_model(self, path: str) -> load_model:
        try:
            model = load_model(path)
        except Exception as e:
            raise "teacher model load error : " + str(e)
        return model

    def __preprocess_dataset(self):
        dataset_generator = ImageDataGenerator(
            rotation_range=10,  # rotation
            width_shift_range=0.2,  # horizontal shift
            height_shift_range=0.2,  # vertical shift
            zoom_range=0.2,  # zoom
            horizontal_flip=True,  # horizontal flip
            brightness_range=[0.2, 1.2],
            validation_split=0.3
        )

        __train_generator = dataset_generator.flow_from_directory(
            directory=self.labelled_dataset,
            target_size=self.teacher_model_size,
            class_mode='categorical',
            color_mode="rgb",
            batch_size=self.batch_size,
            shuffle=True,
            seed=2020,  # to make the result reproducible
            subset='training'
        )
        __validate_generator = dataset_generator.flow_from_directory(
            directory=self.labelled_dataset,
            target_size=self.teacher_model_size,
            class_mode='categorical',
            color_mode="rgb",
            batch_size=self.batch_size,
            shuffle=True,
            seed=2020,  # to make the result reproducible
            subset='validation'
        )
        return __train_generator, __validate_generator

    def __monitoring_initial(self):
        earlystop = EarlyStopping(patience=10)
        tensorboard = TensorBoard(log_dir="~/logs", histogram_freq=1, update_freq='batch', profile_batch=True,
                                  write_graph=True, write_images=True, write_steps_per_second=True)
        checkpoint = ModelCheckpoint(filepath="checkpoints/", save_weights_only=False, monitor='val_accuracy',
                                     mode='max',
                                     save_best_only=True)
        learning_rate_reduction = ReduceLROnPlateau(monitor='val_accuracy', patience=2, verbose=1, factor=0.5,
                                                    min_lr=0.0001)
        return [earlystop, learning_rate_reduction, checkpoint, tensorboard]

    def __model_builder(self):
        inputs = Input(shape=(self.teacher_model_size[0], self.teacher_model_size[1], 3))
        model = EfficientNetB0(include_top=False, input_tensor=inputs, weights="imagenet")

        # Freeze the pretrained weights
        model.trainable = False

        # Rebuild top
        x = GlobalAveragePooling2D(name="avg_pool")(model.output)
        x = BatchNormalization()(x)

        top_dropout_rate = 0.2
        x = Dropout(top_dropout_rate, name="top_dropout")(x)
        outputs = Dense(self.num_classes, activation="softmax", name="pred")(x)

        # Compile
        model = Model(inputs, outputs, name="EfficientNet")
        optimizer = Adam(lr=self.learning_rate)
        model.compile(
            optimizer=optimizer, loss=self.loss_function, metrics=["accuracy", "MeanSquaredError", "AUC"]
        )
        return model

    def train(self):
        train_generator, validate_generator = self.__preprocess_dataset()
        model = self.__model_builder()
        callbacks = self.__monitoring_initial()
        if self.teacher_model_path != None:
            if os.path.exists(self.teacher_model_path):
                model.load_weights(self.teacher_model_path)
        model.fit(
            train_generator,
            epochs=self.epochs,
            validation_data=validate_generator,
            validation_steps=validate_generator.samples // self.batch_size,
            steps_per_epoch=train_generator.samples // self.batch_size,
            callbacks=callbacks
        )
        model.save(self.teacher_model_path)

    def main(self):
        print("start task")
        print("number of train ",self.number_of_train)
        for i in range(self.number_of_train):
            print("===================== train ================")
            self.train()
            print("===================== predict ================")
            self.predict_unlabelled_data()
            

    def predict_unlabelled_data(self):
        print("start predict")
        image_width, image_height = self.teacher_model_size
        files = os.listdir(self.unlabelled_dataset)
        predict_batch = int(len(files) / self.batch_size)
        self.teacher_model = self.__load_model(self.teacher_model_path)
        try:
            for batch in range(0, predict_batch):
                predict_files = files[batch * self.batch_size:(batch + 1) * self.batch_size]
                images = []
                images_path = []
                for image_name in predict_files:
                    img = os.path.join(self.unlabelled_dataset, image_name)
                    images_path.append(img)
                    img = image.load_img(img, target_size=(image_width, image_height))
                    img = image.img_to_array(img)
                    img = np.expand_dims(img, axis=0)
                    images.append(img)
                images_data = np.vstack(images)
                classes = self.teacher_model.predict(images_data, batch_size=self.batch_size)
                for file_number, predict in enumerate(classes):
                  
                    if predict[np.argmax(predict)] > self.confidence:
                        print(
                            f"{images_path[file_number]} is {self.labels[np.argmax(predict)]} score {predict[np.argmax(predict)]}")
                        image_path = images_path[file_number]
                        image_name = image_path.split('/')[-1]
                        shutil.move(image_path,
                                    self.labelled_dataset + '/' + self.labels[np.argmax(predict)] + '/' + image_name)
        except Exception as e:
            print(image_name)
            print(e)

ns = NoisyStudent(labelled_dataset="labled_data/train", unlabelled_dataset="unlabled_data/", labels=["class1", "class2"])


In [None]:
ns.main()