In [1]:
import numpy as np
import pandas as pd
import anndata
import matplotlib.pyplot as plt
import spatialdata as sd
import scanpy as sc
from scipy import stats
import seaborn as sns

In [None]:
filtered_data = sc.read_h5ad('filtered_data.h5ad')
filtered_data

In [None]:
filtered_data.obs['sample'].unique()

In [None]:
# Calculate total counts per cell
total_counts = filtered_data.X.sum(axis=1)

# Calculate median of total counts
median_reads = np.median(total_counts)

print(f"Median number of reads per cell: {median_reads}")

In [5]:
# Normalize main matrix (X)
sc.pp.normalize_total(filtered_data, target_sum=median_reads)

In [6]:
sc.pp.log1p(filtered_data)

In [None]:
np.expm1(filtered_data.X).sum(axis=1)

In [None]:
from statsmodels.stats.multitest import multipletests
import numpy as np
from scipy import stats

control_group = 'C3control'
early_groups = ['B4Tg15min', 'B5Tg30min']
late_groups = ['C4Tg2h', 'C5Tg4h']
combined_groups = {'early': early_groups, 'late': late_groups}

results = {
    'adjusted_p_values': {group: [] for group in combined_groups},
    'log2_fold_change': {group: [] for group in combined_groups}
}

for gene in filtered_data.var_names:
    gene_totalRNA = filtered_data[:, gene].X.toarray().flatten()
    control_cells = filtered_data.obs['sample'] == control_group
    control_totalRNA = gene_totalRNA[control_cells]
    for group_name, group_list in combined_groups.items():
        group_mask = filtered_data.obs['sample'].isin(group_list)
        group_totalRNA = gene_totalRNA[group_mask]
        if len(control_totalRNA) > 0 and len(group_totalRNA) > 0:
            try:
                u_stat, p_value = stats.mannwhitneyu(group_totalRNA, control_totalRNA, alternative='two-sided')
                mean_group = np.mean(group_totalRNA)
                mean_control = np.mean(control_totalRNA)
                if mean_control > 0 and mean_group > 0:
                    log2_fold_change = np.log2(mean_group / mean_control)
                else:
                    log2_fold_change = np.nan
            except ValueError:
                p_value = np.nan
                log2_fold_change = np.nan
        else:
            p_value = np.nan
            log2_fold_change = np.nan
        results['adjusted_p_values'][group_name].append(p_value)
        results['log2_fold_change'][group_name].append(log2_fold_change)

for group in combined_groups:
    p_values = results['adjusted_p_values'][group]
    valid_indices = ~np.isnan(p_values)
    valid_p_values = np.array(p_values)[valid_indices]
    if len(valid_p_values) > 0:
        adjusted_p_values = np.full(len(p_values), np.nan)
        adjusted_p_values[valid_indices] = multipletests(valid_p_values, method='fdr_bh')[1]
        results['adjusted_p_values'][group] = adjusted_p_values.tolist()

for group in combined_groups:
    filtered_data.var[f'{group}_adjusted_p_value'] = results['adjusted_p_values'][group]
    filtered_data.var[f'{group}_log2_fold_change'] = results['log2_fold_change'][group]


In [10]:
last_four_columns = filtered_data.var.iloc[:, -4:]

last_four_columns.to_csv('wilcoxon_results_last_four_columns.csv')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

early_groups = ['B4Tg15min', 'B5Tg30min']
late_groups = ['C4Tg2h', 'C5Tg4h']

p_value_threshold = 0.05
log2_fold_change_threshold = 0.4

fig, axes = plt.subplots(1, 2, figsize=(18, 4))
fig.subplots_adjust(wspace=0.4)

early_log2fc = filtered_data.var['early_log2_fold_change']
early_padj = filtered_data.var['early_adjusted_p_value']

late_log2fc = filtered_data.var['late_log2_fold_change']
late_padj = filtered_data.var['late_adjusted_p_value']

for i, (log2_fold_change, adjusted_p_values, title) in enumerate([
    (early_log2fc, early_padj, 'Early vs Control'),
    (late_log2fc, late_padj, 'Late vs Control')
]):
    ax = axes[i]
    y_values = -np.log10(adjusted_p_values)
    sns.scatterplot(
        x=log2_fold_change,
        y=np.minimum(y_values, 60),
        alpha=0.6,
        color='grey',
        ax=ax,
        s=10
    )
    upregulated_genes = (adjusted_p_values < p_value_threshold) & (log2_fold_change > log2_fold_change_threshold)
    downregulated_genes = (adjusted_p_values < p_value_threshold) & (log2_fold_change < -log2_fold_change_threshold)
    sns.scatterplot(
        x=log2_fold_change[upregulated_genes],
        y=np.minimum(y_values[upregulated_genes], 60),
        alpha=0.8,
        color='red',
        ax=ax,
        s=10
    )
    sns.scatterplot(
        x=log2_fold_change[downregulated_genes],
        y=np.minimum(y_values[downregulated_genes], 60),
        alpha=0.8,
        color='blue',
        ax=ax,
        s=10
    )
    ax.axvline(x=log2_fold_change_threshold, color='blue', linestyle='--', linewidth=1.5)
    ax.axvline(x=-log2_fold_change_threshold, color='blue', linestyle='--', linewidth=1.5)
    ax.axhline(y=-np.log10(p_value_threshold), color='green', linestyle='--', linewidth=1.5)
    ax.set_ylim(0, 60)
    ax.set_xlim(-2, 2)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_title(title, fontsize=14)
    ax.set_xlabel('Log2 Fold Change', fontsize=12)
    ax.set_ylabel('-Log10(Adjusted P-value)', fontsize=12)
    upregulated_gene_names = filtered_data.var.index[upregulated_genes].tolist()
    downregulated_gene_names = filtered_data.var.index[downregulated_genes].tolist()
    period = 'early' if i == 0 else 'late'
    pd.Series(upregulated_gene_names).to_csv(f'{period}_upregulated_genes.csv', index=False, header=False)
    pd.Series(downregulated_gene_names).to_csv(f'{period}_downregulated_genes.csv', index=False, header=False)

plt.savefig('totalRNA_volcano_plots_early_late.pdf')
plt.close(fig)