In [None]:
import os
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np


COLORS = ['#1f77b4','#ff7f0e','#279e68','#d62728','#aa40fc','#8c564b','#e377c2','#b5bd61','#17becf',
'#aec7e8','#ffbb78','#98df8a','#ff9896','#c5b0d5','#c49c94','#f7b6d2','#c7c7c7','#dbdb8d','#9edae5']

def build_color_map_for_cts(all_cell_types):
    color_map = {}
    for i, ct in enumerate(all_cell_types):
        color_map[ct] = COLORS[i % len(COLORS)]
    return color_map

def get_slice_number(file_name, method):
    if method in ["SPIRAL", "GPSA"]:
        return file_name.split("_")[2]
    elif method == "original":
        return file_name.split("_")[0]
    else:  
        return file_name.split("_")[3]

def get_save_path(dataset, slice_number, method):
    if method == "PASTE":
        return f"{dataset}_{slice_number}_{method}_pairwise_GT.pdf"
    elif method == "PASTE_CENTER":
        return f"{dataset}_{slice_number}_{method}_center_GT.pdf"
    else:  # PASTE2, SPIRAL, GPSA
        return f"{dataset}_{slice_number}_{method}_GT.pdf"

def get_all_cell_types(folder_paths):
    all_cell_types = set()
    for method, folder_path in folder_paths.items():
        if not os.path.exists(folder_path):
            print(f"Filefolder doesn't exist{folder_path}")
            continue
        for file_name in os.listdir(folder_path):
            if file_name.endswith(".h5ad"):
                file_path = os.path.join(folder_path, file_name)
                adata = sc.read_h5ad(file_path)
                cell_types = set(adata.obs['Ground_Truth'].unique())
                all_cell_types.update(cell_types)
    return sorted(list(all_cell_types))

def get_global_bounds(folder_paths):
    global_min_x = float('inf')
    global_min_y = float('inf')
    global_max_x = float('-inf')
    global_max_y = float('-inf')
    
    for method, folder_path in folder_paths.items():
        if not os.path.exists(folder_path):
            continue
        for file_name in os.listdir(folder_path):
            if file_name.endswith(".h5ad"):
                file_path = os.path.join(folder_path, file_name)
                adata = sc.read_h5ad(file_path)
                coords = adata.obsm["spatial"]
                
                min_x, min_y = np.min(coords, axis=0)
                max_x, max_y = np.max(coords, axis=0)
                
                global_min_x = min(global_min_x, min_x)
                global_min_y = min(global_min_y, min_y)
                global_max_x = max(global_max_x, max_x)
                global_max_y = max(global_max_y, max_y)
    
    return {
        'min_x': global_min_x,
        'min_y': global_min_y,
        'max_x': global_max_x,
        'max_y': global_max_y
    }

def plot_spatial_data_with_bounds(dataset, folder_paths, global_bounds):
    all_cell_types = get_all_cell_types(folder_paths)
    color_map = build_color_map_for_cts(all_cell_types)
    print(f"All cell types: {all_cell_types}")
    for method, folder_path in folder_paths.items():
        if not os.path.exists(folder_path):
            print(f"Filefolder doesn't exist{folder_path}")
            continue
            
        print(f"\nMethod: {method}")
        for file_name in sorted(os.listdir(folder_path)):
            if file_name.endswith(".h5ad"):
                file_path = os.path.join(folder_path, file_name)
                print(f"Processing:{file_path}")
                
                adata = sc.read_h5ad(file_path)
                slice_number = get_slice_number(file_name, method)
                coords = adata.obsm["spatial"]
                cell_types = adata.obs["Ground_Truth"].astype(str)
                point_colors = [color_map[ct] for ct in cell_types]
                x_range = global_bounds['max_x'] - global_bounds['min_x']
                y_range = global_bounds['max_y'] - global_bounds['min_y']
                aspect_ratio = x_range / y_range
                fig_width = 10
                fig_height = fig_width / aspect_ratio
                fig, ax = plt.subplots(figsize=(fig_width, fig_height))
                save_path = get_save_path(dataset, slice_number, method)
                ax.set_xlim(global_bounds['min_x'], global_bounds['max_x'])
                ax.set_ylim(global_bounds['min_y'], global_bounds['max_y'])
                
                ax.set_aspect('equal')
                ax.invert_yaxis()
                ax.plot([global_bounds['min_x'], global_bounds['max_x']], 
                        [global_bounds['min_y'], global_bounds['min_y']], 
                        '-', color=(0,0,0,0))  
                ax.plot([global_bounds['min_x'], global_bounds['max_x']], 
                        [global_bounds['max_y'], global_bounds['max_y']], 
                        '-', color=(0,0,0,0))  
                ax.plot([global_bounds['min_x'], global_bounds['min_x']], 
                        [global_bounds['min_y'], global_bounds['max_y']], 
                        '-', color=(0,0,0,0))  
                ax.plot([global_bounds['max_x'], global_bounds['max_x']], 
                        [global_bounds['min_y'], global_bounds['max_y']], 
                        '-', color=(0,0,0,0))  
                
                scatter = ax.scatter(
                    coords[:, 0],
                    coords[:, 1],
                    c=point_colors,
                    s=10,   
                    marker='o',
                    edgecolors='none',
                    linewidths=0,
                    alpha=1
                )
                
                ax.set_xticks([])
                ax.set_yticks([])
                ax.axis('off')
                
                handles = [
                    plt.Line2D(
                        [0], [0],
                        marker='o',
                        color='w',
                        markerfacecolor=color_map[ct], 
                        markersize=10,
                        linestyle='',
                        markeredgecolor='none'
                    )
                    for ct in all_cell_types
                ]
                ax.legend(handles, all_cell_types, loc='center left', bbox_to_anchor=(1, 0.5))
                fig.savefig(save_path, dpi=300, bbox_inches='tight')
                plt.close(fig)
                
                print(f"Saved:{save_path}")




def plot_spatial(methods,dataset):    
    global_bounds = get_global_bounds(methods)
    plot_spatial_data_with_bounds(dataset, methods, global_bounds)

dataset = "D64"
methods = {
    "PASTE":       f"../../../result/registration/D64/paste/donor4/paste_pairwise_aligned_slices",
    "PASTE_CENTER":f"../../../result/registration/D64/paste/donor4/paste_center_aligned_slices",
    "PASTE2":      f"../../../result/registration/D64/PASTE2/donor4",
    "SPIRAL":      f"../../../result/registration/D64/SPIRAL/donor4/spiral_aligned_slices",
    "GPSA":        f"../../../result/registration/D64/GPSA/donor4/GPSA_aligned_slices",
    "original":    f"../../../data/dataset_final/D64/processed_new/donor4"
}

plot_spatial(methods,dataset)