In [None]:
import matplotlib
import numpy as np
import scipy as scp
import pywt
import matplotlib.pyplot as plt
%matplotlib inline

import os
import pickle

import pdb

In [None]:
from sklearn.svm import SVC, LinearSVC
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import label_binarize

from sklearn.pipeline import Pipeline

In [None]:
from transform_scatt import load_transform, Concat_scat_tree, Concat_scal, JointScat

In [None]:
from compute_features import compute_features, load_features
from plot_utils import plot_confusion_matrix

# Hyperparameters

In [None]:
db_location = '/users/data/blier/ESC-50/'
log_features = True
log_eps = 0.001

connections = 'pca_net'

#update this list to test over different features
features = 'mfcc' #, 'plain_scat_1', 'plain_scat_2']

params = {'channels': (84,12), 'hops': (512,4),
          'fmin':32.7, 'fmax':11001,
          'alphas':(6,6),'Qs':(12,12), # only used for flex scattering
          'nclasses': 5, 'n_itemsbyclass':40, 'max_sample_size':2**17,
          'audio_ext':'*.ogg'}

nOctaves=10
nfo=12
nfo2=12

## Load and prepare the Data

In [None]:
load = 'raw'
if load == 'scat':
    directory = "/users/data/blier/features_esc50/scat_8_12_1/"
    trans_obj = Concat_scat_tree(M = 2, transf=None, nOctaves=nOctaves, nfo=nfo, nfo2=nfo2)
    X, y = load_transform(directory, params['nclasses'], params['n_itemsbyclass'], trans_obj)
elif load == 'raw':
    #root_path = "/users/data/blier/ESC-50"
    #X, y = compute_features(root_path, features, params)
    directory = "/users/data/blier/features_esc50/scat_10_12_12/"
    X, y = load_features(directory, params['nclasses'], params['n_itemsbyclass'])
elif load == 'joint':
    vert_wav = JointScat(nOctaves, nfo, nfo2)
    directory = "/users/data/blier/features_esc50/scat_8_12_1/"
    X, y = load_transform(directory, params['nclasses'], params['n_itemsbyclass'], vert_wav)
else:
    raise ValueError("Unknown loading parameter")

In [None]:
#X[0].shape

In [None]:
from learned_joint_scat import learn_joint_scat_model

model = learn_joint_scat_model(nOctaves, nfo, nfo2, filter_factor=2, 
                                   nClasses=params['nclasses'], n_samples=256)
model.compile(optimizer='rmsprop', metrics=['categorical_accuracy'], loss='categorical_crossentropy')


def scat_to_list(X):
    X0, X1, X2 = [[x[i] for x in X] for i in range(3)]
    X2_list = [np.stack([x2[:j2*nfo,j2*nfo2:(j2+1)*nfo2,:] for x2 in X2]) \
               for j2 in range(1, nOctaves)]
    return X0, X1, X2_list


X0, X1, X2_list = scat_to_list(X)
X0 = np.mean(np.log(log_eps+np.abs(X0)), axis=1)
X1 = np.mean(np.log(log_eps+np.abs(X1)), axis=2)
inputs = [X0, X1]
inputs.extend(X2_list)

y_binarized = label_binarize(y, np.arange(params['nclasses']))
model.fit(inputs, y_binarized, nb_epoch=50, batch_size=32, validation_split=0.20)

In [None]:
classifier1 = SVC(C=1., kernel='linear')
classifier2 = RandomForestClassifier()
#('vertical_cwt', Vertical_wavelet())
#estimators = [('concat', Concat_scat_tree(M = 3, transf='mean', nOctaves=nOctaves, nfo=nfo, nfo2=nfo2)), 
#              ('classifier', classifier1)]
#estimators = [('concat', Vertical_wavelet()), 
#              ('classifier', classifier1)]
estimators = [('classifier', classifier1)]
#estimators = [('concat', Concat_scal(transf='max')),('svc', SVC(C=1., kernel='linear'))]
pipe = Pipeline(estimators)

In [None]:
cross_valid = True
if cross_valid:
    scores = cross_val_score(pipe, X, y, cv=10)
    print(scores, scores.mean())
else:
    X_train, X_test, y_train, y_test = \
        train_test_split(X, y, test_size=0.25, 
                         random_state=42, stratify=y)
    pipe.fit(X_train, y_train)
    score_train = pipe.score(X_train, y_train)
    score_test = pipe.score(X_test, y_test)
    score_train, score_test

In [None]:
scores.mean()

In [None]:
y_pred = pipe.predict(X_test)
conf = confusion_matrix(y_test, y_pred)
conf_surclasses = confusion_matrix(y_test // 10, y_pred // 10)

## Confusion Matrix

In [None]:
plot_confusion_matrix(conf, range(params["nclasses"]), plot_values = False, normalize=True)

In [None]:
plot_confusion_matrix(conf_surclasses, range(5), plot_values = True, normalize=True)