# Imports

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

from sklearn import set_config
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.exceptions import FitFailedWarning
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import GridSearchCV
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.metrics import concordance_index_censored
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.metrics import (concordance_index_censored, 
                            cumulative_dynamic_auc)
from sksurv.metrics import integrated_brier_score
import matplotlib.pyplot as plt
from scipy import stats
from sksurv.datasets import load_breast_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis, CoxPHSurvivalAnalysis
from sksurv.preprocessing import OneHotEncoder
# Initialize and run the analysis
from matplotlib.colors import rgb2hex
from pydeseq2 import preprocessing, dds, ds
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
import seaborn as sns
# Standard library imports
import itertools
import os
import random
import shutil
from itertools import combinations, cycle

# Third-party library imports
import anndata
import decoupler as dc
import gseapy as gp
from gseapy import prerank
from gseapy.plot import dotplot, gseaplot
import liana as ln
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, LogNorm, rgb2hex, to_rgba
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Ellipse, Polygon
import numpy as np
import openpyxl
from openpyxl.styles import Font, PatternFill
from openpyxl.utils.dataframe import dataframe_to_rows
import pandas as pd
import PyComplexHeatmap as pch
from pydeseq2 import preprocessing, dds, ds
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
import scanpy as sc
import scipy
from scipy import stats
from scipy.cluster import hierarchy
import scipy.cluster.hierarchy as sch
from scipy.cluster.hierarchy import dendrogram, linkage, cophenet
from scipy.spatial.distance import pdist, squareform
from scipy.stats import linregress, median_abs_deviation
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import NMF, PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    adjusted_rand_score,
    calinski_harabasz_score,
    confusion_matrix,
    davies_bouldin_score,
    make_scorer,
    roc_auc_score,
    silhouette_score,
)
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, StandardScaler
from sklearn.tree import plot_tree
from statsmodels.nonparametric.smoothers_lowess import lowess
from statsmodels.stats.multitest import multipletests
from tqdm import tqdm
import umap

# Local imports
from cnmf import cNMF
from catboost import CatBoostClassifier

# Set global rcParams for consistent formatting
plt.rcParams.update({
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'savefig.facecolor': 'white',
    'legend.frameon': True,
    'legend.edgecolor': 'black',
    'legend.labelcolor': 'black',
    'legend.fontsize': 12,
    'axes.labelcolor': 'black',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'text.color': 'black',
})

# Functions

In [2]:
def revert_from_conversion(adata):
    conversion_info = adata.uns.get('conversion_info', {})

    for key, original_type in conversion_info.items():
        df_name, col = key.split(':')
        df = getattr(adata, df_name)

        if 'datetime' in original_type.lower():
            df[col] = pd.to_datetime(df[col])
        elif 'timedelta' in original_type.lower():
            df[col] = pd.to_timedelta(df[col])
        elif original_type == 'category':
            df[col] = df[col].astype('category')
        elif 'int' in original_type.lower():
            df[col] = df[col].astype('Int64')  # Use nullable integer type
        elif 'float' in original_type.lower():
            df[col] = df[col].astype('float64')
        elif 'bool' in original_type.lower():
            df[col] = df[col].astype('boolean')
        # Other types will remain as they are

    return adata

In [3]:


class DEGAnalysis:
    def __init__(self, adata, design_factor, layer='raw_counts', output_dir='deg_analysis_results'):
        self.original_adata = adata
        self.design_factor = design_factor
        self.layer = layer
        self.adata = None
        self.dds = None
        self.results = {}
        self.colors = self._generate_colors()
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)

    def _generate_colors(self):
        unique_groups = self.original_adata.obs[self.design_factor].unique()
        n_colors = len(unique_groups)
        color_map = plt.cm.get_cmap('tab20')
        colors = {group: rgb2hex(color_map(i/n_colors)) for i, group in enumerate(unique_groups)}
        return colors

    def prepare_data(self):
        self.adata = self.original_adata.copy()
        if self.layer in self.adata.layers:
            self.adata.X = self.adata.layers[self.layer].copy()
        elif self.layer != 'X':
            raise ValueError(f"Layer '{self.layer}' not found in the AnnData object.")
        min_val = self.adata.X.min()
        if min_val < 0:
            self.adata.X -= min_val

    def create_dds(self):
        if self.adata is None:
            self.prepare_data()
        self.dds = DeseqDataSet(
            adata=self.adata,
            design_factors=self.design_factor,
            refit_cooks=True,
        )
        self.dds.deseq2()

    def run_comparisons(self):
        subtypes = sorted(self.adata.obs[self.design_factor].unique().tolist())
        n_subtypes = len(subtypes)
        
        # One-vs-One comparisons
        for i in range(n_subtypes):
            for j in range(i+1, n_subtypes):
                self._run_comparison(subtypes[i], subtypes[j])
        
        # One-vs-Rest comparisons
        for subtype in subtypes:
            self._run_comparison(subtype, 'rest', one_vs_rest=True)

    def _run_comparison(self, group1, group2, one_vs_rest=False):
        comparison_name = f"{group1}_vs_{group2}"
        if one_vs_rest:
            temp_counts = pd.DataFrame(self.adata.X, index=self.adata.obs_names, columns=self.adata.var_names)
            temp_metadata = pd.DataFrame({self.design_factor: self.adata.obs[self.design_factor]})
            temp_metadata[self.design_factor] = np.where(temp_metadata[self.design_factor] == group1, group1, 'rest')
            
            temp_dds = DeseqDataSet(
                counts=temp_counts,
                metadata=temp_metadata,
                design_factors=self.design_factor,
                refit_cooks=True,
            )
            temp_dds.deseq2()
            res = DeseqStats(temp_dds, contrast=[self.design_factor, group1, 'rest'])
        else:
            res = DeseqStats(self.dds, contrast=[self.design_factor, group1, group2])
        
        res.summary()
        res.results_df["Log2FC_pval"] = res.results_df["log2FoldChange"] * -np.log10(res.results_df["pvalue"])
        res.results_df = res.results_df.sort_values("Log2FC_pval", ascending=False)
        self.results[comparison_name] = res
        
    def create_volcano_plots(self, highlight_dict=None):
        for comparison_name, res in self.results.items():
            self._create_volcano_plot(res.results_df, comparison_name, highlight_dict)

    def _create_volcano_plot(self, deg_df, comparison_name, highlight_dict=None):
        with plt.rc_context({'figure.figsize': (12, 8)}):
            plt.figure()
            
            # Plot all genes
            plt.scatter(deg_df['log2FoldChange'], -np.log10(deg_df['pvalue']), 
                        alpha=0.6, s=3, color='grey', label='Other')
            
            # Highlight significant genes
            significant = (deg_df['padj'] < 0.05) & (abs(deg_df['log2FoldChange']) > 0.5)
            plt.scatter(deg_df.loc[significant, 'log2FoldChange'], 
                        -np.log10(deg_df.loc[significant, 'pvalue']), 
                        alpha=0.6, s=3, color='lightgrey', label='Significant')
            
            # Highlight genes from the dictionary
            if highlight_dict:
                colors = sns.color_palette("husl", len(highlight_dict))
                for (group, genes), color in zip(highlight_dict.items(), colors):
                    mask = deg_df.index.isin(genes)
                    plt.scatter(deg_df.loc[mask, 'log2FoldChange'], 
                                -np.log10(deg_df.loc[mask, 'pvalue']),
                                alpha=0.8, s=30, color=color, label=group)
                    
                    # Annotate these genes
                    for gene in genes:
                        if gene in deg_df.index:
                            gene_data = deg_df.loc[gene]
                            plt.annotate(gene, 
                                         (gene_data['log2FoldChange'], -np.log10(gene_data['pvalue'])),
                                         xytext=(5, 5), textcoords='offset points', 
                                         ha='left', va='bottom',
                                         fontsize=8,
                                         bbox=dict(boxstyle='round,pad=0.1', fc='white', ec='none', alpha=0.7),
                                         arrowprops=dict(arrowstyle='->', color='black', lw=0.5))
            else:
                # If no highlight_dict, annotate top genes as before
                top_genes = deg_df.sort_values('pvalue').head(20)
                for _, gene in top_genes.iterrows():
                    plt.annotate(gene.name, 
                                 (gene['log2FoldChange'], -np.log10(gene['pvalue'])),
                                 xytext=(3, 3), textcoords='offset points', 
                                 ha='left', va='bottom',
                                 fontsize=6,
                                 bbox=dict(boxstyle='round,pad=0.1', fc='white', ec='none', alpha=0.7),
                                 arrowprops=dict(arrowstyle='->', color='black', lw=0.5))
            
            plt.axvline(x=0.5, color='gray', linestyle='--', linewidth=1)
            plt.axvline(x=-0.5, color='gray', linestyle='--', linewidth=1)
            plt.axhline(y=-np.log10(0.05), color='gray', linestyle='--', linewidth=1)
            
            plt.xlabel('Log2 fold change')
            plt.ylabel('-Log10 p-value')
            plt.title(f'Volcano Plot: {comparison_name}')
            
            group1, group2 = comparison_name.split('_vs_')
            plt.text(0.02, 0.98, f"Positive log2FC: Upregulated in {group1}\nNegative log2FC: Upregulated in {group2}",
                     transform=plt.gca().transAxes, va='top', ha='left', fontsize=8, 
                     bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
            
            if highlight_dict:
                plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
            
            plt.tight_layout()
            plt.savefig(os.path.join(self.output_dir, f'volcano_plot_{comparison_name}.png'), dpi=300, bbox_inches='tight')
            plt.close()

    def create_clustermaps(self, marker_list=None, padj_threshold=0.05, log2fc_threshold=1):
        for comparison_name, res in self.results.items():
            self._create_clustermap(res.results_df, comparison_name, marker_list, padj_threshold, log2fc_threshold)
            
    def _create_clustermap(self, results_df, comparison_name, marker_list=None, padj_threshold=0.05, log2fc_threshold=1):
        cluster1, cluster2 = comparison_name.split('_vs_')
        significant_genes = results_df[(results_df['padj'] < padj_threshold) & 
                                    (abs(results_df['log2FoldChange']) > log2fc_threshold)].index

        if marker_list is not None:
            marker_genes = [gene for gene in marker_list if gene in significant_genes]
            if len(marker_genes) < 2:
                print(f"Warning: Less than 2 genes from the marker list are present in the significant genes for {comparison_name}.")
                print("Using all significant genes instead.")
                genes_to_plot = significant_genes
            else:
                genes_to_plot = marker_genes
        else:
            genes_to_plot = significant_genes

        if len(genes_to_plot) < 2:
            print(f"Not enough genes to plot for {comparison_name}. Skipping this comparison.")
            return

        if cluster2 == 'rest':
            dds_sub = self.dds[:, genes_to_plot]
        else:
            dds_sub = self.dds[self.dds.obs[self.design_factor].isin([cluster1, cluster2]), genes_to_plot]

        if dds_sub.shape[1] < 2 or dds_sub.shape[0] < 2:
            print(f"Not enough data to plot for {comparison_name}. Skipping this comparison.")
            print(f"Number of genes: {dds_sub.shape[1]}, Number of samples: {dds_sub.shape[0]}")
            return

        dds_sub = dds_sub[dds_sub.obs[self.design_factor].argsort()]
        grapher = pd.DataFrame(dds_sub.layers["lognormalized_counts"].T, index=dds_sub.var_names, columns=dds_sub.obs_names)
        col_colors_leiden = dds_sub.obs[self.design_factor].map(self.colors)
        
        try:
            plt.figure(figsize=(30, 15 + 0.2 * len(genes_to_plot)))
            g = sns.clustermap(grapher, 
                            z_score=0,  
                            cmap="RdBu_r",
                            col_cluster=False, 
                            row_cluster=True,
                            col_colors=col_colors_leiden,
                            dendrogram_ratio=(.1, .3),
                            robust=True)  # Add robust to handle outliers better
            
            # Adjust the colorbar to be symmetric around 0
            vmax = max(abs(g.ax_heatmap.collections[0].get_clim()))
            g.ax_heatmap.collections[0].set_clim(-vmax, vmax)
        
            
            
            plt.setp(g.ax_heatmap.get_xticklabels(), rotation=90, ha='center', va='top')
            g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=8)
            
            plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0)
            
            for label in [cluster1, cluster2]:
                g.ax_col_dendrogram.bar(0, 0, color=self.colors[label], label=label, linewidth=0)
            g.ax_col_dendrogram.legend(title=self.design_factor, loc="center", ncol=2)
            
            plt.suptitle(f"Clustermap for {comparison_name}", fontsize=16, y=1.02)
            plt.tight_layout()
            plt.savefig(os.path.join(self.output_dir, f"clustermap_{comparison_name}.png"), dpi=300, bbox_inches='tight')
            plt.close()
        except Exception as e:
            print(f"Error creating clustermap for {comparison_name}: {str(e)}")
            plt.close()

        print(f"Successfully created clustermap for {comparison_name} with {len(genes_to_plot)} genes.")

    def get_results(self):
        return self.results

    def save_results(self):
        for comparison_name, res in self.results.items():
            res.results_df.to_csv(os.path.join(self.output_dir, f"DEG_results_{comparison_name}.csv"))
        print(f"All DEG results have been saved to CSV files in the {self.output_dir} folder.")

    def create_boxplots(self, genes_to_plot, test='ttest', figsize=(20, 5), save_path=None):
        adata_subset, valid_genes = self._prepare_data_for_boxplots(genes_to_plot)
        
        if not valid_genes:
            print("No valid genes found in the dataset.")
            return

        n_genes = len(valid_genes)
        fig, axes = plt.subplots(1, n_genes, figsize=figsize)
        if n_genes == 1:
            axes = [axes]

        for i, gene in enumerate(valid_genes):
            ax = axes[i]
            
            data = pd.DataFrame({
                'expression': adata_subset[:, gene].X.flatten(),
                self.design_factor: adata_subset.obs[self.design_factor]
            })
            
            clusters = sorted(data[self.design_factor].unique())
            cluster_to_pos = {cluster: idx for idx, cluster in enumerate(clusters)}
            
            sns.boxplot(data=data, x=self.design_factor, y='expression', ax=ax, order=clusters, palette=self.colors)
            sns.stripplot(data=data, x=self.design_factor, y='expression', color='black', size=2, alpha=0.4, ax=ax, order=clusters)
            
            comparisons = list(itertools.combinations(clusters, 2))
            max_bars = len(comparisons)
            
            plot_top = ax.get_ylim()[1]
            bar_height = plot_top * 0.05
            spacing = plot_top * 0.1
            
            for idx, (c1, c2) in enumerate(comparisons):
                data1 = data[data[self.design_factor] == c1]['expression']
                data2 = data[data[self.design_factor] == c2]['expression']
                p_value = self._perform_test(data1, data2, test, gene, f"{self.design_factor} {c1}", f"{self.design_factor} {c2}")
                
                y_pos = plot_top + spacing + (bar_height + spacing) * idx
                
                x1, x2 = cluster_to_pos[c1], cluster_to_pos[c2]
                ax.plot([x1, x1, x2, x2], [y_pos, y_pos + bar_height, y_pos + bar_height, y_pos], lw=1.5, c='black')
                significance = self._get_stars(p_value)
                ax.text((x1 + x2) / 2, y_pos + bar_height, significance, ha='center', va='bottom', fontsize=10)
            
            ax.set_title(f'{gene}', fontsize=14)
            ax.set_xlabel(self.design_factor.capitalize(), fontsize=12)
            ax.set_ylabel('log2(CPM+1)' if i == 0 else '', fontsize=12)
            ax.set_ylim(0, plot_top + (bar_height + spacing) * (max_bars + 1))
            ax.tick_params(axis='both', which='major', labelsize=10)
            
            ax.set_xticks(range(len(clusters)))
            ax.set_xticklabels(clusters)

        plt.tight_layout()
        
        if save_path:
            plt.savefig(os.path.join(self.output_dir, save_path), dpi=300, bbox_inches='tight')
            print(f"Figure saved to {os.path.join(self.output_dir, save_path)}")
        
        plt.show()

        missing_genes = set(genes_to_plot) - set(valid_genes)
        if missing_genes:
            print(f"The following genes were not found in the dataset: {', '.join(missing_genes)}")

    def create_volcano_grid(self, highlight_genes=None):
        # Get unique group names
        group_names = sorted(set([k.split('_vs_')[0] for k in self.results.keys() if '_vs_' in k]))
        n_groups = len(group_names)

        fig = plt.figure(figsize=(4*n_groups, 4*n_groups))
        gs = GridSpec(n_groups, n_groups)

        for i, group1 in enumerate(group_names):
            for j, group2 in enumerate(group_names):
                if i < j:  # Upper triangle
                    ax = fig.add_subplot(gs[i, j])
                    comparison = f"{group1}_vs_{group2}"
                    if comparison in self.results:
                        self._plot_volcano(ax, self.results[comparison].results_df, f'{group1} vs. {group2}', self.colors[group1], highlight_genes=highlight_genes)
                elif i == j:  # Diagonal
                    ax = fig.add_subplot(gs[i, i])
                    comparison = f"{group1}_vs_rest"
                    if comparison in self.results:
                        self._plot_volcano(ax, self.results[comparison].results_df, f'{group1} vs. Others', self.colors[group1], highlight_genes=highlight_genes)

        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, "volcano_grid.png"), dpi=300, bbox_inches='tight')
        plt.show()

    def _prepare_data_for_boxplots(self, genes_to_plot):
        valid_genes = [gene for gene in genes_to_plot if gene in self.original_adata.var_names]
        
        adata_full = self.original_adata.copy()
        sc.pp.normalize_total(adata_full, target_sum=1e6)
        sc.pp.log1p(adata_full)
        
        adata_subset = adata_full[:, valid_genes]
        
        return adata_subset, valid_genes

    def _perform_test(self, data1, data2, test_type, gene, group1_name, group2_name):
        if test_type == 'mannwhitneyu':
            statistic, p_value = stats.mannwhitneyu(data1, data2)
        elif test_type == 'ttest':
            statistic, p_value = stats.ttest_ind(data1, data2, equal_var=False)
        else:
            raise ValueError("Invalid test type. Choose 'mannwhitneyu' or 'ttest'.")
        
        print(f"\nGene: {gene}")
        print(f"Comparison: {group1_name} vs {group2_name}")
        print(f"{group1_name}: n={len(data1)}, median={np.median(data1):.2f}, mean={np.mean(data1):.2f}")
        print(f"{group2_name}: n={len(data2)}, median={np.median(data2):.2f}, mean={np.mean(data2):.2f}")
        print(f"p-value: {p_value:.4f}")
        print(f"Significance: {self._get_stars(p_value)}")
        return p_value

    def _get_stars(self, p_value):
        p_value = round(p_value, 4)
        if p_value < 0.001:
            return '***'
        elif p_value < 0.01:
            return '**'
        elif p_value < 0.05:
            return '*'
        else:
            return 'ns'

    def _plot_volcano(self, ax, results, title, color, top_genes=5, highlight_genes=None):
        ax.scatter(results['log2FoldChange'], -np.log10(results['pvalue']), 
                   alpha=0.6, s=3, color=color)
        
        ax.axhline(-np.log10(0.05), color='red', linestyle='--', linewidth=0.5)
        ax.axvline(-1, color='red', linestyle='--', linewidth=0.5)
        ax.axvline(1, color='red', linestyle='--', linewidth=0.5)
        
        top = results.sort_values('pvalue').head(top_genes)
        for _, gene in top.iterrows():
            ax.text(gene['log2FoldChange'], -np.log10(gene['pvalue']), gene.name, 
                    fontsize=6, ha='center', va='bottom', color='black')
        
        if highlight_genes:
            for gene in highlight_genes:
                if gene in results.index:
                    gene_data = results.loc[gene]
                    ax.text(gene_data['log2FoldChange'], -np.log10(gene_data['pvalue']), gene, 
                            fontsize=6, ha='center', va='bottom', color='red', fontweight='bold')
                    ax.scatter(gene_data['log2FoldChange'], -np.log10(gene_data['pvalue']), 
                               color='red', s=20, zorder=5)
        
        ax.set_title(title, fontsize=10)
        ax.set_xlabel('log2(Fold Change)', fontsize=8)
        ax.set_ylabel('-log10(p-value)', fontsize=8)
        ax.tick_params(axis='both', which='major', labelsize=6)
        ax.grid(True, which="both", ls="-", alpha=0.2)

    def infer_tf_activities(self):
        # Retrieve CollecTRI gene regulatory network
        collectri = dc.get_collectri(organism='human', split_complexes=False)
        
        # Prepare data for TF activity inference
        mat = self.results['treatment.vs.control'].results_df[['stat']].T.rename(index={'stat': 'treatment.vs.control'})
        
        # Infer TF activities with ulm
        self.tf_acts, tf_pvals = dc.run_ulm(mat=mat, net=collectri, verbose=True)

    def plot_tf_activities(self, top=25):
        dc.plot_barplot(
            acts=self.tf_acts,
            contrast='treatment.vs.control',
            top=top,
            vertical=True,
            figsize=(3, 6)
        )
        plt.savefig(os.path.join(self.output_dir, 'tf_activities.png'), dpi=300, bbox_inches='tight')
        plt.close()

    def infer_pathway_activities(self):
        # Retrieve PROGENy model weights
        progeny = dc.get_progeny(top=500)
        
        # Prepare data for pathway activity inference
        mat = self.results['treatment.vs.control'].results_df[['stat']].T.rename(index={'stat': 'treatment.vs.control'})
        
        # Infer pathway activities with mlm
        self.pathway_acts, pathway_pvals = dc.run_mlm(mat=mat, net=progeny, verbose=True)

    def plot_pathway_activities(self):
        dc.plot_barplot(
            self.pathway_acts,
            'treatment.vs.control',
            top=25,
            vertical=False,
            figsize=(6, 3)
        )
        plt.savefig(os.path.join(self.output_dir, 'pathway_activities.png'), dpi=300, bbox_inches='tight')
        plt.close()

    def run_functional_enrichment(self):
        # Retrieve MSigDB gene sets
        msigdb = dc.get_resource('MSigDB')
        msigdb = msigdb[msigdb['collection']=='hallmark']
        msigdb = msigdb[~msigdb.duplicated(['geneset', 'genesymbol'])]
        msigdb.loc[:, 'geneset'] = [name.split('HALLMARK_')[1] for name in msigdb['geneset']]

        # Prepare data for enrichment analysis
        top_genes = self.results['treatment.vs.control'].results_df[self.results['treatment.vs.control'].results_df['padj'] < 0.05]

        # Run ora
        self.enr_pvals = dc.get_ora_df(
            df=top_genes,
            net=msigdb,
            source='geneset',
            target='genesymbol'
        )

    def plot_functional_enrichment(self, top=15):
        dc.plot_dotplot(
            self.enr_pvals.sort_values('Combined score', ascending=False).head(top),
            x='Combined score',
            y='Term',
            s='Odds ratio',
            c='FDR p-value',
            scale=1.5,
            figsize=(3, 6)
        )
        plt.savefig(os.path.join(self.output_dir, 'functional_enrichment.png'), dpi=300, bbox_inches='tight')
        plt.close()


In [4]:
import pandas as pd
from typing import Dict, List
import os

def get_top_genes(source_paths: Dict[str, str], n_genes: int = 10) -> Dict[str, List[str]]:
    """
    Get top genes from CSV files containing differential expression results.
    
    Args:
        source_paths: Dictionary where keys are comparison names and values are paths to CSV files
        n_genes: Number of top genes to return
        
    Returns:
        Dictionary containing top genes for each comparison
    """
    deg_results = {}
    
    for comparison, file_path in source_paths.items():
        # Read CSV file
        results = pd.read_csv(file_path)
        
        # Sort by stat column in descending order
        results = results.sort_values('stat', ascending=False)
        
        # Print top genes info
        print(f"\n{comparison} Top Genes:")
        print(results[['gene_identifier', 'stat', 'log2FoldChange', 'padj']].head(n_genes))
        
        # Store top gene identifiers
        deg_results[comparison] = results['gene_identifier'].head(n_genes).tolist()

    return deg_results



# DEG 

In [5]:
import scanpy as sc
km_data_ttp = pd.read_csv(r"/home/rafaed/work/RO_src/Projects/THORA/StatisticalAnalysis/scripts/data_ttp.csv")
km_data_os = pd.read_csv(r"/home/rafaed/work/RO_src/Projects/THORA/StatisticalAnalysis/scripts/data_os.csv")
log_norm_counts = pd.read_csv(r"/home/rafaed/work/RO_src/Projects/THORA/StatisticalAnalysis/scripts/log_normalized_counts_df.csv")
adata_nmf = sc.read_h5ad(r"/mnt/work/RO_src/Projects/THORA/DataProcessing/data/processed/adata_cp_full_preprocessed.h5ad")

km_data_ttp.rename(columns={"tend": "tend_ttp"}, inplace=True)
km_data_os.rename(columns={"tend": "tend_os"}, inplace=True)
km_data_ttp.drop(columns=['ERCC1', 'ERCC2',
       'ERCC5', 'BRCA1', 'TUBB3', 'STK11', 'HIF1A','status', "Unnamed: 0","status"], inplace=True)

km_data_os.drop(columns=['ERCC1', 'ERCC2',
       'ERCC5', 'BRCA1', 'TUBB3', 'STK11', 'HIF1A','status', "Unnamed: 0","status"], inplace=True)
common_cols = km_data_ttp.columns.intersection(km_data_os.columns)
merged_df = pd.merge(km_data_ttp, km_data_os, on=list(common_cols))

merged_data = pd.merge(merged_df, log_norm_counts, on="ID_Sample")
merged_df.set_index("ID_Sample", inplace=True)
merged_data.set_index("ID_Sample", inplace=True)

adata_nmf = revert_from_conversion(adata_nmf)
adata_nmf_cp = adata_nmf.copy()
adata_nmf_cp.obs.set_index("ID_Sample", inplace=True)

# Get the common indices
common_idx = adata_nmf_cp.obs.index.isin(merged_data.index)

# Subset adata_nmf to only keep samples that are in km_data_os
adata_nmf_cp = adata_nmf_cp[common_idx].copy()

# Now you can safely assign the new observation dataframe
adata_nmf_cp.obs = merged_data
adata_nmf_cp.obs.rename(columns={"status_os":"status-os"}, inplace=True)
adata_nmf_cp.obs['status-os'] = adata_nmf_cp.obs['status-os'].map({0: 'Alive', 1: 'Dead'})
adata_nmf_cp.obs.rename(columns={"status_ttp":"status-ttp"}, inplace=True)
adata_nmf_cp.obs['status-ttp'] = adata_nmf_cp.obs['status-ttp'].map({0: 'Did not progress', 1: 'Progressed'})


# Evaluate gene sets

In [6]:
def evaluate_gene_sets(merged_data, adata, top_genes, layer="lognormalized_counts", endpoint="os"):
    """
    Evaluate different numbers of top genes for each comparison using Cox models.
    Includes risk score calculation, KM curves, and time-dependent ROC analysis.
    
    Parameters:
    -----------
    merged_data : pandas.DataFrame
        DataFrame containing the survival data and gene expression data
    top_genes : dict
        Dictionary with comparisons as keys and lists of gene names as values
    
    Returns:
    --------
    dict
        Dictionary containing results_df, best_models, coefficient_dfs, and risk scores
    """


    # Store results
    results = []
    # Dictionary to store best models and their info
    best_models = {}

    expr = pd.DataFrame(adata.layers[layer], index=adata.obs.index, columns = adata.var.index).reset_index()
    expr.set_index("ID_Sample",inplace=True)

    # Loop through each comparison in top_genes
    for comparison in top_genes.keys():
        print(f"\nProcessing comparison: {comparison}")
        comparison_best_score = -np.inf
        comparison_best_info = None

        # Loop through different numbers of genes
        for n_genes in range(5, len(top_genes[comparison]), 5):
            print(f"Testing with top {n_genes} genes")
            
            # Select features
            selected_genes = top_genes[comparison][:n_genes]
            try:
                X =expr[selected_genes]
                print(X)
            except Exception as e:
                print(f"{e}")
                continue
            # Store ID_Sample separately and set it as index for X_model
            # sample_ids = merged_data.index
            # X_model.index = sample_ids
            
            # Create the structured array for the full dataset
            status_bool = merged_data[f'status-{endpoint}'].astype(bool)
            time = merged_data[f'tend_{endpoint}']
            y_structured = np.zeros(len(time), dtype=[('status', bool), ('time', float)])
            y_structured['status'] = status_bool
            y_structured['time'] = time
            
            # Split the data
            X_train, X_test, y_train_idx, y_test_idx = train_test_split(
                X, 
                np.arange(len(y_structured)), 
                test_size=0.2, 
                random_state=42, 
                stratify=status_bool
            )
            
            # Create structured arrays for train and test sets
            y_train = y_structured[y_train_idx]
            y_test = y_structured[y_test_idx]
            
            # Fit initial model to get alphas
            coxnet_pipe = make_pipeline(CoxnetSurvivalAnalysis(l1_ratio=0.9, alpha_min_ratio=0.01, max_iter=1000))
            coxnet_pipe.fit(X_train, y_train)
            estimated_alphas = coxnet_pipe.named_steps["coxnetsurvivalanalysis"].alphas_
            
            # Perform cross-validation
            cv = KFold(n_splits=5, shuffle=True, random_state=0)
            gcv = GridSearchCV(
                make_pipeline(CoxnetSurvivalAnalysis(l1_ratio=0.9)),
                param_grid={"coxnetsurvivalanalysis__alphas": [[v] for v in estimated_alphas]},
                cv=cv,
                error_score=0.5,
                n_jobs=-1,
                verbose=0
            ).fit(X_train, y_train)
            alphas = [alpha[0] for alpha in gcv.cv_results_['param_coxnetsurvivalanalysis__alphas']]

            # Accessing the mean and std test scores from grid search results
            mean = gcv.cv_results_['mean_test_score']
            std = gcv.cv_results_['std_test_score']

            # Plotting the performance
            fig, ax = plt.subplots(figsize=(9, 6))
            ax.plot(alphas, mean, label='Mean Concordance Index')
            ax.fill_between(alphas, mean - std, mean + std, alpha=0.15)
            ax.set_xscale("log")  # Log scale for the x-axis (alphas)
            ax.set_ylabel("Concordance Index")  # Y-axis label
            ax.set_xlabel("Alpha")  # X-axis label (alpha values)

            # Plotting the best alpha from grid search
            best_alpha = gcv.best_params_["coxnetsurvivalanalysis__alphas"][0]
            ax.axvline(best_alpha, c="C1", label="Best Alpha")  # Line at best alpha

            # Adding a horizontal line at concordance index 0.5
            ax.axhline(0.5, color="grey", linestyle="--", label="Random Concordance")

            # Add grid and legend
            ax.grid(True)
            ax.legend()

            # Display the plot
            plt.show()
            # Get best model
            best_model = gcv.best_estimator_.named_steps["coxnetsurvivalanalysis"]
            
            # Make predictions
            train_predictions = best_model.predict(X_train)
            test_predictions = best_model.predict(X_test)
            
            # Calculate concordance indices
            train_cindex = concordance_index_censored(
                y_train['status'],
                y_train['time'],
                train_predictions
            )[0]
            
            test_cindex = concordance_index_censored(
                y_test['status'],
                y_test['time'],
                test_predictions
            )[0]
            
            # Use predictions directly as risk scores
            train_risk_scores = train_predictions
            test_risk_scores = test_predictions
            
            # Find optimal cutoff using median in training set
            risk_cutoff = np.median(train_risk_scores)
            
            # Assign risk groups
            train_risk_groups = (train_risk_scores > risk_cutoff).astype(int)
            test_risk_groups = (test_risk_scores > risk_cutoff).astype(int)
            
            # Store results
            results.append({
                'comparison': comparison,
                'n_genes': n_genes,
                'train_cindex': train_cindex,
                'test_cindex': test_cindex,
                'best_alpha': gcv.best_params_["coxnetsurvivalanalysis__alphas"][0],
                'cv_score': gcv.best_score_,
                'risk_cutoff': risk_cutoff,
                'train_high_risk_n': sum(train_risk_groups),
                'test_high_risk_n': sum(test_risk_groups)
            })
            
            # Update best model info if this is the best test score so far for this comparison
            if test_cindex > comparison_best_score:
                comparison_best_score = test_cindex
                comparison_best_info = {
                    'model': best_model,
                    'n_genes': n_genes,
                    'test_cindex': test_cindex,
                    'train_cindex': train_cindex,
                    'X_train': X_train,
                    'X_test': X_test,
                    'y_train': y_train,
                    'y_test': y_test,
                    'selected_genes': selected_genes,
                }
            
            print(f"Train C-index: {train_cindex:.3f}")
            print(f"Test C-index: {test_cindex:.3f}")
        
        # Store best model info for this comparison
        best_models[comparison] = comparison_best_info

    # Convert results to DataFrame
    results_df = pd.DataFrame(results)

    # Create coefficient DataFrames for each best model
    coef_dfs = {}
    for comparison, model_info in best_models.items():
        best_coefs = pd.DataFrame(
            model_info['model'].coef_,
            index=model_info['X_train'].columns,
            columns=["coefficient"]
        )
        coef_dfs[comparison] = best_coefs

    # Plot results for best models
    for comparison, model_info in best_models.items():
        # Get the best model data
        best_model = model_info['model']
        X_train = model_info['X_train']
        X_test = model_info['X_test']
        y_train = model_info['y_train']
        y_test = model_info['y_test']
        n_genes = model_info['n_genes']
        
        # Recalculate predictions and use them directly as risk scores
        train_predictions = best_model.predict(X_train)
        test_predictions = best_model.predict(X_test)
        
        # Use predictions directly as risk scores
        train_risk_scores = train_predictions
        test_risk_scores = test_predictions
        
        # Find optimal cutoff using median in training set
        risk_cutoff = np.median(train_risk_scores)
        
        # Assign risk groups
        train_risk_groups = (train_risk_scores > risk_cutoff).astype(int)
        test_risk_groups = (test_risk_scores > risk_cutoff).astype(int)
        
        # Plot KM curves for best model
        plt.figure(figsize=(10, 6))
        for group in [0, 1]:
            mask = train_risk_groups == group
            if np.any(mask):
                time, survival_prob = kaplan_meier_estimator(
                    y_train['status'][mask],
                    y_train['time'][mask]
                )
                plt.step(time, survival_prob, 
                        label=f"{'High' if group else 'Low'} Risk (train)")
        
        for group in [0, 1]:
            mask = test_risk_groups == group
            if np.any(mask):
                time, survival_prob = kaplan_meier_estimator(
                    y_test['status'][mask],
                    y_test['time'][mask]
                )
                plt.step(time, survival_prob, '--',
                        label=f"{'High' if group else 'Low'} Risk (test)")
        
        plt.xlabel('Time')
        plt.ylabel('Survival Probability')
        plt.title(f'Kaplan-Meier Curves by Risk Group\nBest Model for {comparison} ({n_genes} genes)')
        plt.grid(True)
        plt.legend()
        plt.show()
        
        # Calculate time-dependent ROC curve
        times = np.array([365, 730, 1095])  # 1, 2, and 3 years
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        auc, mean_auc = cumulative_dynamic_auc(
            y_train, y_test,
            test_predictions,
            times
        )
        
        for i, t in enumerate(times):
            if i == 0:
                ax.plot(times, auc, marker="o", color="crimson", label="AUC")
            else:
                ax.plot(times, auc, marker="o", color="crimson")
            ax.text(
                times[i], auc[i] + 0.02,
                f"{t/365:.0f} year AUC={auc[i]:.3f}",
                ha="center",
            )
        
        ax.plot([times[0], times[-1]], [0.5, 0.5], color="gray", linestyle="--")
        ax.set_xlabel("Days")
        ax.set_ylabel("Time-dependent AUC")
        ax.set_title(f"Time-dependent ROC\nBest Model for {comparison} ({n_genes} genes)")
        ax.grid(True)
        plt.tight_layout()
        plt.show()

    # Print best results and plot coefficients for each comparison
    fig, axes = plt.subplots(len(best_models), 1, figsize=(10, 6*len(best_models)))
    if len(best_models) == 1:
        axes = [axes]

    for idx, (comparison, model_info) in enumerate(best_models.items()):
        print(f"\n=== Best Model for {comparison} ===")
        print(f"Number of genes: {model_info['n_genes']}")
        print(f"Test C-index: {model_info['test_cindex']:.3f}")
        print(f"Train C-index: {model_info['train_cindex']:.3f}")
        
        # Get non-zero coefficients
        best_coefs = coef_dfs[comparison]
        non_zero_coefs = best_coefs.query("coefficient != 0")
        coef_order = non_zero_coefs.abs().sort_values("coefficient").index
        
        print(f"Number of non-zero coefficients: {len(non_zero_coefs)}")
        
        # Plot coefficients
        non_zero_coefs.loc[coef_order].plot.barh(ax=axes[idx], legend=False)
        axes[idx].set_xlabel("coefficient")
        axes[idx].set_title(f"Coefficients for {comparison}")
        axes[idx].grid(True)

    plt.tight_layout()
    plt.show()

    # Create performance plot
    fig, ax = plt.subplots(figsize=(12, 6))
    for comparison in top_genes.keys():
        comparison_results = results_df[results_df['comparison'] == comparison]
        ax.plot(comparison_results['n_genes'], comparison_results['test_cindex'], 
                marker='o', label=f'{comparison} (test)')
        ax.plot(comparison_results['n_genes'], comparison_results['train_cindex'], 
                marker='o', linestyle='--', label=f'{comparison} (train)')

    ax.set_xlabel('Number of genes')
    ax.set_ylabel('Concordance Index')
    ax.set_title('Performance vs Number of Genes')
    ax.grid(True)
    ax.legend()
    plt.tight_layout()
    plt.show()

    return {
        'results_df': results_df,
        'best_models': best_models,
        'coefficient_dfs': coef_dfs
    }

# Get risk scores

In [7]:
import numpy as np
import pandas as pd
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.compare import compare_survival
from sksurv.linear_model import CoxPHSurvivalAnalysis
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

def plot_risk_profile(risk_scores_df, merged_data, gene_coefficients, optimal_cutpoint):
    """
    Create a combined plot showing risk scores, survival status, and gene expression profiles.
    """
    # Create figure with subplots
    fig, axes = plt.subplots(3, 1, figsize=(12, 12), 
                            gridspec_kw={'height_ratios': [1, 1, 2]},
                            sharex=True)
    plt.subplots_adjust(hspace=0.05)

    # 1. Risk score distribution
    axes[0].scatter(range(len(risk_scores_df)), 
                   risk_scores_df['risk_score'],
                   c=['red' if x == 'high' else 'blue' for x in risk_scores_df['risk_group']], 
                   s=30)
    axes[0].axvline(x=sum(risk_scores_df['risk_group'] == 'low'), 
                    color='red', linestyle='--')
    axes[0].set_ylabel('Risk score')
    axes[0].grid(True)

    # 2. Survival status
    survival_data = pd.DataFrame({
        'time': risk_scores_df['time'],
        'status': risk_scores_df['status'],
    })
    
    scatter = axes[1].scatter(range(len(survival_data)), 
                            survival_data['time'],
                            c=['red' if s else 'blue' for s in survival_data['status']],
                            s=30)
    axes[1].set_ylabel('Time (days)')
    axes[1].grid(True)
    
    # Add legend for survival status
    from matplotlib.lines import Line2D
    legend_elements = [Line2D([0], [0], marker='o', color='w', 
                            markerfacecolor='red', label='Death', markersize=8),
                      Line2D([0], [0], marker='o', color='w', 
                            markerfacecolor='blue', label='Live', markersize=8)]
    axes[1].legend(handles=legend_elements, loc='right')

    # 3. Gene expression heatmap
    genes = list(gene_coefficients.keys())
    expression_data = merged_data[genes].iloc[risk_scores_df.index]
    
    # Z-scale the expression data
    expression_data_scaled = pd.DataFrame(
        stats.zscore(expression_data, axis=0),
        columns=expression_data.columns,
        index=expression_data.index
    )
    
    # Create heatmap with z-scaled data
    sns.heatmap(expression_data_scaled.T, 
                cmap='RdBu_r',
                center=0,
                ax=axes[2],
                cbar_kws={'label': 'Z-score'})
    
    axes[2].set_xlabel('Patients (ranked by risk score)')
    axes[2].set_ylabel('Genes')

    # Adjust layout
    plt.tight_layout()
    plt.show()

def calculate_risk_scores(merged_data, coef_df):
    """
    Calculate risk scores and find optimal cutpoint for stratification using sksurv.
    """


    # Define the coefficients
    original_coefficients = {coef_df.index[i]: round(coef_df.iloc[i, 0],4) for i in range(len(coef_df))}
    
    # Alternative names for PDCD1
    pdcd1_alternatives = ['PDCD1', 'PD1', 'CD279', 'PD-1']
    
    # Check which genes are available and create final coefficients dict
    coefficients = {}
    for gene, coef in original_coefficients.items():
        if gene in merged_data.columns:
            coefficients[gene] = coef
    
    print("\nGenes used in risk score calculation:")
    for gene, coef in coefficients.items():
        print(f"{gene}: {coef}")
    
    print("\nGenes not found in data and skipped:")
    missing_genes = set(original_coefficients.keys()) - set(coefficients.keys())
    for gene in missing_genes:
        print(gene)
    
    # Calculate risk scores
    risk_scores = []
    for idx, row in merged_data.iterrows():
        # Calculate the risk score
        risk_score = sum(coefficients[gene] * row[gene] for gene in coefficients.keys())
        risk_scores.append({
            'ID_Sample': row['ID_Sample'],
            'risk_score': risk_score,
            'time': row['tend'],
            'status': row['status_os']
        })
    
    # Convert to DataFrame
    risk_scores_df = pd.DataFrame(risk_scores)
    
    # Create survival data structure
    survival_data = np.zeros(len(merged_data), dtype=[('status', bool), ('time', float)])
    survival_data['status'] = merged_data['status_os'].astype(bool)
    survival_data['time'] = merged_data['tend']
    
    # Find optimal cutpoint using log-rank statistics
    max_statistic = -np.inf
    optimal_cutpoint = None
    
    # Test cutpoints between 15th and 85th percentiles
    percentiles = np.arange(15, 86, 1)
    cutpoints = np.percentile(risk_scores_df['risk_score'], percentiles)
    
    for cutpoint in cutpoints:
        groups = (risk_scores_df['risk_score'] > cutpoint).astype(int)
        chisq, pvalue = compare_survival(survival_data, groups)

        print(f"The p-value for {cutpoint} is {pvalue} with a chi-stat of {chisq} using the following groups ")

        if chisq > max_statistic:
            max_statistic = chisq
            optimal_cutpoint = cutpoint
            optimal_pvalue = pvalue
    
    # Assign risk groups - higher scores = higher risk
    risk_scores_df['risk_group'] = (risk_scores_df['risk_score'] > optimal_cutpoint).map({True: 'high', False: 'low'})
    
    # Sort by risk score for visualization
    risk_scores_df = risk_scores_df.sort_values('risk_score')
    risk_scores_df['rank'] = range(len(risk_scores_df))
    
    # Create the combined risk profile plot
    plot_risk_profile(risk_scores_df, merged_data, coefficients, optimal_cutpoint)
    
    # Calculate HR using Cox model
    X = (risk_scores_df['risk_group'] == 'high').values.reshape(-1, 1)
    y = np.zeros(len(X), dtype=[('status', bool), ('time', float)])
    y['status'] = risk_scores_df['status'].astype(bool)
    y['time'] = risk_scores_df['time']
    
    cph = CoxPHSurvivalAnalysis()
    cph.fit(X, y)
    hr = np.exp(cph.coef_[0])
    
    # Plot Kaplan-Meier curves with statistics
    plt.figure(figsize=(10, 6))
    
    for group, color in zip(['high', 'low'], ['red', 'blue']):
        mask = risk_scores_df['risk_group'] == group
        if sum(mask) > 0:
            time = risk_scores_df.loc[mask, 'time']
            status = risk_scores_df.loc[mask, 'status'].astype(bool)
            
            time_km, survival_prob = kaplan_meier_estimator(
                status,
                time
            )
            
            plt.step(time_km, survival_prob, where="post", 
                    label=f"{group.capitalize()} risk (n={sum(mask)})",
                    color=color)
    
    # Add HR and p-value to the plot
    stats_text = (f'HR = {hr:.2f}\n'
                 f'Log-rank P = {optimal_pvalue:.2e}')
    plt.text(0.05, 0.15, stats_text, transform=plt.gca().transAxes,
             bbox=dict(facecolor='white', alpha=0.8))
    
    plt.xlabel('Time')
    plt.ylabel('Survival probability')
    plt.title('Kaplan-Meier Curves by Risk Group')
    plt.grid(True)
    plt.legend()
    plt.show()  
    
    # Print statistics
    print("\nRisk group sizes:")
    print(risk_scores_df['risk_group'].value_counts())
    print(f"\nOptimal cutpoint: {optimal_cutpoint:.4f}")
    print(f"Log-rank test p-value: {optimal_pvalue:.2e}")
    print(f"Hazard Ratio: {hr:.2f}")
    
    return risk_scores_df


# Run pipeline

In [None]:
# Example usage:
source_paths = {
    #"I-NE vs Rest": "/home/rafaed/work/RO_src/STAnalysis/notebooks/downstream/Bulk RNA/deg_analysis_results_de_novo/deg_results_4_vs_rest.csv",
    #"I-nNE vs Rest": "/home/rafaed/work/RO_src/STAnalysis/notebooks/downstream/Bulk RNA/deg_analysis_results_de_novo/deg_results_3_vs_rest.csv",
    #"N vs Rest": "/home/rafaed/work/RO_src/STAnalysis/notebooks/downstream/Bulk RNA/deg_analysis_results_de_novo/deg_results_2_vs_rest.csv",
    #"A vs Rest": "/home/rafaed/work/RO_src/STAnalysis/notebooks/downstream/Bulk RNA/deg_analysis_results_de_novo/deg_results_1_vs_rest.csv",
    "Alive vs Dead": "/home/rafaed/work/RO_src/STAnalysis/notebooks/downstream/Bulk RNA/deg_analysis_os/DEG_results_Alive_vs_Dead.csv"
}

# Create CSV file for status OS (dead vs alive)
deg_os = DEGAnalysis(adata_nmf_cp, design_factor='status-os', layer='raw_counts', output_dir="./deg_analysis_os")
deg_os.create_dds()
deg_os.run_comparisons()
deg_os.save_results()
deg_os.create_volcano_grid()
results_deg_os = deg_os.get_results()
os_top_genes = get_top_genes(source_paths = {"Alive vs Dead": "/home/rafaed/work/RO_src/STAnalysis/notebooks/downstream/Bulk RNA/deg_analysis_os/DEG_results_Alive_vs_Dead.csv"}, n_genes = 100)

# Create CSV file for status TTP (progressed vs didn't progress)
deg_ttp = DEGAnalysis(adata_nmf_cp, design_factor='status-ttp', layer='raw_counts', output_dir="./deg_analysis_ttp")
deg_ttp.create_dds()
deg_ttp.run_comparisons()
deg_ttp.save_results()
deg_ttp.create_volcano_grid()
results_deg_ttp = deg_ttp.get_results()
ttp_top_genes = get_top_genes(source_paths = {"Did not progress vs Progressed": "/home/rafaed/work/RO_src/STAnalysis/notebooks/downstream/Bulk RNA/deg_analysis_ttp/DEG_results_Did not progress_vs_Progressed.csv"}, n_genes = 100)

# Retrieve the coefficients from LASSO Cox from the DEG genes
results_ttp = evaluate_gene_sets(merged_data, adata_nmf_cp, ttp_top_genes, endpoint="ttp")
results_df_ttp = results_ttp['results_df']
best_models_ttp = results_ttp['best_models']
coefficient_dfs_ttp = results_ttp['coefficient_dfs']

results_os = evaluate_gene_sets(merged_data, adata_nmf_cp, os_top_genes, endpoint="os")
results_df_os = results_os['results_df']
best_models_os = results_os['best_models']
coefficient_dfs_os = results_os['coefficient_dfs']

In [None]:
# Retrieve the risk scores
risk_scores_df_ttp = calculate_risk_scores(adata_nmf_cp, coef_df = coefficient_dfs_ttp["Did not progress vs Progressed"])
risk_scores_df_os = calculate_risk_scores(adata_nmf_cp, coef_df = coefficient_dfs_os["Alive vs Dead"])