In [None]:
%cd ..

In [None]:
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler, FunctionTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import mean_squared_error
from imvc.datasets import LoadDataset
from imvc.decomposition import jNMF
from imvc.preprocessing import MultiViewTransformer, ConcatenateViews, DropView
from imvc.ampute import Amputer
import matplotlib
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from functools import reduce
from sklearn.preprocessing import FunctionTransformer
from imvc.feature_selection import jNMFFeatureSelector
from sklearn.metrics import accuracy_score
import matplotlib.patches as mpatches

In [None]:
from tueplots import axes, bundles
plt.rcParams.update(**bundles.icml2022(), **axes.lines())
for key in ["axes.labelsize", "axes.titlesize", "font.size", "legend.fontsize", "xtick.labelsize", "ytick.labelsize"]:
    if key == "legend.fontsize":
        plt.rcParams[key] += 3
    else:
        plt.rcParams[key] += 6

In [None]:
Xs, y = LoadDataset.load_dataset(dataset_name="nutrimouse", return_y=True)
print("Samples:", len(Xs[0]), "\t", "Modalities:", len(Xs), "\t", "Features:", [X.shape[1] for X in Xs])
y = y.iloc[:,0]
n_clusters = int(y.nunique())
y.value_counts()

In [None]:
ps = np.arange(0, 0.9, 0.2)
n_components_list = [1, 2, 4, 8, 16]
n_times = 50
algorithms = ["Dimensionality reduction", "Feature selection", "Random features"]
mechanisms = ["um", "pm", "mcar", "mnar"]
all_metrics = {}

In [None]:
for algorithm in tqdm(algorithms):
    all_metrics[algorithm] = {}
    for mechanism in tqdm(mechanisms):
        all_metrics[algorithm][mechanism] = {}
        for n_components in n_components_list:
            all_metrics[algorithm][mechanism][n_components] = {}
            for p in ps:
                missing_percentage = int(p*100)
                all_metrics[algorithm][mechanism][n_components][missing_percentage] = {}
                for i in range(n_times):
                    all_metrics[algorithm][mechanism][n_components][missing_percentage][i] = {}
                    Xs_train = Amputer(p= p, random_state=i).fit_transform(Xs)
                    for X in Xs_train:
                        X.iloc[np.random.default_rng(i).choice([True, False], p= [p,1-p], size = X.shape)] = np.nan
                    if algorithm == "Dimensionality reduction":
                        pipeline = make_pipeline(MultiViewTransformer(MinMaxScaler().set_output(transform="pandas")),
                                                 jNMF(n_components = n_components, random_state=i),
                                                 )
                    elif algorithm == "Feature selection":
                        pipeline = make_pipeline(MultiViewTransformer(MinMaxScaler().set_output(transform="pandas")),
                                                 jNMFFeatureSelector(n_components = n_components, random_state=i),
                                                 FunctionTransformer(lambda x: np.concatenate(x, axis=1)),
                                                 SimpleImputer(),
                                                 )
                    elif algorithm == "Random features":
                        pipeline = make_pipeline(MultiViewTransformer(MinMaxScaler().set_output(transform="pandas")),
                                                 ConcatenateViews(),
                                                 SimpleImputer().set_output(transform="pandas"),
                                                 FunctionTransformer(lambda x: 
                                                                     x.iloc[:,np.random.default_rng(i).integers(0,
                                                                                                                sum([X.shape[1] for X in Xs_train]),
                                                                                                                size= n_components)]),
                         )
                    try:
                        transformed_X = pipeline.fit_transform(Xs_train)
                        preds = SVC(random_state=i).fit(transformed_X, y).predict(transformed_X)
                        metric = accuracy_score(y_pred=preds, y_true=y)
                        all_metrics[algorithm][mechanism][n_components][missing_percentage][i]["Accuracy"] = metric
                        all_metrics[algorithm][mechanism][n_components][missing_percentage][i]["Comments"] = ""
                    except Exception as ex:
                        all_metrics[algorithm][mechanism][n_components][missing_percentage][i]["Comments"] = ex

In [None]:
flattened_data = [
    {
        'Method': algorithm,
        'Mechanism': mechanism,
        'Missing rate (\%)': p,
        'Components': n_components,
        'Iteration': i,
        **iter_dict
    }
    for algorithm, algorithm_dict in all_metrics.items()
    for mechanism, mechanism_dict in algorithm_dict.items()
    for n_components, n_components_dict in mechanism_dict.items()
    for p, p_dict in n_components_dict.items()
    for i, iter_dict in p_dict.items()
]
df = pd.DataFrame(flattened_data)
df = df.sort_values(["Method", "Mechanism", "Missing rate (\%)", "Components", "Iteration"], ascending=[True, True, True, True, True])
df.to_csv("tutorials/reduction_results.csv", index= None)
print(df.shape)
df.head()

In [None]:
df = pd.read_csv("tutorials/reduction_results.csv")
df

In [None]:
errors = df[df["Comments"].notnull()]
print("errors", errors.shape)
errors

In [None]:
mechanism_names = {"um": "Unpaired missing", "pm": "Partial missing", "mnar": "Missing not at random", "mcar": "Missing completely at random"}
colorblind_palette = sns.color_palette("colorblind")
g = sns.FacetGrid(data=df, col="Mechanism", row="Components", despine=False).map_dataframe(sns.pointplot, x="Missing rate (\%)", y="Accuracy", hue="Method",
                                                                          linestyles=["-", "--", ":"], capsize= 0.05, 
                                                                          seed= 42, palette=colorblind_palette)
handles = [plt.Line2D([0], [0], color=col, lw=2, linestyle=linestyle)
                  for col,linestyle in zip(colorblind_palette, ["-", "--", ":"])]
g.axes[0][0].legend(handles=handles, labels=df["Method"].unique().tolist(), loc= "best")

# for ax,n_components in zip(g.axes.flatten(), n_components_list):
for axes,n_components in zip(g.axes, df["Components"].unique()):
    for ax,mechanism in zip(axes, df["Mechanism"].unique()):
        ax.set_title(f"{mechanism.upper()}, Components\(|\)Features = {n_components}")

plt.tight_layout()
plt.savefig("paper_figures/selection_results_comps.pdf")
plt.savefig("paper_figures/selection_results_comps.svg")

In [None]:
p = 0.2
amputed_Xs = Amputer(p= p, mechanism="mcar", random_state=42).fit_transform(Xs)
for X in amputed_Xs:
    X.iloc[np.random.default_rng(42).choice([True, False], p= [p,1-p], size = X.shape)] = np.nan
n_components = 4

In [None]:
plt.figure(figsize= (4, 3))
ax = sns.pointplot(data=df[df["Components"] == n_components], x="Missing rate (\%)", y="Accuracy", hue="Method", linestyles=["-", "--", ":"],
              capsize= 0.05, seed= 42, palette=colorblind_palette, legend=False)

handles = [plt.Line2D([0], [0], color=col, lw=2, linestyle=linestyle)
                  for col,linestyle in zip(colorblind_palette, ["-", "--", ":"])]
ax.legend(handles=handles, labels=df["Method"].unique().tolist(), loc= "best")

plt.savefig("paper_figures/selection_results.pdf")
plt.savefig("paper_figures/selection_results.svg")

In [None]:
pipeline = make_pipeline(MultiViewTransformer(MinMaxScaler().set_output(transform="pandas")),
                         jNMFFeatureSelector(n_components = n_components, select_by="average", random_state=42,
                                             f_per_component= sum([X.shape[1] for X in Xs])//n_components))
transformed_Xs = pipeline.fit(amputed_Xs)
selected_features = {"Feature": pipeline[-1].selected_features_, "Feature Importance": pipeline[-1].weights_}
selected_features = pd.DataFrame(selected_features).sort_values(by="Feature Importance", ascending= False)
selected_features["Modality"] = selected_features["Feature"].apply(lambda x: "Genes" if ((x in Xs[0].columns) and (x not in Xs[1].columns)) \
    else ("Fatty Acids" if ((x in Xs[1].columns) and (x not in Xs[0].columns)) else "Not found"))
palette = {mod:col for mod, col in zip(selected_features["Modality"].unique(), ["#2ca25f", "#99d8c9"])}
selected_features = selected_features.groupby("Modality")["Feature Importance"].sum()
selected_features = selected_features.div(selected_features.sum()).mul(100)
selected_features = selected_features.sort_values(ascending=False)
selected_features

In [None]:
ax = selected_features.plot(kind= "bar", color= list(palette.values()), 
                            figsize= (4, 3), ylabel= "Modality Importance (\%)",
                            rot=0)

handles = [mpatches.Patch(color=color, label=modality) for modality, color in palette.items()]
ax.legend(handles=handles, title="Modality", loc='best')

plt.savefig("paper_figures/expl_mod.pdf")
plt.savefig("paper_figures/expl_mod.svg")

In [None]:
pipeline = make_pipeline(MultiViewTransformer(MinMaxScaler().set_output(transform="pandas")),
                         jNMFFeatureSelector(n_components = n_components, select_by="component", random_state=42))
transformed_Xs = pipeline.fit(amputed_Xs)
selected_features = {"Feature": pipeline[-1].selected_features_, "Feature Importance": pipeline[-1].weights_, "Component": pipeline[-1].component_}
selected_features = pd.DataFrame(selected_features).sort_values(by="Feature Importance", ascending= False)
selected_features["Modality"] = selected_features["Feature"].apply(lambda x: "Genes" if ((x in Xs[0].columns) and (x not in Xs[1].columns)) \
    else ("Fatty Acids" if ((x in Xs[1].columns) and (x not in Xs[0].columns)) else "Not found"))
selected_features["Component"] += 1
palette = [palette[mod] for mod in selected_features["Modality"]]
selected_features

In [None]:
plt.figure(figsize= (4, 3))
ax = sns.barplot(data=selected_features, y="Component", x="Feature Importance",
                 legend=False, orient="h", order= selected_features["Component"],
                 )
ax.set_xlim(0, selected_features["Feature Importance"].max() + .8)

col = 0
for x in ax.properties()['children']:
    if isinstance(x, matplotlib.patches.Rectangle):
        x.set_color(palette[col])
        col += 1
    if col == len(selected_features):
        break

for i, container in enumerate(ax.containers):
    ax.bar_label(container, labels=selected_features["Feature"], padding = 3)

plt.savefig("paper_figures/selected_features.pdf")
plt.savefig("paper_figures/selected_features.svg")

In [None]:
pipeline = make_pipeline(MultiViewTransformer(MinMaxScaler().set_output(transform="pandas")),
                         jNMFFeatureSelector(n_components = n_components, select_by="component", random_state=42, f_per_component=3))
transformed_Xs = pipeline.fit(amputed_Xs)
selected_features = {"Feature": pipeline[-1].selected_features_, "Feature Importance": pipeline[-1].weights_, "Component": pipeline[-1].component_}
selected_features = pd.DataFrame(selected_features).sort_values(by=["Component", "Feature Importance"], ascending= [True, False])
selected_features["Modality"] = selected_features["Feature"].apply(lambda x: "Genes" if ((x in Xs[0].columns) and (x not in Xs[1].columns)) \
    else ("Fatty Acids" if ((x in Xs[1].columns) and (x not in Xs[0].columns)) else "Not found"))
selected_features["Hue"] = list(range(selected_features["Component"].max())) * pipeline[-1].n_components
selected_features["Component"] += 1
selected_features

In [None]:
palette = {mod:col for mod, col in zip(selected_features["Modality"].unique(), ["#2ca25f", "#99d8c9"])}
plt.figure(figsize= (4, 3))
ax = sns.barplot(data=selected_features, y="Component", x="Feature Importance",
                 hue="Hue", legend=False, orient="h", width= .9,
                 )

ax.set_xlim(0, selected_features["Feature Importance"].max() + .8)

selected_features = selected_features.sort_values("Hue")
palette = [palette[mod] for mod in selected_features["Modality"]]
col = 0
for x in ax.properties()['children']:
    if isinstance(x, matplotlib.patches.Rectangle):
        x.set_color(palette[col])
        col += 1
    if col == len(selected_features):
        break

for i, container in enumerate(ax.containers):
    ax.bar_label(container, labels=selected_features[selected_features["Hue"] == i]["Feature"],
                 padding = 3)

plt.savefig("paper_figures/expl_features.pdf")
plt.savefig("paper_figures/expl_features.svg")