# scBasset on PBMC dataset described at scBasset GitHub page

In [None]:
import os
import numpy as np
import h5py
import gc
import psutil
import anndata
import pickle
from scipy import sparse
import tensorflow as tf
from datetime import datetime

# see ig GPU is available
tf.config.list_physical_devices('GPU')

In [None]:
# a generator to read examples from h5 file
# create a tf dataset
class generator:
    def __init__(self, file, m):
        self.file = file # h5 file for sequence
        self.m = m # csr matrix, rows as seqs, cols are cells
        self.n_cells = m.shape[1]
        self.ones = np.ones(1344)
        self.rows = np.arange(1344)

    def __call__(self):
        with h5py.File(self.file, 'r') as hf:
            X = hf['X']
            for i in range(X.shape[0]):
                x = X[i]
                x_tf = sparse.coo_matrix((self.ones, (self.rows, x)), 
                                               shape=(1344, 4), 
                                               dtype='int8').toarray()
                y = self.m.indices[self.m.indptr[i]:self.m.indptr[i+1]]
                y_tf = np.zeros(self.n_cells, dtype='int8')
                y_tf[y] = 1
                yield x_tf, y_tf

def print_memory():
    process = psutil.Process(os.getpid())
    print('cpu memory used: %.1fGB.'%(process.memory_info().rss/1e9))

In [None]:
input_dir = '/cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/pbmc-granulocyte-sorted-3k_10x-Multiome/processed'
split_file = os.path.join(input_dir, 'splits.h5')
train_file = os.path.join(input_dir, 'train_seqs.h5')
val_file = os.path.join(input_dir, 'val_seqs.h5')
test_file = os.path.join(input_dir, 'test_seqs.h5')
ad_file = os.path.join(input_dir, 'atac_ad.h5ad')
output_dir = '/cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/pbmc-granulocyte-sorted-3k_10x-Multiome/model'

# Load data

In [None]:
# Grab the sparse matrix from the anndata object
adata = anndata.read_h5ad(ad_file)
n_cells = adata.shape[0]
m = adata.X.tocoo().transpose().tocsr()

In [None]:
print_memory()     # memory usage
del adata
gc.collect()

In [None]:
# Get the splits
with h5py.File(split_file, 'r') as hf:
    train_ids = hf['train_ids'][:]
    val_ids = hf['val_ids'][:]

In [None]:
# Split into train and val
m_train = m[train_ids,:]
m_val = m[val_ids,:]
del m
gc.collect()
m_train.shape, m_val.shape

In [None]:
# Create the tf datasets
train_ds = tf.data.Dataset.from_generator(
     generator(train_file, m_train), 
     output_signature=(
          tf.TensorSpec(shape=(1344,4), dtype=tf.int8),
          tf.TensorSpec(shape=(n_cells), dtype=tf.int8),
     )
).shuffle(2000, reshuffle_each_iteration=True).batch(128).prefetch(tf.data.AUTOTUNE)

In [None]:
val_ds = tf.data.Dataset.from_generator(
     generator(val_file, m_val), 
     output_signature=(
          tf.TensorSpec(shape=(1344,4), dtype=tf.int8),
          tf.TensorSpec(shape=(n_cells), dtype=tf.int8),
     )
).batch(128).prefetch(tf.data.AUTOTUNE)

In [None]:
# Get an example batch from training dataset
for x, y in train_ds.take(1):
    print(x.shape, y.shape)

# Build an scBasset model with their code

In [None]:
from scbasset.utils import make_model

In [None]:
model = make_model(32, n_cells)

In [None]:
# Set up loss, optimizer, and compile the mdodel
loss_fn = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01,beta_1=0.95,beta_2=0.9995)
model.compile(
    loss=loss_fn, 
    optimizer=optimizer,
    metrics=[tf.keras.metrics.AUC(curve='ROC', multi_label=True),
    tf.keras.metrics.AUC(curve='PR', multi_label=True)]
)


In [None]:
# earlystopping, track train AUC
filepath = os.path.join(output_dir, 'best_model.h5')
    
# tensorboard
logs = os.path.join(output_dir, "logs" + datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(output_dir)
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath, 
    save_best_only=True, 
    save_weights_only=True, 
    monitor='auc', 
    mode='max'
)
earlystopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='auc', 
    min_delta=1e-6, 
    mode='max', 
    patience=50, 
    verbose=1
)
callbacks = [tensorboard_callback, checkpoint_callback, earlystopping_callback]

In [39]:
history = model.fit(
        train_ds,
        epochs=1000,
        callbacks=callbacks,
        validation_data=val_ds
)

Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000


2023-09-21 10:34:28.844961: W tensorflow/core/framework/op_kernel.cc:1733] UNKNOWN: KeyError: "Unable to open object (object 'X' doesn't exist)"
Traceback (most recent call last):

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py", line 271, in __call__
    ret = func(*args)

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1004, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "/tmp/ipykernel_4095444/344002438.py", line 13, in __call__
    X = hf['X']

  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper

  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper

  File "/cellar/

UnknownError: Graph execution error:

2 root error(s) found.
  (0) UNKNOWN:  KeyError: "Unable to open object (object 'X' doesn't exist)"
Traceback (most recent call last):

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py", line 271, in __call__
    ret = func(*args)

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1004, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "/tmp/ipykernel_4095444/344002438.py", line 13, in __call__
    X = hf['X']

  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper

  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/h5py/_hl/group.py", line 305, in __getitem__
    oid = h5o.open(self.id, self._e(name), lapl=self._lapl)

  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper

  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper

  File "h5py/h5o.pyx", line 190, in h5py.h5o.open

KeyError: "Unable to open object (object 'X' doesn't exist)"


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]]
	 [[assert_greater_equal_1/Assert/AssertGuard/else/_59/assert_greater_equal_1/Assert/AssertGuard/Assert/data_1/_148]]
  (1) UNKNOWN:  KeyError: "Unable to open object (object 'X' doesn't exist)"
Traceback (most recent call last):

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py", line 271, in __call__
    ret = func(*args)

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1004, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "/tmp/ipykernel_4095444/344002438.py", line 13, in __call__
    X = hf['X']

  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper

  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper

  File "/cellar/users/aklie/opt/miniconda3/envs/scbasset/lib/python3.7/site-packages/h5py/_hl/group.py", line 305, in __getitem__
    oid = h5o.open(self.id, self._e(name), lapl=self._lapl)

  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper

  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper

  File "h5py/h5o.pyx", line 190, in h5py.h5o.open

KeyError: "Unable to open object (object 'X' doesn't exist)"


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_7154]

In [None]:
pickle.dump(history.history, open('%s/history.pickle'%output_dir, 'wb'))

# Train with script 

%%bash
source activate scbasset
python /cellar/users/aklie/opt/ml4gland/scBasset/bin/scbasset_train.py \
    --input_folder /cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/pbmc-granulocyte-sorted-3k_10x-Multiome/processed \
    --out_path /cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/pbmc-granulocyte-sorted-3k_10x-Multiome/model/21Sep23/scbasset/script