# Test du solveur par transport optimal

Le but de ce notebook est d'analyser les résultats fournis par notre solveur par transport optimal, défini dans un précédent notebook.

## Code du solveur

Voici le code précédemment généré pour le solveur.

In [2]:
import ot
import scipy.optimize, scipy.signal
import numpy as np
import matplotlib.pyplot as plt

def frequential_cost_matrix(F):
    """
    Generates a frequential cost matrix, i.e. the frequency moving cost.
    ---
    Inputs:
        - F : 1-D float array. The range of frequencies to be considered.
    ---
    Outputs
        - C : 2-D float array. The wanted cost matrix.
    """
    M = F.shape[0]
    C = np.zeros((M,M))
    for i in range(M):
        for j in range(M):
            C[i,j] = (F[i]-F[j])**2
    return C

def simplex_OT_canonical_matrices(v, W, C):
    """
    Generates the canonical matrices for the simplex algorithm.
    ---
    Inputs:
        - v : 1-D float array. The target trame to be matched.
        - W : 2-D float array. The signal dictionnary.
        - C : 2-D float array. The frequential cost matrix.
    ---
    Outputs:
        - c : 1-D float array. The objective function vector for the simplex algorithm 
                    (i.e. the flattened cost matrix).
        - A : 2-D float array. The left-side (variables sided) equality matrix for the simplex algorithm.
        - b : 1-D float array. The right-side (value constraints) equality vector for the simplex algorithm.
    """
    M = v.shape[0]
    K = W.shape[1]
    # Coefficients for minimization
    c = np.zeros((M**2+ K,))
    #Constraints matrix and vertice
    A = np.zeros((2*M,M**2+ K))
    b = np.zeros((2*M,))
    
    # c is unfolding C and letting zeros after
    c[:M**2] =  [C[i//M, i%M] for i in range(M**2)]
    
    # Writing first constraints set
    for i in range(M):
        for j in range(M):
            A[i, i*M+j] = 1
        b[i] = v[i] 
    
    # Writing second constraints set
    for j in range(M):
        for i in range(M):
            A[M + j, i*M+j] = 1
        for k in range(K):
            A[M + j, M**2 +k] = -W[j,k]
    
    # Returning matrices
    return c, A, b

def simplex_algorithm(c, A, b):
    """
    Returns the optimal solution found by the simplex algorithm.
    ---
    Inputs:
        - c : 1-D float array. The objective function vector for the simplex algorithm 
                    (i.e. the flattened cost matrix).
        - A : 2-D float array. The left-side (variables sided) equality matrix for the simplex algorithm.
        - b : 1-D float array. The right-side (value constraints) equality vector for the simplex algorithm.
    ---
    Outputs:
        - x : 1-D float array. The simplex's optimal solution, 
                    i.e. the vector that minimizes <c.x>, 
                    so that Ax = b.            
    """
    #result = scipy.optimize.linprog(c, method="simplex", options={'A':A, 'b':b})
    result = scipy.optimize.linprog(c, A_eq=A, b_eq=b,bounds=(0, None))
    if not result.success:
        print("Failure: ", result.message)
    print("Simplex iterations:", result.nit)
    return result.x

def compute_NMF(v,C,W):
    """
    Computes a Non-negative Matrix Factorization of v based on W, optimizing with respects to C.
    ---
    Inputs:
        - v : 1-D float array. The target trame to be matched.
        - W : 2-D float array. The signal dictionnary.
        - C : 2-D float array. The frequential cost matrix.
    ---
    Outputs:
        - h : 1-D float array. An optimal vector so that v ~ Wh, with respects to C.
    """
    c, A, b = simplex_OT_canonical_matrices(v, W, C)
    th = simplex_algorithm(c, A, b)
    K = W.shape[1]
    h = th[-K:]
    return h

def slice_spectrogram( f, t, S, fmin=0, fmax=50000):
    """
    Slices the frequences of a spectrogram.
    -- 
    Inputs:
        - f : 1-D float array. The spectrogram's frequency values.
        - t : 1-D float array. The spectrogram's time values.
        - S : 2-D float array. The spectrogram.
    ---
    Parameters:
        - fmin : float. The slice's minimal frequency.
        - fmax : float. The slice's maximal frequency.
    ---
    Outputs:
        - f : 1-D float array. The sliced spectrogram's frequency values.
        - t : 1-D float array. The sliced spectrogram's time values.
        - S : 2-D float array. The slicespectrogram.
    """
    freq_slice = np.where((f >= fmin) & (f <= fmax) )

    f   = f[freq_slice]
    S = S[freq_slice,:][0]
    
    return f, t, S