In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
import scipy.io
import matplotlib.pyplot as plt
import os
import sys
from STINR import model
import warnings
warnings.filterwarnings("ignore")

import torch
import random

shuffle_ratios_to_test = [0.0, 0.1, 0.25, 0.5, 1.0]  #  0.0 infer to the baseline
seed = 1

for shuffle_ratio in shuffle_ratios_to_test:
    print(f"\n{'='*60}")
    print(f"Running experiment with {int(shuffle_ratio*100)}% shuffling")
    print(f"{'='*60}\n")
    
    # Reset random seed
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    slice_idx = [151673, 151674, 151675, 151676]

    adata_st = ad.read_h5ad('adata_st_DLPFC.h5ad')
    adata_st_list_raw0 = ad.read_h5ad('adata_st_list_raw0.h5ad')
    adata_st_list_raw1 = ad.read_h5ad('adata_st_list_raw1.h5ad')
    adata_st_list_raw2 = ad.read_h5ad('adata_st_list_raw2.h5ad')
    adata_st_list_raw3 = ad.read_h5ad('adata_st_list_raw3.h5ad')
    adata_st_list_raw = [adata_st_list_raw0, adata_st_list_raw1, 
                         adata_st_list_raw2, adata_st_list_raw3]

    # ====== 2. Shuffle coordinates (apply to both adata_st and adata_st_list_raw) ======
    if shuffle_ratio > 0:
        print(f"Shuffling {int(shuffle_ratio*100)}% of coordinates...")
        np.random.seed(seed)
        coords = adata_st.obsm['3D_coor'].copy()
        
        # Shuffle spots within each slice (slice indicated by the 3rd column)
        for slice_id in np.unique(coords[:, 2]):
            mask = coords[:, 2] == slice_id
            indices = np.where(mask)[0]
            
            # Number of spots to shuffle
            n_total = len(indices)
            n_shuffle = int(n_total * shuffle_ratio)
            
            print(f"  Slice {int(slice_id)}: shuffling {n_shuffle}/{n_total} spots")
            
            # Randomly select which spots to shuffle
            shuffle_indices = np.random.choice(indices, size=n_shuffle, replace=False)
            
            # Shuffle the selected spots' coordinates
            shuffled = np.random.permutation(shuffle_indices)
            coords[shuffle_indices, :2] = coords[shuffled, :2]
        
        # Apply to adata_st
        adata_st.obsm['3D_coor'] = coords
        
        # ====== Important: also apply shuffling to adata_st_list_raw ======
        n_spots = 0
        for i, adata_slice in enumerate(adata_st_list_raw):
            n_spots_slice = adata_slice.shape[0]
            # Copy the corresponding shuffled coordinates
            adata_slice.obsm['3D_coor'] = coords[n_spots:n_spots+n_spots_slice].copy()
            
            # Update 'spatial' if exists
            if 'spatial' in adata_slice.obsm:
                adata_slice.obsm['spatial'] = coords[n_spots:n_spots+n_spots_slice, :2].copy()
            
            # Update array_row / array_col if needed
            if 'array_row' in adata_slice.obs and 'array_col' in adata_slice.obs:
                adata_slice.obs['array_row'] = coords[n_spots:n_spots+n_spots_slice, 1]
                adata_slice.obs['array_col'] = coords[n_spots:n_spots+n_spots_slice, 0]
            
            n_spots += n_spots_slice
        
        print(f"Coordinate shuffling completed!\n")
    else:
        print("No coordinate shuffling (control group)\n")

    # ====== 3. Train model ======
    adata_basis = ad.read_h5ad('adata_basis_DLPFC.h5ad')
    st_model = model.Model(adata_st, adata_basis)
    st_model.train()

    save_path = f"results_shuffle_{int(shuffle_ratio*100)}pct"
    os.makedirs(save_path, exist_ok=True)

    result = st_model.eval(adata_st_list_raw, save=False, output_path=save_path)

    from sklearn.mixture import GaussianMixture

    np.random.seed(1234)
    gm = GaussianMixture(n_components=7, covariance_type='tied', 
                        reg_covar = 10e-4, init_params='kmeans')
    y = gm.fit_predict(st_model.adata_st.obsm['latent'], y=None)
    st_model.adata_st.obs["GM"] = y
    st_model.adata_st.obs["GM"].to_csv(os.path.join(save_path, "clustering_result.csv"))

    order = [0,1,2,3,4,5,6]

    st_model.adata_st.obs["Cluster"] = [order[label] for label in st_model.adata_st.obs["GM"].values]

    for i in range(len(result)):
        result[i].obs["GM"] = st_model.adata_st.obs.loc[result[i].obs_names, "GM"].values
        result[i].obs["Cluster"] = st_model.adata_st.obs.loc[result[i].obs_names, "Cluster"].values

    for i in range(len(result)):
        print("Slice %d cell-type deconvolution result:" % slice_idx[i])
        print(list(adata_basis.obs.index))
        sc.pl.spatial(result[i], img_key="lowres", 
                    color=list(adata_basis.obs.index), size=1.)

    print(f"\nResults saved to: {save_path}")
