In [None]:
%pip install gradio
%pip install pretrainedmodels

In [None]:
import gradio as gr
import requests
import torch
from PIL import Image
from io import BytesIO
import sqlite3
import os
import torch.nn as nn
import pretrainedmodels as pm
from torchvision import transforms
import gdown
from collections import UserDict
import logging

# ENABLE LOGGING--------------------------------------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("debug.log"),
        logging.StreamHandler()
    ],
    force=True
)
# ----------------------------------------------------------------------------------------------------------------------

# CREATE DATABASE------------------------------------------------------------------------------------------------------
if not os.path.exists("flagged"):
    os.mkdir("flagged")
conn = sqlite3.connect(os.path.join("flagged", "data.db"))
c = conn.cursor()
c.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='images' ''')
if not c.fetchone()[0] == 1:
    c.execute('''CREATE TABLE images
             (id SERIAL PRIMARY KEY, image_path TEXT, rating INTEGER)''')
    conn.commit()
    conn.close()
# ----------------------------------------------------------------------------------------------------------------------

# dowlonad model if it is missing----------------------------------------------------------------------------------------

if not os.path.exists("model.bin"):
    logging.info("Missing model, downloading")
    url = 'https://drive.google.com/uc?id=1yUlB7UcED604ZZMXWhRgKo73pa630d8v'
    output = 'model.bin'
    gdown.download(url, output, quiet=False)
    logging.info("Model downloaded")
# ----------------------------------------------------------------------------------------------------------------------

# DEFINE MODEl----------------------------------------------------------------------------------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = pm.__dict__["resnet101"](pretrained='imagenet')

model.avg_pool = nn.AdaptiveAvgPool2d(1)
model.last_linear = nn.Sequential(
    nn.BatchNorm1d(2048),
    nn.Dropout(p=0.25),
    nn.Linear(in_features=2048, out_features=2048),
    nn.ReLU(),
    nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1),
    nn.Dropout(p=0.5),
    nn.Linear(in_features=2048, out_features=10),
)
model.load_state_dict(torch.load("model.bin"))

IMAGE_SIZE = 512
IMG_MEAN = model.mean
IMG_STD = model.std
my_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(IMG_MEAN, IMG_STD)
])
model = model.to(device)
for param in model.parameters():
    param.requires_grad = False

model.eval()


# ----------------------------------------------------------------------------------------------------------------------

# CREATE AND DIFINE BUFFE-----------------------------------------------------------------------------------------------
class Buffer(UserDict):
    def __init__(self, max_size=10, *args, **kwargs):
        self.max_size = max_size
        super().__init__(*args, **kwargs)

    def __setitem__(self, key, value):
        if len(self) >= self.max_size:
            oldest_key = next(iter(self))
            del self[oldest_key]
        super().__setitem__(key, value)


buffer = Buffer(1000)
# ----------------------------------------------------------------------------------------------------------------------

headers = {
    "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36"}
with gr.Blocks() as app:
    gr.Markdown("This is demo of gradio.")
    with gr.Column(visible=True):
        # this is for image adding for database
        img_input = gr.Textbox()
        score_input = gr.Textbox()
        button = gr.Button('Click me!')


        def add_image(inp, score=None):
            logging.info(f"Got image with src {inp} and score {score}")
            conn = sqlite3.connect(os.path.join("flagged", "data.db"))
            c = conn.cursor()
            c.execute("SELECT COUNT(*) FROM images")

            # Fetch the result of the query
            row_count = c.fetchone()[0]

            req = requests.get(inp, headers=headers)
            img = Image.open(BytesIO(req.content))
            img.save(os.path.join("flagged", str(row_count) + ".jpg"))
            c.execute("INSERT INTO images (image_path, rating) VALUES (?, ?)",
                      (os.path.join("flagged", str(row_count) + ".jpg"), score))
            conn.commit()
            conn.close()
            logging.info("Image added")
            return "readed"


        button.click(add_image, inputs=[img_input, score_input])

        # this is for connection test
        txt_output = gr.Textbox()
        button = gr.Button('Connection test')


        def return_ok(inp):
            logging.info("Connected")
            return "connected"


        button.click(return_ok, outputs=txt_output)

        # this is for model predict
        img_input_predict = gr.Textbox()
        score_output = gr.Textbox()
        button = gr.Button('Predict')


        def predict(inp):
            if inp in buffer:
                logging.info(f"Using buffer, returned {buffer[inp]}")
                return buffer[inp]
            try:
                logging.info(f"Got: {inp}")
                req = requests.get(inp, headers=headers)
                logging.info("Got image")
                img = Image.open(BytesIO(req.content)).convert('RGB')
                logging.info("Transforming image")
                img = my_transforms(img).to(device).unsqueeze(0)
                logging.info("Predicting")
                ans = str(torch.argmax(model(img).sigmoid()).item() + 1)
                logging.info(f"Responsed with: {ans}")
                buffer[inp] = ans
                torch.cuda.empty_cache()
                return ans
            except Exception:
                logging.error("Error: " + str(Exception), exc_info=True)
                logging.error(f"Response status was {req.status_code}")
                return "-1"


        button.click(predict, inputs=img_input_predict, outputs=score_output)
app.queue(concurrency_count=1, max_size=30)
app.launch(max_threads=1, share=True)

while True:
    pass