In [None]:
import sys
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
#
repo_dir = '/home/projects/amit/annaku/repos/Blueprint'
sys.path.append(repo_dir)
#
import anndata as ad
import scanpy as sc
#
plt.rcParams["image.cmap"] = "Set1"
sns.set_theme(style='ticks', rc={'axes.grid': False})

In [None]:
from src.cnv_utils import *
from src.palettes import *

%load_ext rpy2.ipython

In [None]:
pd.set_option('display.max_columns', 300)
%config InlineBackend.figure_format = 'png'
plt.rcParams['pdf.fonttype'] = 'truetype'
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['figure.dpi'] = 300

sns.set_style('ticks')
sns.set_style('white')

In [None]:
version = '20250306'

project_root = '/home/projects/amit/annaku/repos/Blueprint'
conf_path = os.path.join(project_root, 'configs', 'config.yaml')
conf = OmegaConf.load(conf_path)

save_figures = True

# cnv outputs to adata

In [None]:
data_path =conf['outputs']['output_dir']
filename = f'adata_PC_with_ann_merged_v_{version}.h5ad' 
adata_cnv = sc.read_h5ad(os.path.join(data_path, filename))
adata_cnv = adata_cnv[adata_cnv.obs['cells_rem_dupl_between_methods'] == False]

In [None]:
adata_cnv.obs['Sample.Code.Cell'] = adata_cnv.obs['Sample.Code'].astype(str) + '.' + adata_cnv.obs['Populations'].astype(str)

In [None]:
# arch merge, attention to version

path_save_8arch = '/home/projects/amit/annaku/repos/Blueprint/data/processed/nmf_outputs/renamed_with_8_separated/'

arch = pd.read_csv(os.path.join(path_save_8arch,f'arch_sample_v7_samplelevel_only_malignant_renamed_with_added_samples_v_{version}.csv'), index_col = 0)
arch[['Cluster', 'Cluster_exp']] = arch[['Cluster', 'Cluster_exp']].astype(str)
arch_gene = pd.read_csv(os.path.join(path_save_8arch,'arch_gene_v7_samplelevel_only_malignant_renamed.csv'), index_col = 0)
arch_gene[['Cluster']] = arch_gene[['Cluster']].astype(str)

adata_cnv.obs['PID'] = 'z.' + adata_cnv.obs['Method'].astype(str) + '_' + adata_cnv.obs['Populations'].astype(str) + '_' + adata_cnv.obs["Sample.Code"].astype(str).str.lower()
adata_cnv.obs['PID'] = adata_cnv.obs['PID'].str.lower()

arch['PID'] = arch['PID'].str.lower() 
arch['patient'] = arch['patient'].str.lower()

adata_cnv.obs['index'] = adata_cnv.obs.index
adata_cnv.obs = pd.merge(adata_cnv.obs, arch[['Cluster_exp', 'Cluster', 'PID', 'prolif_high_0.1', '8.0']], how='left', on='PID')
adata_cnv.obs.index = adata_cnv.obs['index']

adata_cnv.obs = adata_cnv.obs.rename(columns = {'Cluster_exp':'arch', 'Cluster':'arch_with_8', '8.0':'prolif_coef'})

arch_rename_dict = {'2.0':'MM1',
                    '3.0':'MM2',
                    '4.0':'MM3',
                    '6.0':'MM4',
                    '7.0':'MM5'}

adata_cnv.obs['arch'] = adata_cnv.obs['arch'].replace(arch_rename_dict)

In [None]:
adata_minibulk_combined = process_minibulk_cnv(adata_cnv,
                                            to_filter = False,
                                            reference_key = 'Populations',
                                        reference_cat=['Normal_PC'],
                                            preproc_standart = False,
                                            cat_to_group=['Sample.Code', 'Populations'])

adata_minibulk_combined.obs['Populations'].value_counts()

In [None]:
adata_minibulk_combined.obs = adata_minibulk_combined.obs.drop(columns = ['cnv_leiden'])

In [None]:
methods = ['SPID', 'MARS']
base_data_path =conf['outputs']['output_dir'] + '/infercnv_r_output/output_infercnv_r_arch_prolif_'

dataframes = []

for method in methods:
    data_path = base_data_path + method
    filename = 'infercnv.observation_groupings.txt'
    
    obs_df = pd.read_csv(os.path.join(data_path, filename), sep=' "', engine='python', header=0,
                         names=['cell_id', 'group', 'dendrogram_color', 'annotation', 'color'])
    
    obs_df['cell_id'] = obs_df['cell_id'].str.strip('"')
    obs_df = obs_df.set_index('cell_id')
    obs_df = obs_df.apply(lambda x: x.str.strip('"'))
    
    dataframes.append(obs_df)

merged_obs_df = pd.concat(dataframes)

ann = adata_minibulk_combined.obs.copy()
ann_cell_level = pd.concat([ann, merged_obs_df], axis=1)
print(ann_cell_level.shape)

ann_cell_level['prolif_high_0.1'] = ann_cell_level['prolif_high_0.1'].astype(str)

ann_cell_level.head(2)

In [None]:
%%R -o expr_data_MARS,row_names_MARS,col_names_MARS,gene_order_MARS
result <- readRDS('/home/projects/amit/annaku/repos/Blueprint/data/processed/infercnv_r_output/output_infercnv_r_arch_prolif_MARS/run.final.infercnv_obj')

expr_data_MARS <- result@expr.data

gene_order_MARS <- result@gene_order

row_names_MARS <- rownames(expr_data_MARS)
col_names_MARS <- colnames(expr_data_MARS)

In [None]:
%%R -o expr_data_SPID,row_names_SPID,col_names_SPID,gene_order_SPID
result <- readRDS('/home/projects/amit/annaku/repos/Blueprint/data/processed/infercnv_r_output/output_infercnv_r_arch_prolif_SPID/run.final.infercnv_obj')

expr_data_SPID <- result@expr.data

gene_order_SPID <- result@gene_order

row_names_SPID <- rownames(expr_data_SPID)
col_names_SPID <- colnames(expr_data_SPID)

In [None]:
expr_data_MARS_df = pd.DataFrame(data=expr_data_MARS, index=row_names_MARS, columns=col_names_MARS)
print(expr_data_MARS_df.shape)
expr_data_SPID_df = pd.DataFrame(data=expr_data_SPID, index=row_names_SPID, columns=col_names_SPID)
print(expr_data_SPID_df.shape)
expr_data_MARS_df.head(2)

common_genes = gene_order_MARS.index.intersection(gene_order_SPID.index)
gene_order = gene_order_MARS.loc[common_genes]

expr_data_MARS_common = expr_data_MARS_df.loc[common_genes]
expr_data_SPID_common = expr_data_SPID_df.loc[common_genes]

cnv_pred = pd.concat([expr_data_MARS_common, expr_data_SPID_common], axis=1)


In [None]:
path = '/home/projects/amit/annaku/repos/Blueprint/data/processed/'
cb_file = pd.read_csv(os.path.join(path, 'cytoBand.txt'), sep='\t', )

def get_cytoband(chrom, pos, cytoband_df):
    chr_bands = cytoband_df[cytoband_df['chr'] == chrom]
    
    band = chr_bands[(chr_bands['start'] <= pos) & (chr_bands['end'] >= pos)]
    
    if len(band) == 0:
        return None
    return band.iloc[0]['band']

def annotate_cytobands(df, cytoband_df):
    df = df.copy()
    df['cytoband'] = df.apply(
        lambda row: get_cytoband(row['chr'], row['start'], cytoband_df), 
        axis=1
    )
    return df

gene_df = pd.DataFrame(gene_order).reset_index().rename(columns={'index': 'gene_name'})
ann_genes = annotate_cytobands(gene_df, cb_file)
ann_genes.index = ann_genes['gene_name']

def get_arm(cytoband):
    if pd.isna(cytoband):
        return None
    return cytoband[0]

ann_genes['arm'] = ann_genes['cytoband'].apply(get_arm)
ann_genes['chr_arm'] = ann_genes['chr'].astype(str) + '_' + ann_genes['arm'].astype(str)

ann_genes['chr_band'] = ann_genes['chr'].astype(str) + '_' + ann_genes['cytoband'].astype(str)

print(ann_genes)
ann_genes.head()

In [None]:
display(pd.crosstab(ann_genes['chr'], ann_genes['arm']))

# # too few genes per arm
ann_genes['chr_arm'] = ann_genes['chr_arm'].replace({'chr14_p':'chr14',
                                                                                'chr14_q':'chr14',
                                                                                'chr22_p':'chr22',
                                                                                'chr22_q':'chr22',
                                                                                'chr15_p':'chr15',
                                                                                'chr15_q':'chr15',
                                                                                'chr13_p':'chr13',
                                                                                'chr13_q':'chr13',
                                                                                'chr21_p':'chr21',
                                                                                'chr21_q':'chr21',
                                                                                 })

In [None]:
# minibulk level

adata_cell_level = ad.AnnData(
    X=cnv_pred.T.loc[ann_cell_level.index].values, 
    obs=ann_cell_level,
    var=ann_genes
)

# aggregate adata pat level

pat_level_expression = adata_cell_level.to_df().groupby(adata_cell_level.obs['Sample.Code.Cell']).mean()
pat_level_ann = adata_cell_level.obs.groupby('Sample.Code.Cell').first()

adata_pat_level = ad.AnnData(
    X=pat_level_expression.values,
    obs=pat_level_ann,
    var=adata_cell_level.var
)

adata_pat_level

In [None]:
adata_pat_level.var['chr_band_wide'] = adata_pat_level.var['chr_band'].str.split('.').str[0]

In [None]:
adata_pat_level.obs['Populations'] = pd.Categorical(adata_pat_level.obs['Populations'])
adata_pat_level.obs['arch'] = pd.Categorical(adata_pat_level.obs['arch'])

adata_pat_level.uns['Populations_colors'] = [pal_cell_pb[x] for x in adata_pat_level.obs['Populations'].cat.categories]
adata_pat_level.uns['arch_colors'] = [pal_architype_renamed[x] for x in adata_pat_level.obs['arch'].cat.categories]

sc.tl.pca(adata_pat_level, svd_solver='arpack', n_comps=50)
sc.pl.pca(adata_pat_level, color=['prolif_high_0.1','arch', 'Method'], ncols = 2,
           frameon = False,)
sc.pp.neighbors(adata_pat_level, 
                n_neighbors=15,     
                n_pcs=30)         

sc.tl.umap(adata_pat_level, min_dist=0.8)

sc.pl.umap(adata_pat_level, 
           color=['arch', 'prolif_high_0.1', 'Populations'],
           frameon = False,
           ncols=2)

sc.pl.pca_variance_ratio(adata_pat_level, n_pcs=50)

# features extraction

In [None]:
data_path ='/home/projects/amit/annaku/repos/Blueprint/data/processed/'

In [None]:
from src.cnv_utils import calculate_genome_wide_cnv_burden

In [None]:
adata_pat_level = calculate_genome_wide_cnv_burden(
    adata=adata_pat_level,
    window_size=100,
    window_step=25,
    sort_genes=True, 
)

adata_pat_level.var['chr_band_wide'] = adata_pat_level.var['chr_band'].str.split('.').str[0]

In [None]:
regions_todel = ['chr5_None', 'chr17_None', 'chr18_None', 'chr20_None']

mask = adata_pat_level.var['chr_arm'].isin(regions_todel)
adata_pat_level_ = adata_pat_level[:, ~mask].copy()

window_size = 10 
window_step = 5 
min_genes = 20  

adata_arms = sliding_window_cnv_region(
    adata_pat_level_, 
    region_col='chr_arm',
    window_size=window_size,
    window_step=window_step,
    min_genes=min_genes
)

def chr_arm_key(x):
    parts = x.split('_')
    chrom = parts[0][3:]
    arm = parts[1] if len(parts) > 1 else 'p' 
    
    try:
        chrom_num = int(chrom)
    except ValueError:
        chrom_num = float('inf') if chrom in ['X', 'Y'] else 0
    
    return (chrom_num, 0 if arm == 'p' else 1)

sorted_arms = sorted(adata_arms.var_names.unique(), key=chr_arm_key)

In [None]:
adata_pat_level_ = adata_pat_level.copy()

window_size = 10  
window_step = 5   
min_genes = 10   

adata_bands = sliding_window_cnv_region(
    adata_pat_level_, 
    region_col='chr_band_wide',
    window_size=window_size,
    window_step=window_step,
    min_genes=min_genes
)

def chr_arm_key(x):
    parts = x.split('_')
    chrom = parts[0][3:]
    arm = parts[1] if len(parts) > 1 else 'p'  
    
    try:
        chrom_num = int(chrom)
    except ValueError:
        chrom_num = float('inf') if chrom in ['X', 'Y'] else 0
    
    return (chrom_num, 0 if arm == 'p' else 1)

sorted_bands = sorted(adata_bands.var_names.unique(), key=chr_arm_key)
len(sorted_bands)

def extract_chr_and_arm(var_name):
    parts = var_name.split('_')
    chr_part = parts[0]
    arm_part = parts[1] if len(parts) > 1 else ''
    arm = arm_part[0] if arm_part else ''
    return chr_part, f"{chr_part}_{arm}" if arm else None

var_names_list = list(adata_bands.var_names)

extracted_data = [extract_chr_and_arm(var_name) for var_name in var_names_list]

extracted_df = pd.DataFrame(extracted_data, columns=['chr', 'chr_arm'], index=adata_bands.var_names)

adata_bands.var['chr'] = extracted_df['chr']
adata_bands.var['chr_arm'] = extracted_df['chr_arm']

unique_chromosomes = set(var_name.split('_')[0] for var_name in adata_bands.var_names)

# binarisation and HP

In [None]:
adata_cell_level.obs['arch'] = adata_cell_level.obs['arch'].astype(str).fillna('NA').astype('category')
adata_cell_level.obs['Populations'] = adata_cell_level.obs['Populations'].astype('category')

In [None]:
sc.tl.dendrogram(adata_cell_level, groupby=['Populations', 'arch'])

sc.pl.heatmap(adata_cell_level, 
              var_names=adata_cell_level.var_names,  
              groupby=['Populations', 'arch'],        
              cmap='bwr',                
              vcenter=1,                     
              vmin=0.6,                     
              vmax=1.4,
              figsize = (6,4),                     
              dendrogram=True,              
              show_gene_labels=False,       
              swap_axes=False)   

In [None]:
for var_name in adata_arms.var_names:
    values = adata_arms[:, var_name].X.flatten()
    
    # initialize calls as 'neutral'
    calls = np.full(len(values), 'neutral', dtype='object')
    
    # set amplifications (> 1.05)
    calls[values > 1.05] = 'ampl'
    
    # set deletions (< 0.95) 
    calls[values < 0.95] = 'del'
    
    adata_arms.obs[f'{var_name}_call'] = calls

print(f"Added {len(adata_arms.var_names)} new call columns")
print("Example column names:", list(filter(lambda x: x.endswith('_call'), adata_arms.obs.columns))[:3])

# create combined chromosome calls for specified chromosomes
for chrom in [3, 5, 7, 9, 11, 19]:
    # get p and q arm calls
    p_calls = adata_arms.obs[f'chr{chrom}_p_call']
    q_calls = adata_arms.obs[f'chr{chrom}_q_call']
    
    # initialize combined calls
    combined_calls = np.full(len(p_calls), 'mixed', dtype='object')
    
    # where p and q match, use that call
    mask_match = p_calls == q_calls
    combined_calls[mask_match] = p_calls[mask_match]
    
    # add to obs
    adata_arms.obs[f'chr{chrom}_call'] = combined_calls

print("Example values for chr3_call:", adata_arms.obs['chr3_call'].value_counts())

# HP
chromosomes = [3, 5, 7, 9, 11, 15, 19, 21]

# Count amplifications for each sample
ampl_counts = np.zeros(len(adata_arms))

for chrom in chromosomes:
    # Add to count where status is 'ampl'
    ampl_counts += (adata_arms.obs[f'chr{chrom}_call'] == 'ampl')

# Create HP_call column - 'Yes' if â‰¥2 amplifications, 'No' otherwise
adata_arms.obs['HP_call'] = np.where(ampl_counts >= 2, 'Yes', 'No')

# Verify the results
print("HP_call distribution:", adata_arms.obs['HP_call'].value_counts())
print("\nMedian number of amplifications:", np.median(ampl_counts))

In [None]:
plt.rcParams['pdf.fonttype'] = 'truetype'
plt.rcParams['svg.fonttype'] = 'none'

save_dir = '/home/projects/amit/annaku/repos/Blueprint/figures/fig3/'

In [None]:
from scipy import stats

def plot_cnv_distribution(adata_arms, regions, figsize=(12, 3), threshold_del=0.95, threshold_amp=1.05):
    
    n = len(regions)
    fig, axes = plt.subplots(1, n, figsize=figsize, sharey=True)
    axes = [axes] if n == 1 else axes
    
    control_labels = ['nan', 'NA', 'None']
    palette = {**pal_architype_renamed, 'Reference': 'red'}
    
    archetypes = sorted(adata_arms.obs['arch'].astype(str).unique())
    stats_dict = {}
    
    for ax, region in zip(axes, regions):
        for arch in archetypes:
            scores = adata_arms[adata_arms.obs['arch'].astype(str) == arch, region].X.flatten()
            label = 'Reference' if arch in control_labels else arch
            sns.kdeplot(scores, label=label, color=palette.get(label, '#7f7f7f'), ax=ax)
        
        ax.axvline(x=threshold_del, color='r', linestyle='--', alpha=0.7)
        ax.axvline(x=threshold_amp, color='r', linestyle='--', alpha=0.7)
        ax.set(xlabel=f'{region} score', ylabel='Density', title=region)
        ax.spines[['top', 'right']].set_visible(False)
        if ax == axes[-1]:
            ax.legend(fontsize=8)
        
        mm_archs = [a for a in archetypes if a not in control_labels]
        scores_by_arch = [adata_arms[adata_arms.obs['arch'].astype(str) == a, region].X.flatten() for a in mm_archs]
        h, p = stats.kruskal(*scores_by_arch)
        stats_dict[region] = {'h': h, 'p': p}
    
    plt.tight_layout()
    return fig, stats_dict

regions = ['chr1_p', 'chr1_q', 'chr13', 'chr17_p']
fig, stats_dict = plot_cnv_distribution(adata_arms, regions)

if save_figures:
    fig.savefig(f'{save_dir}hist_cnv_QC_binarization_v_{version}.svg', format='svg')
    fig.savefig(f'{save_dir}hist_cnv_QC_binarization_v_{version}.png', format='png')

plt.show()

for region, s in stats_dict.items():
    print(f"{region}: H={s['h']:.2f}, p={s['p']:.2e}")

In [None]:
# separated annotation

cnv_call_annotation_arms = adata_arms.obs[[col for col in adata_arms.obs.columns if col.endswith('_call')]]
cnv_call_annotation_bands = adata_bands.obs[[col for col in adata_bands.obs.columns if col.endswith('_call')]]
cnv_call_annotation_burden = adata_pat_level.obs[['cnv_burden']]
cnv_call_annotation = pd.concat([cnv_call_annotation_arms, cnv_call_annotation_bands, cnv_call_annotation_burden], axis = 1)

cnv_call_annotation.head()

In [None]:
# save

ind_malignant =[i for i in cnv_call_annotation.index if 'Malignant' in i]
cnv_call_annotation.loc[ind_malignant].to_csv(data_path+f'cnv_ann_per_sample_v_{version}.csv')

cnv_call_annotation = cnv_call_annotation.loc[ind_malignant].copy()

In [None]:
cols = [c for c in cnv_call_annotation.columns if c not in adata_bands.obs.columns]
print(f'adding {len(cols)} columns')
if len(cols) > 0:
    adata_bands.obs = pd.concat([adata_bands.obs, cnv_call_annotation[cols]], join = 'outer', axis = 1)

In [None]:
for col in cnv_call_annotation.columns:
    adata_bands.obs[col] = adata_bands.obs[col].astype(str).fillna('NA')

In [None]:
cols = [c for c in cnv_call_annotation.columns if c not in adata_pat_level.obs.columns]
print(f'adding {len(cols)} columns')
if len(cols) > 0:
    adata_pat_level.obs = pd.concat([adata_pat_level.obs, cnv_call_annotation[cols]], join = 'outer', axis = 1)

In [None]:
#cols = [c for c in cnv_call_annotation.columns if c not in adata_pat_level.obs.columns]
for col in cnv_call_annotation.columns:
    adata_pat_level.obs[col] = adata_pat_level.obs[col].astype(str).fillna('NA')

In [None]:
cols = [c for c in cnv_call_annotation.columns if c not in adata_arms.obs.columns]
print(f'adding {len(cols)} columns')
if len(cols) > 0:
    adata_arms.obs = pd.concat([adata_arms.obs, cnv_call_annotation[cols]], join = 'outer', axis = 1)

In [None]:
for col in cnv_call_annotation.columns:
    adata_arms.obs[col] = adata_arms.obs[col].astype(str).fillna('NA')

In [None]:
cols = [c for c in cnv_call_annotation.columns if c not in adata_cell_level.obs.columns]

original_index_name = adata_cell_level.obs.index.name
original_index = adata_cell_level.obs.index

merged_obs = pd.merge(
    adata_cell_level.obs, 
    cnv_call_annotation[cols], 
    on='Sample.Code.Cell', 
    how='left'
)

merged_obs = merged_obs.set_index(original_index)
adata_cell_level.obs = merged_obs

In [None]:
for col in cols:
    adata_cell_level.obs[col] = adata_cell_level.obs[col].astype(str).fillna('NA')

In [None]:
adata_cell_level.obs['Populations'].value_counts(dropna = False)

In [None]:
# artefact from removing samples sequenced by both methods

adata_cell_level = adata_cell_level[~adata_cell_level.obs['Populations'].isna()].copy()

In [None]:
sc.pl.heatmap(adata_cell_level, 
              var_names=adata_cell_level.var_names,  
              groupby=['Populations', 'chr1_q_call'],        
              cmap='bwr',                
              vcenter=1,                     
              vmin=0.6,                     
              vmax=1.4,
              figsize = (6,4),                     
              dendrogram=False,              
              show_gene_labels=False,       
              swap_axes=False)               

In [None]:
len(adata_pat_level.obs['Sample.Code'].astype(str).unique())

In [None]:
adata_pat_level.obs['Populations'].value_counts()

In [None]:
sc.pl.heatmap(adata_pat_level, 
              var_names=adata_pat_level.var_names,  
              groupby=['Populations', 'chr1_q_call'],        
              cmap='bwr',                
              vcenter=1,                     
              vmin=0.6,                     
              vmax=1.4,
              figsize = (6,4),                     
              dendrogram=False,              
              show_gene_labels=False,       
              swap_axes=False)  

# saving adatas

In [None]:
#

for col in ['std', 'max', 'min']:
    try: 
        adata_bands.var[col] = adata_bands.var[col].astype(str)
    except:
        continue

file_path = os.path.join(data_path,f'adata_cnv_obs_pat_var_bands_v_{version}.h5ad')
adata_bands.write(file_path)

#

for col in ['std', 'max', 'min']:
    try: 
        adata_arms.var[col] = adata_arms.var[col].astype(str)
    except:
        continue

file_path = os.path.join(data_path,f'adata_cnv_obs_pat_var_arms_v_{version}.h5ad')
adata_arms.write(file_path)

#

for col in ['std', 'max', 'min']:
    try: 
        adata_pat_level.var[col] = adata_pat_level.var[col].astype(str)
    except:
        continue

file_path = os.path.join(data_path,f'adata_cnv_obs_pat_var_genes_v_{version}.h5ad')
adata_pat_level.write(file_path)

#

for col in adata_cell_level.obs.columns:
    if adata_cell_level.obs[col].dtype == 'object':
        try:
            adata_cell_level.obs[col] = adata_cell_level.obs[col].astype(str)
        except:
            continue

for col in ['std', 'max', 'min']:
    try: 
        adata_cell_level.var[col] = adata_cell_level.var[col].astype(str)
    except:
        continue

file_path = os.path.join(data_path,f'adata_cnv_obs_minibulk_var_genes_v_{version}.h5ad')
adata_cell_level.write(file_path)