In [None]:
import os
import sys 
import typing as t
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'figure.max_open_warning': 0})
from skimage.feature import greycomatrix

# add src to path
sys.path.append('../..')
from utils.db_helper import get_image_data

In [None]:
SOURCE_DIR_V1 = "C:/database/StyleGanv1"
SOURCE_DIR_V2 = "C:/database/StyleGanv2"
SOURCE_DIR_FFHQ = "C:/database/FFHQ"

In [None]:
GRAYSCALE=False
base_save_path = os.path.join("..", "..", "..", "exp", "Comatrix")
DISTANCES = [1,2,3,4,5,6,7,8,9,10]
ANGLES = [0, np.pi/4, np.pi/2, 3*np.pi/4]

In [None]:
def comatrix_from_image(np_img: np.ndarray, distances: t.List[int], angles: t.List[float])-> np.ndarray:
    np_comatrix = np.empty(np_img.shape+(len(distances),len(angles)))
    if len(np_img.shape) == 3:
        for i in range(np_img.shape[-1]):
            np_comatrix[:,:,i] = greycomatrix(np_img[:,:,i], distances, angles)
    elif len(np_img.shape) == 2:
       np_comatrix = greycomatrix(np_img, distances, angles)
    else:
        raise ValueError('Bad shape of the image')
    return np_comatrix

In [None]:
def comatrix_from_dataset(src_path: str, distances: t.List[int], angles: t.List[float], grayscale=False)-> np.ndarray:
    print("Loading dataset...")
    dataset_gen = get_image_data(src_path, type='int', grayscale=grayscale)
    final_type = np.float64
    print("Processing database...")
    np_comatrix_dataset = np.zeros(next(dataset_gen).shape+(len(distances),len(angles)), 
                                   dtype=final_type)
    number_of_images = 0
    for np_image in dataset_gen:
        np_tmp_comatrix = comatrix_from_image(np_image, distances, angles)
        np_comatrix_dataset += np_tmp_comatrix
        number_of_images +=1
        if number_of_images % 100 == 0:
            print(f"Image number: {number_of_images}")
        elif np_comatrix_dataset.max() > 0.95*np.finfo(final_type).max:
            raise ValueError("Dataset to large, datatype overflow")
            break
    print (f"{number_of_images} processed from path: {src_path}")
    return np_comatrix_dataset / number_of_images

In [None]:
# def display_comatrix(np_comatrix_dataset: np.ndarray):
np_comatrix_ffhq = comatrix_from_dataset(SOURCE_DIR_FFHQ, DISTANCES, ANGLES, grayscale=GRAYSCALE)
np_comatrix_style1 = comatrix_from_dataset(SOURCE_DIR_V1, DISTANCES, ANGLES, grayscale=GRAYSCALE)
np_comatrix_style2 = comatrix_from_dataset(SOURCE_DIR_V2, DISTANCES, ANGLES, grayscale=GRAYSCALE)

In [None]:
datasets = {
    'ffhq': np_comatrix_ffhq,
    'StyleGanv1': np_comatrix_style1,
    'StyleGanv2': np_comatrix_style2
}

if GRAYSCALE:
    color_channels = {
    "0": 'Gray',
    }
else:
    color_channels = {
        "0": 'R',
        "1": 'G',
        '2': 'B'
    }

In [None]:
i=0
for distance in DISTANCES:
    j=0
    for angle in ANGLES:
        fig, axs = plt.subplots(nrows=3, ncols=len(color_channels), figsize=(15,15))
        fig.suptitle(f'Comatrixes for distance: {distance} and angle: {angle} rads')
        row=0
        for name, dataset in datasets.items():
            if len(color_channels) > 1:
                for col in range(len(color_channels)):
                    axs[row,col].imshow(dataset[:,:,col,i,j])
                    axs[row,col].set_title(f"{name}_{color_channels[str(col)]}")
            else:
                    axs[row].imshow(dataset[:,:,i,j])
                    axs[row].set_title(f"{name}")
            row+=1
        #saving image
        directory = 'Grayscale' if GRAYSCALE else 'RGB'
        filename=f"{distance}_{angle}.png"
        save_path = os.path.join(base_save_path, directory, filename)
        #print(save_path)
        fig.savefig(save_path)
        j+=1
    i+=1