In [None]:
import scanpy as sc
import numpy as np
import gc

# Read data with backed mode to avoid loading everything into memory
orig_path = "/scratch/saigum/PerturbationPredictionUsingGeneNetworks/vcc_data/arc/perturb_processed.h5ad"
adata = sc.read_h5ad(orig_path)

# Get condition information without loading full data
conditions = adata.obs['condition'].values
ctrl_mask = conditions == 'ctrl'
pert_mask = ~ctrl_mask

# Get unique perturbations
pert_list = np.unique(conditions[pert_mask]).tolist()
print(f"Number of perturbations: {len(pert_list)}")

# Calculate split indices
n_ctrl = np.sum(ctrl_mask)
ctrl_split_idx = int(n_ctrl * 0.5)

n_pert = len(pert_list)
pert_split_idx = int(n_pert * 0.6)

train_list = pert_list[:pert_split_idx]
val_list = pert_list[pert_split_idx:]
print(f"Train perturbations: {len(train_list)}, Val perturbations: {len(val_list)}")

# Create masks for train/val split
ctrl_indices = np.where(ctrl_mask)[0]
train_ctrl_mask = np.zeros(len(conditions), dtype=bool)
val_ctrl_mask = np.zeros(len(conditions), dtype=bool)

train_ctrl_mask[ctrl_indices[:ctrl_split_idx]] = True
val_ctrl_mask[ctrl_indices[ctrl_split_idx:]] = True

# Create perturbation masks
train_pert_mask = np.isin(conditions, train_list)
val_pert_mask = np.isin(conditions, val_list)

# Final train and validation masks
train_mask = train_ctrl_mask | train_pert_mask
val_mask = val_ctrl_mask | val_pert_mask

# Process and save train data


In [None]:
print("Processing training data...")
adata_train = adata[train_mask]
print(f"Shape of adata train: {adata_train.shape}")


In [None]:

# Save train data and immediately delete from memory
adata_train.write_h5ad("/scratch/saigum/PerturbationPredictionUsingGeneNetworks/vcc_train/arc/perturb_processed.h5ad")
del adata_train
gc.collect()

In [None]:

# Process and save validation data
print("Processing validation data...")
adata_val = adata[val_mask]
print(f"Shape of adata val: {adata_val.shape}")
adata_val.write_h5ad("/scratch/saigum/PerturbationPredictionUsingGeneNetworks/vcc_val/arc/perturb_processed.h5ad")
del adata_val
del adata
gc.collect()

print("Train and validation datasets created successfully.")

In [None]:
(orig_path)
ctrl_adata = adata[adata.obs['condition'] == 'ctrl']
adata = adata[adata.obs['condition'] != 'ctrl']
pert_list =  adata.obs['condition'].unique().tolist()
## splitting ctrl 50-50 train-val
ctrl_train = ctrl_adata[:int(len(ctrl_adata)*0.5)]
ctrl_val = ctrl_adata[int(len(ctrl_adata)*0.5):]
print(len(pert_list))
#splitting 60 -40 train -val
train_list = pert_list[:int(len(pert_list)*0.6)]
val_list = pert_list[int(len(pert_list)*0.6):]
print(len(train_list), len(val_list))
adata_train = adata[adata.obs['condition'].isin(train_list)]
adata_val = adata[adata.obs['condition'].isin(val_list)]
## adding ctrl to train and val
adata_train = adata_train.concat(ctrl_train,)
adata_val = adata_val.concat(ctrl_val)

print(f"Shape of adata train: {adata_train.shape}")
print(f"Shape of adata val: {adata_val.shape}")
adata_train.write_h5ad("/scratch/saigum/PerturbationPredictionUsingGeneNetworks/vcc_train/arc/perturb_processed.h5ad")
adata_val.write_h5ad("/scratch/saigum/PerturbationPredictionUsingGeneNetworks/vcc_val/arc/perturb_processed.h5ad")
print("Train and validation datasets created successfully.")

In [None]:
from gears import PertData
valpert_data = PertData('./vcc_data') # specific saved folder
valpert_data.load(data_path="/scratch/saigum/PerturbationPredictionUsingGeneNetworks/vcc_data/arc_val/")
