## Fine-tune Borzoi to predict RNA-seq coverage

In [65]:
import numpy as np
import pandas as pd
import torch
import os
import tiledb
import importlib
import multiprocessing
from multiprocessing import Pool
from tqdm import tqdm
from typing import Callable, List, Optional, Sequence, Tuple, Union
from grelu.utils import get_aggfunc, get_transform_func
from grelu.transforms.label_transforms import LabelTransform
from grelu.data.augment import Augmenter, _split_overall_idx
import time

from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from grelu.data.utils import _create_task_data, get_chromosomes
from grelu.sequence.format import convert_input_type, indices_to_one_hot
from grelu.sequence.utils import get_unique_length, get_lengths, resize
from grelu.io.bed import read_bed
from grelu.resources import load_model
from grelu.utils import make_list
from genomicarrays import buildutils_tiledb_array as uta
from grelu.data.tdb_utils import _write_cov, _write_seqs, create_tiledb_array
from grelu.data.preprocess import filter_chrom_ends
intervals = read_bed('sequences_human.bed.gz')
val_intervals = filter_chrom_ends(intervals[intervals[3] == 'fold4'], genome='hg38', pad=200000)
len(val_intervals)
#test_intervals = intervals[intervals[3] == 'fold3']
#train_intervals = intervals[~intervals[3].isin(['fold3', 'fold4'])]
#!wget https://www.encodeproject.org/files/ENCFF732DIV/@@download/ENCFF732DIV.bigWig
#!wget https://www.encodeproject.org/files/ENCFF411BAA/@@download/ENCFF411BAA.bigWig
#!wget https://raw.github.com/calico/borzoi/refs/heads/main/data/sequences_human.bed.gz

Keeping 6886 intervals


6886

In [2]:
new_metadata= pd.DataFrame(
    {'experiment_acc':['ENCSR967DSJ', 'ENCSR373XVN']},
    index=['ENCFF411BAA', 'ENCFF732DIV'],
)
bw_files = ['ENCFF411BAA.bigWig', 'ENCFF732DIV.bigWig']

In [3]:
tdb_path='tutorial_7/tdb/val'
if not os.path.exists(tdb_path):
    os.makedirs(tdb_path)

In [31]:
def _write_params(str_params, int_params, uri):

    attributes = [tiledb.Attr(dtype=np.int32)] * len(int_params.keys()) + [tiledb.Attr(dtype="U256")] * len(str_params.keys())

    # Create the attribute list, with dtype specified for each attribute
    attributes = [tiledb.Attr(name=k, dtype=np.int32) for k in int_params.keys()] + [tiledb.Attr(name=k, dtype="U256") for k in str_params.keys()]

    domain = tiledb.Domain(
        tiledb.Dim(name="unit_id", domain=(0, 0), tile=1, dtype=np.int64)
    )
    schema = tiledb.ArraySchema(domain=domain, attrs=attributes, sparse=True)
    tiledb.SparseArray.create(uri, schema)

    arr_dict = dict()
    for k, v in int_params.items():
        arr_dict[k] = np.array([v], dtype=np.int32)
    for k, v in str_params.items():
        arr_dict[k] = np.array([v], dtype="U256")

    with tiledb.open(uri, mode="w") as arr:
        arr[np.array([0], dtype=np.int64)] = arr_dict


In [36]:
def bigwigs_to_tiledb(tdb_path, intervals, seq_len, label_len, max_seq_shift, max_pair_shift, bin_size, aggfunc, bw_files, genome, tasks, num_threads, chunk_size):

    if not os.path.exists(tdb_path):
        os.mkdir(tdb_path)

    task_uri = f"{tdb_path}/tasks"
    intervals_uri = f"{tdb_path}/intervals"
    seq_uri = f"{tdb_path}/sequences"
    label_uri = f"{tdb_path}/labels"
    params_uri = f"{tdb_path}/params"

    assert max_pair_shift % bin_size == 0

    int_params = {
        'n_seqs': len(intervals),
        'seq_len': seq_len,
        'n_tasks': len(bw_files),
        'label_len':label_len, 
        'label_bins': label_len//bin_size,
        'max_seq_shift':max_seq_shift, 
        'max_pair_shift':max_pair_shift,
        'bin_size':bin_size,
        'padded_seq_len': seq_len + (2 * max_seq_shift) + (2 * max_pair_shift),
        'padded_label_len': label_len + (2 * max_pair_shift),
        'padded_label_bins': (label_len + (2 * max_pair_shift))//bin_size,
    }

    str_params = {
        'aggfunc':aggfunc, 
        'genome':genome,
    }
    _write_params(str_params, int_params, params_uri)

    # Create task dataframe
    bw_files = make_list(bw_files)
    if tasks is None:
        tasks = [os.path.splitext(os.path.basename(f))[0] for f in bw_files]
    if isinstance(tasks, List):
        tasks = _create_task_data(tasks)
    tasks["task_idx"] = range(len(tasks))
    tasks["bigwig_path"] = [os.path.abspath(f) for f in bw_files]

    # Write task dataframe
    tiledb.from_pandas(task_uri, tasks)
    uta.optimize_tiledb_array(task_uri)

    # Write intervals dataframe
    intervals = resize(intervals, seq_len=int_params['padded_seq_len'], end='both')
    tiledb.from_pandas(intervals_uri, intervals)
    uta.optimize_tiledb_array(intervals_uri)

    # Create empty arrays
    create_tiledb_array(seq_uri, x_dim_length=int_params['n_seqs'], y_dim_length=int_params['padded_seq_len'], 
                        x_dim_tile=1, y_dim_tile=64000, matrix_dim_dtype = np.int8)
    create_tiledb_array(label_uri, x_dim_length=int_params['n_seqs'], 
            y_dim_length=int_params['n_tasks'], z_dim_length=int_params['padded_label_bins'], 
            x_dim_tile=1, y_dim_tile=int_params['n_tasks'], z_dim_tile=int_params['padded_label_bins'], matrix_dim_dtype=np.float32)

    # Set up multiprocessing
    if num_threads > 1:
        try:
            multiprocessing.set_start_method("spawn", force=True)
        except RuntimeError:
            pass

    print("Writing genome sequence")
    _write_seqs(intervals, chunk_size, genome, seq_uri, num_threads)

    print("Writing coverage from BigWig files")
    intervals = resize(intervals, int_params['padded_label_len'])
    _write_cov(intervals, chunk_size, tasks, label_uri, bin_size, aggfunc, num_threads)


In [37]:
bigwigs_to_tiledb(tdb_path, val_intervals, seq_len=524288, label_len=196608, 
                  max_seq_shift=3, max_pair_shift=640, bin_size=32, aggfunc='sum', bw_files=bw_files, 
                  genome='hg38', tasks=new_metadata, num_threads=16, chunk_size=1000)

[Attr(name='n_seqs', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='seq_len', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='n_tasks', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='label_len', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='label_bins', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='max_seq_shift', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='max_pair_shift', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='bin_size', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='padded_seq_len', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='padded_label_len', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='padded_label_bins', dtype='int32', var=False, nullable=False, enum_label=None), Attr(name='aggfunc', dtype='<U0', var=True, nullable=False, enum

7it [00:56,  8.11s/it]


Optimizing tutorial_7/tdb/val/labels
Fragments before consolidation: 14
Fragments after consolidation: 1


In [67]:
val_ds = TileDBSeqDataset(
    tdb_path = tdb_path,
    rc = False,
    #label_transform_func=np.sqrt,
)

The following parameters will be loaded from the TileDB
{'n_seqs': 6886, 'seq_len': 524288, 'n_tasks': 2, 'label_len': 196608, 'label_bins': 6144, 'max_seq_shift': 3, 'max_pair_shift': 640, 'bin_size': 32, 'padded_seq_len': 525574, 'padded_label_len': 197888, 'padded_label_bins': 6184, 'aggfunc': 'sum', 'genome': 'hg38', 'unit_id': 0}


  self.n_augmented = len(self.augmenter)


In [72]:
import grelu.data.tdb_utils

In [71]:

# Test single sample timing
start = time.time()
_ = val_ds[0]  # Direct access to dataset
print("Single sample fetch time:", time.time() - start)

# Test batch timing without DataLoader (manual fetch)
start = time.time()
batch = [val_ds[i] for i in range(12)]  # Fetch 12 items manually
print("Manual batch fetch time:", time.time() - start)

# Test DataLoader fetch timing
val_dl = DataLoader(val_ds, batch_size=12, shuffle=False, num_workers=2,
        worker_init_fn=grelu.data.tdb_utils.worker_init_fn, persistent_workers=True
                   )
start = time.time()
_ = next(iter(val_dl))  # Fetch batch from DataLoader
print("DataLoader batch fetch time:", time.time() - start)

Single sample fetch time: 0.05424189567565918
Manual batch fetch time: 0.49542975425720215


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'TileDBSeqDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 

In [None]:
#multiprocessing.set_start_method('spawn', force=True)
tiledb.Config({"sm.num_reader_threads": 1})

In [None]:
importlib.reload(grelu.data.tdb_utils)


Single sample fetch time: 0.22231197357177734
Manual batch fetch time: 4.7304770946502686
DataLoader batch fetch time: 15.908597230911255


In [32]:
f = tiledb.open('tutorial_7/tdb/chr1/')

In [33]:
%%time
f[:, :5]

CPU times: user 2.16 s, sys: 10.3 s, total: 12.5 s
Wall time: 256 ms


OrderedDict([('data',
              array([[4, 4, 4, 4, 4],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]], dtype=int8))])

In [24]:
%%time
f.multi_index[:, slice(1, 5)]

CPU times: user 6.47 s, sys: 15.3 s, total: 21.8 s
Wall time: 400 ms


OrderedDict([('data',
              array([[4, 4, 4, 4, 4],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]], dtype=int8))])

In [29]:
len(val_dl)/12

287.9166666666667

In [20]:
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)

In [None]:
model = grelu.resources.load_model(
    project="borzoi",
    model_name="human_fold0",
)

In [None]:
train_ds = grelu.data.dataset.TileDBSeqDataset(
    tdb_path = tdb_path,
    intervals = train_intervals,
    max_seq_shift = 5,
    rc = True,
    seq_len=524288,
    end='both',
    label_len=196608,
    bin_size=32,
    label_transform_func=lambda x: x**.75,
    label_aggfunc='sum',
)
test_ds = grelu.data.dataset.TileDBSeqDataset(
    tdb_path = tdb_path,
    intervals = test_intervals,
    max_seq_shift = 0,
    rc = False,
    seq_len=524288,
    end='both',
    label_len=196608,
    bin_size=32,
    label_transform_func=lambda x:x**0.75,
    label_aggfunc='sum',
)