## MusicLIME

The class for multimodal text and audio explanations is provided below. Functions and classes (from AudioLIME) needed to run MusicLIME are also provided in the next cell.

In [None]:
!pip install -q openunmix
#!pip install -q spleeter

In [None]:
import ...

In [None]:
class OpenunmixFactorization(SourceSeparationBasedFactorization):
    def __init__(self, input, temporal_segmentation_params, composition_fn, target_sr=44100):
        super().__init__(input, target_sr, temporal_segmentation_params, composition_fn)

    def initialize_components(self):
        spleeter_sr = 44100

        waveform = self._original_mix
        waveform = librosa.resample(waveform, orig_sr=self.target_sr, target_sr=spleeter_sr)
        waveform = np.expand_dims(waveform, axis=1)
        prediction = predict.separate(torch.as_tensor(waveform).float(), rate=44100)

        original_components = [librosa.resample(prediction[key][0].mean(dim=0).numpy(), orig_sr=spleeter_sr,target_sr= self.target_sr) for key in prediction]

        components_names = list(prediction.keys())
        return original_components, components_names

# MusicLIME can be implemented with the original audioLIME factorization function (spleeter)
# class SpleeterFactorization(SourceSeparationBasedFactorization):
#     def __init__(self, input, temporal_segmentation_params, composition_fn, target_sr=16000,
#                  model_name="spleeter:5stems"):
#         self.model_name = model_name
#         super().__init__(input, target_sr, temporal_segmentation_params, composition_fn)

#     def initialize_components(self):
#         spleeter_sr = 44100

#         waveform = self._original_mix
#         separator = Separator(self.model_name, multiprocess=False)
#         waveform = librosa.resample(waveform, self.target_sr, spleeter_sr)
#         waveform = np.expand_dims(waveform, axis=1)
#         prediction = separator.separate(waveform)

#         original_components = [
#             librosa.resample(np.mean(prediction[key], axis=1), spleeter_sr, self.target_sr) for
#             key in prediction]

#         components_names = list(prediction.keys())
#         return original_components, components_names

class LimeMusicExplainer(object):

    def __init__(self,
                 kernel_width=25,
                 kernel=None,
                 verbose=False,
                 class_names=None,
                 feature_selection='auto',
                 absolute_feature_sort=False,
                 split_expression=r'\W+',
                 bow=True,
                 mask_string=None,
                 random_state=None,
                 char_level=False):
        kernel_width = float(kernel_width)

        if kernel is None:
            def kernel(d, kernel_width):
                return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))

        kernel_fn = partial(kernel, kernel_width=kernel_width)

        self.random_state = check_random_state(random_state)
        self.feature_selection = feature_selection
        self.base = lime_base.LimeBase(kernel_fn, verbose, absolute_feature_sort)
        self.class_names = class_names
        self.vocabulary = None
        self.feature_selection = feature_selection
        self.bow = bow
        self.mask_string = mask_string
        self.split_expression = split_expression
        self.char_level = char_level

    def explain_instance(self,
                         factorization,
                         text_instance,
                         predict_fn,
                         labels=None,
                         top_labels=None,
                         num_reg_targets=None,
                         num_features=100000,
                         num_samples=1000,
                         batch_size=10,
                         distance_metric='cosine',
                         model_regressor=None,
                         random_seed=None,
                         fit_intercept=True,
                         modality='both'):
        if labels or top_labels:
            is_classification = True
        if is_classification and num_reg_targets:
            raise ValueError('Set labels or top_labels for classification. '
                             'Set num_reg_targets for regression.')
        if modality not in ['both', 'lyrical', 'audio']:
             raise ValueError('Set modality arguement to "both", "lyrical" or "audio".')

        if random_seed is None:
            random_seed = self.random_state.randint(0, high=1000)

        self.factorization = factorization
        top = labels

        indexed_string = (lime_text.IndexedCharacters(text_instance, bow=self.bow, mask_string=self.mask_string) if self.char_level else
                          lime_text.IndexedString(text_instance, bow=self.bow, split_expression=self.split_expression, mask_string=self.mask_string))
        domain_mapper = lime_text.TextDomainMapper(indexed_string)

        data, labels, distances = self.combined_data_labels_distances(indexed_string, predict_fn,
        	num_samples, batch_size=batch_size, distance_metric=distance_metric, modality=modality)

        ret_exp = MultimodalExplanation(indexed_string, self.factorization, data, labels, modality=modality)

        if top_labels:
            top = np.argsort(labels[0])[-top_labels:]
            ret_exp.top_labels = list(top)
            ret_exp.top_labels.reverse()
        for label in top:
            (ret_exp.intercept[label],
             ret_exp.local_exp[label],
             ret_exp.score[label], ret_exp.local_pred[label]) = self.base.explain_instance_with_data(
                data, labels, distances, label, num_features,
                model_regressor=model_regressor,
                feature_selection=self.feature_selection,
                )

        return ret_exp

    def combined_data_labels_distances(self,
                                       indexed_string,
                                       predict_fn,
                                       num_samples,
                                       modality,
                                       batch_size=10,
                                       distance_metric='cosine'):

        doc_size = indexed_string.num_words()
        audio_size = self.factorization.get_number_components()

        if modality == 'both':
            total_features = doc_size + audio_size
        elif modality == 'lyrical':
            total_features = doc_size
        elif modality == 'audio':
            total_features = audio_size

        data = self.random_state.randint(0, 2, num_samples * (total_features))\
        	.reshape((num_samples, total_features))
        data[0, :] = 1

        labels = []
        audios = []
        texts = []

        for row  in data:
            if modality == 'both':
                non_zeros = np.where(row[:audio_size] != 0)[0]
                temp = self.factorization.compose_model_input(non_zeros)
                audios.append(temp)

                inactive = np.where(row[audio_size:] == 0)[0]
                perturbed_string = indexed_string.inverse_removing(inactive)
                texts.append(perturbed_string)

            if modality == 'audio':
                non_zeros = np.where(row != 0)[0]
                temp = self.factorization.compose_model_input(non_zeros)
                audios.append(temp)

                inactive = np.array([])
                perturbed_string = indexed_string.inverse_removing(inactive)
                texts.append(perturbed_string)

            if modality == 'lyrical':
                all_oness = np.ones(audio_size)
                non_zeros = np.where(all_oness != 0)[0]
                temp = self.factorization.compose_model_input(non_zeros)
                audios.append(temp)

                inactive = np.where(row == 0)[0]
                perturbed_string = indexed_string.inverse_removing(inactive)
                texts.append(perturbed_string)

            if len(audios) == batch_size:
              preds = predict_fn(texts ,np.array(audios))
              labels.extend(preds)
              audios = []
              texts = []

        if len(audios) > 0:
            preds = predict_fn(texts ,np.array(audios))
            labels.extend(preds)

        distances = sklearn.metrics.pairwise_distances(data, data[0].reshape(1, -1), metric=distance_metric).ravel()

        return data, np.array(labels), distances


class MultimodalExplanation(object):
    def __init__(self, indexed_string, factorization, neighborhood_data, neighborhood_labels, modality):

        self.factorization = factorization
        self.indexed_string = indexed_string
        self.neighborhood_data = neighborhood_data
        self.neighborhood_labels = neighborhood_labels
        self.modality= modality
        self.intercept = {}
        self.local_exp = {}
        self.local_pred = {}
        self.score = {}
        self.distance = {}

    def get_sorted_components(self, label, positive_components=True, negative_components=True, num_components='all',
                              min_abs_weight=0.0, return_indeces=False):
        if label not in self.local_exp:
            raise KeyError('Label not in explanation')
        if positive_components is False and negative_components is False:
            raise ValueError('positive_components, negative_components or both must be True')
        n_audio_features = self.factorization.get_number_components()
        exp = self.local_exp[label]

        w = [[x[0], x[1]] for x in exp]
        used_features, weights = np.array(w, dtype=int)[:, 0], np.array(w)[:, 1]

        if not negative_components:
            pos_weights = np.argwhere(weights > 0)[:, 0]
            used_features = used_features[pos_weights]
            weights = weights[pos_weights]
        elif not positive_components:
            neg_weights = np.argwhere(weights < 0)[:, 0]
            used_features = used_features[neg_weights]
            weights = weights[neg_weights]
        if min_abs_weight != 0.0:
            abs_weights = np.argwhere(abs(weights) >= min_abs_weight)[:, 0]
            used_features = used_features[abs_weights]
            weights = weights[abs_weights]

        if num_components == 'all':
            num_components = len(used_features)
        else:
            assert(isinstance(num_components, int))

        used_features = used_features[:num_components]
        weights = weights[:num_components]
        components = []
        for index in used_features:
            if self.modality == 'both':
                if index < n_audio_features:
                    components.append(self.factorization.get_ordered_component_names()[index])
                else:
                    components.append(self.indexed_string.word(index - n_audio_features))
            elif self.modality == 'lyrical':
                components.append(self.indexed_string.word(index))
            elif self.modality == 'audio':
                components.append(self.factorization.get_ordered_component_names()[index])

        if return_indeces:
            return components, weights, used_features
        return components, weights

In [None]:
def default_composition_fn(x):
    return x

def load_audio(audio_path, target_sr):
    waveform, _ = librosa.load(audio_path, mono=True, sr=target_sr)
    return waveform

def compute_segments(signal, sr, temporal_segmentation_params=None):
    # TODO: parameter for return type (samples, frames, seconds)?
    audio_length = len(signal)
    explained_length = audio_length
    if temporal_segmentation_params is None:
        n_temporal_segments_default = min(audio_length // sr, 10) # 1 segment per second, but maximally 10 segments
        temporal_segmentation_params = {'type': 'fixed_length',
                                        'n_temporal_segments': n_temporal_segments_default}
    elif isinstance(temporal_segmentation_params, int):
        temporal_segmentation_params = {'type': 'fixed_length',
                                        'n_temporal_segments': temporal_segmentation_params}

    segmentation_type = temporal_segmentation_params['type']
    assert segmentation_type in ['fixed_length', 'manual']

    segments = []
    if segmentation_type == "fixed_length":
        n_temporal_segments = temporal_segmentation_params['n_temporal_segments']
        samples_per_segment = audio_length // n_temporal_segments

        explained_length = samples_per_segment * n_temporal_segments
        if explained_length < audio_length:
            warnings.warn("last {} samples are ignored".format(audio_length - explained_length))

        for s in range(n_temporal_segments):
            segment_start = s * samples_per_segment
            segment_end = segment_start + samples_per_segment
            segments.append((segment_start, segment_end))
    elif segmentation_type == "manual":
        segments = temporal_segmentation_params["manual_segments"]
        explained_length = segments[-1][1]  # end of last segment

    return segments, explained_length


class Factorization(object):
    def __init__(self, input, target_sr, temporal_segmentation_params=None, composition_fn=None):
        self._audio_path = None
        self.target_sr = target_sr
        if isinstance(input, str):
            self._audio_path = input
            input = load_audio(input, target_sr)
        self._original_mix = input
        if composition_fn is None:
            composition_fn = default_composition_fn
        self._composition_fn = composition_fn

        self.original_components = []
        self.components = []
        self._components_names = []
        self.temporal_segments, self.explained_length = compute_segments(self._original_mix,
                                                                         self.target_sr,
                                                                         temporal_segmentation_params)

    def compose_model_input(self, components=None):
        return self._composition_fn(self.retrieve_components(components))

    def get_number_components(self):
        # TODO: probably no need to overwrite in other classes
        return len(self._components_names)

    def retrieve_components(self, selection_order=None):
        raise NotImplementedError

    def get_ordered_component_names(self): # e.g. instrument names
        return self._components_names


class TimeOnlyFactorization(Factorization):
    # TODO: add other baseline except 0's?
    def __init__(self, input, target_sr, temporal_segmentation_params=None, composition_fn=None):
        super().__init__(input, target_sr, temporal_segmentation_params, composition_fn)
        for i in range(len(self.temporal_segments)):
            self._components_names.append("T"+str(i+1))

    def retrieve_components(self, selection_order=None):
        # TODO: check if selection_order contains out of bounds segments
        if selection_order is None:
            return self._original_mix
        retrieved_mix = np.zeros_like(self._original_mix)
        for so in selection_order:
            s, e = self.temporal_segments[so]
            retrieved_mix[s:e] = self._original_mix[s:e]
        return retrieved_mix


class SourceSeparationBasedFactorization(Factorization):

    def __init__(self, input, target_sr=44100, temporal_segmentation_params=None, composition_fn=None):
        super().__init__(input, target_sr, temporal_segmentation_params, composition_fn)
        # the following part is specific to each source sep. algorithm
        self.original_components, self._components_names = self.initialize_components()
        self.prepare_components(0, len(self._original_mix))

    def compose_model_input(self, components=None):
        sel_sources = self.retrieve_components(selection_order=components)
        if len(sel_sources) > 1:
            y = sum(sel_sources)
        else:
            y = sel_sources[0]
        return self._composition_fn(y)

    def get_number_components(self):
        return len(self.components)

    def retrieve_components(self, selection_order=None):
        if selection_order is None:
            return self.components
        if len(selection_order) == 0:
            return [np.zeros_like(self.components[0])]
        return [self.components[o] for o in selection_order]

    def get_ordered_component_names(self):
        if len(self._components_names) == 0:
            raise Exception("Components were not named.")
        return self._components_names

    def initialize_components(self):
        raise NotImplementedError

    def prepare_components(self, start_sample, y_length):
        # this resets in case temporal segmentation was previously applied
        self.components = [
            comp[start_sample:start_sample + y_length] for comp in self.original_components]

        component_names = []
        temporary_components = []
        for s, (segment_start, segment_end) in enumerate(self.temporal_segments):
            for co in range(self.get_number_components()):
                current_component = np.zeros(self.explained_length, dtype=np.float32)
                current_component[segment_start:segment_end] = self.components[co][segment_start:segment_end]
                temporary_components.append(current_component)
                component_names.append(self._components_names[co] + str(s))

        self.components = temporary_components
        self._components_names = component_names

def pickle_dump(x, path):
    pickle.dump(x, open(path, "wb"))


def pickle_load(path):
    return pickle.load(open(path, "rb"))

class LimeBase(object):
    """Class for learning a locally linear sparse model from perturbed data"""
    def __init__(self,
                 kernel_fn,
                 verbose=False,
                 absolute_feature_sort=False,
                 random_state=None):

        self.kernel_fn = kernel_fn
        self.verbose = verbose
        self.absolute_feature_sort = absolute_feature_sort
        self.random_state = check_random_state(random_state)

    @staticmethod
    def generate_lars_path(weighted_data, weighted_labels):

        x_vector = weighted_data
        alphas, _, coefs = lars_path(x_vector,
                                     weighted_labels,
                                     method='lasso',
                                     verbose=False)
        return alphas, coefs

    def forward_selection(self, data, labels, weights, num_features):
        """Iteratively adds features to the model"""
        clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
        used_features = []
        for _ in range(min(num_features, data.shape[1])):
            max_ = -100000000
            best = 0
            for feature in range(data.shape[1]):
                if feature in used_features:
                    continue
                clf.fit(data[:, used_features + [feature]], labels,
                        sample_weight=weights)
                score = clf.score(data[:, used_features + [feature]],
                                  labels,
                                  sample_weight=weights)
                if score > max_:
                    best = feature
                    max_ = score
            used_features.append(best)
        return np.array(used_features)

    def feature_selection(self, data, labels, weights, num_features, method):
        """Selects features for the model. see explain_instance_with_data to
           understand the parameters."""
        if method == 'none':
            return np.array(range(data.shape[1]))
        elif method == 'forward_selection':
            return self.forward_selection(data, labels, weights, num_features)
        elif method == 'highest_weights':
            clf = Ridge(alpha=0, fit_intercept=True,
                        random_state=self.random_state)
            clf.fit(data, labels, sample_weight=weights)

            coef = clf.coef_
            if sp.sparse.issparse(data):
                coef = sp.sparse.csr_matrix(clf.coef_)
                weighted_data = coef.multiply(data[0])
                # Note: most efficient to slice the data before reversing
                sdata = len(weighted_data.data)
                argsort_data = np.abs(weighted_data.data).argsort()
                # Edge case where data is more sparse than requested number of feature importances
                # In that case, we just pad with zero-valued features
                if sdata < num_features:
                    nnz_indexes = argsort_data[::-1]
                    indices = weighted_data.indices[nnz_indexes]
                    num_to_pad = num_features - sdata
                    indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
                    indices_set = set(indices)
                    pad_counter = 0
                    for i in range(data.shape[1]):
                        if i not in indices_set:
                            indices[pad_counter + sdata] = i
                            pad_counter += 1
                            if pad_counter >= num_to_pad:
                                break
                else:
                    nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
                    indices = weighted_data.indices[nnz_indexes]
                return indices
            else:
                weighted_data = coef * data[0]
                feature_weights = sorted( # TODO: check if abs should be optional
                    zip(range(data.shape[1]), weighted_data),
                    key=lambda x: np.abs(x[1]),
                    reverse=True)
                return np.array([x[0] for x in feature_weights[:num_features]])
        elif method == 'lasso_path':
            weighted_data = ((data - np.average(data, axis=0, weights=weights))
                             * np.sqrt(weights[:, np.newaxis]))
            weighted_labels = ((labels - np.average(labels, weights=weights))
                               * np.sqrt(weights))
            nonzero = range(weighted_data.shape[1])
            _, coefs = self.generate_lars_path(weighted_data,
                                               weighted_labels)
            for i in range(len(coefs.T) - 1, 0, -1):
                nonzero = coefs.T[i].nonzero()[0]
                if len(nonzero) <= num_features:
                    break
            used_features = nonzero
            return used_features
        elif method == 'auto':
            if num_features <= 6:
                n_method = 'forward_selection'
            else:
                n_method = 'highest_weights'
            return self.feature_selection(data, labels, weights,
                                          num_features, n_method)

    def explain_instance_with_data(self,
                                   neighborhood_data,
                                   neighborhood_labels,
                                   distances,
                                   label,
                                   num_features,
                                   feature_selection='auto',
                                   model_regressor=None,
                                   fit_intercept=True):


        weights = self.kernel_fn(distances)
        labels_column = neighborhood_labels[:, label]
        used_features = self.feature_selection(neighborhood_data,
                                               labels_column,
                                               weights,
                                               num_features,
                                               feature_selection)
        if model_regressor is None:
            model_regressor = Ridge(alpha=1, fit_intercept=fit_intercept,
                                    random_state=self.random_state)
        easy_model = model_regressor
        easy_model.fit(neighborhood_data[:, used_features],
                       labels_column, sample_weight=weights)
        prediction_score = easy_model.score(
            neighborhood_data[:, used_features],
            labels_column, sample_weight=weights)

        local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))

        if self.absolute_feature_sort:
            sorted_local_exp = sorted(zip(used_features, easy_model.coef_),
                   key=lambda x: np.abs(x[1]), reverse=True)
        else:
            sorted_local_exp = sorted(zip(used_features, easy_model.coef_),
                                  key=lambda x: x[1], reverse=True)

        if self.verbose:
            print('Intercept:', easy_model.intercept_)
            print('Prediction_local:', local_pred,)
            print('Right:', neighborhood_labels[0, label])
            print('Score:', prediction_score)
        return (easy_model.intercept_,
                sorted_local_exp,
                prediction_score, local_pred)

class AudioExplanation(object):
    def __init__(self, factorization, neighborhood_data, neighborhood_labels):
        """Init function.

        Args:
            factorization: a Factorization object
        """
        self.factorization = factorization
        self.neighborhood_data = neighborhood_data
        self.neighborhood_labels = neighborhood_labels
        self.intercept = {}
        self.local_exp = {}
        self.local_pred = {}
        self.score = {}
        self.distance = {}

    def get_sorted_components(self, label, positive_components=True, negative_components=True, num_components='all',
                              min_abs_weight=0.0, return_indeces=False):
        if label not in self.local_exp:
            raise KeyError('Label not in explanation')
        if positive_components is False and negative_components is False:
            raise ValueError('positive_components, negative_components or both must be True')
        if num_components == 'auto':
            raise ValueError("num_components='auto' was removed.")

        exp = self.local_exp[label]

        w = [[x[0], x[1]] for x in exp]
        used_features, weights = np.array(w, dtype=int)[:, 0], np.array(w)[:, 1]

        if not negative_components:
            pos_weights = np.argwhere(weights > 0)[:, 0]
            used_features = used_features[pos_weights]
            weights = weights[pos_weights]
        elif not positive_components:
            neg_weights = np.argwhere(weights < 0)[:, 0]
            used_features = used_features[neg_weights]
            weights = weights[neg_weights]
        if min_abs_weight != 0.0:
            abs_weights = np.argwhere(abs(weights) >= min_abs_weight)[:, 0]
            used_features = used_features[abs_weights]
            weights = weights[abs_weights]

        if num_components == 'all':
            num_components = len(used_features)
        else:
            assert(isinstance(num_components, int))

        used_features = used_features[:num_components]
        components = self.factorization.retrieve_components(used_features)
        if return_indeces:
            return components, used_features
        return components

    def get_detailed_components(self, label, positive_components=True, negative_components=True, num_components='all', min_abs_weight=0.0, return_indices=False):
        components, used_indices = self.get_sorted_components(label, positive_components, negative_components, num_components, min_abs_weight, return_indeces=True)
        detailed_labels = []

        for index in used_indices:
            component_label = self.factorization.get_ordered_component_names()[index]
            detailed_labels.append(component_label)

        component_details = zip(detailed_labels, components)
        detailed_components_list = [{"label": label, "weight": weight} for label, weight in component_details]

        return detailed_components_list


class LimeAudioExplainer(object):
    """Explains predictions on audio data."""

    def __init__(self, kernel_width=.25, kernel=None, verbose=False,
                 feature_selection='auto', absolute_feature_sort=False, random_state=None):

        kernel_width = float(kernel_width)

        if kernel is None:
            def kernel(d, kernel_width):
                return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))

        kernel_fn = partial(kernel, kernel_width=kernel_width)

        self.random_state = check_random_state(random_state)
        self.feature_selection = feature_selection
        self.base = LimeBase(kernel_fn, verbose, absolute_feature_sort, random_state=self.random_state)

    def explain_instance(self, factorization, predict_fn,
                         labels=None,
                         top_labels=None,
                         num_reg_targets=None,
                         num_features=100000,
                         num_samples=1000,
                         batch_size=32,
                         distance_metric='cosine',
                         model_regressor=None,
                         random_seed=None,
                         fit_intercept=True):


        is_classification = False
        if labels or top_labels:
            is_classification = True
        if is_classification and num_reg_targets:
            raise ValueError('Set labels or top_labels for classification. '
                             'Set num_reg_targets for regression.')

        if random_seed is None:
            random_seed = self.random_state.randint(0, high=1000)

        self.factorization = factorization
        top = labels

        data, labels = self.data_labels(predict_fn, num_samples,
                                        batch_size=batch_size)

        distances = sklearn.metrics.pairwise_distances(
            data,
            data[0].reshape(1, -1),
            metric=distance_metric
        ).ravel()

        ret_exp = AudioExplanation(self.factorization, data, labels)

        if is_classification:
            if top_labels:
                top = np.argsort(labels[0])[-top_labels:]
                ret_exp.top_labels = list(top)
                ret_exp.top_labels.reverse()
            for label in top:
                (ret_exp.intercept[label],
                 ret_exp.local_exp[label],
                 ret_exp.score[label], ret_exp.local_pred[label]) = self.base.explain_instance_with_data(
                    data, labels, distances, label, num_features,
                    model_regressor=model_regressor,
                    feature_selection=self.feature_selection,
                    fit_intercept=fit_intercept)
        else:
            for target in range(num_reg_targets):
                (ret_exp.intercept[target],
                 ret_exp.local_exp[target],
                 ret_exp.score[target],
                 ret_exp.local_pred[target]) = self.base.explain_instance_with_data(
                    data, labels, distances, target, num_features,
                    model_regressor=model_regressor,
                    feature_selection=self.feature_selection,
                    fit_intercept=fit_intercept)
        return ret_exp

    def data_labels(self,
                    predict_fn,
                    num_samples,
                    batch_size=10):

        n_features = self.factorization.get_number_components()
        if num_samples == 'exhaustive':
            import itertools
            num_samples = 2**n_features
            data = np.array(list(map(list, itertools.product([1, 0], repeat=n_features))))
        else:
            data = self.random_state.randint(0, 2, num_samples * n_features) \
                .reshape((num_samples, n_features))
            data[0, :] = 1  # first row all is set to 1

        labels = []
        audios = []
        for row in data:
            non_zeros = np.where(row != 0)[0]
            temp = self.factorization.compose_model_input(non_zeros)
            audios.append(temp)
            if len(audios) == batch_size:
                preds = predict_fn(np.array(audios))
                labels.extend(preds)
                audios = []
        if len(audios) > 0:
            preds = predict_fn(np.array(audios))
            labels.extend(preds)
        return data, np.array(labels)o