# Web App Testing

Try using `gradio`.

In [2]:
# Run this cell to install gradio, only if it is missing from the venv. The requirements file has been updated to reflect the changes.
#pip install gradio

Collecting gradio
  Downloading gradio-4.21.0-py3-none-any.whl.metadata (15 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting altair<6.0,>=4.2.0 (from gradio)
  Downloading altair-5.2.0-py3-none-any.whl.metadata (8.7 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.110.0-py3-none-any.whl.metadata (25 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.3.2.tar.gz (5.5 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Installing backend dependencies: started
  Installing backend dependencies: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting gradio-client==0.12.0 (from gradio)
  Downloading gradio_client-0.12.0-py3-none-any.wh

## Workflow
**Inputs**
1. Allow user to upload image of artwork.
2. Allow user to select if segmentation will be applied (`True`,`False`)
3. Allow user to select region of interest by cropping (Need to decide if automatically prompts)
4. Flagging?

**Internal**

1. Process user image to fit input required by model (zero-padding resize to 224x224 with optional segmentation into 3x3 grid)
2. Input to model (maybe allow model selection in the future?)
3. Outputs: Style Class, Genre Class, Caption, (Image)

**Output**

1. Display outputs
2. Allow for feedback

## Basic Webapp
Test app with basic functionality and integrations.

In [12]:
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms

# Define a random model for testing purposes
class RandomModel(nn.Module):
    def __init__(self, num_classes):
        super(RandomModel, self).__init__()
        self.num_classes = num_classes
        self.fc = nn.Linear(224 * 224 * 3, num_classes)  # Assuming input size is 224x224 and 3 channels

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Define a list of genre labels for testing purposes
classes = ['Realism', 'Baroque', 'Post_Impressionism', 'Impressionism',
       'Romanticism', 'Art_Nouveau', 'Northern_Renaissance', 'Symbolism',
       'Naive_Art_Primitivism', 'Expressionism', 'Cubism', 'Fauvism',
       'Analytical_Cubism', 'Abstract_Expressionism', 'Synthetic_Cubism',
       'Pointillism', 'Early_Renaissance', 'Color_Field_Painting',
       'New_Realism', 'Ukiyo_e', 'Rococo', 'High_Renaissance',
       'Mannerism_Late_Renaissance', 'Pop_Art', 'Contemporary_Realism',
       'Minimalism', 'Action_painting']

# Define the custom function to preprocess the image and get predictions
def preprocess_image(image):
    # Open image using PIL
    # Add zero padding to make the aspect ratio 1:1
    width, height = image.size
    max_dim = max(width, height)
    new_size = (max_dim, max_dim)
    padded_image = Image.new("RGB", new_size)
    padded_image.paste(image, ((max_dim - width) // 2, (max_dim - height) // 2))

    # Resize the image to 224x224
    resized_image = padded_image.resize((224, 224))

    # Convert PIL image to PyTorch tensor
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(resized_image).unsqueeze(0)

    return input_tensor

def predict_genre(image):
    # Preprocess the image
    input_tensor = preprocess_image(image['layers'][0])

    # Initialize the random model
    model = RandomModel(num_classes=len(classes))

    # Get prediction from model
    with torch.no_grad():
        model.eval()
        outputs = model(input_tensor)
        _, predicted = torch.max(outputs, 1)
        genre = classes[predicted.item()]

    return genre

# Define the Gradio interface
input_image = gr.ImageEditor(type='pil', image_mode='RGB', transforms='crop', eraser=False, brush=False)
output_genre = gr.Textbox()

test = gr.Interface(fn=predict_genre, inputs=input_image, outputs=output_genre, title='Art Style Prediction', 
             description='Upload an image (and select a region of interest)',allow_flagging='never')
test.launch()

Running on local URL:  http://127.0.0.1:7867

To create a public link, set `share=True` in `launch()`.




# TODO

Add custom blocks:
- Inputs
    - Segmentation Button/Checkbox (True/False)
    - Select Model?
- Outputs
    - Style Labels (sorted by highest likelihood)
    - Genre Labels (sorted by highest likelihood)
    - Caption
    - Display Image
- Feedback
    - Flagging (should this new image be added to improve model)
    - Satistifaction (quality of caption)
    - Corrections (allow users to submit correct class if wrong)