In [75]:
import os
import numpy as np
import pandas as pd
import cv2
import re
import matplotlib.pyplot as plt
from cmcrameri import cm
from cmap import Colormap
from PIL import Image
from sklearn.metrics import mean_squared_error as MSE

In [28]:
results_dir = "../results"
dataset_dir = "../datasets/FDP"

In [29]:
def ZMCC(A, B):
    A = A - np.mean(A)
    B = B - np.mean(B)
    numerator = np.sum(np.multiply(A, B))
    denominator = 2 * np.sqrt(np.sum(np.multiply(A, A)) * np.sum(np.multiply(B, B)))
    if denominator == 0:
        return 0.5
    return 0.5 + numerator / denominator

In [73]:
def get_images(code, cmap, epoch="100", pattern=True):
    code = str(code)
    space_groups = pd.read_csv("space_groups.csv")
    founds = []
    for seed in os.listdir(results_dir):
        n = seed[-1]
        if pattern:
            m = '^pattern.*'
        else:
            m = '^structure.*'
        if re.match(m, seed):
            for output in os.listdir(os.path.join(results_dir, seed)):
                
                if pattern:
                    input_path = os.path.join(dataset_dir, code, code + "_structure.png")
                    synth_path = os.path.join(results_dir, seed, output, "images", code + "_structure_synthesized_image.png")
                    real_path = os.path.join(dataset_dir, code, code + "_+0+0+0.png")
                else:
                    input_path = os.path.join(dataset_dir, code, code + "_+0+0+0.png")
                    synth_path = os.path.join(results_dir, seed, output, "images", code + "_+0+0+0_synthesized_image.png")
                    real_path = os.path.join(dataset_dir, code, code + "_structure.png")
                
                # print(input_path, synth_path, real_path, sep="\n")
                
                if not os.path.exists(input_path):
                    continue
                
                inpt = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2GRAY) / 255
                im3 = cmap(inpt)[:, :, 0:3]

                phase, e = output.split("_")
                if e != epoch or phase == "train":
                    continue
                
                if not os.path.exists(synth_path):
                    continue
                
                fake = cv2.cvtColor(cv2.imread(synth_path), cv2.COLOR_BGR2GRAY) / 255
                real = cv2.cvtColor(cv2.imread(real_path), cv2.COLOR_BGR2GRAY) / 255
                
                Z = ZMCC(fake, real)
                M = MSE(fake, real)
                founds.append(Z)
                print(code, phase, n, Z, M)
                im1 = cmap(fake)[:, :, 0:3]
                im2 = cmap(real)[:, :, 0:3]
                os.makedirs(f"codes/{code}", exist_ok=1)
                Image.fromarray((im3 * 255).astype(np.uint8), mode="RGB").save(f"codes/{code}/inpt.png")
                Image.fromarray((im1 * 255).astype(np.uint8), mode="RGB").save(f"codes/{code}/{n}_{e}_{Z:.5f}.png")
                Image.fromarray((im2 * 255).astype(np.uint8), mode="RGB").save(f"codes/{code}/real.png")
    # if (
    #     len([x for x in founds if 0.7 < x < 0.8]) > 0 and
    #     len([x for x in founds if 0.8 < x < 0.9]) > 0 and
    #     len([x for x in founds if 0.9 < x < 0.95]) > 0):
    #     print(code)


In [79]:
cmap = Colormap("colorcet:cet_r3_r").to_mpl()

# for code in np.random.choice(os.listdir(os.path.join(dataset_dir)), 10000):
#     get_images(code, cmap, pattern=True)

# df = pd.read_csv("data_pattern.csv", index_col="index")
# df_codes = df.query("ZMCC > 0.991 & ZMCC < 0.994")["code"]
# code = "24091"
# get_images(code, cmap)


for code in [153466]:
    get_images(code, cmap)

153466 val 0 0.9293946154151358 0.03283850693693531
153466 test 2 0.8703395802662852 0.044170561089785176
153466 val 9 0.7288837570611932 0.07021381030583188
