In [1]:
import mlcroissant as mlc
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import matplotlib.pyplot as plt
import cmcrameri.cm as cmc
from sklearn.metrics.pairwise import pairwise_distances_argmin
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

# Stop PIL getting upset at the v. big images
Image.MAX_IMAGE_PIXELS = None

In [2]:
#open dataset with croissant

ds = mlc.Dataset("./gebco_doi.json")

MAX_RECORDS = 1

for i, record in enumerate(ds.records(record_set="default")):
    record["images/content"]
    if i + 1 >= MAX_RECORDS:
        break

In [3]:
records = ds.records(record_set="default")

In [4]:
def split_2d(array, splits):
    #split large 2D (z,y,x) array into subsets (z*splits[0]*splits[1],y/splits[0],x/splits[1])
    x, y = splits
    return np.stack(np.split(np.concatenate(np.split(array, y, axis=1)), x*y))

In [5]:
#take subset of data and produce array of subsets
arrs = []
for arr in iter(records):
    img = np.asarray(arr['images/content'])
    subset = split_2d(img,(100,100))
    arrs.append(subset)

In [6]:
big_arr = np.concatenate(arrs,axis=0)

In [8]:
flattened = np.reshape(big_arr,(big_arr.shape[0],big_arr.shape[1]*big_arr.shape[2]))

In [None]:
#find first three principal components of flattened data
pca = PCA(n_components=3)
pca_arr = pca.fit(flattened)
pca_d = pca_arr.transform(flattened)

In [None]:
pca_arr.explained_variance_ratio_

In [None]:
#plot first two components
plt.scatter(pca_d[:,0],pca_d[:,1],s=0.1)

In [None]:
n_clusters=8

kmeans = KMeans(init="k-means++", n_clusters=n_clusters, n_init=4)
k_means = kmeans.fit(pca_d)

fig = plt.figure(figsize=(8, 3))
fig.subplots_adjust(left=0.02, right=0.98, bottom=0.05, top=0.9)
k_means_cluster_centers = k_means.cluster_centers_


k_means_labels = pairwise_distances_argmin(pca_d, k_means_cluster_centers)

# KMeans
ax = fig.add_subplot(1, 1, 1)
for k in range(n_clusters):
    col = cmc.glasgowS(k/(n_clusters-1))
    my_members = k_means_labels == k
    cluster_center = k_means_cluster_centers[k]
   #ax.plot(pca_d[my_members, 0], pca_d[my_members, 1], "w", markerfacecolor=col, marker="o")
    ax.scatter(pca_d[my_members, 0], pca_d[my_members, 1], color=col, s=0.1)
    ax.plot(
        cluster_center[0],
        cluster_center[1],
        "o",
        markerfacecolor=col,
        markeredgecolor="k",
        markersize=6,
    )
ax.set_title("KMeans")
ax.set_xticks(())
ax.set_yticks(())
