In [None]:
import argparse
import os
import pickle

import cv2
import imagehash
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from torch import load
from torch.cuda import is_available
from tqdm import tqdm

from src.resnet import ResNet
from src.tools import art_cropper
from src.transformations import final_data_transforms
from scipy.ndimage import rotate
from skimage.transform import rotate
from skimage.feature import local_binary_pattern
from skimage import data
from skimage.color import label2rgb


In [None]:
img = Image.open("output/Divine-Arsenal-AA-ZEUS---Sky-Thunder-3119-90448279/90448279.jpg")
img = art_cropper(img)
img = np.array(img)
img = img[:, :, ::-1].copy()

In [None]:
clahe = cv2.createCLAHE(clipLimit=2.0,
                        tileGridSize=(8, 8))

In [None]:
class ReferenceImage:
    """
    Container for a card image and the associated recoginition data.
    """

    def __init__(self, name, original_image, clahe, phash=None):
        self.name = name
        self.original = original_image[71:275, 32:236]
        self.clahe = clahe
        self.adjusted = None
        self.phash = phash

        if self.original is not None:
            self.histogram_adjust()
            self.calculate_phash()

    def calculate_phash(self):
        """
        Calculates the perceptive hash for the image
        """
        self.phash = imagehash.phash(
            Image.fromarray(np.uint8(255 * cv2.cvtColor(
                self.adjusted, cv2.COLOR_BGR2RGB))),
            hash_size=32)

    def histogram_adjust(self):
        """
        Adjusts the image by contrast limited histogram adjustmend (clahe)
        """
        lab = cv2.cvtColor(self.original, cv2.COLOR_BGR2LAB)
        lightness, redness, yellowness = cv2.split(lab)
        corrected_lightness = self.clahe.apply(lightness)
        limg = cv2.merge((corrected_lightness, redness, yellowness))
        self.adjusted = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)

In [None]:
def histogram_adjust(image):
    """
    Adjusts the image by contrast limited histogram adjustmend (clahe)
    """
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    lightness, redness, yellowness = cv2.split(lab)
    corrected_lightness = clahe.apply(lightness)
    limg = cv2.merge((corrected_lightness, redness, yellowness))
    return cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)

In [None]:
maxsize = 1000

target = cv2.imread("output.png")
target = cv2.resize(target, (204, 204), interpolation=cv2.INTER_LINEAR)
if min(target.shape[0], target.shape[1]) > maxsize:
    scalef = maxsize / min(target.shape[0], target.shape[1])
    img = cv2.resize(target,
                     (int(target.shape[1] * scalef),
                      int(target.shape[0] * scalef)),
                     interpolation=cv2.INTER_AREA)

test_image = histogram_adjust(target)

In [None]:
reference_images = []
with tqdm(total=11688, desc="Saving pickles", colour='cyan') as pbar:
    for subdir, dirs, files in os.walk("./output/"):
        for file in files:
            pbar.update(1)
            abs_file_path = os.path.join(subdir, file)

            img = cv2.imread(abs_file_path)
            img_name = subdir
            reference_images.append(
                ReferenceImage(img_name, img, clahe))

In [None]:
def phash_diff(target, references):
        """
        Calculates the phash difference between the given phash and
        each of the reference images.
        """
        diff = np.zeros(len(references))
        for i, ref_im in enumerate(references):
            diff[i] = target - ref_im.phash
        return diff

In [None]:
rotations = np.array([0., 90., 180., 270.])

d_0_dist = np.zeros(len(rotations))
d_0 = np.zeros((len(reference_images), len(rotations)))

for j, rot in enumerate(rotations):
    if not -1.e-5 < rot < 1.e-5:
        phash_im = imagehash.phash(
            Image.fromarray(np.uint8(255 * cv2.cvtColor(
                rotate(test_image, rot), cv2.COLOR_BGR2RGB))),
            hash_size=32)
    else:
        phash_im = imagehash.phash(
            Image.fromarray(np.uint8(255 * cv2.cvtColor(
                test_image, cv2.COLOR_BGR2RGB))),
            hash_size=32)

    d_0[:, j] = phash_diff(phash_im, reference_images)
    d_0_ = d_0[d_0[:, j] > np.amin(d_0[:, j]), j]
    d_0_ave = np.average(d_0_)
    d_0_std = np.std(d_0_)
    d_0_dist[j] = (d_0_ave - np.amin(d_0[:, j])) / d_0_std
    # if (d_0_dist[j] > 4 and
    #     np.argmax(d_0_dist) == j):
    card_name = reference_images[np.argmin(d_0[:, j])].name
    is_recognized = True
    recognition_score = d_0_dist[j] / 4
    print(recognition_score, card_name)

In [None]:
d_0

In [None]:
sorted(d_0[:, 1])[:100]

In [None]:
(is_recognized, recognition_score, card_name)

In [None]:
for j, ref in enumerate(reference_images):
    if ref.name == "./output/Mermail-Abysslung-0-95466842":
        print(j)

In [None]:
phash_im = imagehash.phash(
    Image.fromarray(np.uint8(255 * cv2.cvtColor(
        test_image, cv2.COLOR_BGR2RGB))),
    hash_size=32)

In [None]:
phash_im - reference_images[8042].phash

In [None]:

phash_im - reference_images[450].phash

In [None]:
cv2.imshow("", test_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
cv2.imshow("", reference_images[8042].adjusted)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
# settings for LBP
radius = 3
n_points = 8 * radius

In [None]:
def overlay_labels(image, lbp, labels):
    mask = np.logical_or.reduce([lbp == each for each in labels])
    return label2rgb(mask, image=image, bg_label=0, alpha=0.5)

In [None]:
def highlight_bars(bars, indexes):
    for i in indexes:
        bars[i].set_facecolor('r')

In [None]:
image = cv2.cvtColor(test_image, cv2.COLOR_BGR2GRAY)
lbp = local_binary_pattern(image, n_points, radius)

In [None]:
def hist(ax, lbp):
    n_bins = int(lbp.max() + 1)
    return ax.hist(
        lbp.ravel(), density=True, bins=n_bins, range=(0, n_bins), facecolor='0.5'
    )

In [None]:
fig, (ax_img, ax_hist) = plt.subplots(nrows=2, ncols=3, figsize=(9, 6))
plt.gray()

titles = ('edge', 'flat', 'corner')
w = width = radius - 1
edge_labels = range(n_points // 2 - w, n_points // 2 + w + 1)
flat_labels = list(range(0, w + 1)) + list(range(n_points - w, n_points + 2))
i_14 = n_points // 4  # 1/4th of the histogram
i_34 = 3 * (n_points // 4)  # 3/4th of the histogram
corner_labels = list(range(i_14 - w, i_14 + w + 1)) + list(
    range(i_34 - w, i_34 + w + 1)
)

label_sets = (edge_labels, flat_labels, corner_labels)

for ax, labels in zip(ax_img, label_sets):
    ax.imshow(overlay_labels(image, lbp, labels))

for ax, labels, name in zip(ax_hist, label_sets, titles):
    counts, _, bars = hist(ax, lbp)
    highlight_bars(bars, labels)
    ax.set_ylim(top=np.max(counts[:-1]))
    ax.set_xlim(right=n_points + 2)
    ax.set_title(name)

ax_hist[0].set_ylabel('Percentage')
for ax in ax_img:
    ax.axis('off')

In [None]:
radius = 2
n_points = 8 * radius

In [None]:
def kullback_leibler_divergence(p, q):
    p = np.asarray(p)
    q = np.asarray(q)
    filt = np.logical_and(p != 0, q != 0)
    return np.sum(p[filt] * np.log2(p[filt] / q[filt]))

In [None]:
def match(refs, img):
    best_score = 10
    best_name = None
    lbp = local_binary_pattern(img, n_points, radius)
    n_bins = int(lbp.max() + 1)
    hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
    for name, ref in refs.items():
        ref_hist, _ = np.histogram(ref, density=True, bins=n_bins, range=(0, n_bins))
        score = kullback_leibler_divergence(hist, ref_hist)
        if score < best_score:
            best_score = score
            best_name = name
    return best_name

In [None]:
brick = test_image[:, :, 0]
grass = reference_images[450].adjusted[:, :, 0]
gravel = reference_images[8042].adjusted[:, :, 0]

In [None]:
cv2.imshow("", gravel)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
refs = {
    # 'brick': local_binary_pattern(brick, n_points, radius),
    'grass': local_binary_pattern(grass, n_points, radius),
    'gravel': local_binary_pattern(gravel, n_points, radius),
}

In [None]:
match(refs, brick)

In [None]:
print('Rotated images matched against references using LBP:')
print(
    'original: brick, rotated: 30deg, match result: ',
    match(refs, rotate(brick, angle=0, resize=False)),
)
print(
    'original: brick, rotated: 70deg, match result: ',
    match(refs, rotate(brick, angle=70, resize=False)),
)
print(
    'original: grass, rotated: 145deg, match result: ',
    match(refs, rotate(grass, angle=0, resize=False)),
)

In [None]:
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3, figsize=(9, 6))
plt.gray()

ax1.imshow(brick)
ax1.axis('off')
hist(ax4, refs['brick'])
ax4.set_ylabel('Percentage')

ax2.imshow(grass)
ax2.axis('off')
hist(ax5, refs['grass'])
ax5.set_xlabel('Uniform LBP values')

ax3.imshow(gravel)
ax3.axis('off')
hist(ax6, refs['gravel'])

plt.show()

In [None]:
brick