In [23]:
import gradio as gr
from PIL import Image as im_lib
from PIL import ImageFilter
from random import randint
import numpy as np
from keras import models
from keras import layers



def crop_image(image_array): #Gradio takes image, and automatically converts into an ndarray.
    #This function will return a value which indicates whether the crop succeeded or not
    #and why it failed, if it did. It will also return the cropped image if succeeded, and the original one otherwise.
    #-1 indicates failure because of grayscale or non-RGB.
    #-2 indicates failure because image is too small to crop
    # 1 indicates success.
    if len(image_array.shape)!=3: #This is a grayscale image.
        return -1, image_array
    x_dim, y_dim = image_array.shape[0], image_array.shape[1]
    if x_dim < 256 or y_dim < 256:
        return -2, image_array
    image = im_lib.fromarray(image_array)
    left, upper = randint(0, x_dim-256), randint(0, y_dim-256)
    right, lower = 256+left, 256+upper
    image = image.crop((left,upper,right,lower))
    return 1, np.asarray(image)

def HPF_filter(image):
    return im_lib.fromarray(np.asarray(image)-np.asarray(image.filter(ImageFilter.GaussianBlur)))

def final_image_array(image_array):
    cropped_key, image_array = crop_image(image_array)
    if cropped_key == 1:
        image_array = np.asarray(HPF_filter(im_lib.fromarray(image_array)))
    return image_array

def create_model(model_weights_file = "single_channel_model_best_val_real_precision_weights"):
    try:
        recalled_model
    except NameError:
        pass
    else:
        del recalled_model
    recalled_model = models.Sequential()
    recalled_model.add(layers.Conv2D( 32, (3,3), activation='relu', input_shape=(256,256,3,)))
    recalled_model.add( layers.MaxPooling2D( (2,2), strides = 2 ) )
    recalled_model.add( layers.Conv2D(64, (3,3), activation='relu'))
    recalled_model.add( layers.MaxPooling2D( (2,2), strides=2) )
    recalled_model.add( layers.Flatten() )
    recalled_model.add(layers.Dense(64, activation='relu'))
    recalled_model.add(layers.Dense(14, activation='softmax'))
    recalled_model.load_weights(model_weights_file)
    recalled_model.compile(optimizer='adam',
                  loss='categorical_crossentropy')
    return recalled_model

model_test = create_model()

    
def real_or_not(image_array, model = model_test):
    cropped_key, cropped_image_array = crop_image(image_array)
    if cropped_key == -1:
        return "This image cannot be processed as it is not RGB."
    elif cropped_key == -2:
        return "This image is too small to be processed."
    else:
        image_array = final_image_array(cropped_image_array)
        prediction_list = model.predict(image_array.reshape(1,256,256,3))
        if np.argmax(prediction_list) == len(prediction_list[0]) - 1:
            return "This image is probably real."
        else:
            return "This image is probably fake."
        
with gr.Blocks() as webpage:
    gr.HTML("<center> <b> Using this webpage, determine if your image is AI-generated. </b> </center></br>")
    image_input = gr.Image()
    text_output = gr.Textbox(interactive = False, value = "This is where I'll display whether your image is real", 
                            label = "Answer: ")
    image_button = gr.Button("Submit")
    image_button.click(real_or_not, inputs = image_input, outputs = text_output)
    
webpage.launch()

Running on local URL:  http://127.0.0.1:7880

To create a public link, set `share=True` in `launch()`.




