# Model II: generate data, perform Sobolev Alignment and analyse results.

In [None]:
import os, sys, pylab, gc, umap, re, scanpy, torch, scipy
import numpy as np
import seaborn as sns
from anndata import AnnData
import matplotlib.pyplot as plt
from functools import reduce
from pickle import dump, load
from copy import deepcopy
import pandas as pd
from joblib import Parallel, delayed
import umap
from statannot import add_stat_annotation
from shutil import copyfile
import rpy2
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
from sklearn.decomposition import PCA
from time import process_time
%config IPCompleter.use_jedi = False

from sobolev_alignment import SobolevAlignment, KRRApprox

In [None]:
source_palette = {'common': '#7F91D0', 'specific': '#71AB5A', 'half': 'tab:grey'}
target_palette = {'common': '#7F91D0', 'specific': '#C89E34', 'half': 'tab:grey'}

## Generate data
Call outside script to generate the data using Dyngen. <br/>
<b>WARNING:</b> Launching this script may work better in the console, reducing the overhead of Jupyter Notebook.

In [None]:
!sh generate_data.sh

## Pre-processing
### Import data

In [None]:
n_targets = 5
X_source = scanpy.read_h5ad('./data/source_dataset_%s_targets_large.h5ad'%(n_targets))
X_target = scanpy.read_h5ad('./data/target_dataset_%s_targets_large.h5ad'%(n_targets))

# Densify data
X_source.X = np.array(X_source.X.todense())
X_target.X = np.array(X_target.X.todense())

# Change simulation as string
X_target.obs['simulation_i'] = X_target.obs['simulation_i'].astype(str)
X_source.obs['simulation_i'] = X_source.obs['simulation_i'].astype(str)

### Potential filtering

In [None]:
def remove_target_kousekeeping_genes(X):
    is_target = np.array(X.var.index.str.contains('Target'))
    is_housekeeping = np.array(X.var.index.str.contains('HK'))
    to_remove = (~is_target) & (~is_housekeeping)
    
    return X[:,to_remove]

In [None]:
assert False
X_source = remove_target_kousekeeping_genes(X_source)
X_target = remove_target_kousekeeping_genes(X_target)

### Save data

In [None]:
X_combined = pd.concat({
    'CELL_LINE': X_source.to_df(),
    'TUMOR' : X_target.to_df()
})

# Save files
X_combined.to_csv('./data/combined_counts.csv', sep=',')

## Sobolev Alignment
Performs and save Sobolev Alignment between source and target. <br/>
<b>WARNING:</b> Launching this script may work better in the console, reducing the overhead of Jupyter Notebook.

In [None]:
!sh launch_model_II.sh

## Post-processing
### Parameters

In [None]:
output_folder = './output/'
figure_folder = './figures/'

### Computation of factor weights

In [None]:
%run ./model_II_scripts/compute_feature_weights.py -o $output_folder -j 10 -m 5

### Load Sobolev Alignment

In [None]:
# Input data
X_combined = pd.read_csv(
    './data/combined_counts.csv'%(output_folder),
    sep=',', 
    index_col=[0,1]
)
X_source = X_combined.loc['CELL_LINE']
X_target = X_combined.loc['TUMOR']
X_input = {'source': X_source, 'target': X_target}

In [None]:
iter_idx = 0
sobolev_alignment_clf = {}

# Load Sobolev Alignment
for kernel_type in ['laplacian', 'gaussian']:
    if 'iter_%s_nu_%s'%(iter_idx, kernel_type) not in os.listdir(output_folder):
        continue
    sobolev_alignment_clf[kernel_type] = SobolevAlignment.load(
        '%s/iter_%s_nu_%s/sobolev_alignment_model'%(output_folder, iter_idx, kernel_type),
        with_krr=True,
        with_model=False
    )

    sobolev_alignment_clf[kernel_type].training_data = {
        'source': X_source,
        'target': X_target
    }
    
    sobolev_alignment_clf[kernel_type].krr_log_input_ = True

In [None]:
X_log_input = {
    'source': sobolev_alignment_clf['laplacian'].approximate_krr_regressions_['source'].anchors(),
    'target': sobolev_alignment_clf['laplacian'].approximate_krr_regressions_['target'].anchors()
}

X_log_input = {
    x: pd.DataFrame(X_log_input[x].detach().numpy(), columns=X_source.columns)
    for x in X_log_input
}

## Goodness of fit

In [None]:
latent_embedding = {
    x: pd.read_csv('%s/scvi_embedding_%s.csv'%(output_folder, x), header=None, sep=' ')
    for x in ['source', 'target']
}

In [None]:
plt.figure(figsize=(8,3))
prediction_latent_corr = {}
for data_source in ['source', 'target']:
    print('START %s'%(data_source))
    
    prediction_latent_corr[data_source] = []
    krr_pred = sobolev_alignment_clf['laplacian'].approximate_krr_regressions_[data_source].transform(
        torch.Tensor(
            sobolev_alignment_clf['laplacian']._frobenius_normalisation(
                data_source,
                np.log10(X_input[data_source] + 1),
                frob_norm_source=True
    ).values))
    target_spearman_corr = []
    for x in range(krr_pred.shape[1]):
        prediction_latent_corr[data_source].append(
            scipy.stats.spearmanr(krr_pred[:,x], latent_embedding[data_source][x])[0]
        )
        
    sns.distplot(
        prediction_latent_corr[data_source], 
        label='CELL LINE' if data_source == 'source' else 'TUMOR',
        kde_kws={"lw": 3}
    )
    
plt.xlabel('Spearman correlation between\n KRR and scVI', fontsize=20, color='black')
plt.xticks(fontsize=15)
plt.ylabel('Proportion', fontsize=20, color='black')
plt.yticks([], [])
plt.legend(fontsize=15)
plt.tight_layout()
plt.savefig('%s/hist_spearman_corr_reconstruction_latent.png'%(figure_folder))
plt.show()

## Recompute all PVs (if need be)

In [None]:
for kernel_type in sobolev_alignment_clf:
    sobolev_alignment_clf[kernel_type]._compute_principal_vectors(all_PVs=True)
    sobolev_alignment_clf[kernel_type].save(
        folder='%s/iter_%s_nu_%s/sobolev_alignment_model'%(output_folder, iter_idx, kernel_type),
        with_krr=True,
        with_model=False
    )

## Import corrected terms

### Linear

In [None]:
from supporting_scripts.linear_terms_treatment import read_latent_factors
from supporting_scripts.linear_terms_treatment import process_df
from supporting_scripts.linear_terms_treatment import aggregate_norm_comparison
from supporting_scripts.linear_terms_treatment import process_PV_linear_norm_comparison
from supporting_scripts.linear_terms_treatment import correct_linear_features
from supporting_scripts.linear_terms_treatment import assess_equality_contribution
from supporting_scripts.linear_terms_treatment import process_linear_weights
from supporting_scripts.linear_terms_treatment import compute_gene_std
from supporting_scripts.linear_terms_treatment import _compute_kernel_param

In [None]:
# Read data
n_iter = 0
kernel_name = 'laplacian'

source_PV_linear_features = read_latent_factors(
    output_folder,
    'PV_linear_weights_source_order_1',
    n_iter,
    kernel_name
)
target_PV_linear_features = read_latent_factors(
    output_folder,
    'PV_linear_weights_target_order_1',
    n_iter,
    kernel_name
)

# Process files
source_PV_linear_features = process_df(source_PV_linear_features, square_value=True)
target_PV_linear_features = process_df(target_PV_linear_features, square_value=True)

In [None]:
# Format DataFrame
source_linear_PV_norm_comparison = process_PV_linear_norm_comparison(
    source_PV_linear_features.groupby(['kernel', 'factor', 'iter']).agg('sum'),
    kernel_name=kernel_name
)
target_linear_PV_norm_comparison = process_PV_linear_norm_comparison(
    target_PV_linear_features.groupby(['kernel', 'factor', 'iter']).agg('sum'),
    kernel_name=kernel_name
)

# Correct scaling coefficient between Gaussian and Laplacian
source_PV_linear_features = correct_linear_features(source_PV_linear_features, source_linear_PV_norm_comparison)
target_PV_linear_features = correct_linear_features(target_PV_linear_features, target_linear_PV_norm_comparison)

# Assert that the norm correction is correct
if kernel_name != 'gaussian':
    assess_equality_contribution(source_PV_linear_features)
    assess_equality_contribution(target_PV_linear_features)

# Put the matrix in form genes x P    
source_linear_weights = process_linear_weights(source_PV_linear_features, kernel_name=kernel_name)
target_linear_weights = process_linear_weights(target_PV_linear_features, kernel_name=kernel_name)

# Process the gene names
source_PV_linear_features['variable'] = source_PV_linear_features['variable'].str.extract(r'([A-Za-z0-9-_]*)\^1')
target_PV_linear_features['variable'] = target_PV_linear_features['variable'].str.extract(r'([A-Za-z0-9-_]*)\^1')

In [None]:
# Correct for standard deviation
gamma = {s: _compute_kernel_param(s, sobolev_alignment_clf) for s in ['source', 'target']}
assert gamma['source'] == gamma['target']
gamma = gamma['source']

# Compute feature-level (gene * exp offset) standard deviation
all_genes_std = {
    data_source: compute_gene_std(
        X_log_input[data_source],
        np.exp(- gamma * np.square(np.linalg.norm(X_log_input[data_source], axis=1)))
    )
    for data_source in ['source', 'target']
}
all_genes_std = {
    data_source: pd.DataFrame(all_genes_std[data_source], index=['feature_std']).T
    for data_source in ['source', 'target']
}

# Merge dataset and correct
source_PV_linear_features = source_PV_linear_features.merge(
    all_genes_std['source'],
    left_on='variable',
    right_index=True
)
target_PV_linear_features = target_PV_linear_features.merge(
    all_genes_std['target'],
    left_on='variable',
    right_index=True
)

# Correct naming
source_PV_linear_features['variable'] = source_PV_linear_features['variable'].str.replace('_TF1', '')
target_PV_linear_features['variable'] = target_PV_linear_features['variable'].str.replace('_TF1', '')

# WARNING: HAS TO BE MULTIPLIED BY STANDARD DEVIATION
for df in [source_PV_linear_features, target_PV_linear_features]:
    df['standardized_value'] = df['corrected_value'] * df['feature_std']

In [None]:
source_PV_linear_features = source_PV_linear_features.set_index('kernel').loc[kernel_name]
source_PV_linear_features = source_PV_linear_features.pivot(
    index='variable', 
    columns='factor',
    values='standardized_value'
)

target_PV_linear_features = target_PV_linear_features.set_index('kernel').loc[kernel_name]
target_PV_linear_features = target_PV_linear_features.pivot(
    index='variable', 
    columns='factor',
    values='standardized_value'
)

### Interactions

In [None]:
from supporting_scripts.interaction_terms_treatment import read_interaction_weights
from supporting_scripts.interaction_terms_treatment import compute_all_interactions_std, _compute_kernel_param

In [None]:
iter_idx = 0

source_file_radical = 'PV_linear_weights_source_order_2'
source_interaction_scaling_coef, source_PV_interaction_features = read_interaction_weights(
    file_radical=source_file_radical,
    output_folder=output_folder,
    iter_idx=iter_idx,
    kernel_name=kernel_name
)

target_file_radical = 'PV_linear_weights_target_order_2'
target_interaction_scaling_coef, target_PV_interaction_features = read_interaction_weights(
    file_radical=target_file_radical,
    output_folder=output_folder,
    iter_idx=iter_idx,
    kernel_name=kernel_name
)
gc.collect()

#### Correct for standard deviation

In [None]:
interactions_std = {
    data_source: compute_all_interactions_std(
        X_log_input[data_source],
        np.exp(- gamma * np.square(np.linalg.norm(X_log_input[data_source], axis=1)))
    ) for data_source in X_log_input
}
interactions_std = {
    data_source: pd.DataFrame(
        reduce(lambda x,y: x + y, interactions_std[data_source]),
        columns=['gene_A', 'gene_B', 'interaction_std']
    ).set_index(['gene_A', 'gene_B'])
    for data_source in interactions_std
}

In [None]:
source_PV_interaction_features = pd.concat(
    [source_PV_interaction_features, interactions_std['source']],
    axis=1
)
for col in source_PV_interaction_features.columns:
    if col == 'interactions_std':
        continue
    source_PV_interaction_features[col] *= source_PV_interaction_features['interaction_std']
del source_PV_interaction_features['interaction_std']

target_PV_interaction_features = pd.concat(
    [target_PV_interaction_features, interactions_std['target']],
    axis=1
)
for col in target_PV_interaction_features.columns:
    if col == 'interactions_std':
        continue
    target_PV_interaction_features[col] *= target_PV_interaction_features['interaction_std']
del target_PV_interaction_features['interaction_std']

### Compute square loadings

In [None]:
for col in source_PV_linear_features.columns:
    source_PV_linear_features['%s_squared'%(col)] = np.square(source_PV_linear_features[col])
    source_PV_interaction_features['%s_squared'%(col)] = np.square(source_PV_interaction_features[col])
for col in target_PV_linear_features.columns:
    target_PV_linear_features['%s_squared'%(col)] = np.square(target_PV_linear_features[col])
    target_PV_interaction_features['%s_squared'%(col)] = np.square(target_PV_interaction_features[col])

## Compute all weights

In [None]:
def plot_contribution_prop(data_source):
    weights = {}
    weights_sum = {}

    # Read coefficients
    for i in range(1,5):
        weights[i] = pd.read_pickle(
            '%s/iter_0_nu_gaussian/PV_linear_weights_%s_order_%s.gz'%(output_folder, data_source, i),
            compression='gzip'
        )
        weights_sum[i] = np.sum(np.square(weights[i]), axis=1)

    # Format weights and plot
    weights_sum_df = pd.concat(weights_sum, axis=1)
    weights_sum_df.index = ['SPV %s'%(i+1) for i in range(weights_sum_df.shape[0])]
    weights_sum_df.plot.bar(stacked=True, width=0.95)
    plt.axhline(1, linestyle='--', color='black')
    
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=15).set_title('Order', prop={'size':20})
    plt.xticks(fontsize=20, color='black')
    plt.yticks(fontsize=15, color='black')
    plt.ylabel('Contribution', fontsize=20, color='black')
    plt.tight_layout()
    
for data_source in ['source', 'target']:
    plot_contribution_prop(data_source)
    plt.savefig('%s/proportion_PVs_per_order_%s.png'%(figure_folder, data_source), dpi=300)
    plt.show()

## Similarity visualisation

In [None]:
n_source_PVs = source_PV_linear_features.shape[1]//2
n_target_PVs = target_PV_linear_features.shape[1]//2

In [None]:
def plot_feature_correlation_matrix(source_PV_feat, target_PV_feat, n_source_PVs, n_target_PVs):
    # Compute correlation
    feat_corr_matrix = pd.concat([
        source_PV_feat[['PV %s'%(e) for e in range(n_source_PVs)]],
        target_PV_feat[['PV %s'%(e) for e in range(n_target_PVs)]]
    ], axis=1).corr().values[n_source_PVs:,:n_source_PVs]

    # Plot the heatmap
    ax = sns.heatmap(feat_corr_matrix, cmap='seismic_r', center=0)
    
    # Visualisation routines
    cax = plt.gcf().axes[-1]
    cax.tick_params(labelsize=15)
    plt.xticks(
        np.arange(n_source_PVs)+.5, ['SPV %s'%(i+1) for i in range(n_source_PVs)],
        fontsize=15, color='black', rotation=90
    )
    plt.yticks(
        np.arange(n_target_PVs)+.5, ['SPV %s'%(i+1) for i in range(n_target_PVs)],
        fontsize=15, color='black', rotation=0
    )
    plt.xlabel('Source', fontsize=20, color='black')
    plt.ylabel('Target', fontsize=20, color='black')
    plt.tight_layout()

# Linear terms
plot_feature_correlation_matrix(
    source_PV_linear_features, target_PV_linear_features, n_source_PVs, n_target_PVs
)
plt.savefig('%s/cosine_sim_linear_terms.png'%(figure_folder), dpi=300)
plt.show()

# Interaction terms
plot_feature_correlation_matrix(
    source_PV_interaction_features, target_PV_interaction_features, n_source_PVs, n_target_PVs
)
plt.savefig('%s/cosine_sim_interaction_terms.png'%(figure_folder), dpi=300)

In [None]:
def plot_one_self_corr_matrix(PV_feat, n_PV, ax, label):
    self_corr_matrix = pd.concat([
        PV_feat[['PV %s'%(e) for e in range(n_PV)]],
        PV_feat[['PV %s'%(e) for e in range(n_PV)]]
    ], axis=1).corr().values[n_PV:,:n_PV]
    sns.heatmap(self_corr_matrix, cmap='seismic_r', center=0, vmin=-1, vmax=1, ax=ax)
    ax.set_xlabel(label, fontsize=20, color='black')
    ax.set_ylabel(label, fontsize=20, color='black')
    ax.set_xticks(np.arange(n_PV)+.5)
    ax.set_xticklabels(['SPV %s'%(i+1) for i in range(n_PV)], fontsize=15, color='black', rotation=90)
    ax.set_yticks(np.arange(n_PV)+.5)
    ax.set_yticklabels(['SPV %s'%(i+1) for i in range(n_PV)], fontsize=15, color='black', rotation=0)

# Linear to linear
fig, axes = plt.subplots(1,2, figsize=(10,4))
plot_one_self_corr_matrix(source_PV_linear_features, n_source_PVs, axes[0], 'Source')
plot_one_self_corr_matrix(target_PV_linear_features, n_target_PVs, axes[1], 'Target')
plt.tight_layout()
plt.savefig('%s/self_cosine_sim_linear_terms.png'%(figure_folder), dpi=300)
plt.show()

# Interactions to interactions
fig, axes = plt.subplots(1,2, figsize=(10,4))
plot_one_self_corr_matrix(source_PV_interaction_features, n_source_PVs, axes[0], 'Source')
plot_one_self_corr_matrix(target_PV_interaction_features, n_target_PVs, axes[1], 'Target')
plt.tight_layout()
plt.savefig('%s/self_cosine_sim_interaction_terms.png'%(figure_folder), dpi=300)
plt.show()

## Feature waterfall and boxplot
### Routine

In [None]:
common_genes = [
    'Burn1', 'Burn2', 'Burn3', 'Burn4',
    'A1', 'A2', 'A3', 'A4', 'A5',
    'B1', 'B2', 'B3', 'B4', 'B5',
    'C1', 'C2', 'C3', 'C4', 'C5'
]
individual_genes = [
    'X1', 'X2', 'X3', 'Y1', 'Y2', 'Y3', 
    'Burn5', 'Burn6', 'Burn7', 'Burn8'
]
feature_renaming = {'Burn': 'Ext'}
colors = {'common': 'tab:green', 'specific': 'tab:red', 'half': 'tab:grey'}

def compute_ordered_contributions(linear_df, interactions_df, PV_number, exclusion_pattern):
    """
    Takes as input interactions and linear features, stack them and order them
    exclusion_pattern: string to remove certain genes or interactions, e.g. Y for source and
    X for target.
    """
    # Format linear features
    combined_features_df = linear_df[['PV %s_squared'%(PV_number)]]
    combined_features_df['gene_A'] = combined_features_df.index
    combined_features_df['gene_B'] = 'NA'

    # Concatenate linear and interaction terms and format the names
    combined_features_df = pd.concat(
        [combined_features_df, interactions_df[['PV %s_squared'%(PV_number)]].reset_index()],
        axis=0
    )
    combined_features_df['gene_A'] = combined_features_df['gene_A'].str.extract(r'([A-Za-z0-9]*)')
    combined_features_df['gene_B'] = combined_features_df['gene_B'].str.extract(r'([A-Za-z0-9]*)')
    
    # Compute feature status
    combined_features_df['feature_status'] = combined_features_df['gene_A'].isin(common_genes).astype(int)
    combined_features_df['feature_status'] *= (1 + (combined_features_df['gene_B'] == 'NA')).astype(int) # For genes to be in common
    combined_features_df['feature_status'] += combined_features_df['gene_B'].isin(common_genes).astype(int)
    combined_features_df['feature_status'] -= 1

    # Remove target-specific genes for comparison
    combined_features_df = combined_features_df.loc[
        ~combined_features_df['gene_A'].str.contains(exclusion_pattern)
    ]
    combined_features_df = combined_features_df.loc[
        ~combined_features_df['gene_B'].str.contains(exclusion_pattern)
    ]
    
    feature_status_dict = {0: 'half', 1: 'common', -1: 'specific'}
    combined_features_df['feature_status'] = combined_features_df['feature_status'].apply(
        lambda x: feature_status_dict[x]
    )

    return combined_features_df

def barplot_top_contributors(df, PV_number, palette, file_save=None, top_genes=30, ax=None):
    plot_df = deepcopy(df)
    plot_df = plot_df.sort_values('PV %s_squared'%(PV_number), ascending=False).reset_index(drop=True).reset_index()
    plot_df['color'] = plot_df['feature_status'].apply(lambda x: palette[x])
    plot_df['feat_name'] = (plot_df['gene_A'] + '-' + plot_df['gene_B']).str.replace('-NA', '')
    for x in feature_renaming:
        plot_df['feat_name'] = plot_df['feat_name'].str.replace(x, feature_renaming[x])
    
    sns.barplot(
        data=plot_df.head(top_genes),
        x='feat_name', y='PV %s_squared'%(PV_number), palette=plot_df['color'], ax=ax
    )
    if ax is None:
        plt.xticks(fontsize=15, color='black', rotation=90)
        plt.yticks(fontsize=15, color='black', rotation=0)
        plt.ylabel('Loadings', fontsize=20, color='black')
        plt.xlabel('')
        plt.tight_layout()
    else:
        ax.tick_params(axis='x', labelsize=15, labelcolor='black', rotation=90)
        ax.tick_params(axis='y', labelsize=15, labelcolor='black', rotation=0)
        ax.set_xlabel('')
        ax.set_ylabel('Loadings', size=20, color='black')
    
    if file_save is not None:
        plt.savefig(file_save, dpi=300)
    if ax is None:
        plt.show()
    del plot_df

### Source

In [None]:
source_combined_features_df = []
for PV_number in range(source_PV_linear_features.shape[1] // 2):
    combined_features_df = compute_ordered_contributions(
        source_PV_linear_features, source_PV_interaction_features, PV_number, r'(Y|Burn7|Burn8)'
    )
    barplot_top_contributors(
        combined_features_df, PV_number, top_genes=20, palette=source_palette,
        file_save='%s/barplot_features_source_PV_%s.png'%(figure_folder, PV_number+1)
    )
    source_combined_features_df.append(combined_features_df.set_index(['gene_A', 'gene_B', 'feature_status']))

In [None]:
# Stacked plot of contributions
stacked_plot_df = pd.concat(source_combined_features_df, axis=1).groupby('feature_status').agg('sum').T
stacked_plot_df = stacked_plot_df.T / np.sum(stacked_plot_df, axis=1)
stacked_plot_df.T.plot.bar(stacked=True, color=source_palette, width=0.9, figsize=(8,4))

plt.xticks(np.arange(n_source_PVs), ['SPV %s'%(i+1) for i in np.arange(n_source_PVs)], fontsize=20, rotation=90)
plt.yticks(fontsize=15, rotation=0)
plt.ylabel('PV contribution', fontsize=20, color='black')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
plt.tight_layout()
plt.savefig('%s/contributions_source.png'%(figure_folder), dpi=300)

del stacked_plot_df

In [None]:
if type(source_combined_features_df) != pd.DataFrame:
    source_combined_features_df = pd.concat(source_combined_features_df, axis=1)
source_combined_features_df['is_linear'] = source_combined_features_df.index.get_level_values('gene_B') == 'NA'

# Linear features
linear_proportion = source_combined_features_df.loc[source_combined_features_df['is_linear']]
linear_proportion = linear_proportion / np.sum(linear_proportion)
del linear_proportion['is_linear']
linear_proportion = linear_proportion.groupby('feature_status').agg('sum')
linear_proportion.T.plot.bar(stacked=True, width=0.9, figsize=(8,4), color=source_palette)

plt.xticks(np.arange(n_source_PVs), ['SPV %s'%(i+1) for i in np.arange(n_source_PVs)], fontsize=20, rotation=90)
plt.yticks(fontsize=15, rotation=0)
plt.ylabel('PV contribution', fontsize=20, color='black')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
plt.tight_layout()
plt.savefig('%s/contributions_source_linear.png'%(figure_folder), dpi=300)
plt.show()

# Interaction features
interaction_proportion = source_combined_features_df.loc[~source_combined_features_df['is_linear']]
interaction_proportion = interaction_proportion / np.sum(interaction_proportion)
del interaction_proportion['is_linear']

interaction_proportion = interaction_proportion.groupby('feature_status').agg('sum')
interaction_proportion.T.plot.bar(stacked=True, width=0.9, figsize=(8,4), color=source_palette)

plt.xticks(np.arange(n_source_PVs), ['SPV %s'%(i+1) for i in np.arange(n_source_PVs)], fontsize=20, rotation=90)
plt.yticks(fontsize=15, rotation=0)
plt.ylabel('PV contribution', fontsize=20, color='black')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
plt.tight_layout()
plt.savefig('%s/contributions_source_interactions.png'%(figure_folder), dpi=300)
plt.show()

### Target

In [None]:
target_combined_features_df = []
for PV_number in range(target_PV_interaction_features.shape[1] // 2):
    combined_features_df = compute_ordered_contributions(
        target_PV_linear_features, target_PV_interaction_features, PV_number, r'(X|Burn5|Burn6)'
    )
    barplot_top_contributors(
        combined_features_df, PV_number, top_genes=20, palette=target_palette,
        file_save='%s/barplot_features_target_PV_%s.png'%(figure_folder, PV_number+1)
    )
    target_combined_features_df.append(combined_features_df.set_index(['gene_A', 'gene_B', 'feature_status']))

In [None]:
# Stacked plot of contributions
stacked_plot_df = pd.concat(target_combined_features_df, axis=1).groupby('feature_status').agg('sum').T
stacked_plot_df = stacked_plot_df.T / np.sum(stacked_plot_df, axis=1)
stacked_plot_df.T.plot.bar(stacked=True, color=target_palette, width=0.9, figsize=(8,4))

plt.xticks(np.arange(n_target_PVs), ['SPV %s'%(i+1) for i in np.arange(n_target_PVs)], fontsize=20, rotation=90)
plt.yticks(fontsize=15, rotation=0)
plt.ylabel('PV contribution', fontsize=20, color='black')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
plt.tight_layout()
plt.savefig('%s/contributions_target.png'%(figure_folder), dpi=300)

del stacked_plot_df

In [None]:
if type(target_combined_features_df) != pd.DataFrame:
    target_combined_features_df = pd.concat(target_combined_features_df, axis=1)
target_combined_features_df['is_linear'] = target_combined_features_df.index.get_level_values('gene_B') == 'NA'

# Linear features
linear_proportion = target_combined_features_df.loc[target_combined_features_df['is_linear']]
linear_proportion = linear_proportion / np.sum(linear_proportion)
del linear_proportion['is_linear']
linear_proportion = linear_proportion.groupby('feature_status').agg('sum')
linear_proportion.T.plot.bar(stacked=True, width=0.9, figsize=(8,4), color=target_palette)

plt.xticks(np.arange(n_target_PVs), ['SPV %s'%(i+1) for i in np.arange(n_target_PVs)], fontsize=20, rotation=90)
plt.yticks(fontsize=15, rotation=0)
plt.ylabel('PV contribution', fontsize=20, color='black')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
plt.tight_layout()
plt.savefig('%s/contributions_target_linear.png'%(figure_folder), dpi=300)
plt.show()

# Interaction featurs
interaction_proportion = target_combined_features_df.loc[~target_combined_features_df['is_linear']]
interaction_proportion = interaction_proportion / np.sum(interaction_proportion)
del interaction_proportion['is_linear']

interaction_proportion = interaction_proportion.groupby('feature_status').agg('sum')
interaction_proportion.T.plot.bar(stacked=True, width=0.9, figsize=(8,4), color=target_palette)

plt.xticks(np.arange(n_target_PVs), ['SPV %s'%(i+1) for i in np.arange(n_target_PVs)], fontsize=20, rotation=90)
plt.yticks(fontsize=15, rotation=0)
plt.ylabel('PV contribution', fontsize=20, color='black')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
plt.tight_layout()
plt.savefig('%s/contributions_target_interactions.png'%(figure_folder), dpi=300)
plt.show()

## Common plot

In [None]:
n_number_top_genes = 10
n_max_PV = max(source_PV_linear_features.shape[1] // 2, target_PV_linear_features.shape[1] // 2)
fig, axes = plt.subplots(n_max_PV, 2, figsize=(10,3*n_max_PV))

# Fill source barplots
for PV_number in range(n_source_PVs):
    combined_features_df = compute_ordered_contributions(
        source_PV_linear_features, source_PV_interaction_features, PV_number, r'(Y|Burn7|Burn8)'
    )
    barplot_top_contributors(
        combined_features_df, PV_number, top_genes=n_number_top_genes, palette=source_palette, ax=axes[PV_number, 0]
    )

# Fill target barplots
for PV_number in range(n_target_PVs):
    combined_features_df = compute_ordered_contributions(
        target_PV_linear_features, target_PV_interaction_features, PV_number, r'(X|Burn5|Burn6)'
    )
    barplot_top_contributors(
        combined_features_df, PV_number, top_genes=n_number_top_genes, palette=source_palette, ax=axes[PV_number, 1]
    )

# Remove empty ones
for PV_number in range(min(n_source_PVs, n_target_PVs), n_max_PV):
    empty_ax_idx = int(n_target_PVs < n_max_PV)
    axes[PV_number, empty_ax_idx].spines['top'].set_visible(False)
    axes[PV_number, empty_ax_idx].spines['right'].set_visible(False)
    axes[PV_number, empty_ax_idx].spines['bottom'].set_visible(False)
    axes[PV_number, empty_ax_idx].spines['left'].set_visible(False)
    axes[PV_number, empty_ax_idx].get_xaxis().set_ticks([])
    axes[PV_number, empty_ax_idx].get_yaxis().set_ticks([])

plt.tight_layout()
plt.savefig('%s/complete_feature_barplot.png'%(figure_folder), dpi=300)

del n_max_PV, n_number_top_genes, empty_ax_idx, PV_number