In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
# !git clone -b master 'https://github.com/InesVATI/data-challenge-kernel-methods.git'

In [2]:
# %cd data-challenge-kernel-methods

# %ls

In [3]:
# !pip install jaxopt

In [4]:
from src.load_data import load_data
import jax
import jax.numpy as jnp
from src.kmeans import SphericalKMeans
from src.svm import MultiClassKernelSVM
import pandas as pd
from typing import Tuple
from src.kernels import RBF
import os
import time
import pickle

# Train CKN and save models 

In [5]:
def get_ovelapping_patch_idx(h : int, w:int, patch_size: int =3):
        """  """
        xp, yp = jnp.meshgrid(jnp.arange(w-patch_size+1), jnp.arange(h-patch_size+1))
        x, y = jnp.meshgrid(jnp.arange(patch_size), jnp.arange(patch_size))
        X = x[None, None, ...] + xp[..., None, None]
        X = X.reshape(-1, patch_size, patch_size)
        Y = y[None, None, ...] + yp[..., None, None]
        Y = Y.reshape(-1, patch_size, patch_size)

        return (Y, X)

def normalize_row(X : jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
     norm_rows = jnp.linalg.norm(X, axis=1, ord=2, keepdims=True)
     return X / (norm_rows + 1e-5), norm_rows

def kappa(x : jnp.ndarray, alpha : float):
     """ gaussian function """
     return jnp.exp(alpha * (x - 1))

def linear_gaussian_pooling(map : jnp.ndarray, beta : float = 1, sampling_factor : int = 2):
    h, w, c = map.shape
    hpool = h // sampling_factor
    conv_pool = jnp.zeros((hpool, hpool, c))
    z = jnp.stack(jnp.meshgrid(jnp.arange(h), jnp.arange(w), indexing='ij'), axis=2)
    for i in range(0, hpool):
        for j in range(0, hpool):
            pixel = jnp.array([i*sampling_factor, j*sampling_factor])
            conv_pool = conv_pool.at[i,j].set(jnp.sum(map * jnp.exp( - beta *  jnp.sum((pixel[None, None, :] - z)**2 , axis=2))[..., None], axis = (0, 1)))

    return conv_pool


class ConvKN:
    """ CKN Layer """

    def __init__(self, patch_size:int, out_channels : int, sampling_factor : int = 2,
                 n_patch_per_img_for_kmean : int = 10) -> None:
        self.patch_size = patch_size
        self.patch_idx = None
        self.patch_pad_idx = None
        self.Z = None
        self.out_channels = out_channels
        self.sampling_factor = sampling_factor
        self.spherical_kmeans = SphericalKMeans(nb_clusters=self.out_channels, max_iter=1000)
        self.n_patch_per_img_for_kmean = n_patch_per_img_for_kmean

    def extract_patches_2d(self, img : jnp.ndarray, inference : bool = False):
        if inference :
            if self.patch_pad_idx is None:
                h, w, _ = img.shape
                self.patch_pad_idx = get_ovelapping_patch_idx(h, w, self.patch_size)
            return img[self.patch_pad_idx]
        else :
            if self.patch_idx is None :
                # to do once
                h, w, _ = img.shape
                self.patch_idx = get_ovelapping_patch_idx(h, w, self.patch_size)

            return img[self.patch_idx]
    
    def train(self, key, input_maps):
        
        # extract random patches 
        batch_size, _, _, in_channels = input_maps.shape
        X = jnp.empty((0, in_channels*self.patch_size**2))
        for i in range(batch_size):
            patches = self.extract_patches_2d(input_maps[i])
            patches = patches.reshape(-1, in_channels*self.patch_size**2)
            
            # remove constant patches
            non_cst_idx = jnp.any(patches != patches[:, [0]], axis=1) # to remove
            n_patch = min(non_cst_idx.sum(), self.n_patch_per_img_for_kmean)
            # n_patch = min(patches.shape[0], self.n_patch_per_img_for_kmean)
            key, _ = jax.random.split(key)
            X = jnp.vstack((X, 
                           jax.random.choice(key, patches[non_cst_idx], shape=(n_patch,), replace=False, axis=0))
            )
            
        X = jax.random.permutation(key, X, axis=0)
        # normalize row applying kmeans
        X, _ = normalize_row(X) 
        print(f'For training kmeans, X shape {X.shape}')
        self.Z, _ = self.spherical_kmeans.fit(X, init_centroids=self.Z) # centroids are normalized
        
        # compute linear weights
        mat = kappa(self.Z.dot(self.Z.T), alpha = 1/ .5**2) # out_channels x out_channels
        D, U = jnp.linalg.eigh(mat)
        D = D.at[D<1e-6].set(1e-6)
        inv_sqrt_D = jnp.diag(D  ** (-0.5))
        self.W = U.dot(inv_sqrt_D.dot(U.T))         
        

    def __call__(self, input_maps : jnp.ndarray):
        """
        Suppose images are squarred
        :param input_maps : array of size (B, H, W, C) B is batch size, C is channel size (e.g. 3)
        """

        if self.Z is None:
             raise Warning('Filters Z have to be initialized or learned. Call .train() befor evaluating model')

        batch_size, h, w, in_channels = input_maps.shape

        for i in range(batch_size):
            # add padding, so that there is a patch for each pixel in input map
            pad = (self.patch_size - 1)
            p = pad // 2
            image = jnp.zeros((h+pad, w+pad, in_channels))
            image = image.at[p:p+h, p:p+w, :].set(input_maps[i])

            patches = self.extract_patches_2d(image, inference=True)
            patches = patches.reshape(-1, in_channels*self.patch_size**2)

            normalized_x, norm_x = normalize_row(patches)
    
            out_map = norm_x * kappa( normalized_x.dot(self.Z.T), alpha=1/(.25)).dot(self.W)
            out_map = out_map.reshape(h, w, self.out_channels)
            # linear pooling
            beta = jnp.square(jnp.sqrt(2)/self.sampling_factor) # jnp.square(1 / (h*self.sampling_factor))

            conv_pool = linear_gaussian_pooling(out_map, beta=beta,
                                                sampling_factor=self.sampling_factor)
            if i == 0:
                out_maps = conv_pool[None, ...]
            else :
                out_maps = jnp.vstack((out_maps,
                                    conv_pool[None, ...]))
                
            if jnp.isnan(out_maps).any():
                print('i', i)
                raise Warning(f'Output map has {jnp.isnan(out_maps).sum()} NaN values')
            
        return out_maps
    

class ModelCKN:
    def __init__(self, patch_sizes : list,
                 out_channels : list,
                 subsampling_factors : list,
                 n_patch_per_img_for_kmean : int = 20):
        self.n_layers = len(out_channels)
        self.layers = []
        for i in range(self.n_layers):
            self.layers.append( ConvKN(patch_size=patch_sizes[i], 
                                       out_channels=out_channels[i],
                                       sampling_factor=subsampling_factors[i],
                                       n_patch_per_img_for_kmean = n_patch_per_img_for_kmean)
                                       )
            
    def train(self, key, input_maps, batch_size=2000):
        keys = [key]

        N = input_maps.shape[0]
        train_idx = jax.random.choice(key, jnp.arange(N), (N//batch_size, batch_size), replace=False)
        for b in range(len(train_idx)):
            input_maps_layer = input_maps[train_idx[b]]
            keys = jax.random.split(keys[-1], self.n_layers)
            print('b', b)
            for i in range(self.n_layers):
                self.layers[i].train(keys[i], input_maps_layer)
                input_maps_layer = self.layers[i](input_maps_layer)
                print(f'Layer {i}: out max {input_maps_layer.max()} out min {input_maps_layer.min()} Z max {self.layers[i].Z.max()} W max {self.layers[i].W.max()}')
                print(f'Output map shape {input_maps_layer.shape}')

    def __call__(self, input_maps):
    
        for i in range(self.n_layers):
            input_maps = self.layers[i](input_maps)

        return input_maps

In [7]:
data_folder = os.path.join(os.getcwd(), 'data')
Xtr, Ytr, Xte = load_data(data_folder, reshape=True)
print(f"Xtr {Xtr.shape}; Ytr {Ytr.shape}; Xte {Xte.shape}")

Xtr (5000, 32, 32, 3); Ytr (5000,); Xte (2000, 32, 32, 3)


In [8]:
# Check ConvKN layer works
key = jax.random.PRNGKey(6)
model = ConvKN(patch_size=2, out_channels=64, sampling_factor=2, n_patch_per_img_for_kmean=20)
model.train(key, Xtr[:20])
out = model(Xtr[:20])
jnp.isnan(out).any()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


For training kmeans, X shape (400, 12)
Spherical Kmean main loop starts
Spherical K means at iter 10 :  mean cosine similarity 0.8659
End of main loop


Array(False, dtype=bool)

In [15]:
# Maybe try 2 layers w [128, 256] out_channels and [3, 2] or [2, 1]  patch size 
key = jax.random.PRNGKey(9) 
myCKN = ModelCKN(patch_sizes=[3, 2, 2], 
                 out_channels=[64, 128, 256],
                 subsampling_factors=[2, 4, 4],
                 n_patch_per_img_for_kmean=20)
key = jax.random.PRNGKey(0)
t0 = time.time()
idx = jax.random.choice(key, jnp.arange(5000), (2000,))
myCKN.train(key, Xtr[idx], batch_size=500)
t1 = time.time()

print(f"Training myCKN took {(t1-t0)/60:.1} min ")

In [11]:
models_folder = os.path.join(os.getcwd(), 'models')

In [12]:
if not os.path.exists(models_folder):
    os.makedirs(models_folder)
    
with open(f'{models_folder}/myckn_f.pkl', 'wb') as f:
    pickle.dump(myCKN, f)

In [19]:
with open(f'{models_folder}/myckn_f.pkl', 'rb') as f:
    myCKN = pickle.load(f)
    
myCKN.layers[0].W

Array([[ 1.0860980e+00, -1.3044222e-03,  6.2260870e-03, ...,
        -4.8774155e-04,  6.0615120e-03, -4.4224094e-04],
       [-1.3044182e-03,  1.0384851e+00,  8.0706197e-04, ...,
        -1.9493288e-03,  9.8094472e-04,  2.9343476e-03],
       [ 6.2260889e-03,  8.0704771e-04,  1.1474202e+00, ...,
        -4.1138922e-04, -1.6049797e-02,  4.4441978e-03],
       ...,
       [-4.8774216e-04, -1.9493316e-03, -4.1136827e-04, ...,
         1.1899413e+00,  1.5721554e-02, -8.4726915e-02],
       [ 6.0615172e-03,  9.8093669e-04, -1.6049797e-02, ...,
         1.5721556e-02,  1.0434704e+00, -8.4156178e-02],
       [-4.4224047e-04,  2.9343530e-03,  4.4441856e-03, ...,
        -8.4726900e-02, -8.4156148e-02,  1.0826175e+00]], dtype=float32)

In [20]:
print('Compute output of saved CKN model')
t0 = time.time()
outputs = []
# To avoid jax out of memory issues
for i in range(5):
    out = myCKN(Xtr[i*1000:(i+1)*1000])
    outputs.append(out.reshape(out.shape[0], -1))
t1 = time.time()
print(f'Evaluate myCKN took {(t1-t0)/60:.1} min')

Compute output of saved CKN model
Evaluate myCKN took 1e+02 min


In [21]:
outputs = jnp.vstack(outputs)
print(outputs.shape)

(5000, 2048)


In [22]:
# z-score output
mu = outputs.mean(axis=0, keepdims=True)
X = outputs - mu
s = jnp.sqrt( jnp.mean( X**2, axis=0, keepdims=True) )
X = X / s
print('Input shape', X.shape)
print(f'mu {mu}, s {s}')


jnp.save(os.path.join(models_folder, 'scaler_features_means.npy'), mu)
jnp.save(os.path.join(models_folder, 'scaler_features_std.npy'), s)

Input shape (5000, 2048)
mu [[0.1365981  0.09182774 0.10040864 ... 0.42932993 0.37867564 0.31226152]], s [[0.20315729 0.09915642 0.09488929 ... 0.3609316  0.26758936 0.17920268]]


In [16]:
ntrain = X.shape[0] - 80
kernel_func = RBF(sigma=jnp.sqrt(X.shape[1]))
my_svm = MultiClassKernelSVM(num_classes=10, kernel_func=kernel_func, c=1)
my_svm.fit(X[:ntrain], Ytr[:ntrain])

preds = my_svm.predict(X[ntrain:ntrain+80])
print('Acc', jnp.mean(preds == Ytr[ntrain:ntrain+80]))

In [17]:
if not os.path.exists(models_folder):
    os.makedirs(models_folder)
with open(f'{models_folder}/svm_full.pkl', 'wb') as f:
    pickle.dump(my_svm, f)

In [18]:
with open(f'{models_folder}/svm_full.pkl', 'rb') as f:
    my_svm = pickle.load(f)

In [19]:
print('Compute Xte features from CKN')
t0 = time.time()
out_test = myCKN(Xte)
out_test = out_test.reshape(out_test.shape[0], -1)
t1 = time.time()
print(f'Evaluating CKN on test data took {(t1-t0)/60:.1} min')

mu = out_test.mean(axis=0)
X = out_test - mu[None, :]
s = jnp.sqrt( jnp.mean( X**2, axis=0) )
X = X / s[None, :]
print('Input shape', X.shape)
print(f'mu {mu}, s {s}')

Compute Xte features from CKN
Evaluating CKN on test data took 3e+01 min
Input shape (2000, 2048)
mu [0.13454047 0.08940188 0.10523514 ... 0.4221093  0.37207156 0.31006384], s [0.19884168 0.0885532  0.10756616 ... 0.36169964 0.27463686 0.18510593]


In [20]:
Yte = my_svm.predict(X)
Yte = {"Prediction": Yte}
dataframe = pd.DataFrame(Yte)
dataframe.index += 1

dataframe.to_csv(f"{data_folder}/Yte.csv", index_label="Id")

Function predict took 15.40 seconds
