# From the paper

We attempt to reproduce and beat the results of "Sliced Wasserstein Kernels for Probability Distributions", https://openaccess.thecvf.com/content_cvpr_2016/papers/Kolouri_Sliced_Wasserstein_Kernels_CVPR_2016_paper.pdf

In [None]:
import numpy as onp
import jax.numpy as jnp
import jax
from jax.config import config

run64 = False
if run64:
    config.update("jax_enable_x64", True)
    global_type = jnp.float64
else:
    global_type = jnp.float32

In [None]:
import ott
from ott.geometry import pointcloud
from ott.core import sinkhorn
from ott.tools import transport

In [None]:
def get_indexes_support(n, m):
    rows = jnp.expand_dims(jnp.arange(n, dtype=jnp.float32), axis=1)
    rows = jnp.tile(rows, m) / n  # renormalize
    cols = jnp.expand_dims(jnp.arange(m, dtype=jnp.float32), axis=1)
    cols = jnp.tile(cols, n).T / m  # renormalize
    coords = jnp.stack([rows, cols], axis=-1)
    coords = jnp.reshape(coords, newshape=(n*m, 2))
    return coords

In [None]:
import tensorflow as tf
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')

In [None]:
import os

def retrieve_label(file_path):
    suffix = tf.strings.split(file_path, 'T')[-1]
    suffix = tf.strings.split(suffix, '.')[0]
    suffix = tf.strings.split(suffix, '_')
    class_label = tf.strings.to_number(suffix[0], out_type=tf.int64)
    image_index = tf.strings.to_number(suffix[1], out_type=tf.int64)
    return class_label, image_index

def process_filename(file_path):
    img = tf.io.read_file(file_path)
    img = tf.io.decode_jpeg(img, channels=1)
    img = tf.squeeze(img)
    class_label, image_index = retrieve_label(file_path)
    return img, class_label, image_index 

def create_texture_dataset():
    ds_images = tf.data.Dataset.list_files(
        "textures/*/*.jpg", shuffle=True
    )
    ds_images = ds_images.map(process_filename)
    return ds_images

In [None]:
import matplotlib.pyplot as plt
import itertools

def vizualize_features(features, channels='first', **kwargs):
    if channels == 'last':
        features = jnp.transpose(features, axes=[2, 3, 0, 1])
    n, m = features.shape[2], features.shape[3]
    n_row, n_col = features.shape[0], features.shape[1]
    imgs = [features[i,j] for i, j in itertools.product(range(n_row), range(n_col))]
    _, axs = plt.subplots(n_row, n_col, figsize=(n_col*4, n_row*4))
    if n_row == 1 and n_col == 1:
        axs = [axs]
    else:
        axs = axs.flatten()
    for img, ax in zip(imgs, axs):
        ax.imshow(img, **kwargs)
    plt.show()

In [None]:
it_dataset = create_texture_dataset().as_numpy_iterator()

In [None]:
from skimage.feature import graycomatrix

def fill_glcm(glcm, lbda):
    eps  = 0.1 / (glcm.shape[-2]*glcm.shape[-1])
    mask = jnp.where(glcm > 0, 0., 1.)
    mask = mask / (jnp.sum(mask, axis=[-2,-1], keepdims=True) + eps)
    glcm = glcm + lbda * mask
    glcm = glcm / jnp.sum(glcm, axis=[-2,-1], keepdims=True)
    return glcm

def GLCM(img, depth=2, width=4, compression=0, lbda=0.):
    img        = onp.floor_divide(img, 1 << compression)
    distances  = onp.arange(1, depth+1)
    if width == 2:
        angles = jnp.array([0, jnp.pi / 2])
    else:
        angles = onp.linspace(0, 2*jnp.pi, num=width, endpoint=False)
    glcm       = graycomatrix(img, distances, angles,
                             levels=256 // (1 << compression),
                             normed=True, symmetric=False)
    glcm       = onp.transpose(glcm, axes=[2, 3, 0, 1])
    glcm       = jnp.array(glcm, dtype=global_type)
    if lbda > 0.:
        # force support with positive mass
        glcm = fill_glcm(glcm, lbda)
    return glcm

In [None]:
@jax.jit
def get_images_embedding(geom, glcm_features, mu, **kwargs):
    num_gray_levels = glcm_features.shape[-1]
    depth           = glcm_features.shape[0]
    width           = glcm_features.shape[1]
    g_embeddings = []
    for depth in range(glcm_features.shape[0]):
        for width in range(glcm_features.shape[1]):
            glcm = glcm_features[depth, width].ravel()
            ot_sol = sinkhorn.sinkhorn(geom, glcm, mu, **kwargs)
            g_embeddings.append(ot_sol.g)
    g_embedding = jnp.concatenate(g_embeddings, axis=0)
    return g_embedding, ot_sol

In [None]:
compression = 2
depth = 2
width = 2
num_gray_levels = 256 // (1 << compression)
lbda = 0.1  # multiplicator of smallest mass available
epsilon = 1e-3  # in Sinkhorn regularized
n, m = 480, 640
mu_type = "average"  # or mean

In [None]:
import tqdm

def compute_uniform_mu():
    uniform = jnp.ones((num_gray_levels*num_gray_levels,), dtype=global_type)
    return uniform / jnp.sum(uniform, axis=-1, keepdims=True)

def compute_average_mu():
    glcms = []
    try:
        ds = create_texture_dataset()
        pbar = tqdm.tqdm(total=int(ds.cardinality()))
        for image, label, idx in ds:
            img = image.numpy()
            glcm = GLCM(img, depth=depth, width=width, compression=compression, lbda=lbda)
            glcms.append(glcm)
            pbar.update()
    finally:
        pbar.close()
    mu = jnp.mean(jnp.stack(glcms), axis=[0,1,2])  # average over images, depths, widths
    mu = fill_glcm(mu, lbda)
    mu = mu.ravel()
    return mu

In [None]:
toy_glcms = []
for img_id in range(1, 9+1):
    file_path = f"img{img_id}.jpg"
    img = tf.io.read_file(file_path)
    img = tf.io.decode_jpeg(img, channels=1)
    img = tf.squeeze(img)
    img = img.numpy()
    glcm = GLCM(img, depth=depth, width=width, compression=compression, lbda=lbda)
    glcm = glcm[0,0]
    toy_glcms.append(glcm)
toy_glcms = onp.array(toy_glcms).reshape((3,3)+glcm.shape)
vizualize_features(toy_glcms, vmin=0.)

In [None]:
img = next(it_dataset)[0]
glcm = GLCM(img, depth=depth, width=width, compression=compression, lbda=lbda)
vizualize_features(glcm, vmin=0.)

In [None]:
from ott.geometry import grid
geom = grid.Grid(grid_size=[num_gray_levels, num_gray_levels], epsilon=epsilon)
if mu_type == "uniform":
    mu = compute_uniform_mu()
elif mu_type == "average":
    mu = compute_average_mu()
vizualize_features(mu.reshape((1,1,num_gray_levels,num_gray_levels)))

In [None]:
glcm = GLCM(img, depth=depth, width=width, compression=compression, lbda=lbda)
g_embeddings, ot_sol = get_images_embedding(geom, glcm, mu)
img_embedding = g_embeddings.reshape(glcm.shape)
print(f"errors={ot_sol.errors[0],ot_sol.errors[-1]}, converged={ot_sol.converged}, cost={ot_sol.reg_ot_cost:.4f}")
vizualize_features(img_embedding)

In [None]:
import pandas as pd
df_describe = pd.DataFrame(img_embedding.ravel())
desc = df_describe.describe().transpose()
desc

In [None]:
import tqdm 

ds = create_texture_dataset()
features = []
raw = []
labels = []
indices = []
try:
    pbar = tqdm.tqdm(total=int(ds.cardinality()))
    converged_hist = []
    for image, label, idx in ds:
        img = image.numpy()
        glcm = GLCM(img, depth=depth, width=width, compression=compression, lbda=lbda)
        raw.append(glcm.ravel())
        g_embeddings, ot_sol = get_images_embedding(geom, jnp.array(glcm), jnp.array(mu))
        features.append(g_embeddings.reshape(glcm.shape))
        labels.append(label)
        indices.append(idx)
        converged_hist.append(ot_sol.converged)
        pbar.set_description(f"converged={jnp.mean(jnp.array(converged_hist))*100:.2f}%")
        pbar.update()
    features = jnp.stack(features).reshape((len(features), -1))
    raw = jnp.stack(raw, axis=0).reshape((len(raw), -1))
    labels = jnp.stack([label.numpy() for label in labels])
    indices = jnp.stack([idx.numpy() for idx in indices])
finally:
    pbar.close()

In [None]:
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'SinkhornKernelLargeScale.ipynb'
import wandb
wandb.login()

In [None]:
from types import SimpleNamespace
import math
config = SimpleNamespace(
    depth=depth,
    width=width,
    compression=compression,
    num_gray_levels=num_gray_levels,
    mu_type=mu_type,
    lbda=lbda
)
wandb.init(project="sinkhorn_kernel", config=config.__dict__)
table = wandb.Table(columns=['Type', 'Best acc', 'Best params'])
print(config)

In [None]:
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split, cross_validate, cross_val_score
from sklearn.model_selection import GridSearchCV

In [None]:
svc = SVC()
grid_cv_features = GridSearchCV(svc, param_grid={'gamma':['scale','auto'], 'C':[0.1, 1, 10, 100, 1000]},
                                cv=5, scoring='accuracy',  # balanced learning so ok.
                                n_jobs=12)  # keep 4 cores for other process running in parallel
grid_cv_features = grid_cv_features.fit(features, labels)
grid_cv_features.best_score_, grid_cv_features.best_params_

In [None]:
svc = SVC(**grid_cv_features.best_params_)
svc.fit(features, labels)

In [None]:
table.add_data('features', grid_cv_features.best_score_, grid_cv_features.best_params_)

In [None]:
svc = SVC()
grid_cv_raw = GridSearchCV(svc, param_grid={'gamma':['scale','auto'], 'C':[0.1, 1, 10, 100, 1000]},
                           cv=5, scoring='accuracy',  # balanced learning so ok.
                           n_jobs=12)  # keep 4 cores for other process running in parallel
grid_cv_raw = grid_cv_raw.fit(raw.astype(jnp.float64), labels)
grid_cv_raw.best_score_, grid_cv_raw.best_params_

In [None]:
table.add_data('raw', grid_cv_features.best_score_, grid_cv_features.best_params_)

In [None]:
wandb.finish()

In [None]:
svc = SVC()
grid_cv_features = GridSearchCV(svc, param_grid={'gamma':['scale','auto'], 'C':[1, 10, 50, 100, 200]},
                                cv=5, scoring='accuracy',  # balanced learning so ok.
                                n_jobs=12)  # keep 4 cores for other process running in parallel
grid_cv_features = grid_cv_features.fit(onp.array(features, dtype=onp.float64), labels)
grid_cv_features.best_score_, grid_cv_features.best_params_

In [None]:
svc = SVC()
grid_cv_raw = GridSearchCV(svc, param_grid={'gamma':['scale','auto'], 'C':[1, 10, 50, 100, 200]},
                           cv=5, scoring='accuracy',  # balanced learning so ok.
                           n_jobs=12)  # keep 4 cores for other process running in parallel
grid_cv_raw = grid_cv_raw.fit(onp.array(raw, dtype=onp.float64), labels)
grid_cv_raw.best_score_, grid_cv_raw.best_params_