# Generating Gold Standard Benchmark Dataset

## Processing StarMap and SeqFish data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines
import pandas as pd
import os
import scanpy as sc
import distinctipy
from matplotlib.image import imread
import warnings
warnings.simplefilter(action='ignore', category=Warning)
import scanpy as sc
sc.settings.verbosity = 0
import time
from datetime import datetime

def check_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    return dir_path


def get_spatial_data(directory, is_index_col = 0):

    paths = {
        "count_file": os.path.join(directory, "simulated_data/combined_spatial_count.txt"),
        "location_file": os.path.join(directory, "simulated_data/combined_Locations.txt"),
        "obs_file": os.path.join(directory, "simulated_data/combined_cell_counts.txt"),
        "truth_file" : os.path.join(directory, "simulated_data/combined_spot_clusters.txt"),
    }

    df_counts = pd.read_csv(paths["count_file"], sep='\t', index_col= is_index_col)
    df_corrd = pd.read_csv(paths["location_file"], sep='\t', index_col= is_index_col)
    df_obs = pd.read_csv(paths["obs_file"], sep='\t', index_col=0)
    df_results = pd.read_csv(paths["truth_file"], sep='\t', index_col=0)

    # Create an AnnData object
    adata = sc.AnnData(X=df_counts.values.astype(np.float32), obs=df_obs, var={'genes': df_counts.columns})
    adata.obsm['spatial'] = df_corrd[["x","y"]].values
    adata.uns["ground_truth"] = df_results
    adata.uns["dataset_name"] = os.path.basename(directory)
    adata.var["highly_variable"] = True
    adata.var.set_index("genes", inplace=True)
    adata.uns["hvg"] = {'flavor': 'simulate'}
    adata.uns["paths"] = paths
    adata.obs[["X", "Y"]] = df_corrd[["x","y"]].values
    return adata


def get_single_cell_data_li2022(directory):
    df_counts = pd.read_csv(os.path.join(directory, "Rawdata/Spatial_count.txt"), sep='\t' ,index_col=0)
    df_corrd = pd.read_csv(os.path.join(directory, "Rawdata/Locations.txt"), sep='\t')
    df_annot = pd.read_csv(os.path.join(directory, "Rawdata/Spatial_annotate.txt"), sep='\t', index_col=0)
    adata = sc.AnnData(X=df_counts.values, obs=df_annot, var={'genes': df_counts.columns})
    adata.obs[["X", "Y"]] = df_corrd.values

    adata.uns["dataset_name"] = os.path.basename(directory)
    return adata



def plot_sim_spatial_single(adata_sc, window = 750, is_show = False):
    data = {
        'x': adata_sc.obs["X"],
        'y': adata_sc.obs["Y"]-adata_sc.obs["Y"].min(),
        'Cell Type': adata_sc.obs["celltype"]
    }
    df = pd.DataFrame(data)

    sample_name = adata_sc.uns["dataset_name"]

    # Generate a palette based on unique types
    unique_types = df['Cell Type'].unique()
    palette = distinctipy.get_colors(len(unique_types))
    palette_dict = dict(zip(unique_types, palette))


    # Create scatter plot
    plt.figure(figsize=(8, 5), dpi=100) #8 , 10 , 750
    scatter_plot = sns.scatterplot(data=df, x='x', y='y', hue='Cell Type', palette=palette_dict, s=20, edgecolors='none')
    # Add grid
    plt.grid(True, which="both")
    plt.legend(title='Cell Type', bbox_to_anchor=(1,1))

    # Set the x and y ticks interval
    plt.xticks(range(0, int(df["x"].max()) + 1, window))  # Adjust the range and interval as needed
    plt.yticks(range(0, int(df["y"].max()) + 1, window))  # Adjust the range and interval as needed

    plt.tight_layout()
    results_dir_path = check_dir(os.path.join(results_dir, sample_name, "analysis"))
    plt.savefig(os.path.join(results_dir_path, f"{sample_name}_scdata_plot.pdf"))
    
    # Show plot
    if(is_show):
        plt.show()
        
    plt.close()

    
base_dir = "D:\\MorrissyLab Dropbox\\Visium_profiling\\benchmark"
data_dir = os.path.join(base_dir, "data")
results_dir = os.path.join(base_dir, "results")

li2022_datasets = {
    "Dataset10_STARmap_li2022_sim_norm_mm": {
        "window" : 750
    },
    "Dataset4_seqFISH_li2022_sim_norm_mm": {
        "window" : 500
    },
}

for dataset in li2022_datasets:
    dataset_path = os.path.join(data_dir, dataset) 
    adata_spatial = get_spatial_data(dataset_path)
    adata_spatial.write_h5ad(os.path.join(dataset_path, "adata_spatial.h5ad"))
    adata_sc = get_single_cell_data_li2022(dataset_path)
    adata_sc.write_h5ad(os.path.join(dataset_path, "adata_sc.h5ad"))

    plot_sim_spatial_single(adata_sc, window=li2022_datasets[dataset]["window"], is_show= False)


## Processing Stereo-Seq data

In [None]:
def bin_matrix(df, window_size=100):
    bins = {}
    
    for _, row in df.iterrows():
        id_, x, y = row['id'], row['x'], row['y']
        # Calculate the bin coordinates
        bin_x = (x // window_size) * window_size
        bin_y = (y // window_size) * window_size
        # Calculate the center of the bin
        center_x = bin_x + window_size // 2
        center_y = bin_y + window_size // 2

        bin_key = (center_x, center_y)
        
        if bin_key not in bins:
            bins[bin_key] = {'ids': [], 'center': (center_x, center_y)}
        
        bins[bin_key]['ids'].append(id_)
    
    # Convert bins dictionary to DataFrame
    result = []
    for center, data in bins.items():
        result.append({'x': center[0], 'y': center[1], 'ids': data['ids']})
    
    binned_df = pd.DataFrame(result)

    binned_df["n_cells"] = binned_df["ids"].apply(len)
    binned_df.index = [f"bin{window_size}_{x}" for x in range(1, len(binned_df)+1)]
    return  binned_df


def process_steroseq_data(file_path, window = 100):

    output_path = check_dir(os.path.join(os.path.dirname(file_path), "simulated_data"))

    adata = sc.read_h5ad(file_path)
    adata.obsm["X_spatial"]=adata.obsm["spatial"]
    sc.pl.spatial(adata, basis="spatial", color="annotation", spot_size=30)
    
    df_cord = pd.DataFrame(adata.obs.index, columns=["id"])
    df_cord["x"] = adata.obsm["spatial"][:,0]
    df_cord["y"] = adata.obsm["spatial"][:,1]
    binned_df = bin_matrix(df_cord, window_size=window)
    print(binned_df["n_cells"].max(), len(binned_df))
    binned_df.to_csv(os.path.join(output_path, "combined_Locations.txt"), sep="\t")
    binned_df.to_csv(os.path.join(output_path, "combined_cell_counts.txt"), sep="\t")
    # combined_Locations.txt
    # combined_cell_counts.txt

    annotation_df = binned_df.explode(["ids"])
    annotation_df["bin_id"] = annotation_df.index
    annotation_df.index = annotation_df["ids"].values
    adata.obs = adata.obs.join(annotation_df)

    merged_df = adata.to_df().join(adata.obs[["bin_id"]], lsuffix='_left', rsuffix='_right')
    # Group by bin_id and sum the values of the other columns
    grouped_df = merged_df.groupby('bin_id').sum()
    grouped_df.loc[binned_df.index,:].to_csv(os.path.join(output_path, "combined_spatial_count.txt"), sep="\t")


    pivot_df = adata.obs.pivot_table(index='bin_id', columns='annotation', aggfunc='size', fill_value=0)
    pivot_df.loc[binned_df.index,:].to_csv(os.path.join(output_path, "combined_spot_clusters.txt"), sep="\t")

process_steroseq_data(r"D:\MorrissyLab Dropbox\Visium_profiling\benchmark\data\stereoseq_mouse_brain_li2023_sim_norm_mm\Mouse_brain_cell_bin.h5ad", window=100)
dataset_path = os.path.join(data_dir, "stereoseq_mouse_brain_li2023_sim_norm_mm") 
adata = get_spatial_data(dataset_path)
adata.write_h5ad(os.path.join(dataset_path, "adata_spatial.h5ad"))


## Processing Dance Data

In [None]:
import os
import scanpy as sc
from anndata import AnnData
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines
import pandas as pd
import os
import scanpy as sc
import distinctipy
from matplotlib.image import imread
import warnings
warnings.simplefilter(action='ignore', category=Warning)
import scanpy as sc
sc.settings.verbosity = 0
import time
from datetime import datetime



def convert_dance_data(dataset_path):
    raw_data_dict = {}
    paths  = {

    }
    for f in os.listdir(dataset_path):
        filepath = os.path.join(dataset_path, f)
        filename, ext = os.path.splitext(f)
        if f in ["readme.txt"]:
            continue
        elif ext == ".csv":
            raw_data_dict[filename] = pd.read_csv(filepath, header=0, index_col=0)
        elif ext == ".h5ad":
            raw_data_dict[filename] = sc.read_h5ad(filepath).to_df()
        else:
            warnings.warn(f"Unsupported file type {ext!r}. Only csv or h5ad are supported now.")
        if ext == ".csv" or ext == ".h5ad":
            raw_data_dict[filename].index = raw_data_dict[filename].index.astype(str)

        if(filename == "mix_count"):
            paths["count_file"] = filepath
        elif(filename == "spatial_location"):
            paths["location_file"] = filepath
        elif(filename == "true_p"):
            paths["truth_file"] = filepath
        else:
            paths[filename] = filepath


    ref_count = raw_data_dict["ref_sc_count"]
    ref_annot = raw_data_dict["ref_sc_annot"]
    count_matrix = raw_data_dict["mix_count"]
    cell_type_portion = raw_data_dict["true_p"]
    if (spatial := raw_data_dict.get("spatial_location")) is None:
        spatial = pd.DataFrame(0, index=count_matrix.index, columns=["x", "y"])



    adata_spatial = AnnData(
        count_matrix.values,
        dtype=np.float32,
        obs=pd.DataFrame(index=count_matrix.index.tolist()),
        var=pd.DataFrame(index=count_matrix.columns.tolist()),
    )
    adata_spatial.uns["ground_truth"] = cell_type_portion.astype(np.float32)
    adata_spatial.obsm["spatial"] = spatial.astype(np.float32)[["x","y"]].values
    adata_spatial.uns["hvg"] = {'flavor': 'simulate'}
    adata_spatial.uns["dataset_name"] = os.path.basename(dataset_path)
    adata_spatial.var["highly_variable"] = True
    adata_spatial.obs[["X", "Y"]] = spatial.astype(np.float32)
    
    paths["obs_file"] = os.path.join(dataset_path, "obs_file.txt")
    adata_spatial.obs.to_csv(paths["obs_file"], sep="\t")

    adata_spatial.uns["paths"] = paths
    
    
    adata_spatial.write_h5ad(os.path.join(dataset_path, "adata_spatial.h5ad"))


    adata_ref = AnnData(
        ref_count.values,
        dtype=np.float32,
        obs=ref_annot,
        var=pd.DataFrame(index=ref_count.columns.tolist()),
    )
    adata_ref.obs = adata_ref.obs.rename(columns={'cellType': 'celltype'})
    adata_ref.var["genes"] = adata_ref.var.index

    adata_ref.write_h5ad(os.path.join(dataset_path, "adata_sc.h5ad"))




dance_data_dir = r"Z:\MorrissyLab Dropbox\Visium_profiling\benchmark\data\dance"
for dataset in os.listdir(dance_data_dir):
    dataset_path = os.path.join(dance_data_dir, dataset)
    convert_dance_data(dataset_path)

## Prepare Single_cells 

In [None]:
def sum_cell_types(adata_sc):
    cell_annotations = adata_sc.obs[["celltype"]]
    cell_annotations.index.name = "spots"

    counts_df = adata_sc.to_df()
    counts_df.columns = adata_sc.var["genes"].values
    counts_df.index = adata_sc.obs.index
    counts_df.index.name = "spots"

    cell_type_gene_df  = pd.merge(counts_df, adata_sc.obs["celltype"], on='spots').groupby('celltype').sum().T
    
    return cell_type_gene_df

for dataset in os.listdir(dance_data_dir):
    dataset_path = os.path.join(dance_data_dir, dataset)
    cell_type_genes_path = os.path.join(dataset_path, "cell_type_gene_df.csv")
    adata_sc = sc.read_h5ad(os.path.join(dataset_path, "adata_sc.h5ad"))
    cell_type_gene_df = sum_cell_types(adata_sc)
    cell_type_gene_df.to_csv(cell_type_genes_path)


In [None]:
# import sys
# import os
# # Update path for imports
# sys.path.append('..')
# from decon.spatialtm.utils import read_stereo_seq
# from decon.spatialtm.annotate import sum_cell_types

# steroseq_dataset_path = os.path.join(data_dir, "stereoseq_mouse_brain_li2023_sim_norm_mm")
# steroseq_sc_path = os.path.join(steroseq_dataset_path, "Mouse_brain_cell_bin.h5ad")
# steroseq_adata_sc = read_stereo_seq(steroseq_sc_path)
# steroseq_cell_type_gene_df = sum_cell_types(steroseq_adata_sc)
# steroseq_cell_type_gene_df.to_csv(os.path.join(steroseq_dataset_path, "cell_type_gene_df.csv"))