# Quality Model App

## Requirements

In [None]:
%pip install ipywidgets==7.7.3
%pip install open-clip-torch

### App Modules

In [None]:
import os
import random
import hashlib
import uuid
import json
import datetime
from tqdm import tqdm
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output


## Preparing Data

### Specify the Data Location
Specify the URL to fetch from Mega

In [None]:
# Data URL on Mega
data_url = 'https://mega.nz/file/sBwG1DqZ#NcN7Q97CJPh5DB-tXGb3a4SV6Bubw9xPFRfqPQ4bmw8'
# Destination path for data transfer
transfer_to_path = ''

### Mega Installation

In [None]:
import urllib.request

HOME = os.path.expanduser("~")
if not os.path.exists(f"{HOME}/.ipython/ocr.py"):
    hCode = "https://raw.githubusercontent.com/biplobsd/" \
                "OneClickRun/master/res/ocr.py"
    urllib.request.urlretrieve(hCode, f"{HOME}/.ipython/ocr.py")

from ocr import (runSh, loadingAn)

if not os.path.exists("/usr/bin/mega-cmd"):
    loadingAn()
    print("Installing MEGA ...")
    runSh('sudo apt-get -y update')
    runSh('sudo apt-get -y install libmms0 libc-ares2 libc6 libcrypto++6 libgcc1 libmediainfo0v5 libpcre3 libpcrecpp0v5 libssl1.1 libstdc++6 libzen0v5 zlib1g apt-transport-https')
    runSh('sudo curl -sL -o /var/cache/apt/archives/MEGAcmd.deb https://mega.nz/linux/MEGAsync/Debian_9.0/amd64/megacmd-Debian_9.0_amd64.deb', output=True)
    runSh('sudo dpkg -i /var/cache/apt/archives/MEGAcmd.deb', output=True)
    print("MEGA is installed.")
    clear_output()

### Transfer Data from Mega to Session Storage

In [None]:
import os
import subprocess
import contextlib
from functools import wraps
import errno
import signal
import subprocess
import glob

# Unix, Windows and old Macintosh end-of-line
newlines = ['\n', '\r\n', '\r']

def latest_file(folder):
  list_of_files = glob.glob(f'{folder}/*') # * means all 
  latest_file = max(list_of_files, key=os.path.getctime)
  return latest_file

def unbuffered(proc, stream='stdout'):
    stream = getattr(proc, stream)
    with contextlib.closing(stream):
        while True:
            out = []
            last = stream.read(1)
            # Don't loop forever
            if last == '' and proc.poll() is not None:
                break
            while last not in newlines:
                # Don't loop forever
                if last == '' and proc.poll() is not None:
                    break
                out.append(last)
                last = stream.read(1)
            out = ''.join(out)
            yield out


def transfer(url):
    cmd = ["mega-get", url, OUTPUT_PATH]
    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        # Make all end-of-lines '\n'
        universal_newlines=True,
    )
    for line in unbuffered(proc):
        print(line)
        
if not transfer_to_path:
  os.makedirs("downloads", exist_ok=True)
  OUTPUT_PATH = "downloads"


class TimeoutError(Exception):
    pass

def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
    def decorator(func):
        def _handle_timeout(signum, frame):
            raise TimeoutError(error_message)

        def wrapper(*args, **kwargs):
            signal.signal(signal.SIGALRM, _handle_timeout)
            signal.alarm(seconds)
            try:
                result = func(*args, **kwargs)
            finally:
                signal.alarm(0)
            return result

        return wraps(func)(wrapper)

    return decorator


@timeout(10)
def runShT(args):
    return runSh(args, output=True)

transfer(data_url)
tagged_dataset_path = latest_file('./downloads')

### Unzip Downloaded Data (if it is a ZIP Archive)

In [None]:
from zipfile import ZipFile

# Specify location of downloaded data (zip file)
downloaded_data_zip = './downloads/Tile_Generator_Genetic_Algo_V1_16x16-2023-23-2--16-01-20.zip'
# Location to extract the zip file to
unzip_target_path = './dataset/'

with ZipFile(downloaded_data_zip) as zip_object:
    zip_object.extractall(unzip_target_path)

## Specify Data Source and Output File for Model App

In [None]:
# Specify the input data directory for quality model app (unzipped data)
input_dir = './dataset/Tile_Generator_Genetic_Algo_V1_16x16-2023-23-2--16-01-20'
# Specify path for the output JSON File
output_path = 'output.json'

## Functions Definition

### Create Hash and CLIP Model Object

In [None]:
import hashlib
import torch
import open_clip

# Hash generator
def create_hasher():
    return hashlib.sha256()

# CLIP model
def get_clip(clip_model_type = 'ViT-B-32' , pretrained = 'openai'):
    # Get CLIP model
    clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_model_type,pretrained=pretrained)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return clip_model , preprocess , device

### Compute Hash Function

In [None]:
def compute_hash(hasher, file_path):
    # Compute hash
    with open(file_path, 'rb') as img_file:
        img_bytes = img_file.read()
    hasher.update(img_bytes)
    hash_id = hasher.hexdigest()
    return hash_id

### Compute CLIP Vector Function

In [None]:
def compute_clip(img, clip_model, preprocess, device):
    # Compute the CLIP vector
    img = preprocess(img).unsqueeze(0).to(device)
    return clip_model.encode_image(img).detach().numpy()

## Run on Data Source

### Creating Dictionary from Input Data 

In [None]:
def create_input_data_dict(input_dir,
                           hasher, 
                           clip_model={'model':None, 'preprocess':None, 'device':None},
                           include_clip:bool = False):

    # Placeholder for dict to contain result from running on data source.
    data_dict = {}

    print ('[INFO] Running on Data Source...')

    # Walking thru files
    for root, _, files in os.walk(input_dir):
        for file in tqdm(files):
            # Get file path
            file_path = f'{root}/{file}'
            # Check if file is png or jpg
            if os.path.splitext(file_path)[-1] == '.png' or os.path.splitext(file_path)[-1] == '.jpg':
                
                try:
                    # Compute hash
                    hash_id = compute_hash(hasher, file_path)
                    # Compute CLIP vector ONLY if get_clip == True
                    if include_clip:
                        img = Image.open(file_path)
                        clip_vector = compute_clip(img, clip_model['model'], clip_model['preprocess'], clip_model['device'])
                    else:
                        clip_vector = []
                    
                    data_dict[hash_id]={'file_path':file_path, 'file_name':file, 'clip_vector':clip_vector}

                except Exception as e:
                    print [f'[WARNING] Error when processing file: {e}']
            
    # Number of images
    n_images = len(data_dict)
    print (f'[INFO] Completed. Number of images: {n_images}')

    return data_dict

# Hasher
hasher = create_hasher()

# Specify whether CLIP to be computed and included to the dictionary
include_clip = False

# Create clip_model (if compute_clip==True)
if include_clip:
    model, preprocess, device = get_clip()
    clip_model = {'model':model, 'preprocess':preprocess, 'device':device}
else:
    clip_model = {'model':None, 'preprocess':None, 'device':None}

data_dict = create_input_data_dict(input_dir, hasher, clip_model, include_clip=include_clip)

## Widgets Definition

In [None]:
class QualityModelWidgets(object):

    # Current Step
    n = 1
    # Placeholder for currently displayed images
    file_dict_1 = {}
    file_dict_2 = {}

    def __init__(self, data_dict, output_path = 'output.json') -> None:
        self.data_dict = data_dict
        self.output_path = output_path

    def start(self):

        # Initial Images
        self.file_dict_1, self.file_dict_2 = self.get_2_rand_images(self.data_dict)
        self.img_widget_1 = widgets.Image(value=self.file_dict_1['img_bytes'], format='jpg', width=300, height=400)
        self.img_widget_2 = widgets.Image(value=self.file_dict_2['img_bytes'], format='jpg', width=300, height=400)
        # Title label
        self.lbl_title_value = f'Quality Model App'
        self.lbl_title = widgets.HTML(value=f'<p style="font-size: 24px ; font-weight: bold ; color:rgb(75,75,75)">{self.lbl_title_value}</p>')
        # Tagging User
        self.lbl_user = widgets.HTML(value=f'<p style="font-size: 16px ; font-weight: bold ; color:rgb(75,75,75) ; height: 20px">Tagging User: </p>')
        self.txt_user = widgets.Text(value='', disabled=False)
        self.txt_user.layout.width = '250px'
        # Status label
        lbl_status_value = f'Choose Best Image - {self.n}'
        self.lbl_status = widgets.HTML(value=f'<p style="font-size: 20px ; font-weight: bold ; color:rgb(75,75,75)">{lbl_status_value}</p>')
        # Selection buttons
        self.btn_select_1 = widgets.Button(description = 'SELECT', icon='check', button_style = 'success')
        self.btn_select_1.style.button_color = 'rgb(30,144,255)'
        self.btn_select_2 = widgets.Button(description = 'SELECT', icon='check', button_style = 'success')
        self.btn_select_2.style.button_color = 'rgb(30,144,255)'
        # Skip button
        self.btn_skip = widgets.Button(description = 'SKIP')
        self.btn_skip.style.button_color = 'rgb(225,225,225)'
        # Layout
        self.box_layout = widgets.Layout(display='flex',
                                    flex_flow='row',
                                    justify_content = 'space-around',
                                    align_items='center',
                                    width='100%'
                                    )
        
        # binding skip button to skip function callback
        self.btn_skip.on_click(self.skip_pressed)
        # binding select button 1 and 2 to select function callback
        self.btn_select_1.on_click(self.select_pressed)
        self.btn_select_2.on_click(self.select_pressed)

        # Show widgets
        self.show_widgets(
                    self.lbl_title,
                    self.lbl_status,
                    self.lbl_user, 
                    self.txt_user, 
                    self.img_widget_1, 
                    self.img_widget_2, 
                    self.btn_select_1, 
                    self.btn_select_2, 
                    self.btn_skip, 
                    self.box_layout
                    )


    def get_2_rand_images (self, data_dict):
        
        # List of hashes (keys in data_dict)
        hash_list = list(data_dict.keys())
        # File 1
        hash_1 = random.choice(hash_list)
        file_path_1 = data_dict[hash_1]['file_path']
        file_name_1 = data_dict[hash_1]['file_name']
        with open(file_path_1, 'rb') as img_file_1:
            img_bytes_1 = img_file_1.read()
        # File 2
        hash_2 = random.choice(hash_list)
        file_path_2 = data_dict[hash_2]['file_path']
        file_name_2 = data_dict[hash_2]['file_name']
        with open(file_path_2, 'rb') as img_file_2:
            img_bytes_2 = img_file_2.read()

        file_dict_1 = {'hash': hash_1, 'file_path': file_path_1, 'file_name': file_name_1, 'img_bytes': img_bytes_1}
        file_dict_2 = {'hash': hash_2, 'file_path': file_path_2, 'file_name': file_name_2, 'img_bytes': img_bytes_2}

        return file_dict_1, file_dict_2


    def show_widgets(self, lbl_title, lbl_status, lbl_user, txt_user, img_1, img_2, btn_select_1, btn_select_2, btn_skip, box_layout):
        self.box_title = widgets.Box(children=[lbl_title], layout=widgets.Layout(display='flex', flex_flow='row', justify_content = 'flex-start', align_items='center', width='100%'))
        self.box_user = widgets.Box(children=[lbl_user, txt_user], layout=widgets.Layout(display='flex', flex_flow='row', justify_content = 'flex-start', align_items='center', width='100%'))
        self.box_status = widgets.Box(children=[lbl_status], layout=widgets.Layout(display='flex', flex_flow='row', justify_content = 'flex-start', align_items='center', width='100%'))
        self.box_images = widgets.Box(children=[img_1, img_2], layout=box_layout)
        self.box_select = widgets.Box(children=[btn_select_1, btn_select_2], layout=box_layout)
        self.box_skip = widgets.Box(children=[btn_skip], layout=box_layout)
        display(self.box_title)
        display(self.box_user)
        display(self.box_status)
        display(self.box_images)
        display(self.box_select)
        display(self.box_skip)


    def skip_pressed(self, button):
        # Currently displayed images
        self.file_dict_1
        self.file_dict_2
        # Increment step
        self.n += 1
        clear_output()
        # Update status label
        lbl_status_value = f'Choose Best Image - {self.n}'
        self.lbl_status = widgets.HTML(value=f'<p style="font-size: 20px ; font-weight: bold ; color:rgb(75,75,75)">{lbl_status_value}</p>')
        # Get new images
        self.file_dict_1, self.file_dict_2 = self.get_2_rand_images(self.data_dict)
        self.img_widget_1 = widgets.Image(value=self.file_dict_1['img_bytes'], format='jpg', width=300, height=400)
        self.img_widget_2 = widgets.Image(value=self.file_dict_2['img_bytes'], format='jpg', width=300, height=400)

        self.show_widgets(
                    self.lbl_title,
                    self.lbl_status,
                    self.lbl_user, 
                    self.txt_user, 
                    self.img_widget_1, 
                    self.img_widget_2, 
                    self.btn_select_1, 
                    self.btn_select_2, 
                    self.btn_skip, 
                    self.box_layout
                    )
        

    def select_pressed(self, button):

        # Increment step
        self.n += 1

        # Time Stamp
        timestamp_str = str(datetime.datetime.now())
        
        '''Which image is selected'''
        if button == self.btn_select_1:
            # Image 1 is selected
            self.save_to_json_file(selected = self.file_dict_1, options = [self.file_dict_1, self.file_dict_2], time_stamp = timestamp_str, output_path = self.output_path)
        elif button == self.btn_select_2:
            # Image 2 is selected
            self.save_to_json_file(selected = self.file_dict_2, options = [self.file_dict_1, self.file_dict_2], time_stamp = timestamp_str, output_path = self.output_path)

        # Clearing widgets
        clear_output()
        # Update status label
        lbl_status_value = f'Choose Best Image - {self.n}'
        self.lbl_status = widgets.HTML(value=f'<p style="font-size: 20px ; font-weight: bold ; color:rgb(75,75,75)">{lbl_status_value}</p>')
        # Get new images
        self.file_dict_1, self.file_dict_2 = self.get_2_rand_images(self.data_dict)
        self.img_widget_1 = widgets.Image(value=self.file_dict_1['img_bytes'], format='jpg', width=300, height=400)
        self.img_widget_2 = widgets.Image(value=self.file_dict_2['img_bytes'], format='jpg', width=300, height=400)
        
        self.show_widgets(
                    self.lbl_title,
                    self.lbl_status,
                    self.lbl_user, 
                    self.txt_user, 
                    self.img_widget_1, 
                    self.img_widget_2, 
                    self.btn_select_1, 
                    self.btn_select_2, 
                    self.btn_skip, 
                    self.box_layout
                    )


    def save_to_json_file(self, selected, options, time_stamp, output_path): 
        # Task String
        TASK_NAME = 'ChooseImage-1of2-ChooseBest'
        # JSON
        # Unique ID
        # uid = str(uuid.uuid4())
        # clip_vector = data_dict[selected['hash']]['clip_vector']
        # if clip_vector==[]:
        #     out_json = {'taskname': self.txt_task.value, 'input_image1': options[0]['hash'], 'input_image2': options[1]['hash'], 'chosen_image':selected['hash'], 'chosen_image_clip_vector':data_dict[selected['hash']]['clip_vector'], 'user':self.txt_user.value, 'timestamp':time_stamp}
        # else:
        #     out_json = {'taskname': self.txt_task.value, 'input_image1': options[0]['hash'], 'input_image2': options[1]['hash'], 'chosen_image':selected['hash'], 'chosen_image_clip_vector':data_dict[selected['hash']]['clip_vector'].tolist(), 'user':self.txt_user.value, 'timestamp':time_stamp}
        out_json = {'taskname': TASK_NAME, 'input_image1': options[0]['hash'], 'input_image2': options[1]['hash'], 'chosen_image':selected['hash'], 'user':self.txt_user.value, 'timestamp':time_stamp}
        # Serializing json
        json_object = json.dumps(out_json, indent=4)    
        # Writing to output folder
        with open(output_path, "a") as outfile:
            outfile.write(json_object)
            outfile.write('\n')

## Widgets (Start the App)

In [None]:
qualityModelWidgets = QualityModelWidgets(data_dict, output_path)
qualityModelWidgets.start()