In [5]:
import scanpy as sc
import pandas as pd
import numpy as np

def sample_celltype_with_oversampling(adata, celltype_key, n_samples, seed):
    """
    Sample cells with oversampling for balanced representation across cell types
    """
    np.random.seed(seed)
    all_indices = []
    for celltype in adata.obs[celltype_key].unique():
        celltype_indices = adata.obs[adata.obs[celltype_key] == celltype].index
        if len(celltype_indices) < n_samples:
            # Random Oversampling
            oversampled_indices = np.random.choice(celltype_indices, n_samples, replace=True)
            all_indices.extend(oversampled_indices)
        else:
            sampled_indices = np.random.choice(celltype_indices, n_samples, replace=False)
            all_indices.extend(sampled_indices)
    return adata[all_indices].copy()

# Load data
adata = sc.read_h5ad("reference_adata.h5ad")

# Perform sampling
sample_sizes = [5, 10, 50]
celltype_key = 'cell_type'  # Adjust based on your actual celltype key

for seed in range(10):  # Using seeds 0 to 9
    for n in sample_sizes:
        sampled_adata = sample_celltype_with_oversampling(adata, celltype_key, n, seed)
        
        # Save sampled dataset
        filename = f"./few_shot_file/MY/My_Train_data_{n}_seed_{seed}.h5ad"
        sampled_adata.write(filename)
        
        # Output distribution summary
        if seed == 0:
            print(f"\nDistribution Summary (n={n}, seed={seed}):")
            print(sampled_adata.obs[celltype_key].value_counts())
            print(f"Total cells: {sampled_adata.n_obs}")
            print(f"Total genes: {sampled_adata.n_vars}")
            print(f"obs keys: {list(sampled_adata.obs.keys())}")
            print(f"var keys: {list(sampled_adata.var.keys())}")
            print(f"File saved as: {filename}")

# Output original dataset distribution
print("\nOriginal Dataset Distribution:")
print(adata.obs[celltype_key].value_counts())
print(f"Total cells: {adata.n_obs}")
print(f"Total genes: {adata.n_vars}")
print(f"obs keys: {list(adata.obs.keys())}")
print(f"var keys: {list(adata.var.keys())}")


Distribution Summary (n=5, seed=0):
cell_type
Macro_C1QC      5
Macro_FN1       5
Macro_GPNMB     5
Macro_IL1B      5
Macro_INHBA     5
Macro_ISG15     5
Macro_LYVE1     5
Macro_NLRP3     5
Macro_SPP1      5
Mono_CD14       5
Mono_CD16       5
cDC1_CLEC9A     5
cDC2_CD1A       5
cDC2_CD1C       5
cDC2_CXCL9      5
cDC2_CXCR4hi    5
cDC2_FCN1       5
cDC2_IL1B       5
cDC2_ISG15      5
cDC3_LAMP3      5
pDC_LILRA4      5
Name: count, dtype: int64
Total cells: 105
Total genes: 3000
obs keys: ['cell_type', 'cancer_type', 'batch']
var keys: ['highly_variable', 'means', 'dispersions', 'dispersions_norm']
File saved as: ./few_shot_file/MY/My_Train_data_5_seed_0.h5ad

Distribution Summary (n=10, seed=0):
cell_type
Macro_C1QC      10
Macro_FN1       10
Macro_GPNMB     10
Macro_IL1B      10
Macro_INHBA     10
Macro_ISG15     10
Macro_LYVE1     10
Macro_NLRP3     10
Macro_SPP1      10
Mono_CD14       10
Mono_CD16       10
cDC1_CLEC9A     10
cDC2_CD1A       10
cDC2_CD1C       10
cDC2_CXCL9      

  utils.warn_names_duplicates("obs")



Distribution Summary (n=50, seed=0):
cell_type
Macro_C1QC      50
Macro_FN1       50
Macro_GPNMB     50
Macro_IL1B      50
Macro_INHBA     50
Macro_ISG15     50
Macro_LYVE1     50
Macro_NLRP3     50
Macro_SPP1      50
Mono_CD14       50
Mono_CD16       50
cDC1_CLEC9A     50
cDC2_CD1A       50
cDC2_CD1C       50
cDC2_CXCL9      50
cDC2_CXCR4hi    50
cDC2_FCN1       50
cDC2_IL1B       50
cDC2_ISG15      50
cDC3_LAMP3      50
pDC_LILRA4      50
Name: count, dtype: int64
Total cells: 1050
Total genes: 3000
obs keys: ['cell_type', 'cancer_type', 'batch']
var keys: ['highly_variable', 'means', 'dispersions', 'dispersions_norm']
File saved as: ./few_shot_file/MY/My_Train_data_50_seed_0.h5ad


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")



Original Dataset Distribution:
cell_type
Mono_CD16       1721
Mono_CD14       1619
cDC2_CD1C       1368
Macro_C1QC       792
Macro_GPNMB      456
Macro_IL1B       456
Macro_NLRP3      454
Macro_LYVE1      447
cDC2_CXCR4hi     371
Macro_SPP1       367
Macro_FN1        263
cDC2_CD1A        246
cDC2_IL1B        245
Macro_INHBA      199
Macro_ISG15      169
cDC2_FCN1        136
cDC3_LAMP3       121
cDC1_CLEC9A      112
pDC_LILRA4       111
cDC2_ISG15        59
cDC2_CXCL9        36
Name: count, dtype: int64
Total cells: 9748
Total genes: 3000
obs keys: ['cell_type', 'cancer_type', 'batch']
var keys: ['highly_variable', 'means', 'dispersions', 'dispersions_norm']


  utils.warn_names_duplicates("obs")
