### Project Setup

In [1]:
import mlx.core as mx
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
from tqdm.notebook import tqdm, trange
from pathlib import Path
import ipywidgets as widgets

### File Location

In [2]:
ROOT = Path.cwd()
MAT_DIR = ROOT / "MAT Files"
GT_DIR = ROOT / "GT Files"

### Datasets - Files & Keys

In [3]:
DATASETS = {
    "Pavia": {
        "data_file": "Pavia.mat",
        "gt_file":   "Pavia_gt.mat",
        "data_key":  "pavia",
        "gt_key":    "pavia_gt",
    },
    "PaviaUni": {
        "data_file": "PaviaUni.mat",
        "gt_file":   "PaviaU_gt.mat",
        "data_key":  "paviaU",
        "gt_key":    "paviaU_gt",
    },
}

In [4]:
DEFAULT_DS = "Pavia"

ds = widgets.Dropdown(options=list(DATASETS.keys()),
                      value=DEFAULT_DS,
                      description="Dataset:")
out = widgets.Output()

display(ds, out)

Dropdown(description='Dataset:', options=('Pavia', 'PaviaUni'), value='Pavia')

Output()

In [5]:
# ds = wgts.Dropdown(options=list(DATASETS.keys()), description="Dataset:")
# display(ds)

In [6]:
# cwd = os.getcwd()
# data_path = os.path.join(cwd, "MAT Files", "Indian_pines.mat")
# gt_path = os.path.join(cwd, "GT Files", "Indian_pines_gt.mat")
# gt_mat = sp.io.loadmat(gt_path)
# data_mat = sp.io.loadmat(data_path)

In [50]:
keys = DATASETS[ds.value]

In [51]:
keys

{'data_file': 'PaviaUni.mat',
 'gt_file': 'PaviaU_gt.mat',
 'data_key': 'paviaU',
 'gt_key': 'paviaU_gt'}

In [52]:
# data_mat = sio.loadmat(os.path.join(MAT_DIR, keys["data_file"]))
# gt_mat   = sio.loadmat(os.path.join(GT_DIR,  keys["gt_file"]))



## Hyperspectral Image Cube

In [53]:
def load_ds(name: str):
    """Load data for the selected dataset and put into globals()."""
    cfg = DATASETS[name]
    data_mat = sio.loadmat(os.path.join(MAT_DIR, cfg["data_file"]))
    gt_mat   = sio.loadmat(os.path.join(GT_DIR,  cfg["gt_file"]))
    globals()["data_cube"] = data_mat[cfg["data_key"]]
    globals()["gt_data"]   = gt_mat[cfg["gt_key"]]
    with out:
        out.clear_output(wait=True)
        print(f"Loaded {name}: X {data_cube.shape}, GT {gt_data.shape}")


In [54]:
def _on_change(change):
    if change["name"] == "value":
        load_ds(change["new"])

In [55]:
ds.observe(_on_change, names="value")

In [56]:
load_ds(ds.value)

In [57]:
# data_cube = data_mat[keys['data_key']]

## Ground Truth Data

In [58]:
# gt_data = gt_mat[keys['gt_key']]

In [59]:
gt_data

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [2, 2, 2, ..., 0, 0, 0],
       [2, 2, 2, ..., 0, 0, 0],
       [2, 2, 2, ..., 0, 0, 0]], shape=(610, 340), dtype=uint8)

In [60]:
gt_data.shape

(610, 340)

In [61]:
bg_indices = gt_data == 0

In [62]:
filter = mx.ones([data_cube.shape[0], data_cube.shape[1]])

In [63]:
mask = mx.where(bg_indices, 0.0, filter)

In [64]:
mask

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]], dtype=float32)

In [65]:
mask3d = mask[..., None]

In [66]:
masked_cube = mx.multiply(data_cube, mask3d)

In [67]:
masked_cube.shape

(610, 340, 103)

In [68]:
type(data_cube)

numpy.ndarray

In [69]:
type(masked_cube)

mlx.core.array

In [70]:
masked_cube_reshaped = mx.reshape(masked_cube, (masked_cube.shape[0]*masked_cube.shape[1], masked_cube.shape[2]))

In [71]:
masked_cube_reshaped.shape

(207400, 103)

In [72]:
masked_cube_reshaped

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=float32)

In [73]:
cube_np = np.array(masked_cube_reshaped)
nonzero_mask = np.any(cube_np != 0, axis=1)
filtered_np = cube_np[nonzero_mask]
filtered_pixels = mx.array(filtered_np)

In [74]:
## Plotting

In [75]:
def polar(x):
    M, N = x.shape
    k = mx.minimum(M, N).item()
    U, S, Vt = mx.linalg.svd(x, stream=mx.cpu)
    return U[:, :k] @ Vt

In [76]:
def KSS(X, d, *, niters=100, Uinit=None):
    K = len(d)
    D, N  = X.shape

    ## Initialize
    
    if Uinit is None:
        U = [polar(mx.random.normal(shape=(D, di))) for di in d]
    else:
        U = [Uk for Uk in Uinit]

    scores = mx.stack(
        [mx.sum(mx.matmul(Uk.T, X, stream = mx.gpu)**2, axis=0) for Uk in U],
        axis=0
    )
    c = np.argmax(np.array(scores), axis=0)
    c_prev = c.copy()
    
    # Iterations
    for t in trange(niters, desc="KSS", leave=False):
        # Update Subspaces
        for k in range(K):
            ilist = np.nonzero(c == k)[0]
            mlx_ilist = mx.array(ilist)
            if ilist.size == 0:
                mx.random.seed(k + 42)
                U[k] = polar(mx.random.normal(shape=(X.shape[0], d[k])))
            else:
                X_k = mx.take(X, mlx_ilist, axis=1)
                A = mx.matmul(X_k, X_k.T, stream = mx.gpu)
                w, V = mx.linalg.eigh(A, stream = mx.cpu)
                U[k] = V[:, -d[k]:]

        # Update clusters
        scores = mx.stack(
            [mx.sum(mx.matmul(Uk.T, X, stream = mx.gpu)**2, axis=0) for Uk in U],
            axis=0
        )
        c = np.argmax(np.array(scores), axis=0)

        # Break if clusters did not change, update otherwise
        if np.array_equal(c, c_prev):
            print(f"Terminated early at iteration {t+1}")
            break
        else:
            c_prev = c.copy() 
    
    return U, mx.array(c)

In [77]:
def batch_KSS(X, d, *, niters=100, nruns=10):
    D, N = X.shape
    runs = [None] * nruns
    for idx in trange(nruns, desc="batch KSS", leave=False):
        mx.random.seed(idx+1)
        U, c_mx = KSS(X, d, niters=niters)

        total_cost = 0.0
        for i in range(N):
            cluster_idx = int(c_mx[i].item())
            cost = mx.sum(mx.matmul(U[cluster_idx].T, X[:, i], stream=mx.gpu)**2, axis=0)
            total_cost += float(cost.item())

        runs[idx] = (U, c_mx, total_cost)

    return runs

In [78]:
KSS_Runs = batch_KSS(filtered_pixels.T, [2, 2, 3, 2, 3, 2], nruns=10)

batch KSS:   0%|          | 0/10 [00:00<?, ?it/s]

KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 35


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 29


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 71


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 49


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 67


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 18


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 50


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 33


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

Terminated early at iteration 46


KSS:   0%|          | 0/100 [00:00<?, ?it/s]

In [36]:
totalcosts = [KSS_Runs[i][2] for i in range(len(KSS_Runs))]

In [37]:
minidx_KSS = int(np.argmax(totalcosts))

In [38]:
c = KSS_Runs[minidx_KSS][1]

In [39]:
c_np = np.array(c)

In [40]:
for k in range(6):
    ilist = np.nonzero(c_np == k)[0]
    print(f"Cluster {k+1}: {ilist}")

Cluster 1: [     5     16     17 ... 104459 104460 104580]
Cluster 2: [    42     47     50 ... 101049 101053 101906]
Cluster 3: [    65     66     67 ... 148141 148142 148143]
Cluster 4: [   138    258    393 ... 148119 148148 148149]
Cluster 5: [     0      1      2 ... 148139 148144 148145]
Cluster 6: [   137    139    255 ... 148147 148150 148151]


### Rough Work

In [41]:
# U[1]

In [42]:
# c_np = np.array(c)

In [43]:
# c_np

In [44]:
# for k in range(6):
#     ilist = np.nonzero(c_np == k)[0]
#     print(f"Cluster {k+1}: {ilist}")

In [45]:
a = mx.array([[1, 2, 3, 4], [2, 3, 4, 5]])

In [46]:
mx.metal.is_available()

True

In [47]:
x = mx.random.normal((1000, 1000))

In [48]:
mx.metal.device_info()

{'resource_limit': 499000,
 'max_buffer_length': 10726686720,
 'architecture': 'applegpu_g15s',
 'memory_size': 19327352832,
 'max_recommended_working_set_size': 14302248960,
 'device_name': 'Apple M3 Pro'}

In [49]:
print(mx.__version__)

0.29.2
