https://towardsdatascience.com/building-a-custom-model-in-scikit-learn-b0da965a1299

In [None]:
import re
import pandas as pd
import numpy as np
import datetime
import time

import unicodedata
import emoji

import matplotlib.pyplot as plt

pd.set_option('display.max_colwidth', None)

In [None]:
class modelMaker(tf.keras.Model):

    def __init__(self, img_height, img_width, num_classes=1, trained='dense'):
        super(modelMaker, self).__init__()
        self.trained = trained
        self.IMG_SHAPE = (img_height,img_width) + (3,)
        # define common layers
        self.flat = tf.keras.layers.Flatten(name="flatten")
        self.classify = tf.keras.layers.Dense(num_classes, name="classify")
        # define layers for when "trained" != "resnet"
        if self.trained == "dense":
            self.dense = tf.keras.layers.Dense(128, name="dense128") 
        
        # layers for when "trained" == "resnet"
        else:
            self.pre_resnet = tf.keras.applications.resnet50.preprocess_input
            self.base_model = tf.keras.applications.ResNet50(input_shape=self.IMG_SHAPE, include_top=False, weights='imagenet')
            self.base_model.trainable = False
            for layer in self.base_model.layers:
                if isinstance(layer, tf.keras.layers.BatchNormalization):
                    layer.trainable = True
                else:
                    layer.trainable = False
    
    def call(self, inputs):
        # define your model without resnet 
        if self.trained == "dense":
            x = self.flat(inputs)
            x = self.dense(x)
            x = self.classify(x)
            return x
        # define your model with resnet
        else:
            x = self.pre_resnet(inputs)
            x = self.base_model(x)
            x = self.flat(x)
            x = self.classify(x)
            return x
        
    # add this function to get correct output for model summary
    def summary(self):
        x = tf.keras.Input(shape=self.IMG_SHAPE, name="input_layer")
        model = tf.keras.Model(inputs=[x], outputs=self.call(x))
        return model.summary()
    
model = modelMaker(128, 128, trained="resnet") # create object
model.build((10,128,128,3))                    # build model
model.summary()                                # print summary

In [49]:
class skmodel_experiment:
    
    def __init__(self, experiment_id, run_name):
        
        # Instantiate the model run with MLFLOW
        #self.x_train = xtrain
        #self.y_train = ytrain
        #self.x_test= xtest
        #self.y_test = ytest
        
        self.experiment_id = experiment_id
        self.run_name = run_name
        #self.classifier = classifier
        
    def __str__(self):
        
        # Print the model run summary
        return f"Mlflow experiment ID: {self.experiment_id} ({self.run_name})\n"
        
    def run_experiment(classifier, model_name, run_name, log_metrics = False):
        if log_metrics:
            with mlflow.start_run(experiment_id=experiment_id, run_name=run_name) as run:
                run_id = run.info.run_uuid
                MlflowClient().set_tag(run_id,
                                       "mlflow.note.content",
                                       "Testing baseline models for binary classification between compliant(0) and non-compliant(1) post")
                # Logged information like Parameters and Training metrics (like precission score, f1 score, ...)
                mlflow.sklearn.autolog()

                tags = {"Application": "Binary Classification - Non-Compliant/Compliant", # Options: ML/DL/Sentiment/Topic
                        "release_version": "1.0.0"}
                mlflow.set_tags(tags)

                clf = classifier
                clf.fit(x_train, y_train)
                valid_prediction = clf.predict_proba(X_valid)[:, 1]

                fpr, tpr, thresholds = roc_curve(y_valid, valid_prediction)
                roc_auc = auc(fpr, tpr)

                if log_metrics:
                    mlflow.log_metrics({"validation_auc": roc_auc})
        else:
                clf = classifier
                clf.fit(x_train, y_train)
                valid_prediction = clf.predict_proba(X_valid)[:, 1]

                fpr, tpr, thresholds = roc_curve(y_valid, valid_prediction)
                roc_auc = auc(fpr, tpr)
                
        return y_valid, valid_prediction
            

    def create_confusion_matrix(savefig = False, log_artifact = False):
        
        ConfusionMatrixDisplay.from_estimator(clf, X_valid, y_valid,
                                              #display_labels=['Non-compliant', 'Compliant'],
                                              cmap='magma')
        # Print Confusion Matrix
        plt.title('Confusion Matrix')
        plt.axis("off")
        filename = f'./{model_name}_validation_confusion_matrix.png'
        
        if savefig:
            plt.savefig(filename)
        
        if log_artifact:
            mlflow.log_artifact(filename)

In [50]:
test_model = skmodel_experiment('1234', 'log_reg')

In [51]:
print(test_model)

Mlflow Experiment ID: 1234 (log_reg)
Testing
