# the bare minimum viable product

In [None]:
import torch
import transformers
import gradio as gr

from onnxruntime import InferenceSession
from PIL import Image

from transformers import (
    pipeline,
    AutoFeatureExtractor, 
    DefaultDataCollator, 
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer
)

In [3]:
# image classifier
food_model_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path = 'stochastic/102722run') #this can also be a tokenizer
food_classification_model = AutoModelForImageClassification.from_pretrained("stochastic/102722run")

# image describer
flan_pipe = pipeline("text2text-generation", model='google/flan-t5-large',)

In [4]:
# helper 
def clean_up_answer(flan_answer):
    
    # replace extra periods 
    if flan_answer[-1] == '.':
        pass
    else:
        flan_answer = flan_answer + "."

    # get rid of underscores from answers
    clean_answer = flan_answer.replace("_", " ")

    return clean_answer

In [6]:
model = AutoModelForImageClassification.from_pretrained('stochastic/102722run')
extractor =  AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path = 'stochastic/102722run')

def classify_and_describe_image(user_input):
    """
    Take an image and describe how to eat it
    """

    session = InferenceSession("../../onnx/vit-model.onnx")
    output_name = session.get_outputs()[0].name
    input_name = session.get_inputs()[0].name


    inputs = food_model_extractor(Image.open(user_input).convert("RGB"), return_tensors="np") #onnx expects numpy
    outputs = session.run(output_names = [output_name], input_feed=dict(inputs))

    predicted_class_idx = outputs[0].argmax(-1).item()
    predicted_food = food_classification_model.config.id2label[predicted_class_idx]
    first_sentence = f"This looks like {predicted_food}! "

    prompts = {
        f"Answer the following question: How do you enjoy {predicted_food}?"                       : f'I recommend to ',
        f"Answer the following question: What are the flavors of {predicted_food}"                 : 'The flavors are similar to ',
        f"Answer the following question: What are the textures of {predicted_food}?"               : 'The textures are ',
        f"Answer the following question: What are the aromas of {predicted_food}?"                 : 'The armoas will be ',
        f"Answer the following question: How do you improve the flavor of {predicted_food}?"       : 'To enjoy, you can '
    }

    answers = []
    for question in prompts.keys():
        answer = flan_pipe(question, max_length = 100)[0]["generated_text"].lower()
        if answer[-1] == '.':
          pass
        else:
          answer = answer + "."
        rec = prompts[question] + answer
        answers.append(rec)
    return clean_up_answer(first_sentence + " ".join(answers))
    

gr.Interface(fn=classify_and_describe_image, 
             inputs=gr.Image(type = 'filepath', label = "Image"),
             outputs=[gr.Textbox(lines=3, label=f"Here are the ways you can enjoy this and improve it!")],
             title = "Upload food you haven't tried before",
).launch()

Running on local URL:  http://127.0.0.1:7866

To create a public link, set `share=True` in `launch()`.


(<gradio.routes.App at 0x2b048cb29d0>, 'http://127.0.0.1:7866/', None)