# LOADING DATA

In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
# import squidpy as sq
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import tqdm
from anndata import AnnData
import scipy.sparse as sp
import anndata as ad
import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

sample_ids = [f"{i:03d}" for i in range(1, 19)]
sample_ids = ['001']

# List to store individual adatas
adatas = []
adata_dict = {}
spatial_dict = {}


# Loop through each sample
for sid in tqdm(sample_ids):
    sample_id = f"IMMUNEX{sid}"
    path = f"/scratch/Projects/IMMUNEX/segmentation/bin2cell/bin2cell_output/{sample_id}/adata_processed.h5ad"
    
    if os.path.exists(path):
        adata_sample = sc.read(path)
        adata_sample.var_names_make_unique()
        adata_sample.raw = adata_sample
        adata_sample.uns['sample_id'] = sample_id
        adata_sample.obs['sample_id'] = sample_id  # Add sample ID to obs for tracking
        display(adata_sample.obs.head())

        # Count total UMIs per bin
        adata_sample.obs['n_counts'] = np.array(adata_sample.X.sum(axis=1)).flatten()
        # Count number of genes per bin
        adata_sample.obs['n_genes'] = np.array((adata_sample.X > 0).sum(axis=1)).flatten()

        
        # Save .uns['spatial'] if it exists
        if 'spatial' in adata_sample.uns:
            spatial_dict[sample_id] = adata_sample.uns['spatial']
            
        adata_dict[sample_id] = adata_sample

        # adatas.append(adata_sample)
    else:
        print(f"File not found for sample {sample_id}: {path}")

    
# Concatenate all samples into one AnnData object
# adata = ad.concat(adatas, label='batch', keys=sample_ids, index_unique='-')
adata = ad.concat(adata_dict, label='sample_id', merge='unique')
adata.uns['spatial'] = spatial_dict

for sid in tqdm(adata.uns['spatial']):
    nested_key = list(adata.uns['spatial'][sid].keys())[0]  # e.g., 'Visium_NSCLC_IMMUNEX018'
    adata.uns['spatial'][sid] = adata.uns['spatial'][sid][nested_key]

adata.obs.sample(5)

  0%|                                                                                                                                                                                                                          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,in_tissue,array_row,array_col,n_counts,destripe_factor,n_counts_adjusted,labels_he,labels_he_expanded,labels_gex,labels_joint,labels_joint_source,sample_id
s_002um_02587_02503-1,1,2587,2503,16.0,0.369561,12.195529,290012,290012,42453,290012,primary,IMMUNEX001
s_002um_01674_00710-1,1,1674,710,5.0,0.368149,12.148933,0,221472,35534,221472,primary,IMMUNEX001
s_002um_02498_02808-1,1,2498,2808,6.0,0.321744,10.617562,269059,269059,0,269059,primary,IMMUNEX001
s_002um_00952_00136-1,1,952,136,9.0,0.706108,23.301573,0,0,17982,0,none,IMMUNEX001
s_002um_00564_00910-1,1,564,910,1.0,0.091521,3.02018,0,0,0,0,none,IMMUNEX001


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:19<00:00, 19.41s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18315.74it/s]


Unnamed: 0,in_tissue,array_row,array_col,n_counts,destripe_factor,n_counts_adjusted,labels_he,labels_he_expanded,labels_gex,labels_joint,labels_joint_source,sample_id,n_genes
s_002um_00303_02786-1,1,303,2786,21.576923,0.653846,21.576923,0,17873,6014,17873,primary,IMMUNEX001,9
s_002um_01259_00369-1,1,1259,369,25.610031,0.776062,25.610031,0,220570,38743,220570,primary,IMMUNEX001,26
s_002um_02580_02931-1,1,2580,2931,5.322421,0.161285,5.322421,0,310649,0,310649,primary,IMMUNEX001,4
s_002um_00598_02061-1,1,598,2061,11.974747,0.362871,11.974747,0,0,0,0,none,IMMUNEX001,6
s_002um_00531_03162-1,1,531,3162,8.550374,0.259102,8.550374,0,0,0,0,none,IMMUNEX001,4


In [2]:
import pandas as pd

# Load clinical data
clinical_df = pd.read_csv("/home/mounim/rawdata/IMMUNEX/data/VisiumHD_18_2024_NSCLC.csv")

# Preview
clinical_df['sample_id'] = clinical_df['HE_image_name'].str[:10]

adata.obs = adata.obs.merge(clinical_df, on='sample_id', how='left')

display(adata.obs.head())
display(adata.var.head())

Unnamed: 0,in_tissue,array_row,array_col,n_counts,destripe_factor,n_counts_adjusted,labels_he,labels_he_expanded,labels_gex,labels_joint,...,Average_fragment_size,Concentratio_pg_per_µL,Concentration finale (ng/µL),Final_concentration_nM,Dual _Index,Volume_available_µL,Volume_send_µL,Sample_code,Position_on_the_plaque,Sea_Dual_Index_TS_Set_A
0,1,2587,2503,12.195529,0.369561,12.195529,290012,290012,42453,290012,...,252,46.6,2.33,14.0,G3,around_24,22,IMMUNEX001,A1,"SI-TS-G3,CCAGACACGG,AGAAAGCGGT,ACCGCTTTCT"
1,1,1674,710,12.148933,0.368149,12.148933,0,221472,35534,221472,...,252,46.6,2.33,14.0,G3,around_24,22,IMMUNEX001,A1,"SI-TS-G3,CCAGACACGG,AGAAAGCGGT,ACCGCTTTCT"
2,1,2498,2808,10.617562,0.321744,10.617562,269059,269059,0,269059,...,252,46.6,2.33,14.0,G3,around_24,22,IMMUNEX001,A1,"SI-TS-G3,CCAGACACGG,AGAAAGCGGT,ACCGCTTTCT"
3,1,952,136,23.301573,0.706108,23.301573,0,0,17982,0,...,252,46.6,2.33,14.0,G3,around_24,22,IMMUNEX001,A1,"SI-TS-G3,CCAGACACGG,AGAAAGCGGT,ACCGCTTTCT"
4,1,564,910,3.02018,0.091521,3.02018,0,0,0,0,...,252,46.6,2.33,14.0,G3,around_24,22,IMMUNEX001,A1,"SI-TS-G3,CCAGACACGG,AGAAAGCGGT,ACCGCTTTCT"


Unnamed: 0,gene_ids,feature_types,genome,n_cells
OR4F5,ENSG00000186092,Gene Expression,GRCh38,18
SAMD11,ENSG00000187634,Gene Expression,GRCh38,328
NOC2L,ENSG00000188976,Gene Expression,GRCh38,1941
KLHL17,ENSG00000187961,Gene Expression,GRCh38,1126
PLEKHN1,ENSG00000187583,Gene Expression,GRCh38,984


In [3]:
# get a sample for testing
# 1. Filter to crop region (spatial coordinates)
# Define crop bounds
x_min, x_max = 100, 300
y_min, y_max = 100, 300

# Assuming 'array_row' and 'array_col' are in `adata.obs`
crop_mask = (
    (adata.obs['array_row'] >= x_min) & (adata.obs['array_row'] <= x_max) &
    (adata.obs['array_col'] >= y_min) & (adata.obs['array_col'] <= y_max)
)

adata_crop = adata[crop_mask].copy()

print(adata.shape)
print(adata_crop.shape)

(9613112, 18322)
(36966, 18322)




In [4]:
x_min, x_max = 1000, 1050
y_min, y_max = 1000, 1050

# Assuming 'array_row' and 'array_col' are in `adata.obs`
crop_mask = (
    (adata.obs['array_row'] >= x_min) & (adata.obs['array_row'] <= x_max) &
    (adata.obs['array_col'] >= y_min) & (adata.obs['array_col'] <= y_max)
)

adata_crop = adata[crop_mask].copy()
adata_crop



AnnData object with n_obs × n_vars = 1889 × 18322
    obs: 'in_tissue', 'array_row', 'array_col', 'n_counts', 'destripe_factor', 'n_counts_adjusted', 'labels_he', 'labels_he_expanded', 'labels_gex', 'labels_joint', 'labels_joint_source', 'sample_id', 'n_genes', 'Manip_Visium', 'DV200', 'HE_image_name', 'Cytasimage_path', 'TLS_status', 'area', 'slide', 'Average_fragment_size', 'Concentratio_pg_per_µL', 'Concentration finale  (ng/µL)', 'Final_concentration_nM', 'Dual _Index', 'Volume_available_µL', 'Volume_send_µL', 'Sample_code', 'Position_on_the_plaque', 'Sea_Dual_Index_TS_Set_A'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells'
    uns: 'spatial'
    obsm: 'spatial', 'spatial_cropped_150_buffer'

In [5]:
# get bins and segmentation status 
adata.obs[['labels_he', 'labels_he_expanded', 'labels_gex']] = adata.obs[['labels_he', 'labels_he_expanded', 'labels_gex']].fillna(0)

# Add isSegmented column
adata.obs['isSegmented'] = (
    (adata.obs['labels_he'] != 0) |
    (adata.obs['labels_he_expanded'] != 0) |
    (adata.obs['labels_gex'] != 0)
)

# Add isNuclei column
adata.obs['isNuclei'] = adata.obs['labels_he'] != 0

# Add isGExCell column
adata.obs['isGExCell'] = (adata.obs['labels_gex'] != 0) & (adata.obs['labels_he_expanded'] == 0)

bins_annotation = adata.obs[['array_row','array_col','isSegmented', 'isNuclei', 'isGExCell']]
bins_annotation.head()

Unnamed: 0,array_row,array_col,isSegmented,isNuclei,isGExCell
0,2587,2503,True,True,False
1,1674,710,True,False,False
2,2498,2808,True,True,False
3,952,136,True,False,True
4,564,910,False,False,False


In [6]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler

target = 'labels_he'  # or 'isSegmented' or 'isGExCell'

sc.pp.highly_variable_genes(adata, n_top_genes=1000)
adata_hvg = adata[:, adata.var['highly_variable']]
X = adata_hvg.X  # sparse matrix (safe to keep sparse)
y = adata.obs[target].astype(int)

# Keep only classes with ≥2 samples
value_counts = y.value_counts()
valid_classes = value_counts[value_counts >= 2].index
mask = y.isin(valid_classes)

# Convert mask to indices for sparse matrix slicing
valid_indices = np.where(mask)[0]

# Subset both X and y safely
X = X[valid_indices]
y = y.iloc[valid_indices]


# Now safe to stratify
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)





In [None]:
from lightgbm import LGBMClassifier

clf = LGBMClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)


In [None]:

# # RandomForest can accept sparse matrix directly in recent versions
# clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=222, verbose =1)
# clf.fit(X_train, y_train)


In [None]:
from lightgbm import LGBMClassifier
from sklearn.metrics import accuracy_score

# Predict on test set
y_pred = clf.predict(X_test)

# Compute accuracy
acc = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {acc:.4f}")


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_curve, auc

# Predict probabilities
y_pred_proba = clf.predict_proba(X_test)[:, 1]  # Probability for class 1
y_pred = clf.predict(X_test)

# Accuracy
acc = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {acc:.4f}")

# ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plot
plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.2f})", linewidth=2)
plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()
