In [None]:
import numpy as np
import pandas as pd
import random
import os

def generate_single_dataset(
    output_dir,
    total_genes,
    num_timepoints,  # Unused, kept for compatibility
    num_correlated_genes,
    correlated_value,  # Unused
    noise_level,       # Unused
    zero_correlated,   # Unused
    celltypes,
    seed,
    ppi_edges_per_gene,
    dataset_index
):
    np.random.seed(seed + dataset_index)
    random.seed(seed + dataset_index)

    # Single shape column
    shape_labels = ["shape"]
    data = []
    gene_names = []
    regular_celltypes = [ct for ct in celltypes if ct != "Fakecells"]

    # 1. All regular celltypes have shape=0
    for gene_index in range(total_genes):
        gene = f"gene{gene_index}"
        gene_names.append(gene)
        for celltype in regular_celltypes:
            data.append([f"{gene}_{celltype}", 0])

    # 2. Correlated Fakecells entries → shape=1
    correlated_genes = set()
    for i in range(num_correlated_genes):
        gene = f"gene{i % total_genes}"
        correlated_genes.add(gene)
        data.append([f"{gene}_Fakecells", 1])

    # 3. Non-correlated Fakecells entries → shape=0
    for gene_index in range(total_genes):
        gene = f"gene{gene_index}"
        if gene not in correlated_genes:
            data.append([f"{gene}_Fakecells", 0])

    # Create DataFrame
    df = pd.DataFrame(data, columns=[""] + shape_labels)

    # Save data
    os.makedirs(output_dir, exist_ok=True)
    filename_prefix = (
        f"synthetic_shape_corr{num_correlated_genes}_genes{total_genes}_idx{dataset_index}"
    )
    data_path = os.path.join(output_dir, f"{filename_prefix}.csv")
    df.to_csv(data_path, index=False)

    # 4. Generate PPI network
    ppi_edges = set()
    correlated_genes = list(correlated_genes)

    # Fully connect correlated genes
    for i in range(len(correlated_genes)):
        for j in range(i + 1, len(correlated_genes)):
            edge = tuple(sorted((correlated_genes[i], correlated_genes[j])))
            ppi_edges.add(edge)

    # Randomly connect remaining genes
    for gene in gene_names:
        if gene in correlated_genes:
            continue
        possible_targets = [g for g in gene_names if g != gene]
        targets = random.sample(possible_targets, min(ppi_edges_per_gene, len(possible_targets)))
        for target in targets:
            edge = tuple(sorted((gene, target)))
            ppi_edges.add(edge)

    # Save PPI
    ppi_path = os.path.join(output_dir, f"{filename_prefix}_ppi.tsv")
    with open(ppi_path, "w") as f:
        f.write("gene1\tgene2\n")
        for g1, g2 in sorted(ppi_edges):
            f.write(f"{g1}\t{g2}\n")

    print(f"✅ Generated: {data_path}, {ppi_path}")


def generate_all_combinations(
    output_dir="datasets",
    timepoint_range=(4,),                  # unused but kept for compatibility
    correlated_gene_range=(2, 4, 6),
    total_gene_range=(50, 75, 100),
    correlated_value=1.0,
    noise_level=0.05,
    zero_correlated=True,
    celltypes=["Bcell", "Tcell", "NKcell", "Monocyte", "Dendritic", "Fakecells"],
    seed=42,
    ppi_edges_per_gene=5
):
    dataset_index = 0
    for num_timepoints in timepoint_range:  # loop kept for compatibility
        for num_correlated_genes in correlated_gene_range:
            for total_genes in total_gene_range:
                if num_correlated_genes >= total_genes:
                    print(f"⚠️ Skipping: correlated_genes >= total_genes ({num_correlated_genes} >= {total_genes})")
                    continue
                generate_single_dataset(
                    output_dir=output_dir,
                    total_genes=total_genes,
                    num_timepoints=num_timepoints,
                    num_correlated_genes=num_correlated_genes,
                    correlated_value=correlated_value,
                    noise_level=noise_level,
                    zero_correlated=zero_correlated,
                    celltypes=celltypes,
                    seed=seed,
                    ppi_edges_per_gene=ppi_edges_per_gene,
                    dataset_index=dataset_index
                )
                dataset_index += 1


# ✅ Run this to generate all combinations of datasets (single shape column)
generate_all_combinations(
    output_dir="datasets",
    correlated_gene_range=(1, 2),                       # correlated genes
    total_gene_range=(50, 75, 100),                     # total genes
    ppi_edges_per_gene=20                               # edges for non-correlated genes
)

✅ Generated: datasets/synthetic_shape_corr1_genes50_idx0.csv, datasets/synthetic_shape_corr1_genes50_idx0_ppi.tsv
✅ Generated: datasets/synthetic_shape_corr1_genes75_idx1.csv, datasets/synthetic_shape_corr1_genes75_idx1_ppi.tsv
✅ Generated: datasets/synthetic_shape_corr1_genes100_idx2.csv, datasets/synthetic_shape_corr1_genes100_idx2_ppi.tsv
✅ Generated: datasets/synthetic_shape_corr2_genes50_idx3.csv, datasets/synthetic_shape_corr2_genes50_idx3_ppi.tsv
✅ Generated: datasets/synthetic_shape_corr2_genes75_idx4.csv, datasets/synthetic_shape_corr2_genes75_idx4_ppi.tsv
✅ Generated: datasets/synthetic_shape_corr2_genes100_idx5.csv, datasets/synthetic_shape_corr2_genes100_idx5_ppi.tsv
