In [None]:
import time
import sys
import os
import glob
import math
import threading
import concurrent.futures as cf

import numpy as np
import pandas as pd
import tensorflow as tf
from keras import Input, Model, layers, metrics, losses, callbacks, optimizers, models, utils
from keras import backend as K
import gc
import keras_tuner as kt
from pyfaidx import Fasta

K.clear_session()
gc.collect()

datasets_path = "../../Datasets/"
models_path = "../../Models/"

In [None]:
# ------------------------------------
# 1) Loading and shifting GTF annotations
# ------------------------------------

def load_gtf_annotations(gtf_file):
    """
    Loads GTF into a pandas DataFrame and converts cstart and cend to zero-based indexing.
    """
    gtf_data = pd.read_csv(
        gtf_file, sep='\t', comment='#', header=None,
        names=['seqname', 'source', 'feature', 'cstart', 'cend', 
               'score', 'strand', 'frame', 'attribute']
    )
    # Convert to zero-based indexing for cstart
    gtf_data['cstart'] = gtf_data['cstart'] - 1
    return gtf_data

# ------------------------------------
# 2) Generating uniform chunks
# ------------------------------------

def compute_chunk_indices(fasta_file, chunk_size):
    """
    Creates a list of (record_id, cstart, cend) for each chunk in the FASTA.
    """
    fa = Fasta(fasta_file)  # for indexed random access
    chunk_indices = []
    for record_id in fa.keys():
        seq_len = len(fa[record_id])
        for cstart in range(0, seq_len, chunk_size):
            cend = min(cstart + chunk_size, seq_len)
            chunk_indices.append((record_id, cstart, cend))
    return chunk_indices

# ------------------------------------
# 3) One-hot encoding + Strand Feature
# ------------------------------------

n_base_encoder = {
    'A': [1, 0, 0, 0],
    'C': [0, 1, 0, 0],
    'G': [0, 0, 1, 0],
    'T': [0, 0, 0, 1],
    'N': [0, 0, 0, 0]  
}

def one_hot_encode_reference(sequence):
    """
    Returns a list of 4-element lists for each base (A, C, G, T).
    """
    return [n_base_encoder.get(nuc, [0, 0, 0, 0]) for nuc in sequence]

def label_sequence_local(sequence_length, annotations):
    """
    Assigns categorical numeric labels:
      0 = non-coding
      1 = intron cstart
      2 = intron cend
      3 = exon cstart
      4 = exon cend

    annotations here should already be filtered to the chunk + strand,
    with 'cstart' and 'cend' in [0..sequence_length).
    """
    labels = [0] * sequence_length  # All non-coding by default
    for _, row in annotations.iterrows():
        cstart, cend, feature = int(row['cstart']), int(row['cend']), row['feature']

        # Clip to chunk boundaries (paranoia check; might already be clipped)
        cstart = max(0, cstart)
        cend   = min(sequence_length, cend)

        if feature == 'exon':
            if cstart < sequence_length:
                labels[cstart] = 3  # Exon cstart
            if (cend - 1) < sequence_length and (cend - 1) >= 0:
                labels[cend - 1] = 4  # Exon cend

        elif feature == 'intron':
            if cstart < sequence_length:
                labels[cstart] = 1  # Intron cstart
            if (cend - 1) < sequence_length and (cend - 1) >= 0:
                labels[cend - 1] = 2  # Intron cend

    return labels

def pad_labels(labels, target_length):
    """
    Pads label array up to target_length with 0 (non-coding).
    """
    if len(labels) < target_length:
        labels += [0] * (target_length - len(labels))
    return labels

def pad_encoded_seq(encoded_seq, target_length):
    """
    Pads sequence of shape (seq_len, 5) up to (target_length, 5) with zeros.
    """
    seq_len = len(encoded_seq)
    pad_size = target_length - seq_len
    if pad_size > 0:
        encoded_seq += [[0, 0, 0, 0, 0]] * pad_size
    return encoded_seq

# ------------------------------------
# 4) Main function to build chunk data
# ------------------------------------

def build_chunk_data(fasta_file, gtf_df, chunk_size=5000, skip_empty=True):
    """
    For each chunk (record_id, cstart, cend) in the FASTA:
      - Extract the reference sequence
      - For each strand (+ and -):
          1) Create a [chunk_size x 5] input: 4 channels for bases, 1 for strand
          2) Label the chunk (0..chunk_size-1) using the GTF annotations that
             fall on this chunk AND on this strand
          3) If skip_empty=True and all labels are 0, skip
          4) Yield (X, y) or store it somewhere

    Returns: generator of (X, y, record_id, chunk_cstart, chunk_cend, strand)
             or you could accumulate in a list.
    """
    fa = Fasta(fasta_file)

    # Pre-group GTF by (seqname, strand) to speed up filtering
    grouped_gtf = {}
    for (seqname, strand), subdf in gtf_df.groupby(['seqname', 'strand']):
        grouped_gtf[(seqname, strand)] = subdf

    chunk_list = compute_chunk_indices(fasta_file, chunk_size)
    
    for (record_id, cstart, cend) in chunk_list:
        # Read the reference chunk
        seq = str(fa[record_id][cstart:cend])  # raw bases from reference
        base_encoded_4 = one_hot_encode_reference(seq)  # shape => (chunk_len, 4)

        chunk_len = len(base_encoded_4)  # could be < chunk_size if at the cend of the chromosome

        for strand_symbol in ['+', '-']:
            # append the 5th channel for strand
            # If you prefer 1 for +, 0 for -:
            strand_flag = 1 if strand_symbol == '+' else 0
            encoded_seq_5 = [row + [strand_flag] for row in base_encoded_4]

            # Filter GTF for (record_id, strand_symbol)
            if (record_id, strand_symbol) not in grouped_gtf:
                # No annotations for that contig+strand => all labels=0
                labels = [0]*chunk_len
            else:
                subdf = grouped_gtf[(record_id, strand_symbol)]
                # Keep only rows that overlap [cstart, cend)
                # Then shift 'cstart' and 'cend' to local coordinates
                overlap = subdf[
                    (subdf['cstart'] < cend) & 
                    (subdf['cend'] > cstart)
                ].copy()

                if len(overlap) == 0:
                    labels = [0]*chunk_len
                else:
                    # Shift coords so that cstart => 0
                    overlap['cstart'] = overlap['cstart'] - cstart
                    overlap['cend']   = overlap['cend']   - cstart

                    # Now label them in local chunk coords
                    labels = label_sequence_local(chunk_len, overlap)

            # Optionally skip if all labels=0 and skip_empty=True
            if skip_empty and all(lbl == 0 for lbl in labels):
                continue

            # Pad up to chunk_size if needed
            encoded_seq_5 = pad_encoded_seq(encoded_seq_5, chunk_size)
            labels = pad_labels(labels, chunk_size)

            # Convert to np.array for deep learning frameworks
            X = np.array(encoded_seq_5, dtype=np.float32)   # shape [chunk_size, 5]
            y = np.array(labels, dtype=np.int32)            # shape [chunk_size]

            # Yield or store
            yield (X, y, record_id, cstart, cend, strand_symbol)



In [6]:
my_fasta = 'test_genome.fa'
my_gtf = load_gtf_annotations('test_annotations.gtf')

In [2]:
def write_tfrecord_shards(chunk_generator, output_dir, shard_size=1000, compression_type="GZIP"):
    """
    Writes data from a generator into compressed TFRecord shards.

    Parameters:
        chunk_generator (generator): An already instantiated generator object.
        output_dir (str): The directory to save the TFRecord shards.
        shard_size (int): Number of records per shard.
        compression_type (str): Compression type, e.g., "GZIP".
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    shard_index = 0
    record_count = 0
    shard_path = os.path.join(output_dir, f"data_shard_{shard_index:03d}.tfrecord")

    # Initialize the TFRecordWriter with compression
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    writer = tf.io.TFRecordWriter(shard_path, options=options)

    for record in chunk_generator: 
        writer.write(str(record))
        record_count += 1

        # Rotate shards when the shard size is reached
        if record_count >= shard_size:
            writer.close()
            shard_index += 1
            shard_path = os.path.join(output_dir, f"data_shard_{shard_index:03d}.tfrecord")
            writer = tf.io.TFRecordWriter(shard_path, options=options)
            record_count = 0

    # Close the last writer
    writer.close()
    print(f"Data written to {shard_index + 1} shards in {output_dir}")

In [23]:
output_directory = "test_tfrecord_shards"
write_tfrecord_shards(build_chunk_data(my_fasta, my_gtf, 5000, True), output_directory, shard_size=1000, compression_type="GZIP")

Data written to 9 shards in test_tfrecord_shards


In [8]:
i = 0
for X, y, record_id, cstart, cend, strand in build_chunk_data(my_fasta, my_gtf, 5000, True):
    if i <10:
        print(X, y, record_id, cstart, cend, strand)
        print(len(X))
        print(len(y))
        i+=1
    else:
        break

[[0. 0. 0. 1. 1.]
 [0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 1.]
 ...
 [1. 0. 0. 0. 1.]
 [0. 1. 0. 0. 1.]
 [0. 1. 0. 0. 1.]] [0 0 0 ... 0 0 4] chrX 250000 255000 +
5000
5000
[[0. 1. 0. 0. 1.]
 [0. 0. 1. 0. 1.]
 [0. 0. 1. 0. 1.]
 ...
 [0. 1. 0. 0. 1.]
 [0. 0. 0. 1. 1.]
 [0. 0. 1. 0. 1.]] [3 0 0 ... 0 0 0] chrX 255000 260000 +
5000
5000
[[0. 1. 0. 0. 1.]
 [1. 0. 0. 0. 1.]
 [0. 0. 1. 0. 1.]
 ...
 [0. 0. 1. 0. 1.]
 [0. 1. 0. 0. 1.]
 [0. 1. 0. 0. 1.]] [0 0 0 ... 0 0 0] chrX 275000 280000 +
5000
5000
[[1. 0. 0. 0. 1.]
 [0. 0. 1. 0. 1.]
 [0. 1. 0. 0. 1.]
 ...
 [1. 0. 0. 0. 1.]
 [1. 0. 0. 0. 1.]
 [0. 0. 0. 1. 1.]] [0 0 0 ... 0 0 0] chrX 280000 285000 +
5000
5000
[[0. 0. 0. 1. 1.]
 [1. 0. 0. 0. 1.]
 [0. 0. 0. 1. 1.]
 ...
 [0. 1. 0. 0. 1.]
 [0. 1. 0. 0. 1.]
 [0. 1. 0. 0. 1.]] [0 0 0 ... 0 0 0] chrX 285000 290000 +
5000
5000
[[1. 0. 0. 0. 1.]
 [0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 1.]
 ...
 [0. 1. 0. 0. 1.]
 [0. 0. 0. 1. 1.]
 [0. 1. 0. 0. 1.]] [0 0 0 ... 0 0 0] chrX 290000 295000 +
5000
5000
[[0. 0. 0. 1. 1.]
 [1.

In [4]:
my_fasta = 'chr_genome.fa'
my_gtf = load_gtf_annotations('chr_annotations.gtf')
output_directory = "basic_2_tfrecord_shards"
write_tfrecord_shards(build_chunk_data(my_fasta, my_gtf, 5000, True), output_directory, shard_size=20000, compression_type="GZIP")

Data written to 11 shards in basic_2_tfrecord_shards


In [None]:
import os
import numpy as np
import tensorflow as tf

def float_feature_list(value_list):
    """
    Utility to convert a list of floats into a FloatList. Floats are needed for backpropagation, 
    activation functions, and loss calculations and are thus the default type in TF and PyTorch.
    """
    return tf.train.Feature(float_list=tf.train.FloatList(value=value_list))

def int_feature_list(value_list):
    """
    Utility to convert a list of ints into an Int64List.  Ints fine because this is classification
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))

def bytes_feature(value):
    """
    Utility for a single string/bytes feature. For efficient string storage
    """
    if isinstance(value, str):
        value = value.encode('utf-8')
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol):
    """
    Converts a single chunk's data into a tf.train.Example protobuf.

    :param X: np.array(float32) of shape [chunk_size, 5]
    :param y: np.array(int32) of shape [chunk_size]
    :param record_id: str (e.g. chromosome name)
    :param cstart, cend: int
    :param strand_symbol: '+' or '-'

    We'll flatten X and y for storage. Then parse/reshape at read time.
    """
    chunk_size = X.shape[0]
    
    # Flatten X to 1D. We'll store it in row-major order.
    X_flat = X.flatten().tolist()
    # y is already 1D, but ensure it's a list of int
    y_list = y.tolist()
    
    # Build a dictionary of features using utility functions above 
    # to cast into types preferred by tensorflow
    feature_dict = {
        'chunk_size':  int_feature_list([chunk_size]),
        'X':           float_feature_list(X_flat),
        'y':           int_feature_list(y_list),
        'record_id':   bytes_feature(record_id),
        'cstart':      int_feature_list([cstart]),
        'cend':        int_feature_list([cend]),
        'strand':      bytes_feature(strand_symbol)
    }
    
    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return example.SerializeToString()


# def write_tfrecord(output_path, build_chunk_generator):
#     """
#     Iterates over the generator that yields (X, y, record_id, cstart, cend, strand_symbol)
#     and writes them into a TFRecord file.

#     :param output_path: e.g. "my_chunks.tfrecord"
#     :param build_chunk_generator: a function or generator that yields training examples.
#     """
#     with tf.io.TFRecordWriter(output_path) as writer:
#         for i, (X, y, record_id, cstart, cend, strand_symbol) in enumerate(build_chunk_generator):
#             example_str = serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol)
#             writer.write(example_str)
#     print(f"TFRecord written to {output_path}")


In [None]:
# fasta_file = "hg38.fa"
# gtf_file   = "hg38.gtf"
# chunk_size = 5000

# # Suppose we have the gtf_df loaded
# gtf_df = load_gtf_annotations(gtf_file)

# def chunk_generator():
#     # This is just a wrapper around build_chunk_data with some arguments
#     for (X, y, record_id, cstart, cend, strand_symbol) in build_chunk_data(fasta_file, gtf_df, chunk_size):
#         yield (X, y, record_id, cstart, cend, strand_symbol)

# # Actually write the TFRecord
# write_tfrecord("my_chunks.tfrecord", chunk_generator())

In [None]:
# def parse_chunk_example(serialized_example):
#     """
#     Parses a single serialized tf.train.Example back into tensors.
#     """
#     feature_spec = {
#         'chunk_size': tf.io.FixedLenFeature([1], tf.int64),
#         'X':          tf.io.VarLenFeature(tf.float32),
#         'y':          tf.io.VarLenFeature(tf.int64),
#         'record_id':  tf.io.FixedLenFeature([], tf.string),
#         'cstart':     tf.io.FixedLenFeature([1], tf.int64),
#         'cend':       tf.io.FixedLenFeature([1], tf.int64),
#         'strand':     tf.io.FixedLenFeature([], tf.string)
#     }
    
#     parsed = tf.io.parse_single_example(serialized_example, feature_spec)
    
#     # chunk_size is shape [1]
#     chunk_size = parsed['chunk_size'][0]
    
#     # X and y are sparse (VarLenFeature). Convert to dense.
#     # X shape: (chunk_size*5,)
#     X_flat = tf.sparse.to_dense(parsed['X'])
#     y_flat = tf.sparse.to_dense(parsed['y'])
    
#     # Reshape X into [chunk_size, 5]
#     X_reshaped = tf.reshape(X_flat, [chunk_size, 5])
#     # y is already [chunk_size], so no reshape needed except for chunk_size
#     # but let's explicitly reshape to ensure correctness
#     y_reshaped = tf.reshape(y_flat, [chunk_size])
    
#     record_id = parsed['record_id']
#     cstart    = parsed['cstart'][0]
#     cend      = parsed['cend'][0]
#     strand    = parsed['strand']
    
#     return X_reshaped, y_reshaped, record_id, cstart, cend, strand

In [None]:
# def build_dataset_from_tfrecord(tfrecord_path, batch_size=16):
#     """
#     Returns a tf.data.Dataset that yields (X, y, record_id, cstart, cend, strand) in batches.
#     """
#     dataset = tf.data.TFRecordDataset([tfrecord_path])
#     dataset = dataset.map(parse_chunk_example, num_parallel_calls=tf.data.AUTOTUNE)
#     # We can batch if we like, though each chunk has a different shape if chunk_size is constant it's okay
#     # However, chunk_size is the same for all? Then a standard batch is possible
#     dataset = dataset.batch(batch_size)
#     return dataset

# # Usage
# ds = build_dataset_from_tfrecord("my_chunks.tfrecord", batch_size=2)
# for batch in ds.take(1):
#     (X_batch, y_batch, record_id_batch, cstart_batch, cend_batch, strand_batch) = batch
#     print("X shape:", X_batch.shape)        # (batch_size, chunk_size, 5)
#     print("y shape:", y_batch.shape)        # (batch_size, chunk_size)
#     print("record_id:", record_id_batch)
#     print("cstart:", cstart_batch)
#     print("cend:", cend_batch)
#     print("strand:", strand_batch)


In [7]:
import tensorflow as tf

def write_tfrecord_in_shards(
    shard_prefix,
    build_chunk_generator,
    num_shards=5,
    compression_type="GZIP"
):
    """
    Writes data in multiple TFRecord shards, each compressed with GZIP by default.

    :param shard_prefix: Base path for shards, e.g. 'my_chunks'
    :param build_chunk_generator: yields (X, y, record_id, cstart, cend, strand)
    :param num_shards: number of TFRecord files to create
    :param compression_type: 'GZIP' or 'ZLIB' or None
    """
    # Create TFRecordWriters for each shard
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    
    writers = []
    for shard_id in range(num_shards):
        shard_path = f"{shard_prefix}-{shard_id:04d}.tfrecord"
        if compression_type:
            shard_path += ".gz"  # Just a naming convention
        writer = tf.io.TFRecordWriter(shard_path, options=options)
        writers.append(writer)

    # Round-robin or sequential assignment of examples
    shard_index = 0

    for i, (X, y, record_id, cstart, cend, strand_symbol) in enumerate(build_chunk_generator):
        example_str = serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol)
        
        # Write to the chosen shard
        writers[shard_index].write(example_str)
        
        # Move to next shard
        shard_index = (shard_index + 1) % num_shards

    # Close all
    for w in writers:
        w.close()

    print(f"Wrote data into {num_shards} shards with {compression_type} compression.")


In [None]:
def parse_chunk_example(serialized_example):
    """
    Parses a single serialized tf.train.Example back into tensors.
    """
    feature_spec = {
        'chunk_size': tf.io.FixedLenFeature([1], tf.int64),
        'X':          tf.io.VarLenFeature(tf.float32),
        'y':          tf.io.VarLenFeature(tf.int64),
        'record_id':  tf.io.FixedLenFeature([], tf.string),
        'cstart':     tf.io.FixedLenFeature([1], tf.int64),
        'cend':       tf.io.FixedLenFeature([1], tf.int64),
        'strand':     tf.io.FixedLenFeature([], tf.string)
    }
    
    parsed = tf.io.parse_single_example(serialized_example, feature_spec)
    
    # chunk_size is shape [1]
    chunk_size = parsed['chunk_size'][0]
    
    # X and y are sparse (VarLenFeature). Convert to dense.
    # X shape: (chunk_size*5,)
    X_flat = tf.sparse.to_dense(parsed['X'])
    y_flat = tf.sparse.to_dense(parsed['y'])
    
    # Reshape X into [chunk_size, 5]
    X_reshaped = tf.reshape(X_flat, [chunk_size, 5])
    # y is already [chunk_size], so no reshape needed except for chunk_size
    # but let's explicitly reshape to ensure correctness
    y_reshaped = tf.reshape(y_flat, [chunk_size])
    
    record_id = parsed['record_id']
    cstart    = parsed['cstart'][0]
    cend      = parsed['cend'][0]
    strand    = parsed['strand']
    
    return X_reshaped, y_reshaped, record_id, cstart, cend, strand


In [None]:
def build_dataset_from_shards(shard_prefix, num_shards=5, compression_type="GZIP"):
    """
    Reads from multiple compressed TFRecord shards, e.g. my_chunks-0000.tfrecord.gz, ...
    Parses them into (X, y).
    """
    shard_files = [
        f"{shard_prefix}-{shard_id:04d}.tfrecord.gz"
        for shard_id in range(num_shards)
    ]
    
    dataset = tf.data.TFRecordDataset(
        shard_files,
        compression_type=compression_type  # Must match what you used in writing
    )
    dataset = dataset.map(parse_chunk_example, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

# Example usage
ds = build_dataset_from_shards("my_chunks", num_shards=5)
ds = ds.batch(2).prefetch(tf.data.AUTOTUNE)

for X_batch, y_batch in ds.take(1):
    print("X shape:", X_batch.shape)
    print("y shape:", y_batch.shape)


Here is the cstart of the parallel processing shards code

In [38]:
import pandas as pd
import numpy as np
import os
import tensorflow as tf
from pyfaidx import Fasta

# ------------------------------------
# 1) Loading and shifting GTF annotations
# ------------------------------------

def load_gtf_annotations(gtf_file):
    """
    Loads GTF into a pandas DataFrame and converts cstart and cend to zero-based indexing.
    """
    gtf_data = pd.read_csv(
        gtf_file, sep='\t', comment='#', header=None,
        names=['seqname', 'source', 'feature', 'cstart', 'cend', 
               'score', 'strand', 'frame', 'attribute']
    )
    # Convert to zero-based indexing for cstart
    gtf_data['cstart'] = gtf_data['cstart'] - 1
    return gtf_data

# ------------------------------------
# 2) Generating uniform chunks
# ------------------------------------

def compute_chunk_indices(fasta_file, chunk_size):
    """
    Creates a list of (record_id, cstart, cend) for each chunk in the FASTA.
    """
    print('Running compute_chunk_indices')
    
    fa = Fasta(fasta_file)  # for indexed random access
    chunk_indices = []
    for record_id in fa.keys():
        seq_len = len(fa[record_id])
        for cstart in range(0, seq_len, chunk_size):
            cend = min(cstart + chunk_size, seq_len)
            chunk_indices.append((record_id, cstart, cend))
    return chunk_indices

# ------------------------------------
# 3) One-hot encoding + Strand Feature
# ------------------------------------

n_base_encoder = {
    'A': [1, 0, 0, 0],
    'C': [0, 1, 0, 0],
    'G': [0, 0, 1, 0],
    'T': [0, 0, 0, 1],
    'N': [0, 0, 0, 0]  
}

def one_hot_encode_reference(sequence):
    """
    Returns a list of 4-element lists for each base (A, C, G, T).
    """
    return [n_base_encoder.get(nuc, [0, 0, 0, 0]) for nuc in sequence]

def label_sequence_local(sequence_length, annotations):
    """
    Assigns categorical numeric labels:
      0 = non-coding
      1 = intron cstart
      2 = intron cend
      3 = exon cstart
      4 = exon cend

    annotations here should already be filtered to the chunk + strand,
    with 'cstart' and 'cend' in [0..sequence_length).
    """
    labels = [0] * sequence_length  # All non-coding by default
    for _, row in annotations.iterrows():
        cstart, cend, feature = int(row['cstart']), int(row['cend']), row['feature']

        # Clip to chunk boundaries (paranoia check; might already be clipped)
        cstart = max(0, cstart)
        cend   = min(sequence_length, cend)

        if feature == 'exon':
            if cstart < sequence_length:
                labels[cstart] = 3  # Exon cstart
            if (cend - 1) < sequence_length and (cend - 1) >= 0:
                labels[cend - 1] = 4  # Exon cend

        elif feature == 'intron':
            if cstart < sequence_length:
                labels[cstart] = 1  # Intron cstart
            if (cend - 1) < sequence_length and (cend - 1) >= 0:
                labels[cend - 1] = 2  # Intron cend

    return labels

def pad_labels(labels, target_length):
    """
    Pads label array up to target_length with 0 (non-coding).
    """
    if len(labels) < target_length:
        labels += [0] * (target_length - len(labels))
    return labels

def pad_encoded_seq(encoded_seq, target_length):
    """
    Pads sequence of shape (seq_len, 5) up to (target_length, 5) with zeros.
    """
    seq_len = len(encoded_seq)
    pad_size = target_length - seq_len
    if pad_size > 0:
        encoded_seq += [[0, 0, 0, 0, 0]] * pad_size
    return encoded_seq

# ------------------------------------
# 4) Main function to build chunk data
# ------------------------------------

def build_chunk_data_for_indices(fasta_file, gtf_df, subset_indices, skip_empty=True, chunk_size=5000):
    """
    For each chunk (record_id, cstart, cend) in the FASTA:
      - Extract the reference sequence
      - For each strand (+ and -):
          1) Create a [chunk_size x 5] input: 4 channels for bases, 1 for strand
          2) Label the chunk (0..chunk_size-1) using the GTF annotations that
             fall on this chunk AND on this strand
          3) If skip_empty=True and all labels are 0, skip
          4) Yield (X, y) or store it somewhere

    Returns: generator of (X, y, record_id, chunk_cstart, chunk_cend, strand)
             or you could accumulate in a list.
    """
    print('running build_chunk_data_for_indices')
    
    fa = Fasta(fasta_file)

    # Pre-group GTF by (seqname, strand) to speed up filtering
    grouped_gtf = {}
    for (seqname, strand), sub_df in gtf_df.groupby(['seqname', 'strand']):
        grouped_gtf[(seqname, strand)] = sub_df

        
    for (record_id, cstart, cend) in subset_indices:
        # Read the reference chunk
        seq = str(fa[record_id][cstart:cend])  # raw bases from reference
        base_encoded_4 = one_hot_encode_reference(seq)  # shape => (chunk_len, 4)

        chunk_len = len(base_encoded_4)  # could be < chunk_size if at the cend of the chromosome

        for strand_symbol in ['+', '-']:
            # append the 5th channel for strand
            # If you prefer 1 for +, 0 for -:
            strand_flag = 1 if strand_symbol == '+' else 0
            encoded_seq_5 = [row + [strand_flag] for row in base_encoded_4]

            # Filter GTF for (record_id, strand_symbol)
            if (record_id, strand_symbol) not in grouped_gtf:
                # No annotations for that contig+strand => all labels=0
                labels = [0]*chunk_len
            else:
                sub_df = grouped_gtf[(record_id, strand_symbol)]
                # Keep only rows that overlap [cstart, cend)
                # Then shift 'cstart' and 'cend' to local coordinates
                overlap = sub_df[
                    (sub_df['cstart'] < cend) & 
                    (sub_df['cend'] > cstart)
                ].copy()

                if len(overlap) == 0:
                    labels = [0]*chunk_len
                else:
                    # Shift coords so that cstart => 0
                    overlap['cstart'] = overlap['cstart'] - cstart
                    overlap['cend']   = overlap['cend']   - cstart

                    # Now label them in local chunk coords
                    labels = label_sequence_local(chunk_len, overlap)

            # Optionally skip if all labels=0 and skip_empty=True
            if skip_empty and all(lbl == 0 for lbl in labels):
                continue

            # Pad up to chunk_size if needed
            encoded_seq_5 = pad_encoded_seq(encoded_seq_5, chunk_size)
            labels = pad_labels(labels, chunk_size)

            # Convert to np.array for deep learning frameworks
            X = np.array(encoded_seq_5, dtype=np.float32)   # shape [chunk_size, 5]
            y = np.array(labels, dtype=np.int32)            # shape [chunk_size]

            # Yield or store
            yield (X, y, record_id, cstart, cend, strand_symbol)

In [39]:
import os
import numpy as np
import tensorflow as tf

def float_feature_list(value_list):
    """
    Utility to convert a list of floats into a FloatList. Floats are needed for backpropagation, 
    activation functions, and loss calculations and are thus the default type in TF and PyTorch.
    """
    return tf.train.Feature(float_list=tf.train.FloatList(value=value_list))

def int_feature_list(value_list):
    """
    Utility to convert a list of ints into an Int64List.  Ints fine because this is classification
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))

def bytes_feature(value):
    """
    Utility for a single string/bytes feature. For efficient string storage
    """
    if isinstance(value, str):
        value = value.encode('utf-8')
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol):
    """
    Converts a single chunk's data into a tf.train.Example protobuf.

    :param X: np.array(float32) of shape [chunk_size, 5]
    :param y: np.array(int32) of shape [chunk_size]
    :param record_id: str (e.g. chromosome name)
    :param cstart, cend: int
    :param strand_symbol: '+' or '-'

    We'll flatten X and y for storage. Then parse/reshape at read time.
    """
    
    chunk_size = X.shape[0]
    
    # Flatten X to 1D. We'll store it in row-major order.
    X_flat = X.flatten().tolist()
    # y is already 1D, but ensure it's a list of int
    y_list = y.tolist()
    
    # Build a dictionary of features using utility functions above 
    # to cast into types preferred by tensorflow
    feature_dict = {
        'chunk_size':  int_feature_list([chunk_size]),
        'X':           float_feature_list(X_flat),
        'y':           int_feature_list(y_list),
        'record_id':   bytes_feature(record_id),
        'cstart':      int_feature_list([cstart]),
        'cend':        int_feature_list([cend]),
        'strand':      bytes_feature(strand_symbol)
    }
    
    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return example.SerializeToString()

In [40]:
import concurrent.futures as cf
import threading
import tensorflow as tf

def write_to_shard_with_threads(
    shard_id, 
    shard_path, 
    num_shards, 
    all_indices, 
    fasta_file, 
    gtf_df, 
    compression_type="GZIP", 
    skip_empty=True, 
    max_threads_per_process=4
):
    """
    Writes data for a specific shard using threads for concurrent writes.
    Each thread serializes its own portion of the data and writes it to the shared TFRecord file.

    :param shard_id: ID of the shard (used to filter data)
    :param shard_path: Path to the shard file
    :param num_shards: Total number of shards
    :param all_indices: All data indices for chunk generation
    :param fasta_file: FASTA file for DNA sequences
    :param gtf_df: DataFrame of GTF annotations
    :param compression_type: Compression for TFRecord files (e.g., "GZIP")
    :param skip_empty: Skip empty chunks
    :param max_threads_per_process: Number of threads to use within this shard
    """
    print('Running write_to_shard_with_threads')
    
    # Step 1: Filter indices for this shard
    subset_indices = [
        idx for i, idx in enumerate(all_indices)
        if i % num_shards == shard_id
    ]
    print(f'Shard subset indices gathered for shard {shard_id}')
    
    # Step 2: Precompute all chunk data for this shard
    shard_chunks = list(build_chunk_data_for_indices(
        fasta_file, gtf_df, subset_indices, skip_empty=skip_empty
    ))
    [print(f'Shard chunks for {shard_id} built')]
    
    # Step 3: Split the chunks into groups for each thread
    chunk_splits = [
        shard_chunks[i::max_threads_per_process] for i in range(max_threads_per_process)
    ]
    print('Shard chunks built for write threads')
    
    # Step 4: Define the thread worker
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    lock = threading.Lock()  # To avoid race conditions during writes

    def thread_worker(chunk_group):
        """Thread worker to serialize and write chunks."""
        print('Thread worker initialized')
        with tf.io.TFRecordWriter(shard_path, options=options) as writer:
            for X, y, record_id, cstart, cend, strand_symbol in chunk_group:
                # Serialize the chunk
                example_str = serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol)
                with lock:  # Ensure only one thread writes at a time
                    writer.write(example_str)

    # Step 5: Launch threads to process each group of chunks
    with cf.ThreadPoolExecutor(max_threads_per_process) as thread_executor:
        thread_futures = [
            thread_executor.submit(thread_worker, chunk_group)
            for chunk_group in chunk_splits
        ]
        for future in thread_futures:
            future.result()
            print(f'future.result() line used for {shard_id}')

In [41]:
def parallel_write_tfrecord_in_shards_hybrid(
    shard_prefix,
    fasta_file,
    gtf_df,
    num_shards=4,
    compression_type="GZIP",
    max_processes=4,
    max_threads_per_process=4,
    chunk_size=5000,
    skip_empty=True
):
    """
    Writes data in multiple TFRecord shards using multiprocessing for shards
    and threading within each shard.

    :param shard_prefix: Base path for shards, e.g., "my_chunks"
    :param build_chunk_generator: Generator yielding (X, y, record_id, cstart, cend, strand)
    :param num_shards: Number of TFRecord shards to create
    :param compression_type: 'GZIP', 'ZLIB', or None
    :param max_processes: Number of processes to use for parallel writing
    :param max_threads_per_process: Number of threads to use within each process
    """
    print('Running parallel_write_tfrecord_in_shards_hybrid')
    # Compute all chunk indices first
    all_indices = compute_chunk_indices(fasta_file, chunk_size)
    print('all_indices calculated')
    
    # Create shard paths
    shard_paths = []
    for shard_id in range(num_shards):
        shard_path = f"{shard_prefix}-{shard_id:04d}.tfrecord"
        if compression_type:
            shard_path += ".gz"  # Naming convention
        shard_paths.append(shard_path)
        print(shard_paths)

     # Spawn multiple processes (one per shard, or up to max_processes)
    with cf.ProcessPoolExecutor(max_workers=max_processes) as process_executor:
        futures = []
        for shard_id in range(num_shards):
            fut = process_executor.submit(
                write_to_shard_with_threads,
                shard_id,
                shard_paths[shard_id],
                num_shards,
                all_indices,
                fasta_file,
                gtf_df,
                compression_type,
                skip_empty,
                max_threads_per_process
            )
            futures.append(fut)
            print('futures.append(fut) line has been run')

        # Wait for all shards to finish
        for future in futures:
            shard_done_id = future.result()
            print(f"Shard {shard_done_id} completed.")

In [42]:
my_fasta = 'test_genome.fa'
my_gtf_df = load_gtf_annotations('test_annotations.gtf')
output_directory = "test_tfrecord_multi"
if not os.path.exists(output_directory):
        os.makedirs(output_directory)
my_prefix = output_directory + '/inex_shard'
# all_indices = compute_chunk_indices(my_fasta, chunk_size=5000)

parallel_write_tfrecord_in_shards_hybrid(my_prefix, my_fasta, my_gtf_df, 4, 'GZIP', 4, 4, 5000, True)

Running parallel_write_tfrecord_in_shards_hybrid
Running compute_chunk_indices
all_indices calculated
['test_tfrecord_multi/inex_shard-0000.tfrecord.gz']
['test_tfrecord_multi/inex_shard-0000.tfrecord.gz', 'test_tfrecord_multi/inex_shard-0001.tfrecord.gz']
['test_tfrecord_multi/inex_shard-0000.tfrecord.gz', 'test_tfrecord_multi/inex_shard-0001.tfrecord.gz', 'test_tfrecord_multi/inex_shard-0002.tfrecord.gz']
['test_tfrecord_multi/inex_shard-0000.tfrecord.gz', 'test_tfrecord_multi/inex_shard-0001.tfrecord.gz', 'test_tfrecord_multi/inex_shard-0002.tfrecord.gz', 'test_tfrecord_multi/inex_shard-0003.tfrecord.gz']
futures.append(fut) line has been run
futures.append(fut) line has been run
futures.append(fut) line has been run
futures.append(fut) line has been run
Running write_to_shard_with_threads
Shard subset indices gathered for shard 0
running build_chunk_data_for_indices
Running write_to_shard_with_threads
Shard subset indices gathered for shard 1
running build_chunk_data_for_indices
Ru

In [60]:
print(all_indices)

[('chrX', 0, 5000), ('chrX', 5000, 10000), ('chrX', 10000, 15000), ('chrX', 15000, 20000), ('chrX', 20000, 25000), ('chrX', 25000, 30000), ('chrX', 30000, 35000), ('chrX', 35000, 40000), ('chrX', 40000, 45000), ('chrX', 45000, 50000), ('chrX', 50000, 55000), ('chrX', 55000, 60000), ('chrX', 60000, 65000), ('chrX', 65000, 70000), ('chrX', 70000, 75000), ('chrX', 75000, 80000), ('chrX', 80000, 85000), ('chrX', 85000, 90000), ('chrX', 90000, 95000), ('chrX', 95000, 100000), ('chrX', 100000, 105000), ('chrX', 105000, 110000), ('chrX', 110000, 115000), ('chrX', 115000, 120000), ('chrX', 120000, 125000), ('chrX', 125000, 130000), ('chrX', 130000, 135000), ('chrX', 135000, 140000), ('chrX', 140000, 145000), ('chrX', 145000, 150000), ('chrX', 150000, 155000), ('chrX', 155000, 160000), ('chrX', 160000, 165000), ('chrX', 165000, 170000), ('chrX', 170000, 175000), ('chrX', 175000, 180000), ('chrX', 180000, 185000), ('chrX', 185000, 190000), ('chrX', 190000, 195000), ('chrX', 195000, 200000), ('ch

In [16]:
my_prefix = 'test_tfrecord_multi'
my_fasta = 'test_genome.fa'
my_gtf_df = load_gtf_annotations('test_annotations.gtf')
output_directory = "test_tfrecord_multi"
all_indices = compute_chunk_indices(my_fasta, 5000)
test_indices = all_indices[:10]
print(len(all_indices))
subset_indices = [
        idx for i, idx in enumerate(all_indices)
        if i % 4 == 2
    ]
print(len(subset_indices))

42659
10665


In [17]:
print(subset_indices)

[('chrX', 10000, 15000), ('chrX', 30000, 35000), ('chrX', 50000, 55000), ('chrX', 70000, 75000), ('chrX', 90000, 95000), ('chrX', 110000, 115000), ('chrX', 130000, 135000), ('chrX', 150000, 155000), ('chrX', 170000, 175000), ('chrX', 190000, 195000), ('chrX', 210000, 215000), ('chrX', 230000, 235000), ('chrX', 250000, 255000), ('chrX', 270000, 275000), ('chrX', 290000, 295000), ('chrX', 310000, 315000), ('chrX', 330000, 335000), ('chrX', 350000, 355000), ('chrX', 370000, 375000), ('chrX', 390000, 395000), ('chrX', 410000, 415000), ('chrX', 430000, 435000), ('chrX', 450000, 455000), ('chrX', 470000, 475000), ('chrX', 490000, 495000), ('chrX', 510000, 515000), ('chrX', 530000, 535000), ('chrX', 550000, 555000), ('chrX', 570000, 575000), ('chrX', 590000, 595000), ('chrX', 610000, 615000), ('chrX', 630000, 635000), ('chrX', 650000, 655000), ('chrX', 670000, 675000), ('chrX', 690000, 695000), ('chrX', 710000, 715000), ('chrX', 730000, 735000), ('chrX', 750000, 755000), ('chrX', 770000, 7750

In [18]:
chunks = list(build_chunk_data_for_indices(my_fasta, my_gtf_df, subset_indices, skip_empty=True))
for chunk in chunks[:5]:
    print(chunk)

(array([[0., 0., 0., 1., 1.],
       [0., 0., 1., 0., 1.],
       [1., 0., 0., 0., 1.],
       ...,
       [1., 0., 0., 0., 1.],
       [0., 1., 0., 0., 1.],
       [0., 1., 0., 0., 1.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 4], dtype=int32), 'chrX', 250000, 255000, '+')
(array([[1., 0., 0., 0., 1.],
       [0., 0., 1., 0., 1.],
       [1., 0., 0., 0., 1.],
       ...,
       [0., 1., 0., 0., 1.],
       [0., 0., 0., 1., 1.],
       [0., 1., 0., 0., 1.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 0], dtype=int32), 'chrX', 290000, 295000, '+')
(array([[0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       ...,
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 4], dtype=int32), 'chrX', 310000, 315000, '-')
(array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.],
       ...,
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [

In [21]:
print(type(chunks[0]))
shard_chunks = list(chunks)

print(shard_chunks)

<class 'tuple'>
[(array([[0., 0., 0., 1., 1.],
       [0., 0., 1., 0., 1.],
       [1., 0., 0., 0., 1.],
       ...,
       [1., 0., 0., 0., 1.],
       [0., 1., 0., 0., 1.],
       [0., 1., 0., 0., 1.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 4], dtype=int32), 'chrX', 250000, 255000, '+'), (array([[1., 0., 0., 0., 1.],
       [0., 0., 1., 0., 1.],
       [1., 0., 0., 0., 1.],
       ...,
       [0., 1., 0., 0., 1.],
       [0., 0., 0., 1., 1.],
       [0., 1., 0., 0., 1.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 0], dtype=int32), 'chrX', 290000, 295000, '+'), (array([[0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       ...,
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 4], dtype=int32), 'chrX', 310000, 315000, '-'), (array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.],
       ...,
       [0., 1., 0., 0., 0.],
       [1., 0., 0

In [20]:
chunk_splits = [
    chunks[i::4] for i in range(4)
]
print(chunk_splits)

[[(array([[0., 0., 0., 1., 1.],
       [0., 0., 1., 0., 1.],
       [1., 0., 0., 0., 1.],
       ...,
       [1., 0., 0., 0., 1.],
       [0., 1., 0., 0., 1.],
       [0., 1., 0., 0., 1.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 4], dtype=int32), 'chrX', 250000, 255000, '+'), (array([[0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       ...,
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32), array([0, 0, 0, ..., 0, 0, 0], dtype=int32), 'chrX', 370000, 375000, '-'), (array([[0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.],
       ...,
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32), array([3, 0, 0, ..., 0, 0, 0], dtype=int32), 'chrX', 1010000, 1015000, '-'), (array([[0., 0., 0., 1., 1.],
       [0., 0., 1., 0., 1.],
       [0., 0., 0., 1., 1.],
       ...,
       [1., 0., 0., 0., 1.],
       [0., 1., 0., 0., 1.],
 

In [43]:
my_fasta = 'chr_genome.fa'
my_gtf_df = load_gtf_annotations('chr_annotations.gtf')
output_directory = "Optimized_TFRecord_Shards"
if not os.path.exists(output_directory):
        os.makedirs(output_directory)
my_prefix = output_directory + '/inex_shard'
parallel_write_tfrecord_in_shards_hybrid(my_prefix, my_fasta, my_gtf_df, 4, 'GZIP', 4, 4, 5000, True)

Running parallel_write_tfrecord_in_shards_hybrid
Running compute_chunk_indices
all_indices calculated
['Optimized_TFRecord_Shards/inex_shard-0000.tfrecord.gz']
['Optimized_TFRecord_Shards/inex_shard-0000.tfrecord.gz', 'Optimized_TFRecord_Shards/inex_shard-0001.tfrecord.gz']
['Optimized_TFRecord_Shards/inex_shard-0000.tfrecord.gz', 'Optimized_TFRecord_Shards/inex_shard-0001.tfrecord.gz', 'Optimized_TFRecord_Shards/inex_shard-0002.tfrecord.gz']
['Optimized_TFRecord_Shards/inex_shard-0000.tfrecord.gz', 'Optimized_TFRecord_Shards/inex_shard-0001.tfrecord.gz', 'Optimized_TFRecord_Shards/inex_shard-0002.tfrecord.gz', 'Optimized_TFRecord_Shards/inex_shard-0003.tfrecord.gz']
futures.append(fut) line has been run
futures.append(fut) line has been run
futures.append(fut) line has been run
futures.append(fut) line has been run
Running write_to_shard_with_threads
Shard subset indices gathered for shard 0
running build_chunk_data_for_indices
Running write_to_shard_with_threads
Shard subset indices 

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.