# Imports

In [1]:
import scipy.io as spio
import matplotlib.pyplot as plt
import numpy as np
import random
import spams
from scipy import sparse
from sklearn.feature_extraction import image

# Read-in HSI

In [2]:
#pavia_mat = spio.loadmat('data/PaviaU/PaviaU.mat')
#pavia_gt_mat = spio.loadmat('data/PaviaU/PaviaU_gt.mat')
#pavia_image = pavia_mat['paviaU']
#pavia_gt = pavia_gt_mat['paviaU_gt']

# IP_data_mat=spio.loadmat("data/IP/Indian_pines_corrected.mat")
# IP_gt_mat=spio.loadmat("data/IP/Indian_pines_gt.mat")
# IP_image=IP_data_mat["indian_pines_corrected"]
# IP_gt=IP_gt_mat['indian_pines_gt']

salinas_mat = spio.loadmat('data/Salinas/salinas.mat')
salinas_gt_mat = spio.loadmat('data/Salinas/salinas_gt.mat')
salinas_image = salinas_mat['salinasA_corrected']
salinas_gt = salinas_gt_mat['salinasA_gt']

# JSR Dictionary Learning

In [3]:
def progress(curr_it, max_it):
    if (curr_it) < max_it:
        print('[{0}] {1}%'.format((curr_it+1), np.round(((curr_it+1)/max_it)*100,2)), end='\r')
    else:
        print('[{0}] {1}%'.format((curr_it+1), np.round(((curr_it+1)/max_it)*100,2)), end='\n')

# Dictionary update step from ODL, taken from J Mairal "Online Dictionary Learning for Sparse Coding"
def updateDict(D, A, B):
    D = D.copy()
    DA  = D.dot(A)
    count = 0
    for j in range(D.shape[1]):
        u_j = (B[:, j] - np.matmul(D, A[:, j]))/A[j, j] + D[:, j]
        D[:, j] = u_j/max([1, np.linalg.norm(u_j)])
    D = np.array(D, dtype=np.double)
    return D

class Jsr_odl:
    def drawRand(self):
        rand = random.randint(0,self.tiles.shape[0]-1)
        return self.tiles[:,rand]
    def normaliseD(self):
        D = self.atoms.copy()
        self.atoms = np.asfortranarray(D / np.tile(np.sqrt((D*D).sum(axis=0)),(D.shape[0],1)),dtype=np.double)
    def tile(self):
#         hsi = self.image.copy()
#         ts = self.ts
#         tiles = np.array([hsi[i:i+ts, j:j+ts]for j in range(0, hsi.shape[1], ts) for i in range(0, hsi.shape[0], ts)])
#         tiles = np.reshape(tiles, (tiles.shape[0], ts*ts, tiles.shape[3]))
#         tiles = np.reshape(tiles, (tiles.shape[0]*tiles.shape[1], tiles.shape[2]))
#         tiles=np.array(tiles, dtype=np.double)
#         tiles/=np.linalg.norm(tiles, axis=1, keepdims=True)
        ts = self.ts
        thsi = np.pad(self.image.copy(), [1,1], mode='edge')
        hsi = thsi[:,:,1:thsi.shape[2]-1]
        tiles = image.extract_patches_2d(hsi, patch_size=[ts, ts], max_patches=None)
        tiles = np.reshape(tiles, (tiles.shape[0], ts*ts, tiles.shape[3]))
        tiles = np.reshape(tiles, (tiles.shape[0]*tiles.shape[1], tiles.shape[2]))
        tiles=np.array(tiles, dtype=np.double)
        tiles/=np.linalg.norm(tiles, axis=1, keepdims=True)
        return np.asfortranarray(tiles.T, dtype=np.double)
    def initD(self):
        self.atoms = np.zeros((self.tiles.shape[0], self.num_coms))
        for i in range(self.num_coms):
            self.atoms[:,i]=self.drawRand()
        self.atoms=np.asfortranarray(self.atoms, dtype=np.double)
    def __init__(self, image, k=100, tilesize=10):
        if k <= image.shape[2]:
            print('Select an adequate number of components for %s.' %str(image.shape))
            return
        else:
            self.num_coms = k
        self.image = image
        self.ts = tilesize
        self.tiles=self.tile()
        assert np.allclose(np.linalg.norm(self.tiles, axis=0), 1.)
        self.atoms=None
        self.coefs=None
        self.initD()
        assert np.allclose(np.linalg.norm(self.atoms, axis=0), 1.)
        self.ind_groups=np.array(np.arange(0,self.tiles.shape[1],self.ts*self.ts), dtype=np.int32)
    def get_coefs(self):
        self.coefs = spams.somp(self.tiles[:,:], self.atoms, self.ind_groups, L=self.L, eps=self.eps, numThreads=self.num_threads)
        return self.coefs
    def fit(self, max_iter=1000, L=3, eps=0.1, numThreads=-1):
        self.max_iter=max_iter
        self.L=L
        self.eps=eps
        self.num_threads=numThreads
        A=np.zeros((self.num_coms, self.num_coms))
        B=np.zeros((self.tiles.shape[0], self.num_coms))
        for i in range(self.max_iter):
            progress(i, self.max_iter)
            random_index = random.choice(self.ind_groups)
            signals = self.tiles[:,random_index:(random_index+(self.ts**2))]
            alphas = spams.somp(signals, self.atoms, self.ind_groups[:1], L=self.L, eps=self.eps, numThreads=self.num_threads)
            alphas = sparse.coo_matrix(alphas)
            alphas = alphas.todense()
            for j in range(alphas.shape[1]):
                alpha=alphas[:, j]
                signal=signals[:,j]
                A += (alpha.T.dot(alpha))
                B += (signal[:, None]*alpha[:].T)
                self.atoms = updateDict(self.atoms, A, B)
                self.atoms = np.asfortranarray(self.atoms)
                self.normaliseD()

## Experiment with tilesize

In [5]:
tilesizes = [9, 11]

for ts in tilesizes:
    jsr = Jsr_odl(image=salinas_image, k=330, tilesize=ts)
    jsr.fit(max_iter=5000, L=3, eps=0.1, numThreads=-1)
    final_d = jsr.atoms
    print('results/JSR/salinas_k%d_L%d_t%d_eps001_ts%d\n'%(330, 3, 5000, ts))
    print(final_d)
    print('\n')
    np.save('results/JSR/salinas_k%d_L%d_t%d_eps001_ts%d'%(330, 3, 5000, ts), final_d)

results/JSR/salinas_k330_L3_t5000_eps001_ts9

[[ 5.93725336e-03 -1.55883578e-02 -1.24121890e-02 ...  1.06615079e-02
   1.04379314e-02  1.00466524e-02]
 [-1.28761095e-01 -2.24541615e-01 -7.26078667e-02 ...  1.33207234e-02
   1.33327900e-02  1.30584566e-02]
 [ 3.73916556e-02  8.05717818e-03 -1.27416428e-02 ...  1.92227450e-02
   1.92795936e-02  1.92230811e-02]
 ...
 [ 1.96181008e-03  2.37066048e-03  5.60560636e-04 ...  7.29358327e-05
   7.96498907e-05  6.76912436e-05]
 [-4.20170753e-03 -6.56205373e-03 -1.86885317e-03 ...  8.37232405e-05
   5.69903648e-05  3.64877653e-05]
 [ 3.90223911e-03  5.23170030e-03  1.33838379e-03 ...  9.21971392e-05
   1.01406867e-04  1.12204660e-04]]


results/JSR/salinas_k330_L3_t5000_eps001_ts11

[[ 1.17651969e-02 -6.18741601e-03 -1.03252943e-02 ...  1.06357098e-02
   1.06296037e-02  1.05920113e-02]
 [ 4.28532400e-02  7.46253557e-02  9.40297077e-03 ...  1.28320198e-02
   1.30273938e-02  1.32021865e-02]
 [ 2.37659018e-02 -2.74298346e-03 -1.57410476e-02 ...  1.88