This is the notebook that converts fasta and gtf data into tfrecords.  Uses parallel processing to make the processing time more reasonable.

In [None]:
import time
import sys
import os
import glob
import math
import threading
import concurrent.futures as cf

import numpy as np
import pandas as pd
import tensorflow as tf
from keras import Input, Model, layers, metrics, losses, callbacks, optimizers, models, utils
from keras import backend as K
import gc
import keras_tuner as kt
from pyfaidx import Fasta

K.clear_session()
gc.collect()

datasets_path = "../../Datasets/"
models_path = "../../Models/"

First cell defines utility functions and a function that converts fasta and gtf data into encoded arrays with DNA location info as annotations.

In [10]:
from keras import backend as K
import gc

K.clear_session()
gc.collect()

0

In [None]:
def load_gtf_annotations(gtf_file):
    """
    Loads GTF into a pandas DataFrame and converts cstart and cend to zero-based indexing.
    """
    gtf_data = pd.read_csv(
        gtf_file, sep='\t', comment='#', header=None,
        names=['seqname', 'source', 'feature', 'cstart', 'cend', 
               'score', 'strand', 'frame', 'attribute']
    )
    # Convert to zero-based indexing for cstart
    gtf_data['cstart'] = gtf_data['cstart'] - 1
    return gtf_data


# def compute_chunk_indices(fasta_file, chunk_size):
#     """
#     Creates a list of (record_id, cstart, cend) for each chunk in the FASTA.
#     This version does not take window shifts
#     """
#     print('Running compute_chunk_indices')
    
#     fa = Fasta(fasta_file)  # for indexed random access
#     chunk_indices = []
#     for record_id in fa.keys():
#         seq_len = len(fa[record_id])
#         for cstart in range(0, seq_len, chunk_size):
#             cend = min(cstart + chunk_size, seq_len)
#             chunk_indices.append((record_id, cstart, cend))
#     return chunk_indices


def compute_chunk_indices(fasta_file, chunk_size, shifts=[0]):
    """
    Creates a list of (record_id, cstart, cend) for each chunk in the FASTA,
    applying additional shifts to augment the dataset.  Shifts takes a list of window shift values.  
    """
    print('Running compute_chunk_indices with data augmentation')
    fa = Fasta(fasta_file)  # for indexed random access
    chunk_indices = []
    for record_id in fa.keys():
        seq_len = len(fa[record_id])
        for shift in shifts:
            # Start at the given shift, then step by chunk_size
            for cstart in range(shift, seq_len, chunk_size):
                cend = min(cstart + chunk_size, seq_len)
                chunk_indices.append((record_id, cstart, cend))
    return chunk_indices


def one_hot_encode_reference(sequence):
    """
    Returns a list of 4-element lists for each base (A, C, G, T).
    """
    n_base_encoder = {
        'A': [1, 0, 0, 0],
        'C': [0, 1, 0, 0],
        'G': [0, 0, 1, 0],
        'T': [0, 0, 0, 1],
        'N': [0, 0, 0, 0],  
        }
    return [n_base_encoder.get(nuc, [0, 0, 0, 0]) for nuc in sequence]


def label_sequence_local(sequence_length, annotations, window=50):
    """
    Builds a label matrix with partial credit for boundary annotations and a background channel
    that is set to 0 if any full annotation (value 1) is present in the other channels, or 1
    if only partial credit is present.
    
    For each annotation (e.g. an exon start), the target values are assigned as follows:
      - At the annotated base: 1.0.
      - At positions 1 base away: 0.5.
      - At positions at distance d (for d=2,...,window):
            credit = 0.5 - 0.01*(d-1)
        so that at distance 50 the credit is 0.5 - 0.01*(50-1) = 0.01.
    
    When two annotations of the same channel are nearby, each base’s final target is the maximum
    credit received from any annotation.
    
    Channels are assigned as:
      - Column 0: non‑coding (background).  
          * It is set to 0 when any of the other channels equals 1.
          * If none of the other channels have a full (1.0) annotation, background is left at 1.
      - Column 1: intron cstart (annotation at cstart for feature "intron")
      - Column 2: intron cend (annotation at cend-1 for feature "intron")
      - Column 3: exon cstart (annotation at cstart for feature "exon")
      - Column 4: exon cend (annotation at cend-1 for feature "exon")
    
    Parameters:
      sequence_length (int): Length of the sequence.
      annotations (DataFrame): Must have columns 'cstart', 'cend', and 'feature'.
      window (int): How far (in bases) the partial credit is spread (default 50).
    
    Returns:
      np.ndarray: A (sequence_length x 5) label matrix.
    """
    
    # Create binary arrays for the four annotation channels.
    exon_cstart_binary   = np.zeros(sequence_length)
    exon_cend_binary     = np.zeros(sequence_length)
    intron_cstart_binary = np.zeros(sequence_length)
    intron_cend_binary   = np.zeros(sequence_length)
    
    # Mark exact positions from annotations.
    for _, row in annotations.iterrows():
        cs = int(row['cstart'])
        ce = int(row['cend'])
        feat = row['feature'].strip().lower()
        if feat == 'exon':
            if 0 <= cs < sequence_length:
                exon_cstart_binary[cs] = 1
            if 0 <= ce - 1 < sequence_length:
                exon_cend_binary[ce - 1] = 1
        elif feat == 'intron':
            if 0 <= cs < sequence_length:
                intron_cstart_binary[cs] = 1
            if 0 <= ce - 1 < sequence_length:
                intron_cend_binary[ce - 1] = 1

    def smooth_binary(binary_arr, window):
        """
        Given a binary array (with 1’s at annotated positions), create a custom 
        "smoothed" array where an annotation at position i contributes:
          - 1.0 at position i,
          - 0.5 at positions i ± 1,
          - and for positions i ± d (with 2 <= d <= window):
                0.5 - 0.01*(d-1)
        Contributions from multiple annotations are combined via max().
        """
        L = len(binary_arr)
        smooth_arr = np.zeros(L)
        # Find indices where an annotation is present.
        annotation_indices = np.where(binary_arr == 1)[0]
        for idx in annotation_indices:
            # Annotated base gets full credit.
            smooth_arr[idx] = 1.0
            # Spread out to left and right.
            for d in range(1, window + 1):
                credit = 0.5 - (0.5 / window) * (d - 1)
                # Ensure credit is not negative.
                credit = max(credit, 0)
                left = idx - d
                right = idx + d
                if left >= 0:
                    smooth_arr[left] = max(smooth_arr[left], credit)
                if right < L:
                    smooth_arr[right] = max(smooth_arr[right], credit)
        return smooth_arr

    # Smooth each binary channel.
    exon_cstart_smooth   = smooth_binary(exon_cstart_binary, window)
    exon_cend_smooth     = smooth_binary(exon_cend_binary, window)
    intron_cstart_smooth = smooth_binary(intron_cstart_binary, window)
    intron_cend_smooth   = smooth_binary(intron_cend_binary, window)
    
    # Build the full label matrix.
    # Columns: [non-coding, intron cstart, intron cend, exon cstart, exon cend]
    labels = np.zeros((sequence_length, 5))
    labels[:, 1] = intron_cstart_smooth
    labels[:, 2] = intron_cend_smooth
    labels[:, 3] = exon_cstart_smooth
    labels[:, 4] = exon_cend_smooth
    
    # For each base:
    # - If any annotation channel is exactly 1, set background to 0.
    # - Otherwise (only partial credits present), leave background as 1.
    max_annotation = np.max(labels[:, 1:], axis=1)
    # Where max_annotation is 1, background = 0; otherwise background = 1.
    labels[:, 0] = np.where(max_annotation == 1, 0, 1)
    
    return labels


def pad_labels(labels, target_length):
    """
    Pads a NumPy label array (shape: [current_length, 5]) up to target_length
    by adding rows of [1, 0, 0, 0, 0] (representing non-coding).
    """
    current_length = len(labels)
    if current_length < target_length:
        pad_length = target_length - current_length
        # Create an array with pad_length rows of the non-coding label.
        pad_array = np.tile(np.array([[1, 0, 0, 0, 0]]), (pad_length, 1))
        labels = np.concatenate([labels, pad_array], axis=0)
        labels = labels.tolist()
    return labels



def pad_encoded_seq(encoded_seq, target_length):
    """
    Pads sequence of shape (seq_len, 5) up to (target_length, 5) with zeros.
    """
    seq_len = len(encoded_seq)
    pad_size = target_length - seq_len
    if pad_size > 0:
        encoded_seq += [[0, 0, 0, 0, 0]] * pad_size
    return encoded_seq


def build_chunk_data_for_indices(fasta_file, gtf_df, subset_indices, skip_empty=True, chunk_size=5000):
    """
    For each chunk (record_id, cstart, cend) in the FASTA:
      - Extract the reference sequence
      - For each strand (+ and -):
          1) Create a [chunk_size x 5] input: 4 channels for bases, 1 for strand
          2) Label the chunk (0..chunk_size-1) using the GTF annotations that
             fall on this chunk AND on this strand
          3) If skip_empty=True and all labels are background, skip
          4) Yield (X, y) or store it somewhere

    Returns: generator of (X, y, record_id, chunk_cstart, chunk_cend, strand, chunk_size)
    """
    print('running build_chunk_data_for_indices')
    
    fa = Fasta(fasta_file)

    # Pre-group GTF by (seqname, strand) to speed up filtering
    grouped_gtf = {}
    for (seqname, strand), sub_df in gtf_df.groupby(['seqname', 'strand']):
        grouped_gtf[(seqname, strand)] = sub_df

        
    for (record_id, cstart, cend) in subset_indices:
        # Read the reference chunk
        seq = str(fa[record_id][cstart:cend])  # raw bases from reference
        base_encoded_4 = one_hot_encode_reference(seq)  # shape => (chunk_len, 4)

        chunk_len = len(base_encoded_4)  # could be < chunk_size if at the cend of the chromosome

        for strand_symbol in ['+', '-']:
            # append the 5th channel for strand
            # 1 for +, 0 for -:
            strand_flag = 1 if strand_symbol == '+' else 0
            encoded_seq_5 = [row + [strand_flag] for row in base_encoded_4]

            # Filter GTF for (record_id, strand_symbol)
            if (record_id, strand_symbol) not in grouped_gtf:
                # No annotations for that contig+strand => all labels=[[1, 0, 0, 0, 0]]
                labels = [[1, 0, 0, 0, 0]]*chunk_len
            else:
                sub_df = grouped_gtf[(record_id, strand_symbol)]
                # Keep only rows that overlap [cstart, cend)
                # Then shift 'cstart' and 'cend' to local coordinates
                overlap = sub_df[
                    (sub_df['cstart'] < cend) & 
                    (sub_df['cend'] > cstart)
                ].copy()

                if len(overlap) == 0:
                    labels = [[1, 0, 0, 0, 0]]*chunk_len
                else:
                    # Shift coords so that cstart => 0
                    overlap['cstart'] = overlap['cstart'] - cstart
                    overlap['cend']   = overlap['cend']   - cstart

                    # Now label them in local chunk coords
                    labels = label_sequence_local(chunk_len, overlap, window=100)
                    labels = labels.tolist()

            # Optionally skip if all labels=0 and skip_empty=True
            if skip_empty and all(lbl == [1, 0, 0, 0, 0] for lbl in labels):
                continue

            # Pad up to chunk_size if needed
            encoded_seq_5 = pad_encoded_seq(encoded_seq_5, chunk_size)
            labels = pad_labels(labels, chunk_size)

            # Convert to np.array for deep learning frameworks
            X = np.array(encoded_seq_5, dtype=np.float32)   # shape [chunk_size, 5]
            y = np.array(labels, dtype=np.float32)            # shape [chunk_size]
            
            # Passing chunk_size through for unzipping the tfrecord later
            chunk_size = chunk_size

            # Yield or store
            yield (X, y, record_id, cstart, cend, strand_symbol, chunk_size)

Second cell defines functions to optimize output of first cell's last function for use in a tfrecord file.

In [4]:
def float_feature_list(value_list):
    """
    Utility to convert a list of floats into a FloatList. Floats are needed for backpropagation, 
    activation functions, and loss calculations and are thus the default type in TF and PyTorch.
    """
    return tf.train.Feature(float_list=tf.train.FloatList(value=value_list))


def int_feature_list(value_list):
    """
    Utility to convert a list of ints into an Int64List.  Ints fine because this is classification
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))


def bytes_feature(value):
    """
    Utility for a single string/bytes feature. For efficient string storage
    """
    if isinstance(value, str):
        value = value.encode('utf-8')
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol, chunk_size = 5000):
    """
    Converts a single chunk's data into a tf.train.Example protobuf.

    :param X: np.array(float32) of shape [chunk_size, 5]
    :param y: np.array(int32) of shape [chunk_size]
    :param record_id: str (e.g. chromosome name)
    :param cstart, cend: int
    :param strand_symbol: '+' or '-'

    Flattens X and y for storage. Requires parse/reshape at read time.
    """
    
    chunk_size = chunk_size
    
    # Flattens X to 1D and stores it in row-major order.
    X_flat = X.flatten().tolist()
    
    # y is already 1D, but ensure it's a list of int
    y_list = y.flatten().tolist()
    
    # Builds a dictionary of features using utility functions above 
    # to cast into types preferred by tensorflow
    feature_dict = {        
        'X':           float_feature_list(X_flat),
        'y':           float_feature_list(y_list),
        'record_id':   bytes_feature(record_id),
        'cstart':      int_feature_list([cstart]),
        'cend':        int_feature_list([cend]),
        'strand':      bytes_feature(strand_symbol),
        'chunk_size':  int_feature_list([chunk_size]),
    }
    
    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return example.SerializeToString()

Third cell defines a function that performs row/example generation and breaks writing process into threads.  

Uses functions defined in first and second cells.

In [5]:
def write_to_shard_with_threads(
    shard_id,
    shard_path,
    num_shards,
    all_indices,
    fasta_file,
    gtf_df,
    compression_type="GZIP",
    skip_empty=True,
    max_threads_per_process=4,
    chunk_size_input = 5000
):
    """
    Writes data for a specific shard using threads for concurrent writes.
    Data is processed incrementally to avoid loading everything into memory.
    """
    print('Running write_to_shard_with_threads')
    
    # Filter indices for this shard
    subset_indices = [
        idx for i, idx in enumerate(all_indices)
        if i % num_shards == shard_id
    ]
    print(f'Shard subset indices gathered for shard {shard_id}')
    
    # Passing chunk_size through for the unzipping step later
    chunk_size_in = chunk_size_input
    
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    writer = tf.io.TFRecordWriter(shard_path, options=options)
    lock = threading.Lock()  # Ensure thread-safe writes

    def thread_worker(subset_indices_split):
        """Thread worker to process and write chunks."""
        for X, y, record_id, cstart, cend, strand_symbol, chunk_size in build_chunk_data_for_indices(
            fasta_file, gtf_df, subset_indices_split, skip_empty=skip_empty, chunk_size=chunk_size_in
            ):
            try:
                # Serialize the chunk
                example_str = serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol, chunk_size)
                with lock:  # Ensure thread-safe writes
                    writer.write(example_str)
            except Exception as e:
                print(f"Error writing chunk: {e}")

    # Divide the subset_indices into splits for each thread
    subset_splits = [
        subset_indices[i::max_threads_per_process] for i in range(max_threads_per_process)
    ]

    # Start threads
    with cf.ThreadPoolExecutor(max_threads_per_process) as thread_executor:
        thread_futures = [
            thread_executor.submit(thread_worker, subset_split)
            for subset_split in subset_splits
        ]
        
        # After executor finishes
        for future in thread_futures:
            future.result() # Returns None but terminates thread if using 'with' didn't
            worker_number = thread_futures.index(future)
            print(f'Thread executor {worker_number} completed for process executor {shard_id}.')

Fourth cell splits data generation into multiple processes and feeds options to lower functions.

Uses functions defined in first and third cells.

In [None]:
def write_tfrecord_in_shards_hybrid(
    shard_prefix,
    fasta_file,
    gtf_df,
    num_shards=4,
    compression_type="GZIP",
    max_processes=4,
    max_threads_per_process=4,
    chunk_size=5000,
    skip_empty=True,
    shifts = [0]
):
    """
    Writes data in multiple TFRecord shards using multiprocessing for shards
    and threading within each shard (hybrid).
    
    Var_name reminders:
    :param shard_prefix: Base path for shards, e.g., "my_chunks"
    :param build_chunk_generator: Generator yielding (X, y, record_id, cstart, cend, strand)
    :param num_shards: Number of TFRecord shards to create
    :param compression_type: 'GZIP', 'ZLIB', or None
    :param max_processes: Number of processes to use for parallel writing
    :param max_threads_per_process: Number of threads to use within each process
    """
    print('Running write_tfrecord_in_shards_hybrid')
    
    # Compute all chunk indices first
    all_indices = compute_chunk_indices(fasta_file, chunk_size, shifts=shifts)
    print('all_indices calculated')
    
    # Create shard paths
    shard_paths = []
    for shard_id in range(num_shards):
        shard_path = f"{shard_prefix}-{shard_id:04d}.tfrecord"
        if compression_type == "GZIP":
            shard_path += ".gz"  # Naming convention
        shard_paths.append(shard_path)
    print(shard_paths)

    # Spawn multiple processes (one per shard, or up to max_processes)
    with cf.ProcessPoolExecutor(max_workers=max_processes) as process_executor:
        futures = []
        for shard_id in range(num_shards):
            proc = process_executor.submit(
                write_to_shard_with_threads,
                shard_id,
                shard_paths[shard_id],
                num_shards,
                all_indices,
                fasta_file,
                gtf_df,
                compression_type,
                skip_empty,
                max_threads_per_process,
                chunk_size
            )
            futures.append(proc)
            print(f'Process executor {shard_id} has started')

        # After all shards finish
        for future in futures:
            future.result() # Returns None but terminates thread if using 'with' didn't
            shard_number = futures.index(future)
            print(f"Process executor {shard_number} completed.")

74 mins 37.5 seconds with compression

73 mins 28.6 seconds without compression

I/O for compressed is much better than time savings on compression and unzipping.

Compressed: 200 MB

Not compressed: 5.4 GB

In [None]:
def main():
    my_fasta = 'trim_chr_genome.fa'
    my_gtf_df = pd.read_csv("FinalIntronExonDF.csv")
    output_directory = "Final_Optimized_TFRecord_Shards"
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    
    # change these every time
    shifts = [3334]
    my_prefix = output_directory + '/3334_inex_shard'
    
    write_tfrecord_in_shards_hybrid(
        shard_prefix=my_prefix, 
        fasta_file=my_fasta, 
        gtf_df=my_gtf_df, 
        num_shards=4, 
        compression_type="GZIP", 
        max_processes=4, 
        max_threads_per_process=2, 
        chunk_size=5000, 
        skip_empty=True,
        shifts=shifts
    )

if __name__ == "__main__":
    main()

Running write_tfrecord_in_shards_hybrid
Running compute_chunk_indices with data augmentation
all_indices calculated
['Final_Optimized_TFRecord_Shards/3334_inex_shard-0000.tfrecord.gz', 'Final_Optimized_TFRecord_Shards/3334_inex_shard-0001.tfrecord.gz', 'Final_Optimized_TFRecord_Shards/3334_inex_shard-0002.tfrecord.gz', 'Final_Optimized_TFRecord_Shards/3334_inex_shard-0003.tfrecord.gz']
Process executor 0 has started
Process executor 1 has started
Process executor 2 has started
Process executor 3 has started
Running write_to_shard_with_threads
Shard subset indices gathered for shard 0
running build_chunk_data_for_indicesrunning build_chunk_data_for_indices

Running write_to_shard_with_threads
Shard subset indices gathered for shard 1
running build_chunk_data_for_indicesrunning build_chunk_data_for_indices

Running write_to_shard_with_threads
Shard subset indices gathered for shard 2
running build_chunk_data_for_indicesrunning build_chunk_data_for_indices

Running write_to_shard_with_thr

Dataset size went from 0.79 GB to 1.14 GB after making a few fixes. 89 mins 25 seconds