In [None]:
import scanpy as sc
import pandas as pd
import numpy as np

def sample_celltype_with_oversampling(adata, celltype_key, n_samples, seed):
    """
    Sample cells with oversampling for balanced representation across cell types
    """
    np.random.seed(seed)
    all_indices = []
    for celltype in adata.obs[celltype_key].unique():
        celltype_indices = adata.obs[adata.obs[celltype_key] == celltype].index
        if len(celltype_indices) < n_samples:
            # Random Oversampling
            oversampled_indices = np.random.choice(celltype_indices, n_samples, replace=True)
            all_indices.extend(oversampled_indices)
        else:
            sampled_indices = np.random.choice(celltype_indices, n_samples, replace=False)
            all_indices.extend(sampled_indices)
    return adata[all_indices].copy()

# Load data
adata = sc.read_h5ad("reference_adata.h5ad")

# Perform sampling
sample_sizes = [5, 10, 50]
celltype_key = 'cell_type'  # Adjust based on your actual celltype key

for seed in range(10):  # Using seeds 0 to 9
    print(f"\nUsing seed {seed}:")
    for n in sample_sizes:
        sampled_adata = sample_celltype_with_oversampling(adata, celltype_key, n, seed)
        
        # Save sampled dataset
        filename = f"./few_shot_file/MY/My_Train_data_{n}_seed_{seed}.h5ad"
        sampled_adata.write(filename)
        
        # Output distribution summary
        print(f"\nDistribution Summary (n={n}, seed={seed}):")
        print(sampled_adata.obs[celltype_key].value_counts())
        print(f"Total cells: {sampled_adata.n_obs}")
        print(f"Total genes: {sampled_adata.n_vars}")
        print(f"obs keys: {list(sampled_adata.obs.keys())}")
        print(f"var keys: {list(sampled_adata.var.keys())}")
        print(f"File saved as: {filename}")

# Output original dataset distribution
print("\nOriginal Dataset Distribution:")
print(adata.obs[celltype_key].value_counts())
print(f"Total cells: {adata.n_obs}")
print(f"Total genes: {adata.n_vars}")
print(f"obs keys: {list(adata.obs.keys())}")
print(f"var keys: {list(adata.var.keys())}")