In [18]:
import os 
import projetLib as proj
import torch 
from math import sqrt,ceil
import numpy as np
from PIL import Image
from torchvision import transforms

resize = (224,224)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

modelsaveFolder = "./modelSave/"
folder_list = os.listdir(modelsaveFolder) 

def transformImg(img,resize,doRGB,doCrop):
    if doCrop : res = proj.image.Crop_img(resize,doRGB)(img)
    else : res = proj.image.Resize_img(resize,doRGB)(img)
    return res
    
def extract_img(filepath):
    with open(filepath, 'rb') as img_set:
        img_arr = list(img_set.read())
        sq   = ceil(sqrt(len(img_arr)))
        rest = (sq*sq)-len(img_arr)
        img_arr += [0]*rest
        img_arr = np.array(img_arr)
        img_arr = img_arr.astype('float32')
        img_arr = np.reshape(img_arr, (sq,sq))
        img_arr = Image.fromarray(img_arr.astype('uint8'), 'L')
        return img_arr

def predict(malwares,model_type,model_save,resize,doRGB,doCrop):
    modelpath = modelsaveFolder + model_save + "/"
    modelpath += os.listdir(modelpath)[-1]

    inputchannels = 1
    if doRGB: inputchannels = 3
    resize = int(resize)

    model = None
    if model_type == "Basic" : model = proj.model.Basic(inputchannels)
    elif model_type == "Resnet50" : model = proj.model.getCNNresnet(50,inputchannels)
    elif model_type == "Resnet101" : model = proj.model.getCNNresnet(101,inputchannels)
    elif model_type == "Resnet152" : model = proj.model.getCNNresnet(152,inputchannels)
    elif model_type == "VGG" : model = proj.model.VGG16(inputchannels)
    model.load_state_dict(torch.load(modelpath,map_location=device))
    model.eval()

    images = []
    tensors = []
    for file in malwares:
        img = extract_img(file.name)
        img = transformImg(img,(resize,resize),doRGB,doCrop)
        images.append(img)
        tensors.append(transforms.ToTensor()(img))
    tensors = torch.stack(tensors)

    y = model(tensors)
    dic = {}
    for i,malware in enumerate(malwares) :
        malname = malware.name
        if "\\" in malname : malname = malname.split("\\")[-1]
        if "/" in malname : malname = malname.split("/")[-1]
        dic[malname] = y[i].item()
    return images,dic

In [20]:
import gradio as gr

demo = gr.Interface(
    predict, 

    [gr.File(file_count="multiple",label="Files to analyse"),
     gr.Dropdown(["Resnet50","Resnet101","Resnet152","VGG","Basic"],label="Model Types"),
     gr.Dropdown(folder_list,label="Model Saved",value="resnet50_classic"),
     gr.Number(value="224",label="Resize value"),
     gr.Checkbox(label="RGB Image"),
     gr.Checkbox(label="Crop",value=True)], 

    [gr.Gallery(label="Malwares en Image"),
     gr.Label(label="Résultat : Malware = 1, Goodware = 0")]
    # examples=[[[os.path.join(os.path.dirname(__file__),"files/titanic.csv"), 
    # os.path.join(os.path.dirname(__file__),"files/titanic.csv"), 
    # os.path.join(os.path.dirname(__file__),"files/titanic.csv")]]], 
    # cache_examples=True
    )
    
if __name__ == "__main__":
    demo.launch()  

IMPORTANT: You are using gradio version 3.0.22, however version 3.14.0 is available, please upgrade.
--------
Running on local URL:  http://127.0.0.1:7866/

To create a public link, set `share=True` in `launch()`.
