In [None]:
%pip install gradio
!rm /content/sample_data -rf

In [None]:
import logging
import sqlite3

import gdown
from PIL import Image
from torch.utils.data import Dataset
import os
import torch.nn as nn
from torchvision.models import resnet101, ResNet101_Weights
from PIL import ImageFile
import torch.optim as optim
from torch.optim import lr_scheduler
import torch
from torchvision import transforms
import pandas as pd

ImageFile.LOAD_TRUNCATED_IMAGES = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLASSES = 10


class CollectionsDataset(Dataset):
    def __init__(self,
                 csv_file,
                 root_dir,
                 num_classes,
                 transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.num_classes = num_classes

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.data.loc[idx, 'image_path'])
        image = Image.open(img_name).convert('RGB')
        label_tensor = torch.zeros(self.num_classes)
        label_tensor[self.data.loc[idx, 'rating'] - 1] = 1

        if self.transform:
            image = self.transform(image)

        return {'image': image,
                'labels': label_tensor
                }


def train_model(model,
                data_loader,
                dataset_size,
                optimizer,
                scheduler,
                num_epochs):
    for param in model.parameters():
        param.requires_grad = True
    criterion = nn.BCEWithLogitsLoss()
    for epoch in range(num_epochs):
        logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
        logging.info('-' * 10)

        model.train()
        running_loss = 0.0
        # Iterate over data.
        for bi, d in enumerate(data_loader):
            inputs = d["image"]
            labels = d["labels"]
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / dataset_size
        scheduler.step()
        logging.info('Loss: {:.4f}'.format(epoch_loss))
    return model


def train_from_resnet101_with_dataset():
    model = get_model_definition()

    # define some re-usable stuff
    IMAGE_SIZE = 512
    BATCH_SIZE = 9
    IMG_MEAN = [0.485, 0.456, 0.406]
    IMG_STD = [0.229, 0.224, 0.225]

    # make some augmentations on training data
    train_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(IMG_MEAN, IMG_STD)
    ])

    # use the collections dataset class we created earlier
    CSV_FILE_TRAIN = r"dataset\balanced\data.csv"
    ROOT_DIR_TRAIN = r"dataset\balanced"

    train_dataset = CollectionsDataset(CSV_FILE_TRAIN, ROOT_DIR_TRAIN, NUM_CLASSES, train_transform)

    # create the pytorch data loader
    train_dataset_loader = torch.utils.data.DataLoader(train_dataset,
                                                       batch_size=BATCH_SIZE,
                                                       shuffle=True)
    # push model to device
    model = model.to(device)

    plist = [
        {'params': model.layer4.parameters(), 'lr': 1e-5},
        {'params': model.fc.parameters(), 'lr': 5e-3}
    ]
    optimizer_ft = optim.Adam(plist, lr=0.001)
    lr_sch = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

    model = train_model(model,
                        train_dataset_loader,
                        len(train_dataset),
                        optimizer_ft,
                        lr_sch,
                        num_epochs=5)
    return model


def train_model_from_ram_with_flagged(model, transforms, num_epochs=5, batch_size=9):
    sql_to_csv()
    CSV_FILE_TRAIN = os.path.join("flagged", "data.csv")
    ROOT_DIR_TRAIN = ""
    train_dataset = CollectionsDataset(CSV_FILE_TRAIN, ROOT_DIR_TRAIN, NUM_CLASSES, transforms)
    train_dataset_loader = torch.utils.data.DataLoader(train_dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True)

    plist = [
        {'params': model.layer4.parameters(), 'lr': 1e-6},
        {'params': model.fc.parameters(), 'lr': 5e-4}
    ]
    optimizer_ft = optim.Adam(plist, lr=0.0001)
    lr_sch = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

    model = train_model(model,
                        train_dataset_loader,
                        len(train_dataset),
                        optimizer_ft,
                        lr_sch,
                        num_epochs=num_epochs)
    for param in model.parameters():
        param.requires_grad = False
    model.eval()

    return model


def train_model_from_my_pretrained(transforms, num_epochs=5, batch_size=9):
    if not os.path.exists("model_pretrained.bin"):
        logging.info("Missing model, downloading")
        url = 'https://drive.google.com/uc?id=1uyIYjcPRg6TwIrLpa_p9bmCALtAax79F'
        output = 'model_pretrained.bin'
        gdown.download(url, output, quiet=False)
        logging.info("Model downloaded")
    model = get_model_definition()
    model.load_state_dict(torch.load("model_pretrained.bin"))
    model = model.to(device)
    sql_to_csv()
    CSV_FILE_TRAIN = os.path.join("flagged", "data.csv")
    ROOT_DIR_TRAIN = ""
    train_dataset = CollectionsDataset(CSV_FILE_TRAIN, ROOT_DIR_TRAIN, NUM_CLASSES, transforms)
    train_dataset_loader = torch.utils.data.DataLoader(train_dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True)

    plist = [
        {'params': model.layer4.parameters(), 'lr': 1e-6},
        {'params': model.fc.parameters(), 'lr': 5e-4}
    ]
    optimizer_ft = optim.Adam(plist, lr=0.0001)
    lr_sch = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

    model = train_model(model,
                        train_dataset_loader,
                        len(train_dataset),
                        optimizer_ft,
                        lr_sch,
                        num_epochs=num_epochs)
    for param in model.parameters():
        param.requires_grad = False
    model.eval()

    return model


def get_model_definition():
    model = resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)
    model.avg_pool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Sequential(
        nn.BatchNorm1d(2048),
        nn.Dropout(p=0.25),
        nn.Linear(in_features=2048, out_features=2048),
        nn.ReLU(),
        nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=2048, out_features=10),
    )
    return model


def sql_to_csv():
    conn = sqlite3.connect(os.path.join("flagged", "data.db"), isolation_level=None,
                           detect_types=sqlite3.PARSE_COLNAMES)
    db_df = pd.read_sql_query("SELECT * FROM images", conn)
    db_df.to_csv(os.path.join("flagged", "data.csv"), index=False)



In [None]:
import base64
from collections import UserDict
from io import BytesIO

import gradio as gr
import requests


# ENABLE LOGGING--------------------------------------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("debug.log"),
        logging.StreamHandler()
    ],
    force=True
)
# ----------------------------------------------------------------------------------------------------------------------

# CREATE DATABASE------------------------------------------------------------------------------------------------------
if not os.path.exists("flagged"):
    os.mkdir("flagged")
conn = sqlite3.connect(os.path.join("flagged", "data.db"))
c = conn.cursor()
c.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='images' ''')
if not c.fetchone()[0] == 1:
    c.execute('''CREATE TABLE images
             (id SERIAL PRIMARY KEY, image_path TEXT, rating INTEGER)''')
    conn.commit()
    conn.close()
# ----------------------------------------------------------------------------------------------------------------------

# dowlonad model if it is missing----------------------------------------------------------------------------------------

if not os.path.exists("model.bin"):
    logging.info("Missing model, downloading")
    url = 'https://drive.google.com/uc?id=1uyIYjcPRg6TwIrLpa_p9bmCALtAax79F'
    output = 'model.bin'
    gdown.download(url, output, quiet=False)
    logging.info("Model downloaded")
# ----------------------------------------------------------------------------------------------------------------------

# DEFINE MODEl----------------------------------------------------------------------------------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = get_model_definition()
model.load_state_dict(torch.load("model.bin"))

IMAGE_SIZE = 512
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225]
my_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(IMG_MEAN, IMG_STD)
])
model = model.to(device)
for param in model.parameters():
    param.requires_grad = False

model.eval()


# ----------------------------------------------------------------------------------------------------------------------

# CREATE AND DIFINE BUFFE-----------------------------------------------------------------------------------------------
class Buffer(UserDict):
    def __init__(self, max_size=10, *args, **kwargs):
        self.max_size = max_size
        super().__init__(*args, **kwargs)

    def __setitem__(self, key, value):
        if len(self) >= self.max_size:
            oldest_key = next(iter(self))
            del self[oldest_key]
        super().__setitem__(key, value)


buffer = Buffer(1000)
# ----------------------------------------------------------------------------------------------------------------------

headers = {
    "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36"}
with gr.Blocks() as app:
    gr.Markdown("Image scorer backend.")
    with gr.Tab("I am here for api only, pls don't touch me", visible=True):
        # this is for image adding for database
        img_input = gr.Textbox()
        score_input = gr.Textbox()
        button = gr.Button('Click me!')


        def add_image(inp, score=None):
            logging.info(f"Got image with src {inp} and score {score}")
            conn = sqlite3.connect(os.path.join("flagged", "data.db"))
            c = conn.cursor()
            c.execute("SELECT COUNT(*) FROM images")

            # Fetch the result of the query
            row_count = c.fetchone()[0]

            req = requests.get(inp, headers=headers)
            img = Image.open(BytesIO(req.content))
            img.save(os.path.join("flagged", str(row_count) + ".png"))
            c.execute("INSERT INTO images (image_path, rating) VALUES (?, ?)",
                      (os.path.join("flagged", str(row_count) + ".png"), score))
            conn.commit()
            conn.close()
            logging.info("Image added")
            return "readed"


        button.click(add_image, inputs=[img_input, score_input])

        # this is for connection test
        txt_output = gr.Textbox()
        button = gr.Button('Connection test')


        def return_ok():
            logging.info("Connected")
            return "connected"


        button.click(return_ok, outputs=txt_output)

        # this is for model predict
        img_input_predict = gr.Textbox()
        score_output = gr.Textbox()
        button = gr.Button('Predict')


        def predict(inp):
            if not str(inp):
                logging.info("Got empty string, returning -1")
                return "-1"
            if inp in buffer:
                logging.info(f"Using buffer, returned {buffer[inp]}")
                return buffer[inp]
            try:
                logging.info(f"Got: {inp}")
                if str(inp).find("http") != -1:
                    logging.info("Requesting image")
                    req = requests.get(inp, headers=headers)
                    logging.info("Got image")
                    img = Image.open(BytesIO(req.content)).convert('RGB')
                if str(inp).find("base64") != -1:
                    logging.info("Got base64, decoding")
                    image_bytes = base64.b64decode(inp.split(',')[1])
                    img = Image.open(BytesIO(image_bytes)).convert('RGB')
                logging.info("Transforming image")
                img = my_transforms(img).to(device).unsqueeze(0)
                logging.info("Predicting")
                ans = str(torch.argmax(model(img).sigmoid()).item() + 1)
                logging.info(f"Responsed with: {ans}")
                buffer[inp] = ans
                torch.cuda.empty_cache()
                return ans
            except Exception:
                logging.error("Error: " + str(Exception), exc_info=True)
                logging.error(f"Response status was {req.status_code}")
                return "-1"


        button.click(predict, inputs=img_input_predict, outputs=score_output)

    with gr.Tab("Train tab", visible=True):
        button = gr.Button('Train model from ram')


        def standart_eval_prep():
            for param in model.parameters():
                param.requires_grad = False
            model.eval()
            torch.cuda.empty_cache()
            buffer.clear()


        def train_with_flagged():
            try:
                train_model_from_ram_with_flagged(model, my_transforms, 15)
                standart_eval_prep()
            except:
                standart_eval_prep()


        button.click(train_with_flagged)

        button = gr.Button('Train model from pretrained')


        def train_from_pretrained():
            try:
                train_model_from_my_pretrained(my_transforms, 15)
                standart_eval_prep()
            except:
                standart_eval_prep()


        button.click(train_from_pretrained)

app.queue(concurrency_count=1, max_size=60)
app.launch(max_threads=1, share=True)



while True:
    pass