In [None]:
import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import colorsys

In [None]:
DATA_DIR = Path("/Users/sylvi/topo_data/hariborings/processed_grains/")
DATE = "2024-03-14"
SAMPLE_TYPES = ["ON_REL", "ON_SC", "OT1_REL", "OT1_SC", "OT2_REL", "OT2_SC"]

data = {}

for sample_type in SAMPLE_TYPES:
    with open(
        DATA_DIR
        / f"unbound_{sample_type}"
        / f"date_{DATE}"
        / f"{sample_type}_turn_in_distance_grain_dict_defect_degrees_80_defect_nm_5.0.pkl",
        "rb",
    ) as f:
        grains = pickle.load(f)
    print(f"Loaded {len(grains)} grains for {sample_type}")

    data[sample_type] = grains

for sample_type, grains in data.items():
    print(f"Sample type: {sample_type}")

    # Plot a bar graph of the counts for each sample type
    tag_labels = ["churro", "pasty", "dorito", "teardrop", "open"]

    tags = [grain_data["simple_tag"] for grain_data in grains.values()]

    tag_counts = [tags.count(tag) for tag in tag_labels]

    plt.bar(tag_labels, tag_counts, color="orange")
    plt.title(f"Simple tag counts for {sample_type}, degrees per nm: 12, sample rate: 0.5 nm")
    plt.show()

for sample_type, grains in data.items():
    print(f"Sample type: {sample_type}")

    # Plot a bar graph of the counts for each sample type
    tag_labels = ["churro", "pasty", "dorito", "teardrop", "open"]

    tags = [grain_data["complex_tag"] for grain_data in grains.values()]

    tag_counts = [tags.count(tag) for tag in tag_labels]

    plt.bar(tag_labels, tag_counts, color="blue")
    plt.title(f"Complex tag counts for {sample_type}, threshold: 75 degrees in 5.0 nm")
    plt.show()

# Plot feret ratios and perimeter / area ratios for each type
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
for i, sample_type in enumerate(SAMPLE_TYPES):
    grains = data[sample_type]
    feret_ratios = [grain_data["feret_ratio"] for grain_data in grains.values()]
    perimeter_area_ratios = [grain_data["perimeter_area_ratio"] for grain_data in grains.values()]

    # Add kde of feret ratios
    sns.kdeplot(feret_ratios, ax=ax[0], label=sample_type)
    # Add kde of perimeter / area ratios
    sns.kdeplot(perimeter_area_ratios, ax=ax[1], label=sample_type)
ax[0].set_title("Feret Ratio (Greater means more round)")
ax[1].set_title("Area / Perimeter Ratio (Greater means more round)")
plt.legend()
plt.show()

# Get the dorito defect distances
simple_dorito_distances = []
complex_dorito_distances = []
for sample_type, grains in data.items():
    for grain_data in grains.values():
        if grain_data["simple_tag"] == "dorito":
            # Append all the dorito defect distances to the list
            simple_dorito_distances.extend(grain_data["simple_defect_distances"])
        if grain_data["complex_tag"] == "dorito":
            # Append all the dorito defect distances to the list
            complex_dorito_distances.extend(grain_data["complex_distances_between_defects"])


# Plot the dorito defect distances
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.kdeplot(simple_dorito_distances, ax=ax, color="orange")
ax.set_title("Simple method dorito distances between defects")
ax.set_xlabel("Distance (nm)")
ax.set_xlim(0, 45)
plt.show()

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.kdeplot(complex_dorito_distances, ax=ax, color="blue")
ax.set_title("Complex method dorito distances between defects")
ax.set_xlabel("Distance (nm)")
ax.set_xlim(0, 45)
plt.show()

In [None]:
# Get the pasty defect distances for OT1_SC pasties
complex_ot1_sc_pasty_distances = []
for grain_data in data["OT1_SC"].values():
    if grain_data["complex_tag"] == "pasty":
        # Add the distances between defects to the list but keep the list structure
        complex_ot1_sc_pasty_distances.append(grain_data["complex_distances_between_defects"])

# Get a list of the minimum distance between defects for each pasty
complex_ot1_sc_pasty_distances_shorter = [min(distances) for distances in complex_ot1_sc_pasty_distances]
complex_ot1_sc_pasty_distances_longer = [max(distances) for distances in complex_ot1_sc_pasty_distances]
complex_ot1_sc_pasty_distances_differences = [
    np.abs(distances[0] - distances[1]) for distances in complex_ot1_sc_pasty_distances
]

# Get the pasty defect distances for OT2_SC pasties
complex_ot2_sc_pasty_distances = []
for grain_data in data["OT2_SC"].values():
    if grain_data["complex_tag"] == "pasty":
        # Add the distances between defects to the list but keep the list structure
        complex_ot2_sc_pasty_distances.append(grain_data["complex_distances_between_defects"])

# Get a list of the minimum distance between defects for each pasty
complex_ot2_sc_pasty_distances_shorter = [min(distances) for distances in complex_ot2_sc_pasty_distances]
complex_ot2_sc_pasty_distances_longer = [max(distances) for distances in complex_ot2_sc_pasty_distances]
complex_ot2_sc_pasty_distances_differences = [
    np.abs(distances[0] - distances[1]) for distances in complex_ot2_sc_pasty_distances
]


# Plot the differences
# Plot shorter and longer distances on kde
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.kdeplot(
    complex_ot1_sc_pasty_distances_shorter,
    ax=ax,
    color="orange",
    label=f"Shorter length OT1_SC n={len(complex_ot1_sc_pasty_distances_shorter)}",
    linestyle="-",
)
sns.kdeplot(
    complex_ot1_sc_pasty_distances_longer,
    ax=ax,
    color="red",
    label=f"Longer length OT1_SC n={len(complex_ot1_sc_pasty_distances_longer)}",
    linestyle="-",
)
sns.kdeplot(
    complex_ot2_sc_pasty_distances_shorter,
    ax=ax,
    color="aqua",
    label=f"Shorter length OT2_SC n={len(complex_ot2_sc_pasty_distances_shorter)}",
    linestyle="-",
)
sns.kdeplot(
    complex_ot2_sc_pasty_distances_longer,
    ax=ax,
    color="blue",
    label=f"Longer length OT2_SC n={len(complex_ot2_sc_pasty_distances_longer)}",
    linestyle="-",
)
ax.set_title("Complex method OT1_SC & OT2_SC pasty distances between defects")
ax.set_xlabel("Distance (nm)")
plt.legend()
plt.show()

# Plot a kde of the differences
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.kdeplot(
    complex_ot1_sc_pasty_distances_differences,
    ax=ax,
    color="orange",
    label=f"OT1_SC n={len(complex_ot1_sc_pasty_distances_differences)}",
)
sns.kdeplot(
    complex_ot2_sc_pasty_distances_differences,
    ax=ax,
    color="blue",
    label=f"OT2_SC n={len(complex_ot2_sc_pasty_distances_differences)}",
)
ax.set_title("Complex method OT1_SC and OT2_SC pasty differences between defects")
ax.set_xlabel("Difference (nm)")
plt.legend()
plt.show()

In [None]:
def mycolor(h, s=1.0, v=1.0):
    r, g, b = colorsys.hsv_to_rgb(h, s, v)
    return r, g, b


# Get defect distances for doritos

dorito_distances = []
dorito_distances_min = []
dorito_distances_middle = []
dorito_distances_max = []
dorito_distances_min_middle_sum = []

pasty_distances = []
pasty_distances_min = []
pasty_distances_max = []
for sample_type, grain_data in data.items():
    for grain_data in grain_data.values():
        if grain_data["complex_tag"] == "pasty":
            pasty_distances.extend(grain_data["simple_defect_distances"])
            pasty_distances_min.append(min(grain_data["complex_distances_between_defects"]))
            pasty_distances_max.append(max(grain_data["complex_distances_between_defects"]))
            if len(grain_data["complex_distances_between_defects"]) != 2:
                raise ValueError("Expected 2 distances for pasty")

        if grain_data["complex_tag"] == "dorito":
            # Plot the dorito
            plt.imshow(grain_data["image"], cmap="afmhot")
            plt.show()

            dorito_distances.extend(grain_data["simple_defect_distances"])
            if len(grain_data["complex_distances_between_defects"]) != 3:
                raise ValueError("Expected 3 distances for dorito")
            dorito_distances_min.append(min(grain_data["complex_distances_between_defects"]))
            dorito_distances_middle.append(sorted(grain_data["complex_distances_between_defects"])[1])
            dorito_distances_max.append(max(grain_data["complex_distances_between_defects"]))
            dorito_distances_min_middle_sum.append(
                min(grain_data["complex_distances_between_defects"])
                + sorted(grain_data["complex_distances_between_defects"])[1]
            )

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.kdeplot(
    dorito_distances, ax=ax, color=mycolor(0.9, 0.5, 1), label=f"dorito all, n={len(dorito_distances)}", linestyle="--"
)
ax.set_title(f"Complex method dorito distances between defects n={len(dorito_distances)}")
ax.set_xlabel("Distance (nm)")
sns.kdeplot(dorito_distances_min, ax=ax, color=mycolor(1, 1, 1), label="dorito min")
sns.kdeplot(dorito_distances_middle, ax=ax, color=mycolor(0.9, 1, 1), label="dorito middle")
sns.kdeplot(dorito_distances_max, ax=ax, color=mycolor(0.8, 1, 1), label="dorito max")
sns.kdeplot(dorito_distances_min_middle_sum, ax=ax, color=mycolor(0.7, 1, 1), label="dorito min + middle")
sns.kdeplot(
    pasty_distances, ax=ax, color=mycolor(0.5, 0.5, 1), label=f"pasty all, n={len(pasty_distances)}", linestyle="--"
)
sns.kdeplot(pasty_distances_min, ax=ax, color=mycolor(0.5, 1, 1), label="pasty min")
sns.kdeplot(pasty_distances_max, ax=ax, color=mycolor(0.6, 1, 1), label="pasty max")
plt.legend()
plt.show()

In [None]:
all_distances = []
for sample_type, grain_data in data.items():
    print(f"Sample type: {sample_type}")
    for grain in grain_data.values():
        if grain["complex_distances_between_defects"]:
            if len(grain["complex_distances_between_defects"]) == 2:
                all_distances.extend(grain["complex_distances_between_defects"])

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.kdeplot(all_distances, ax=ax, color="orange")
ax.set_title("Complex method distances between defects")
ax.set_xlabel("Distance (nm)")
plt.show()


distance_difference_percentage_threshold_1 = 0.05
distance_difference_percentage_threshold_2 = 0.10


sample_classifications = {}

for sample_type, grain_data in data.items():
    churro_defect_indexes = []
    pasty_defect_indexes = []
    dorito_defect_indexes = []

    print(f"Sample type: {sample_type}")
    for grain_index, grain in grain_data.items():
        if grain["complex_distances_between_defects"]:
            if len(grain["complex_distances_between_defects"]) == 2:
                distance_difference = np.abs(
                    grain["complex_distances_between_defects"][0] - grain["complex_distances_between_defects"][1]
                )
                distance_difference_percentage = distance_difference / np.mean(
                    grain["complex_distances_between_defects"]
                )
                if (
                    distance_difference_percentage > distance_difference_percentage_threshold_1
                    and distance_difference_percentage < distance_difference_percentage_threshold_2
                ):
                    pasty_defect_indexes.append(grain_index)
                elif distance_difference_percentage > distance_difference_percentage_threshold_2:
                    dorito_defect_indexes.append(grain_index)
                else:
                    churro_defect_indexes.append(grain_index)

    sample_classifications[sample_type] = {
        "churro": churro_defect_indexes,
        "pasty": pasty_defect_indexes,
        "dorito": dorito_defect_indexes,
    }

print(sample_classifications)

# Plot all the churros
for sample_type, grain_data in data.items():
    churro_defect_indexes = sample_classifications[sample_type]["churro"]
    for grain_index in churro_defect_indexes:
        grain_image = grain_data[grain_index]["image"]
        plt.imshow(grain_image, cmap="gray")
        plt.title(f"Churros for {sample_type}")
        plt.show()

# Plot all the pasties
for sample_type, grain_data in data.items():
    pasty_defect_indexes = sample_classifications[sample_type]["pasty"]
    for grain_index in pasty_defect_indexes:
        grain_image = grain_data[grain_index]["image"]
        plt.imshow(grain_image, cmap="gray")
        plt.title(f"Pasties for {sample_type}")
        plt.show()