# IMPORTS

In [1]:
%run notebook_setup.ipynb

In [None]:
import statsmodels.api as sm

from matplotlib import pyplot as plt
import seaborn as sns

import os
import time
from pathlib import Path

from urllib.request import urlretrieve

from joblib import Parallel, delayed
from joblib import memory

from typing import Tuple,Dict

import cv2
import image_similarity_measures
from image_similarity_measures.quality_metrics import fsim,issm,psnr,rmse,sam,sre,ssim,uiq 
# https://github.com/up42/image-similarity-measures
# https://up42.com/blog/tech/image-similarity-measures

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
#from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

# CLASS DEF

In [None]:
IMG                   = np.ndarray      # cv2.imread return type
IMG_SIMILARITY_DICT   = Dict[str,float] # k,v dict with k = measure and v = similarity metric
IMG_SIMILARITY_DF     = pd.DataFrame    # cols of similarity metrics
IMG_SIMILARITY_SERIES = pd.Series       # column from IMG_SIMILARITY_DF
NOTEBOOK_PLOTS        = None

class Dataset():

    # raw data
    csv = 'Sales Of Summer Cloths.csv'
    
    preselected_columns = [
        'title',
        'title_orig',
        'price',
        'retail_price',
        'units_sold',
        'rating',
        'rating_count',
        'rating_five_count',
        'rating_four_count',
        'rating_three_count',
        'rating_two_count',
        'rating_one_count',
        'tags',
        'product_color',
        'merchant_title',
        'merchant_name',
        'merchant_rating_count',
        'merchant_rating',
        'product_picture',
    ]
    

    def __init__(
        self,
        do_fsim = False, # Feature-based similarity index (FSIM)                                   # 0 to 1, 0 == bad, 1 == identical                      # ascending=False
        do_issm = False, # Information theoretic-based Statistic Similarity Measure (ISSM)
        do_psnr = True,  # Peak signal-to-noise ratio (PSNR)                                                                                               # ascending=False
        do_rmse = True,  # Root mean square error (RMSE)                                           # 0 to inf, 0 == identical, the smaller the better      # ascending=True
        do_sam  = True,  # Spectral angle mapper (SAM)                                                                                                     # ascending=False
        do_sre  = True,  # Signal to reconstruction error ratio (SRE)                                                                                      # ascending=False
        do_ssim = True,  # Structural Similarity Index (SSIM)                                      # -1 to 1, -1 == bad, 1 == good, the larger the better  # ascending=False
        do_uiq  = False, # Universal image quality index (UIQ)                                     # -1 to 1, -1 == bad, 1 == good, the larger the better  # ascending=False
        ) -> None :
        
        # setup directories
        self.cwd      = Path(os.getcwd())  #'C:/Users/ahkar/OneDrive/Documents/Delvify/')
        self.cachedir = self.cwd / 'cache' # cache dir path

        # differences
        self.do_fsim = do_fsim # Feature-based similarity index (FSIM)
        self.do_issm = do_issm # Information theoretic-based Statistic Similarity Measure (ISSM)
        self.do_psnr = do_psnr # Peak signal-to-noise ratio (PSNR)
        self.do_rmse = do_rmse # Root mean square error (RMSE)
        self.do_sam  = do_sam  # Spectral angle mapper (SAM)
        self.do_sre  = do_sre  # Signal to reconstruction error ratio (SRE)
        self.do_ssim = do_ssim # Structural Similarity Index (SSIM)
        self.do_uiq  = do_uiq  # Universal image quality index (UIQ)

        # read data
        self.raw     = self.read_csv() # simple read
        self.df      = self.clean_df()
        
        # precompute k means
        self.kmean_groups = self.precompute_kmeans()
        
    # simple read
    def read_csv(self) -> pd.DataFrame :
        return pd.read_csv(self.cwd / Dataset.csv) # read raw df
    
    # remove columns deemed to be useless + duplicate listings
    def clean_df(
        self,
        ) -> pd.DataFrame :
        df = self.raw.copy()
        df = df.loc[:,Dataset.preselected_columns] # subset to columns deemed useful
        df = df.drop_duplicates() # remove duplicate listings
        return df

    # extract df of numeric features + populate nans
    def df_num(
        self,
        should_impute_nans      : bool = True,
        should_standard_scale   : bool = True,
        #should_minmax_scale     : bool = True,
        #should_minmax_scale_abs : float = 2,
        should_pca              : bool = True,
        should_pca_components   : float = 6,
        ) -> pd.DataFrame :
        # strip
        df_num = self.df.select_dtypes(include=['int64','float64']).copy()
        
        # impute nans
        if should_impute_nans:
            imp = SimpleImputer(missing_values=np.nan, strategy='mean')
            imputed = imp.fit_transform(df_num)
            df_num = pd.DataFrame(imputed,index=df_num.index,columns=df_num.columns)
    
        # apply standard scaler
        if should_standard_scale:
            s_scaler = StandardScaler()
            scaled = s_scaler.fit_transform(df_num)
            df_num = pd.DataFrame(scaled,index=df_num.index,columns=df_num.columns)
            
        '''
        # apply min max scaler
        if should_minmax_scale:
            mm_scaler = MinMaxScaler(feature_range=(should_minmax_scale_abs*-1,should_minmax_scale_abs))
            scaled = mm_scaler.fit_transform(df_num)
            df_num = pd.DataFrame(scaled,index=df_num.index,columns=df_num.columns)
        ''' 
        # apply pca
        if should_pca:
            pca = PCA(n_components=should_pca_components)
            reduced = pca.fit_transform(df_num)
            df_num = pd.DataFrame(reduced)
            df_num.columns = ['ev'+str(x) for x in df_num.columns] # rename columns

        # return
        return df_num

    def precompute_kmeans(
        self,
        ) -> np.ndarray :
        # prep object              
        kmeans = KMeans(
            init         = "random",
            n_clusters   = 10, # seems use to residuals are pretty good already for 10 groups
            n_init       = 10,
            max_iter     = 300,
            random_state = 42
        )

        kmeans.fit(self.df_num()) # apply kmeans on scaled features cos kmeans metric sensitive to scale

        return kmeans.labels_ # predictions

    def url_to_cache_filepath(
        self,
        url : str, # = 'https://contestimg.wish.com/api/webimage/5e9ae51d43d6a96e303acdb0-medium.jpg',
        ) -> Path :
        return self.cachedir / Path(url).name

    # return multiple jpgs associated locs
    def get_product_pictures(
        self,
        locs       : List[int]      = None,  # 0 / index name to read
        ) -> IMG :
        '''
        d.get_product_pictures(locs=[1307,758,1183,1516])
        '''
        # gather imgs
        imgs = [self.get_product_picture(loc=loc,plot=False) for loc in locs]
        
        def ceildiv(a,b): return -(a // -b) # ceiling division without needing any imports

        # prep plotting device
        MAX_COLS   = 5
        PLOT_WIDTH = 5
        plot_cols = min(5,len(locs))
        plot_rows = ceildiv(len(locs),plot_cols) # max number of rows needed to plot 'locs' pictures with 'cols' pictures on each row
        fig,ax = plt.subplots(plot_rows,plot_cols,figsize=(PLOT_WIDTH * plot_cols,PLOT_WIDTH * plot_rows))

        # plot imgs
        curr_plotting_row = -1 # increments to 0 on first loop as increment condition is true
        for i,img in enumerate(imgs):
            # plotting column location
            curr_plotting_col = i%plot_cols
            # increment plotting row each time we hit the 'cols'th column
            if curr_plotting_col == 0:
                curr_plotting_row = curr_plotting_row + 1
                
            # plot
            if plot_rows == 1:
                ax[curr_plotting_col].imshow(img) # dont need 2nd dimension for ax if only 1 row
            else:
                ax[curr_plotting_row,curr_plotting_col].imshow(img) # need 2nd dimension for ax if more than 1 row

        # remove axes from ALL ax's
        if plot_rows == 1:
            for curr_plotting_col in range(plot_cols):
                ax[curr_plotting_col].set_axis_off()
        else:
            for curr_plotting_row in range(plot_rows):
                for curr_plotting_col in range(plot_cols):
                    ax[curr_plotting_row,curr_plotting_col].set_axis_off()


    # return jpg associated with url in product_picture
    def get_product_picture(
        self,
        url        : str            = None,  # 'https://contestimg.wish.com/api/webimage/5e9ae51d43d6a96e303acdb0-medium.jpg'
        loc        : int            = None,  # 0 / index name to read
        plot       : bool           = True,  # should plot final result
        grayscale  : bool           = False, # to avoid finding same item in different colour
        blur       : bool           = False, # to avoid finding same item with minor difference in listing
        blur_ksize : Tuple[int,int] = (5,5), # the larger the stronger the smoothing
        verbose    : int            = 0,     # give details
        ) -> IMG :
        '''
        d=Dataset() # instantiate
        d.get_product_picture(url='https://contestimg.wish.com/api/webimage/5e9ae51d43d6a96e303acdb0-medium.jpg')
        d.get_product_picture(loc=3)
        d.get_product_picture(loc=3,blur=False,grayscale=True)
        '''
        
        ####################################
        # ensure url is populated
        ####################################
        if loc is not None:
            url = d.df['product_picture'].loc[loc]
            
        ####################################
        # ensure local cache populated
        ####################################
        local_filepath = self.url_to_cache_filepath(url) # target location
        if local_filepath.exists():
            if verbose>0: print(f'read {local_filepath}')
        else:
            if not self.cachedir.exists(): os.mkdir(self.cachedir) # ensure cache dir exists
            urlretrieve(url,local_filepath) # populate local cache each time url / jpg is requested
            if verbose>0: print(f'cache {url}')
        
        ####################################
        # read img from cache
        ####################################
        # read raw file
        img = cv2.imread(str(local_filepath))
        
        # should apply grayscale?
        if grayscale:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # yes
            img = np.stack([img,img,img],axis=2)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # no, then just do default RGB
        
        # should apply blur?
        if blur:
            img = cv2.blur(img, blur_ksize) 

        # should plot?
        if plot: plt.imshow(img);plt.show()

        # return
        return img

    # force populate cache with all imgs from urls in df['product_picture']
    def populate_cache(self):
        '''
        d=Dataset() # instantiate
        d.populate_cache() # force read on all product_picture urls
        '''
        Parallel(n_jobs=-1)(delayed(self.get_product_picture)(url) for url in self.df['product_picture']) # batch download images

    # compute similarity metrics between some input image and some other image
    def img_similarity_pair(
        self,
        tgt_loc     : int,
        src_loc     : int            = None,  # one of src_loc or src_img need to be given
        src_img     : IMG            = None,  # one of src_loc or src_img need to be given
        plot_src    : bool           = True,  # should plot src img
        plot_tgt    : bool           = True,  # should plot tgt img
        grayscale   : bool           = False, # to avoid finding same item in different colour
        blur        : bool           = False, # to avoid finding same item with minor difference in listing
        blur_ksize  : Tuple[int,int] = (5,5), # the larger the stronger the smoothing
        ) -> IMG_SIMILARITY_DICT :
        '''
        d.img_similarity_pair(src_loc=3,tgt_loc=0)
        '''
        print(f'tgt_loc = {tgt_loc}')
        
        #####################################
        # srcimg
        #####################################
        # get srcimg
        if src_loc is not None:
            src_img = self.get_product_picture(
                loc        = src_loc,
                plot       = plot_src,
                grayscale  = grayscale,
                blur       = blur,
                blur_ksize = blur_ksize,
            )
        
        # srcimg scale
        scale_pct = 100 # percent of original img size
        src_img_width  = int(src_img.shape[1] * scale_pct / 100)
        src_img_height = int(src_img.shape[0] * scale_pct / 100)
        src_img_dim    = (src_img_width, src_img_height)

        #####################################
        # get other_img
        #####################################
        tgt_img = self.get_product_picture(
            loc        = tgt_loc,
            plot       = plot_tgt,
            grayscale  = grayscale,
            blur       = blur,
            blur_ksize = blur_ksize,
        )

        # resize to srcimg size
        resized_tgt_img = cv2.resize(tgt_img, src_img_dim, interpolation = cv2.INTER_AREA)

        #####################################
        # compute distances
        #####################################
        out = {}
        if self.do_fsim: out['fsim'] = fsim(src_img, resized_tgt_img)
        if self.do_issm: out['issm'] = issm(src_img, resized_tgt_img)
        if self.do_psnr: out['psnr'] = psnr(src_img, resized_tgt_img)
        if self.do_rmse: out['rmse'] = rmse(src_img, resized_tgt_img)
        if self.do_sam:  out['sam']  = sam(src_img,  resized_tgt_img)
        if self.do_sre:  out['sre']  = sre(src_img,  resized_tgt_img)
        if self.do_ssim: out['ssim'] = ssim(src_img, resized_tgt_img)
        if self.do_uiq:  out['uiq']  = uiq(src_img,  resized_tgt_img)

        #####################################
        # return dict
        #####################################
        return out
    
    # compute similarity of some input image to all other images
    def img_similarity_all(
        self,
        loc         : int            = None,  # 0 / index name to read
        plot_src    : bool           = True,  # should plot src img
        plot_tgt    : bool           = False, # should plot tgt img
        grayscale   : bool           = False, # to avoid finding same item in different colour
        blur        : bool           = False, # to avoid finding same item with minor difference in listing
        blur_ksize  : Tuple[int,int] = (5,5), # the larger the stronger the smoothing
        do_parallel : bool           = True,
        ) -> IMG_SIMILARITY_DF :
        
        #####################################
        # src_img
        #####################################
        # get src_img
        src_img = self.get_product_picture(
            loc        = loc,
            plot       = plot_src,
            grayscale  = grayscale,
            blur       = blur,
            blur_ksize = blur_ksize,
        )
                
        #####################################
        # function to compute distances
        #####################################
        def JOBLIB_PARALLEL_FUNC(
            tgt_loc : int,
            ) -> IMG_SIMILARITY_DICT :
            return self.img_similarity_pair(
                tgt_loc    = tgt_loc,
                src_loc    = None,
                src_img    = src_img,
                plot_src   = plot_src,
                plot_tgt   = plot_tgt,
                grayscale  = grayscale,
                blur       = blur,
                blur_ksize = blur_ksize,
            )
        
        # parallelize the compute
        if do_parallel:
            out = Parallel(n_jobs=-1)(delayed(JOBLIB_PARALLEL_FUNC)(tgt_loc) for tgt_loc in self.df.index)
        else:
            out = [JOBLIB_PARALLEL_FUNC(tgt_loc) for tgt_loc in self.df.index]
        
        #####################################
        # return df
        #####################################
        return pd.DataFrame(out,index=self.df.index)

    # show distribution of image similarity metrics for our dataset
    def img_similarity_plot(
        self,
        img_similarity_df : IMG_SIMILARITY_DF,
        ) -> NOTEBOOK_PLOTS :
        df = img_similarity_df.copy()
        df = df.replace([np.inf, -np.inf],0)
        df.hist(figsize=(20,15))