This Notebook hopefully (I'm only human) contains every step needed to generate intron-exon data because it is way too big for github.  It is a Frankenstein's monster of a notebook in that the code in this comes from a bunch of different notebooks so it's possible I missed something somewhere.  A lot of style choices are not terribly consistent because of that, too.

GTF and Fasta files can be found on this page: https://www.gencodegenes.org/human/

GTF file is the 'Basic gene annotation' with CHR as the 'regions' and can be direct downloaded here: https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_47/gencode.v47.basic.annotation.gtf.gz

I unzipped and saved it as 'basic_annotations.gtf'

The fasta file is 'Genome sequence (GRCh38.p14)' with ALL as the 'regions' and can be direct downloaded here: https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_47/GRCh38.p14.genome.fa.gz

I unzipped and saved it as 'chr_genome.fa'

In [None]:
import time
import sys
import os
import glob
import math
import threading
import concurrent.futures as cf
import random
import re

import numpy as np
import pandas as pd
import tensorflow as tf
from keras import Input, Model, layers, metrics, losses, callbacks, optimizers, models, utils
from keras import backend as K
import gc
import keras_tuner as kt
from pyfaidx import Fasta

K.clear_session()
gc.collect()

datasets_path = "../../Datasets/"
models_path = "../../Models/"

2025-03-17 01:10:21.974832: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-17 01:10:21.993167: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-17 01:10:21.998951: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-17 01:10:22.013136: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


0

Writes out a new fasta that uses the full sequences of the 22 numbered chromosomes and the X and Y chromosomes.  chrM has no introns.

In [None]:
valid_chroms = set([f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"])

input_fasta = datasets_path + "chr_genome.fa"  # Removed chrM because there aren't introns on it and I didn't want to confuse the training data
output_fasta = datasets_path + "trim_chr_genome.fa"

keep = False
'''
Removes non-chromosome entries in the human genome data
If line is a record key (starswith(">")), it compares first 'word' in the record key to valid_chroms
If comparison is true, it keeps writing lines until it finds another record key line to make the comparison
'''
with open(input_fasta, "r") as fin, open(output_fasta, "w") as fout:
    for line in fin:
        if line.startswith(">"):
            chrom_name = line.strip()[1:].split()[0]
            keep = (chrom_name in valid_chroms)

            if keep:
                fout.write(line)
                print(chrom_name)
        else:
            if keep:
                fout.write(line)

Loads gtf annotations with proper 0 index numbering

In [None]:
def load_gtf_annotations(gtf_file):
    """
    Loads GTF into a pandas DataFrame and converts cstart and cend to zero-based indexing.
    """
    gtf_data = pd.read_csv(
        gtf_file, sep='\t', comment='#', header=None,
        names=['seqname', 'source', 'feature', 'cstart', 'cend', 
               'score', 'strand', 'frame', 'attribute']
    )
    # Convert to zero-based indexing for cstart
    gtf_data['cstart'] = gtf_data['cstart'] - 1
    return gtf_data

GTF search function that was nice.

In [None]:
def search_gtf_by_range(gtf_df, seqname, pos_min, pos_max, require_both=False):
    """
    Search a GTF annotations DataFrame for rows matching a given sequence name and 
    having cstart and/or cend values within a specified range.

    Parameters:
      gtf_df (pd.DataFrame): DataFrame containing GTF annotations. Must include at least 
                             the columns 'seqname', 'cstart', and 'cend'.
      seqname (str): The chromosome or scaffold name to filter by (e.g., 'chr1' or '1').
      pos_min (int): The lower bound of the position range (inclusive).
      pos_max (int): The upper bound of the position range (inclusive).
      require_both (bool): 
           - If False (default), returns rows where either 'cstart' OR 'cend' falls within the range.
           - If True, returns only rows where BOTH 'cstart' and 'cend' fall within the range.
    
    Returns:
      pd.DataFrame: A DataFrame containing only the rows that match the criteria.
    """
    # Filter by seqname first.
    df = gtf_df[gtf_df['seqname'] == seqname]
    
    if require_both:
        condition = (
            (df['cstart'] >= pos_min) & (df['cstart'] <= pos_max) &
            (df['cend']   >= pos_min) & (df['cend']   <= pos_max)
        )
    else:
        condition = (
            ((df['cstart'] >= pos_min) & (df['cstart'] <= pos_max)) |
            ((df['cend']   >= pos_min) & (df['cend']   <= pos_max))
        )
        
    return df[condition]

Defines the function that calculates intron cstart and cend based on the exon cstart and cend values in the basic annotation gtf.

In [None]:
def calculate_introns(gtf_df):
    """
    Given a pandas DataFrame of GTF records (with columns: 'seqname', 'feature', 'start',
    'end', 'strand', and 'attribute') in which the start (called cstart) has been converted
    to 0-indexed values, calculate the introns for each gene.
    
    This function assumes that the attribute field contains a gene identifier in a form like:
       gene_id "XYZ";
    as it does in the hg38 gtf and groups features by gene_id. For each gene, it collects the exon 
    intervals, merges overlapping exons (forms the union of exonic regions) and then computes each intron 
    as the gap between consecutive merged exons. For plus-strand genes the intron is reported
    as (previous_exon.end, next_exon.start), while for minus-strand genes the order is reversed so
    that the cstart value is higher than the cend.
    
    Returns:
        A new DataFrame with one row per intron, having columns:
          - seqname
          - feature (with value "intron")
          - cstart (the start coordinate in 0-index system; note that for minus strand this is numerically higher)
          - cend   (the end coordinate)
          - strand
    """
    
    # helper to extract gene_id from the attribute string
    def get_gene_id(attr):
        # look for a pattern like: gene_id "XYZ";
        m = re.search(r'gene_id\s+"([^"]+)"', attr)
        if m:
            return m.group(1)
        else:
            return None

    # Add a gene_id column (if not already present)
    if 'gene_id' not in gtf_df.columns:
        gtf_df = gtf_df.copy()  # avoid modifying the original dataframe
        gtf_df['gene_id'] = gtf_df['attribute'].apply(get_gene_id)
    
    intron_records = []
    
    # Group rows by gene_id (each gene should have one gene-level record, and one or more exon records)
    for gene_id, group in gtf_df.groupby('gene_id'):
        # Skip groups with no gene_id (if any)
        if gene_id is None:
            continue
        
        # Identify the gene-level information (if available)
        gene_rows = group[group['feature'] == 'gene']
        if not gene_rows.empty:
            # Use the gene row to get the chromosome and strand.
            seqname = gene_rows.iloc[0]['seqname']
            strand  = gene_rows.iloc[0]['strand']
            gene_start = gene_rows.iloc[0]['cstart']
            gene_end   = gene_rows.iloc[0]['cend']
        else:
            # Fall back on the first exon if no gene record is available.
            seqname = group.iloc[0]['seqname']
            strand  = group.iloc[0]['strand']
            gene_start = None
            gene_end = None
        
        # Get all exon rows for this gene
        exon_rows = group[group['feature'] == 'exon']
        if exon_rows.empty:
            continue
        
        # Build a list of exon intervals (each as a tuple (start, end))
        # Optionally we could filter to exons that fall within the gene boundaries.
        exon_intervals = list(zip(exon_rows['cstart'], exon_rows['cend']))
        
        # Sort by start (genomic order)
        exon_intervals = sorted(exon_intervals, key=lambda x: x[0])
        
        # Merge overlapping or adjacent exons.
        # (For example, if two exons overlap because of alternative splicing, we want the union.)
        merged_exons = []
        for interval in exon_intervals:
            if not merged_exons:
                merged_exons.append(list(interval))
            else:
                last = merged_exons[-1]
                # If the current exon overlaps or touches the previous one, merge them.
                if interval[0] <= last[1]:
                    last[1] = max(last[1], interval[1])
                else:
                    merged_exons.append(list(interval))
        
        # If there are fewer than two merged exons, then there is no intron.
        if len(merged_exons) < 2:
            continue
        
        
        # For each adjacent pair of merged exons, define an intron between them.
        for i in range(len(merged_exons) - 1):
            # The intron is the gap between the end of exon i and the start of exon i+1.
            intron_start = merged_exons[i][1]
            intron_end   = merged_exons[i+1][0]
            # Only add if there is a gap.
            if intron_end > intron_start:
                intron_records.append({
                    'seqname': seqname,
                    'feature': 'intron',
                    'cstart': intron_start,
                    'cend': intron_end,
                    'strand': strand
                })

    
    return pd.DataFrame(intron_records)

Data wrangling and applying calculate_introns

In [None]:
annotation_data = load_gtf_annotations(datasets_path + 'basic_annotations.gtf')
annotation_data = annotation_data[annotation_data["seqname"]!="chrM"]
introns = calculate_introns(annotation_data)

trimmed_annotation_data = annotation_data[["seqname", "feature", "cstart", "cend", "strand"]]

IntronExonDF = pd.concat([trimmed_annotation_data, introns])

# IntronExonDF.to_csv(datasets_path + 'IntronExonDF.csv', index=False)
# introns.to_csv(datasets_path + 'BetterIntrons.csv', index=False)

Below is a fix to an old version of calculate_introns, included to keep the pipeline working.

In [None]:
def swap_columns_if_needed(df, col_a, col_b):
    """
    Turns out the (-) strand lists cstart as smaller than cend.  This fixes the output from
    the above function that calculated intron boundaries.
    For each row in the dataframe, if the value in col_a is greater than the value in col_b,
    swap the two values.

    Parameters:
        df (pd.DataFrame): The dataframe to process.
        col_a (str): The name of the first column.
        col_b (str): The name of the second column.

    Returns:
        pd.DataFrame: The dataframe with swapped values where needed.
    """
    # Create a boolean mask where the value in col_a is greater than col_b.
    mask = df[col_a] > df[col_b]
    
    # Swap the values in col_a and col_b for rows where mask is True.
    df.loc[mask, [col_a, col_b]] = df.loc[mask, [col_b, col_a]].values
    
    return df

Used the print statement in the next cell to confirm the merge worked by comparing to https://genome.ucsc.edu/cgi-bin/hgGateway  

On the human genome UCSC browser, due to indexing 1 on their end and 0 in python, cstart here is the last base to the right of the feature on the browser.

cend looks like the correct spot, but only because python excludes the last base which cancels out the off by 1 issue

Printing line by line swaps back into 1 indexing in a .txt file so locations are accurate as long as line 1 has the first base. That might not always be the case if I put a print statement somewhere without thinking.

In [None]:
FixedIntronExonDF = swap_columns_if_needed(IntronExonDF, 'cstart', 'cend')
Trimmed_Intron_Exon_DF = FixedIntronExonDF[((FixedIntronExonDF["feature"]=="exon") | (FixedIntronExonDF["feature"]=="intron"))]
Trimmed_Intron_Exon_DF = Trimmed_Intron_Exon_DF[["seqname", "feature", "cstart", "cend", "strand"]]

print(Trimmed_Intron_Exon_DF.sample(10))

Trimmed_Intron_Exon_DF.to_csv(datasets_path + "FinalIntronExonDF.csv", index=False)

The next five code cells are used to parallel process the GTF and Fasta data into tfrecord.gz files.

First cell defines utility functions and a function that converts fasta and gtf data into encoded arrays with DNA location info as annotations.

In [None]:
def load_gtf_annotations(gtf_file):
    """
    Loads GTF into a pandas DataFrame and converts cstart and cend to zero-based indexing.
    """
    gtf_data = pd.read_csv(
        gtf_file, sep='\t', comment='#', header=None,
        names=['seqname', 'source', 'feature', 'cstart', 'cend', 
               'score', 'strand', 'frame', 'attribute']
    )
    # Convert to zero-based indexing for cstart
    gtf_data['cstart'] = gtf_data['cstart'] - 1
    return gtf_data


# def compute_chunk_indices(fasta_file, chunk_size):
#     """
#     Creates a list of (record_id, cstart, cend) for each chunk in the FASTA.
#     This version does not take window shifts
#     """
#     print('Running compute_chunk_indices')
    
#     fa = Fasta(fasta_file)  # for indexed random access
#     chunk_indices = []
#     for record_id in fa.keys():
#         seq_len = len(fa[record_id])
#         for cstart in range(0, seq_len, chunk_size):
#             cend = min(cstart + chunk_size, seq_len)
#             chunk_indices.append((record_id, cstart, cend))
#     return chunk_indices


def compute_chunk_indices(fasta_file, chunk_size, shifts=[0]):
    """
    Creates a list of (record_id, cstart, cend) for each chunk in the FASTA,
    applying additional shifts to augment the dataset.  Shifts takes a list of window shift values.  
    """
    print('Running compute_chunk_indices with data augmentation')
    fa = Fasta(fasta_file)  # for indexed random access
    chunk_indices = []
    for record_id in fa.keys():
        seq_len = len(fa[record_id])
        for shift in shifts:
            # Start at the given shift, then step by chunk_size
            for cstart in range(shift, seq_len, chunk_size):
                cend = min(cstart + chunk_size, seq_len)
                chunk_indices.append((record_id, cstart, cend))
    return chunk_indices


def one_hot_encode_reference(sequence):
    """
    Returns a list of 4-element lists for each base (A, C, G, T).
    """
    n_base_encoder = {
        'A': [1, 0, 0, 0],
        'C': [0, 1, 0, 0],
        'G': [0, 0, 1, 0],
        'T': [0, 0, 0, 1],
        'N': [0, 0, 0, 0],  
        }
    return [n_base_encoder.get(nuc, [0, 0, 0, 0]) for nuc in sequence]


def label_sequence_local(sequence_length, annotations, window=50):
    """
    Builds a label matrix with partial credit for boundary annotations and a background channel
    that is set to 0 if any full annotation (value 1) is present in the other channels, or 1
    if only partial credit is present.
    
    For each annotation (e.g. an exon start), the target values are assigned as follows:
      - At the annotated base: 1.0.
      - At positions 1 base away: 0.5.
      - At positions at distance d (for d=2,...,window):
            credit = 0.5 - 0.01*(d-1)
        so that at distance 50 the credit is 0.5 - 0.01*(50-1) = 0.01.
    
    When two annotations of the same channel are nearby, each base’s final target is the maximum
    credit received from any annotation.
    
    Channels are assigned as:
      - Column 0: non‑coding (background).  
          * It is set to 0 when any of the other channels equals 1.
          * If none of the other channels have a full (1.0) annotation, background is left at 1.
      - Column 1: intron cstart (annotation at cstart for feature "intron")
      - Column 2: intron cend (annotation at cend-1 for feature "intron")
      - Column 3: exon cstart (annotation at cstart for feature "exon")
      - Column 4: exon cend (annotation at cend-1 for feature "exon")
    
    Parameters:
      sequence_length (int): Length of the sequence.
      annotations (DataFrame): Must have columns 'cstart', 'cend', and 'feature'.
      window (int): How far (in bases) the partial credit is spread (default 50).
    
    Returns:
      np.ndarray: A (sequence_length x 5) label matrix.
    """
    
    # Create binary arrays for the four annotation channels.
    exon_cstart_binary   = np.zeros(sequence_length)
    exon_cend_binary     = np.zeros(sequence_length)
    intron_cstart_binary = np.zeros(sequence_length)
    intron_cend_binary   = np.zeros(sequence_length)
    
    # Mark exact positions from annotations.
    for _, row in annotations.iterrows():
        cs = int(row['cstart'])
        ce = int(row['cend'])
        feat = row['feature'].strip().lower()
        if feat == 'exon':
            if 0 <= cs < sequence_length:
                exon_cstart_binary[cs] = 1
            if 0 <= ce - 1 < sequence_length:
                exon_cend_binary[ce - 1] = 1
        elif feat == 'intron':
            if 0 <= cs < sequence_length:
                intron_cstart_binary[cs] = 1
            if 0 <= ce - 1 < sequence_length:
                intron_cend_binary[ce - 1] = 1

    def smooth_binary(binary_arr, window):
        """
        Given a binary array (with 1’s at annotated positions), create a custom 
        "smoothed" array where an annotation at position i contributes:
          - 1.0 at position i,
          - 0.5 at positions i ± 1,
          - and for positions i ± d (with 2 <= d <= window):
                0.5 - 0.01*(d-1)
        Contributions from multiple annotations are combined via max().
        """
        L = len(binary_arr)
        smooth_arr = np.zeros(L)
        # Find indices where an annotation is present.
        annotation_indices = np.where(binary_arr == 1)[0]
        for idx in annotation_indices:
            # Annotated base gets full credit.
            smooth_arr[idx] = 1.0
            # Spread out to left and right.
            for d in range(1, window + 1):
                credit = 0.5 - (0.5 / window) * (d - 1)
                # Ensure credit is not negative.
                credit = max(credit, 0)
                left = idx - d
                right = idx + d
                if left >= 0:
                    smooth_arr[left] = max(smooth_arr[left], credit)
                if right < L:
                    smooth_arr[right] = max(smooth_arr[right], credit)
        return smooth_arr

    # Smooth each binary channel.
    exon_cstart_smooth   = smooth_binary(exon_cstart_binary, window)
    exon_cend_smooth     = smooth_binary(exon_cend_binary, window)
    intron_cstart_smooth = smooth_binary(intron_cstart_binary, window)
    intron_cend_smooth   = smooth_binary(intron_cend_binary, window)
    
    # Build the full label matrix.
    # Columns: [non-coding, intron cstart, intron cend, exon cstart, exon cend]
    labels = np.zeros((sequence_length, 5))
    labels[:, 1] = intron_cstart_smooth
    labels[:, 2] = intron_cend_smooth
    labels[:, 3] = exon_cstart_smooth
    labels[:, 4] = exon_cend_smooth
    
    # For each base:
    # - If any annotation channel is exactly 1, set background to 0.
    # - Otherwise (only partial credits present), leave background as 1.
    max_annotation = np.max(labels[:, 1:], axis=1)
    # Where max_annotation is 1, background = 0; otherwise background = 1.
    labels[:, 0] = np.where(max_annotation == 1, 0, 1)
    
    return labels


def pad_labels(labels, target_length):
    """
    Pads a NumPy label array (shape: [current_length, 5]) up to target_length
    by adding rows of [1, 0, 0, 0, 0] (representing non-coding).
    """
    current_length = len(labels)
    if current_length < target_length:
        pad_length = target_length - current_length
        # Create an array with pad_length rows of the non-coding label.
        pad_array = np.tile(np.array([[1, 0, 0, 0, 0]]), (pad_length, 1))
        labels = np.concatenate([labels, pad_array], axis=0)
        labels = labels.tolist()
    return labels


def pad_encoded_seq(encoded_seq, target_length):
    """
    Pads sequence of shape (seq_len, 5) up to (target_length, 5) with zeros.
    """
    seq_len = len(encoded_seq)
    pad_size = target_length - seq_len
    if pad_size > 0:
        encoded_seq += [[0, 0, 0, 0, 0]] * pad_size
    return encoded_seq


def build_chunk_data_for_indices(fasta_file, gtf_df, subset_indices, skip_empty=True, chunk_size=5000):
    """
    For each chunk (record_id, cstart, cend) in the FASTA:
      - Extract the reference sequence
      - For each strand (+ and -):
          1) Create a [chunk_size x 5] input: 4 channels for bases, 1 for strand
          2) Label the chunk (0..chunk_size-1) using the GTF annotations that
             fall on this chunk AND on this strand
          3) If skip_empty=True and all labels are background, skip
          4) Yield (X, y) or store it somewhere

    Returns: generator of (X, y, record_id, chunk_cstart, chunk_cend, strand, chunk_size)
    """
    print('running build_chunk_data_for_indices')
    
    fa = Fasta(fasta_file)

    # Pre-group GTF by (seqname, strand) to speed up filtering
    grouped_gtf = {}
    for (seqname, strand), sub_df in gtf_df.groupby(['seqname', 'strand']):
        grouped_gtf[(seqname, strand)] = sub_df

        
    for (record_id, cstart, cend) in subset_indices:
        # Read the reference chunk
        seq = str(fa[record_id][cstart:cend])  # raw bases from reference
        base_encoded_4 = one_hot_encode_reference(seq)  # shape => (chunk_len, 4)

        chunk_len = len(base_encoded_4)  # could be < chunk_size if at the cend of the chromosome

        for strand_symbol in ['+', '-']:
            # append the 5th channel for strand
            # 1 for +, 0 for -:
            strand_flag = 1 if strand_symbol == '+' else 0
            encoded_seq_5 = [row + [strand_flag] for row in base_encoded_4]

            # Filter GTF for (record_id, strand_symbol)
            if (record_id, strand_symbol) not in grouped_gtf:
                # No annotations for that contig+strand => all labels=[[1, 0, 0, 0, 0]]
                labels = [[1, 0, 0, 0, 0]]*chunk_len
            else:
                sub_df = grouped_gtf[(record_id, strand_symbol)]
                # Keep only rows that overlap [cstart, cend)
                # Then shift 'cstart' and 'cend' to local coordinates
                overlap = sub_df[
                    (sub_df['cstart'] < cend) & 
                    (sub_df['cend'] > cstart)
                ].copy()

                if len(overlap) == 0:
                    labels = [[1, 0, 0, 0, 0]]*chunk_len
                else:
                    # Shift coords so that cstart => 0
                    overlap['cstart'] = overlap['cstart'] - cstart
                    overlap['cend']   = overlap['cend']   - cstart

                    # Now label them in local chunk coords
                    labels = label_sequence_local(chunk_len, overlap, window=100)
                    labels = labels.tolist()

            # Optionally skip if all labels=0 and skip_empty=True
            if skip_empty and all(lbl == [1, 0, 0, 0, 0] for lbl in labels):
                continue

            # Pad up to chunk_size if needed
            encoded_seq_5 = pad_encoded_seq(encoded_seq_5, chunk_size)
            labels = pad_labels(labels, chunk_size)

            # Convert to np.array for deep learning frameworks
            X = np.array(encoded_seq_5, dtype=np.float32)   # shape [chunk_size, 5]
            y = np.array(labels, dtype=np.float32)            # shape [chunk_size]
            
            # Passing chunk_size through for unzipping the tfrecord later
            chunk_size = chunk_size

            # Yield or store
            yield (X, y, record_id, cstart, cend, strand_symbol, chunk_size)

Second cell defines functions to optimize output of first cell's last function for use in a tfrecord file.

In [None]:
def float_feature_list(value_list):
    """
    Utility to convert a list of floats into a FloatList. Floats are needed for backpropagation, 
    activation functions, and loss calculations and are thus the default type in TF and PyTorch.
    """
    return tf.train.Feature(float_list=tf.train.FloatList(value=value_list))


def int_feature_list(value_list):
    """
    Utility to convert a list of ints into an Int64List.  Ints fine because this is classification
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))


def bytes_feature(value):
    """
    Utility for a single string/bytes feature. For efficient string storage
    """
    if isinstance(value, str):
        value = value.encode('utf-8')
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol, chunk_size = 5000):
    """
    Converts a single chunk's data into a tf.train.Example protobuf.

    :param X: np.array(float32) of shape [chunk_size, 5]
    :param y: np.array(int32) of shape [chunk_size]
    :param record_id: str (e.g. chromosome name)
    :param cstart, cend: int
    :param strand_symbol: '+' or '-'

    Flattens X and y for storage. Requires parse/reshape at read time.
    """
    
    chunk_size = chunk_size
    
    # Flattens X to 1D and stores it in row-major order.
    X_flat = X.flatten().tolist()
    
    # y is already 1D, but ensure it's a list of int
    y_list = y.flatten().tolist()
    
    # Builds a dictionary of features using utility functions above 
    # to cast into types preferred by tensorflow
    feature_dict = {        
        'X':           float_feature_list(X_flat),
        'y':           float_feature_list(y_list),
        'record_id':   bytes_feature(record_id),
        'cstart':      int_feature_list([cstart]),
        'cend':        int_feature_list([cend]),
        'strand':      bytes_feature(strand_symbol),
        'chunk_size':  int_feature_list([chunk_size]),
    }
    
    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return example.SerializeToString()

Third cell defines a function that performs row/example generation and breaks writing process into threads.  

Uses functions defined in first and second cells.

In [None]:
def write_to_shard_with_threads(
    shard_id,
    shard_path,
    num_shards,
    all_indices,
    fasta_file,
    gtf_df,
    compression_type="GZIP",
    skip_empty=True,
    max_threads_per_process=4,
    chunk_size_input = 5000
):
    """
    Writes data for a specific shard using threads for concurrent writes.
    Data is processed incrementally to avoid loading everything into memory.
    """
    print('Running write_to_shard_with_threads')
    
    # Filter indices for this shard
    subset_indices = [
        idx for i, idx in enumerate(all_indices)
        if i % num_shards == shard_id
    ]
    print(f'Shard subset indices gathered for shard {shard_id}')
    
    # Passing chunk_size through for the unzipping step later
    chunk_size_in = chunk_size_input
    
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    writer = tf.io.TFRecordWriter(shard_path, options=options)
    lock = threading.Lock()  # Ensure thread-safe writes

    def thread_worker(subset_indices_split):
        """Thread worker to process and write chunks."""
        for X, y, record_id, cstart, cend, strand_symbol, chunk_size in build_chunk_data_for_indices(
            fasta_file, gtf_df, subset_indices_split, skip_empty=skip_empty, chunk_size=chunk_size_in
            ):
            try:
                # Serialize the chunk
                example_str = serialize_chunk_example(X, y, record_id, cstart, cend, strand_symbol, chunk_size)
                with lock:  # Ensure thread-safe writes
                    writer.write(example_str)
            except Exception as e:
                print(f"Error writing chunk: {e}")

    # Divide the subset_indices into splits for each thread
    subset_splits = [
        subset_indices[i::max_threads_per_process] for i in range(max_threads_per_process)
    ]

    # Start threads
    with cf.ThreadPoolExecutor(max_threads_per_process) as thread_executor:
        thread_futures = [
            thread_executor.submit(thread_worker, subset_split)
            for subset_split in subset_splits
        ]
        
        # After executor finishes
        for future in thread_futures:
            future.result() # Returns None but terminates thread if using 'with' didn't
            worker_number = thread_futures.index(future)
            print(f'Thread executor {worker_number} completed for process executor {shard_id}.')

Fourth cell splits data generation into multiple processes and feeds options to lower functions.

Uses functions defined in first and third cells.

In [None]:
def write_tfrecord_in_shards_hybrid(
    shard_prefix,
    fasta_file,
    gtf_df,
    num_shards=4,
    compression_type="GZIP",
    max_processes=4,
    max_threads_per_process=4,
    chunk_size=5000,
    skip_empty=True,
    shifts = [0]
):
    """
    Writes data in multiple TFRecord shards using multiprocessing for shards
    and threading within each shard (hybrid).
    
    Var_name reminders:
    :param shard_prefix: Base path for shards, e.g., "my_chunks"
    :param build_chunk_generator: Generator yielding (X, y, record_id, cstart, cend, strand)
    :param num_shards: Number of TFRecord shards to create
    :param compression_type: 'GZIP', 'ZLIB', or None
    :param max_processes: Number of processes to use for parallel writing
    :param max_threads_per_process: Number of threads to use within each process
    """
    print('Running write_tfrecord_in_shards_hybrid')
    
    # Compute all chunk indices first
    all_indices = compute_chunk_indices(fasta_file, chunk_size, shifts=shifts)
    print('all_indices calculated')
    
    # Create shard paths
    shard_paths = []
    for shard_id in range(num_shards):
        shard_path = f"{shard_prefix}-{shard_id:04d}.tfrecord"
        if compression_type == "GZIP":
            shard_path += ".gz"  # Naming convention
        shard_paths.append(shard_path)
    print(shard_paths)

    # Spawn multiple processes (one per shard, or up to max_processes)
    with cf.ProcessPoolExecutor(max_workers=max_processes) as process_executor:
        futures = []
        for shard_id in range(num_shards):
            proc = process_executor.submit(
                write_to_shard_with_threads,
                shard_id,
                shard_paths[shard_id],
                num_shards,
                all_indices,
                fasta_file,
                gtf_df,
                compression_type,
                skip_empty,
                max_threads_per_process,
                chunk_size
            )
            futures.append(proc)
            print(f'Process executor {shard_id} has started')

        # After all shards finish
        for future in futures:
            future.result() # Returns None but terminates thread if using 'with' didn't
            shard_number = futures.index(future)
            print(f"Process executor {shard_number} completed.")

Below cell is used to parallel process and write tfrecord.gz shards.  Turns out cells consider themselves as their own module so using main() notation works

In [None]:
def main():
    my_fasta = datasets_path + 'trim_chr_genome.fa'
    my_gtf_df = pd.read_csv(datasets_path + "FinalIntronExonDF.csv")
    output_directory = datasets_path + "Final_Optimized_TFRecord_Shards"
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    
    # change these every time. The shifts is for data augmentation
    # naming convention to work with the shuffle function: list 4 digit shift values lowest to highest
    # ex: [0000, 2500, 4000] becomes /000025004000_inex_shard
    # It's more recommended to do a single shift at a time as done here
    # Mixing multiple shifted datasets can be done later on
    shifts = [3334]
    my_prefix = output_directory + '/3334_inex_shard'
    
    write_tfrecord_in_shards_hybrid(
        shard_prefix=my_prefix, 
        fasta_file=my_fasta, 
        gtf_df=my_gtf_df, 
        num_shards=4, 
        compression_type="GZIP", 
        max_processes=4, 
        max_threads_per_process=2, 
        chunk_size=5000, 
        skip_empty=True,
        shifts=shifts
    )

if __name__ == "__main__":
    main()

Clunky shuffle method in the following 3 code cells

In [None]:
def parse_chunk_example(serialized_example):
    """
    Parses a single serialized tf.train.Example back into tensors.
    Used in testing datasets and in piping tfrecords to DL Algorithms
    """
    feature_spec = {
        'X':          tf.io.VarLenFeature(tf.float32),
        'y':          tf.io.VarLenFeature(tf.float32),
        'record_id':  tf.io.FixedLenFeature([], tf.string),
        'cstart':     tf.io.FixedLenFeature([1], tf.int64),
        'cend':       tf.io.FixedLenFeature([1], tf.int64),
        'strand':     tf.io.FixedLenFeature([], tf.string),
        'chunk_size': tf.io.FixedLenFeature([1], tf.int64),
    }
    
    parsed = tf.io.parse_single_example(serialized_example, feature_spec)
    
    # chunk_size is shape [1]
    chunk_size = parsed['chunk_size'][0]
    
    # Convert sparse to dense
    X_flat = tf.sparse.to_dense(parsed['X'])
    y_flat = tf.sparse.to_dense(parsed['y'])

    # Reshape X to [chunk_size, 5]
    X_reshaped = tf.reshape(X_flat, [chunk_size, 5])
    # Reshape y to [chunk_size], probably redundant
    y_reshaped = tf.reshape(y_flat, [chunk_size, 5])
    
    record_id = parsed['record_id']
    cstart    = parsed['cstart'][0]
    cend      = parsed['cend'][0]
    strand    = parsed['strand']
    
    return X_reshaped, y_reshaped, record_id, cstart, cend, strand

In [None]:
def build_dataset_from_tfrecords(
    tfrecord_pattern,
    batch_size=28,
    compression_type='GZIP',
    shuffle_buffer=66000,
):
    '''
    Builds shuffled dataset from tfrecords.  Returns unparsed serialized
    dataset that is not human readable.  
    '''

    # Loads in records in a round robin fashion for slightly increased mixing
    files = tf.data.Dataset.list_files(tfrecord_pattern, shuffle=True)
    dataset = files.interleave(
        lambda fname: tf.data.TFRecordDataset(fname, compression_type=compression_type),
        cycle_length=4,        # how many files to read in parallel
        block_length=1,         # how many records to read from each file before switching
        num_parallel_calls=tf.data.AUTOTUNE
)
    
    # Shuffle at the record level
    dataset = dataset.shuffle(shuffle_buffer, reshuffle_each_iteration=True)

    # Shuffle at batch level
    dataset = dataset.batch(batch_size)
    dataset = dataset.shuffle(8*batch_size, reshuffle_each_iteration=True)
    dataset = dataset.unbatch()

    # Prefetch for efficient access
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

Ran this manually 10 times to get a good shuffle for the initial dataset.

In [None]:
options = tf.io.TFRecordOptions(compression_type="GZIP")
tfrecord_pattern = datasets_path + "Shuffling/Shuffle_9/shuffled_shard_*.tfrecord.gz"
ds = build_dataset_from_tfrecords(tfrecord_pattern,
                                  batch_size=32, compression_type='GZIP',
                                  shuffle_buffer=50000)

'''Commented out so I don't accidentally try to rewrite anything'''
# output_path = datasets_path + "Shuffling/Shuffle_10"
# if not os.path.exists(output_path):
#     os.makedirs(output_path)

# num_shards = 4
# writers = [
#     tf.io.TFRecordWriter(f"{output_path}/shuffled_shard_{i}.tfrecord.gz", options=options)
#     for i in range(num_shards)
# ]

# # Write out round-robin to each shard
# for i, serialized_example in enumerate(ds):
#     shard_index = i % num_shards
#     writers[shard_index].write(serialized_example.numpy())

# # Close all writers
# for w in writers:
#     w.close()

Better modular method: Split shards from parallel process into 24 tiny shards each, then write out a dataset that grabs one record from each tiny shard in a random order.

First, split shards into tiny shards

In [None]:
def split_tfrecords(input_directory, output_directory, num_splits=24):
    """
    Splits each TFRecord in the input_directory (assumed to be gzipped TFRecords)
    into 'num_splits' smaller TFRecords. The goal is to randomize the order of the
    smaller TFRecords when grabbing a row to produce highly shuffled data.
    
    Each output file is named based on the original filename's components.
    
    Parameters:
      input_directory (str): Relative or absolute path to the input TFRecords.
      output_directory (str): Relative or absolute path where the smaller shards will be saved.
      num_splits (int): Number of splits (shards) to create per input file.
    """
    
    current_directory = os.getcwd()
    input_path = os.path.join(current_directory, input_directory)
    output_path = os.path.join(current_directory, output_directory)
    os.makedirs(output_path, exist_ok=True)

    input_file_names = os.listdir(input_path)
    input_file_paths = [os.path.join(input_path, file) for file in input_file_names]

    for file in input_file_paths:
        # Example filename: "1000_inex_shard-0002.tfrecord.gz"
        basename = os.path.basename(file)
        # Split at "_inex_shard-"
        set_index, remainder = basename.split("_inex_shard-")
        # Get the first 4 digits from remainder (ignoring the extension)
        sub_index = remainder[:4]
        # Use the final digit of the sub-index.
        final_digit = sub_index[-1]
        
        # First pass: Count records without loading them all.
        total_records = 0
        for _ in tf.data.TFRecordDataset(file, compression_type="GZIP"):
            total_records += 1

        # Compute even splits: base chunk size and distribute any remainder.
        chunk_size = total_records // num_splits
        remainder_count = total_records % num_splits

        # Pre-calculate boundaries for each split.
        boundaries = []
        start = 0
        for i in range(num_splits):
            extra = 1 if i < remainder_count else 0
            end = start + chunk_size + extra
            boundaries.append(end)
            start = end

        # Open all the TFRecord writers.
        writers = []
        for i in range(num_splits):
            sub_sub_index = f"{i:02d}"
            new_filename = f"{set_index}_{final_digit}_{sub_sub_index}_tiny_inex_shard.tfrecord.gz"
            new_filepath = os.path.join(output_path, new_filename)
            options = tf.io.TFRecordOptions(compression_type="GZIP")
            writer = tf.io.TFRecordWriter(new_filepath, options=options)
            writers.append(writer)

        # Second pass: Write records to the appropriate shard in a streaming fashion.
        current_index = 0
        current_shard = 0
        for record in tf.data.TFRecordDataset(file, compression_type="GZIP"):
            if current_index >= boundaries[current_shard]:
                current_shard += 1
            writers[current_shard].write(record.numpy())
            current_index += 1

        # Close all writers.
        for writer in writers:
            writer.close()

        print(f"Processed {basename}: {total_records} records split into {num_splits} tiny shards.")

In [None]:
# input_directory = datasets_path + "Augmented TFRecords 3/"
# output_directory = datasets_path + "Shuffle Shards/"
# split_tfrecords(input_directory, output_directory, 24)

Now construct the dataset taking one sample from each tiny shard in a random order.  Generates a new random order each time after all tiny shards are visited.

In [None]:
def stream_shuffled_records(input_dir, allowed_indices):
    """
    Lazily iterates over TFRecord files in input_dir whose filenames start with one of the allowed_indices.
    Each round, it shuffles the list of file iterators and yields one record per file.
    Files that are exhausted are removed from future rounds.
    
    Args:
        input_dir (str): Directory containing the TFRecord files.
        allowed_indices (list): List of allowed starting indices (as strings or integers).
    
    Yields:
        A TFRecord (as a tf.Tensor) from one of the files.
    """
    # List file paths that start with one of the allowed indices.
    file_paths = [os.path.join(input_dir, fname)
                  for fname in os.listdir(input_dir)
                  if any(fname.startswith(str(idx)) for idx in allowed_indices)]
    
    if not file_paths:
        raise ValueError("No TFRecord files found matching allowed indices.")
    
    # Create a list of (file_path, iterator) tuples.
    file_iterators = [(fp, iter(tf.data.TFRecordDataset(fp, compression_type="GZIP")))
                      for fp in file_paths]
    
    # Continue until all iterators are exhausted.
    while file_iterators:
        random.shuffle(file_iterators)
        next_file_iterators = []
        for fp, iterator in file_iterators:
            try:
                record = next(iterator)
                yield record
                next_file_iterators.append((fp, iterator))
            except StopIteration:
                print(f"File {fp} is exhausted and will be skipped.")
        file_iterators = next_file_iterators

def write_shuffled_records_to_single_tfrecord(input_dir, allowed_indices, output_filepath):
    """
    Writes all records produced by stream_shuffled_records into one big gzip-compressed TFRecord file.
    
    Args:
        input_dir (str): Directory containing the source TFRecord files.
        allowed_indices (list): List of allowed starting indices.
        output_filepath (str): Full path to the output TFRecord file.
    """
    # Set up the TFRecord writer with gzip compression.
    options = tf.io.TFRecordOptions(compression_type="GZIP")
    writer = tf.io.TFRecordWriter(output_filepath, options=options)
    
    record_count = 0
    # Stream through the records.
    for record in stream_shuffled_records(input_dir, allowed_indices):
        writer.write(record.numpy())
        record_count += 1
        # Print a status update every 1000 records.
        if record_count % 1000 == 0:
            print(f"{record_count} records written...")
    
    writer.close()
    print(f"Finished writing {record_count} records to {output_filepath}")

In [None]:
if __name__ == "__main__":
    # Define the directory containing your shuffled tiny shards.
    input_directory = datasets_path + "Shuffle Shards"  # Adjust as needed.
    
    # Define the allowed starting indices (adjust to your needs).
    allowed_indices = ["0000", "2500"]  # Example indices.
    
    # Define the output filepath for the big combined TFRecord.
    output_filename = datasets_path + "AugDataSets/00002500_shuffled.tfrecord.gz"
    current_directory = os.getcwd()
    output_filepath = os.path.join(current_directory, output_filename)
    
    # Run the function to write the big TFRecord.
    write_shuffled_records_to_single_tfrecord(input_directory, allowed_indices, output_filepath)

The next two cells print tfrecord data to check quality

In [None]:
def test_dataset_from_tfrecords(
    tfrecord_pattern,
    batch_size=32,
    compression_type='GZIP',
    shuffle_buffer=75000
):
    '''
    Imports tfrecord and shuffles it then parses it and returns a
    human readable dataset.  
    Two goals: 
        1. To confirm tfrecord(s) is/are saved properly
        2. To view list of record_ids in the batch to see if dataset 
            is sufficiently shuffled.  Ideally, a good spread of chrN
            shows up.
    '''
    # Loads in records in a round robin fashion for slightly increased mixing
    files = tf.data.Dataset.list_files(tfrecord_pattern, shuffle=True)
    dataset = files.interleave(
        lambda fname: tf.data.TFRecordDataset(fname, compression_type=compression_type),
        cycle_length=4,        # how many files to read in parallel
        block_length=1,         # how many records to read from each file before switching
        num_parallel_calls=tf.data.AUTOTUNE
)
    
    # Shuffle at the record level
    dataset = dataset.shuffle(shuffle_buffer, reshuffle_each_iteration=True)

    # Shuffle at batch level
    dataset = dataset.batch(batch_size)
    dataset = dataset.shuffle(8*batch_size, reshuffle_each_iteration=True)

    # Unbatch for parsing and parse
    dataset = dataset.unbatch()    
    dataset = dataset.map(parse_chunk_example, num_parallel_calls=tf.data.AUTOTUNE)

    # Rebatch parsed and prefetch for efficient reading
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

In [None]:
tfrecord_pattern = datasets_path + "Shuffling/Shuffle_10/shuffled_shard_*.tfrecord.gz"

ds = test_dataset_from_tfrecords(tfrecord_pattern,
                                  batch_size=32, compression_type='GZIP',
                                  shuffle_buffer=50000)

for X_batch, y_batch, record_id_batch, cstart_batch, cend_batch, strand_batch in ds.take(1):
    print("X shape:", X_batch.shape)
    print("y shape:", y_batch.shape)
    print("record_id:", record_id_batch)
    print("cstart:", cstart_batch)
    print("cend:", cend_batch)
    print("strand:", strand_batch)
    # Might not work as written due to changes since writing this
    # for i in range(5000):
    #     print(f"Data: {X_batch[0][i]},   {y_batch[0][i]} :Label")
    # print(f"chr: {record_id_batch[0]}, cstart: {cstart_batch[0]}, cend: {cend_batch[0]}")

The next two code cells can be used to generate Test, Validate, Train splits

In [None]:
def tvt_split_tfrecords(original_pattern, train_path, val_path, test_path, train_frac=0.8, val_frac=0.10):
    """
    Splits TFRecord files into separate train, validation, and test sets *without parsing*.
    Reads raw serialized records and writes them into new TFRecord files.
    """
    options = tf.io.TFRecordOptions(compression_type="GZIP")
    # Create TFRecord writers
    train_writer = tf.io.TFRecordWriter(train_path, options=options)
    val_writer = tf.io.TFRecordWriter(val_path, options=options)
    test_writer = tf.io.TFRecordWriter(test_path, options=options)

    # List the original TFRecord files
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(original_pattern), compression_type='GZIP')
    
    num_records = 0
    for _ in dataset:
        num_records += 1
    print(f"Total records found: {num_records}")

    # Compute split sizes
    train_size = int(train_frac * num_records)
    val_size   = int(val_frac * num_records)
    test_size  = num_records - train_size - val_size  # Ensuring all records are accounted for

    print(f"Splitting into -> Train: {train_size}, Val: {val_size}, Test: {test_size}")

    # Iterate over records and write them to appropriate files
    train_count, val_count, test_count = 0, 0, 0
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(original_pattern), compression_type='GZIP')
    dataset = dataset.shuffle(25000, reshuffle_each_iteration=True)

    for i, raw_record in enumerate(dataset):
        if i < train_size:
            train_writer.write(raw_record.numpy())
            train_count += 1
        elif i < train_size + val_size:
            val_writer.write(raw_record.numpy())
            val_count += 1
        else:
            test_writer.write(raw_record.numpy())
            test_count += 1

    # Close writers
    train_writer.close()
    val_writer.close()
    test_writer.close()

    print(f"Final Split Counts -> Train: {train_count}, Val: {val_count}, Test: {test_count}")

Using the testvaltrain function

In [None]:
# directory = datasets_path + "AugDataSets/New/"
# paths = os.listdir(directory)
# for filename in paths:
#     pattern = directory + filename
#     tvt_split_tfrecords(
#         original_pattern=pattern,
#         train_path="TestValTrain/train_" + filename,
#         val_path="TestValTrain/val_" + filename,
#         test_path="TestValTrain/test_" + filename,
#     )

The custom not-quite-label-smoothing was hard written into the data.  This can remove that, generating fully binary versions of datasets passed to it.

In [None]:
def convert_labels_to_binary(x, y):
    """
    Converts y so that any value not exactly 1 becomes 0.
    Both x and y are expected to be tensors of shape (chunk_size, 5).
    """
    y_binary = tf.cast(tf.equal(y, 1.0), y.dtype)
    return x, y_binary

def _float_feature(value):
    """Returns a float_list from a list of floats."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
    """Returns an int64_list from a list of ints."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
    """Returns a bytes_list from a string (or byte string)."""
    if isinstance(value, str):
        value = value.encode('utf-8')
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_example_with_metadata_no_convert(x, y, record_id, cstart, cend, strand):
    """
    Serializes a single example into a tf.train.Example.
    Expects that y is already binary.
    x and y are tensors of shape (chunk_size, 5). All metadata is preserved.
    """
    # Flatten the X and y tensors into lists.
    x_flat = tf.reshape(x, [-1]).numpy().tolist()
    y_flat = tf.reshape(y, [-1]).numpy().tolist()
    
    # Determine chunk_size from the first dimension of x.
    chunk_size = int(x.shape[0])
    
    # Convert metadata to Python types.
    record_id_val = record_id.numpy() if isinstance(record_id, tf.Tensor) else record_id
    strand_val = strand.numpy() if isinstance(strand, tf.Tensor) else strand

    # cstart and cend are tensors of shape [1].
    cstart_val = cstart.numpy() if isinstance(cstart, tf.Tensor) else cstart
    cend_val   = cend.numpy()   if isinstance(cend, tf.Tensor)   else cend
    cstart_int = int(cstart_val[0]) if isinstance(cstart_val, (list, tuple, np.ndarray)) else int(cstart_val)
    cend_int   = int(cend_val[0])   if isinstance(cend_val, (list, tuple, np.ndarray))   else int(cend_val)
    
    feature = {
        'X': _float_feature(x_flat),
        'y': _float_feature(y_flat),
        'record_id': _bytes_feature(record_id_val),
        'cstart': _int64_feature([cstart_int]),
        'cend': _int64_feature([cend_int]),
        'strand': _bytes_feature(strand_val),
        'chunk_size': _int64_feature([chunk_size])
    }
    
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def parse_chunk_example(serialized_example):
    """
    Parses a single serialized tf.train.Example back into tensors.
    Assumes the TFRecord contains metadata fields:
      - 'X': VarLenFeature(tf.float32)
      - 'y': VarLenFeature(tf.float32)
      - 'record_id': FixedLenFeature([], tf.string)
      - 'cstart': FixedLenFeature([1], tf.int64)
      - 'cend': FixedLenFeature([1], tf.int64)
      - 'strand': FixedLenFeature([], tf.string)
      - 'chunk_size': FixedLenFeature([1], tf.int64)
    """
    feature_spec = {
        'X':          tf.io.VarLenFeature(tf.float32),
        'y':          tf.io.VarLenFeature(tf.float32),
        'record_id':  tf.io.FixedLenFeature([], tf.string),
        'cstart':     tf.io.FixedLenFeature([1], tf.int64),
        'cend':       tf.io.FixedLenFeature([1], tf.int64),
        'strand':     tf.io.FixedLenFeature([], tf.string),
        'chunk_size': tf.io.FixedLenFeature([1], tf.int64),
    }
    
    parsed = tf.io.parse_single_example(serialized_example, feature_spec)
    
    # Extract chunk_size (a scalar)
    chunk_size = parsed['chunk_size'][0]
    
    # Convert sparse tensors to dense and reshape.
    X_flat = tf.sparse.to_dense(parsed['X'])
    y_flat = tf.sparse.to_dense(parsed['y'])
    X_reshaped = tf.reshape(X_flat, [chunk_size, 5])
    y_reshaped = tf.reshape(y_flat, [chunk_size, 5])
    
    record_id = parsed['record_id']
    cstart = parsed['cstart'][0]
    cend = parsed['cend'][0]
    strand = parsed['strand']
    
    return X_reshaped, y_reshaped, record_id, cstart, cend, strand

def convert_and_write_tfrecord(input_tfrecord, output_tfrecord, compression_type="GZIP"):
    """
    Reads an existing TFRecord (with smoothed labels), converts the labels to binary,
    and writes out a new TFRecord file with the same metadata.
    
    Args:
      input_tfrecord: Path to the original TFRecord file.
      output_tfrecord: Path where the new TFRecord (with binary labels) will be saved.
      compression_type: Compression type used in the TFRecord (e.g., "GZIP").
    """
    # Create a dataset from the input TFRecord.
    dataset = tf.data.TFRecordDataset(
        input_tfrecord,
        compression_type=compression_type,
        num_parallel_reads=tf.data.AUTOTUNE
    )
    
    # Parse each example.
    dataset = dataset.map(parse_chunk_example, num_parallel_calls=tf.data.AUTOTUNE)
    
    # Convert labels to binary.
    def convert_sample(x, y, record_id, cstart, cend, strand):
        x, y_binary = convert_labels_to_binary(x, y)
        return x, y_binary, record_id, cstart, cend, strand
    
    dataset = dataset.map(convert_sample, num_parallel_calls=tf.data.AUTOTUNE)
    
    # Write out each converted sample to the new TFRecord file.
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    with tf.io.TFRecordWriter(output_tfrecord, options=options) as writer:
        for sample in dataset:
            # sample is a tuple: (X, y_binary, record_id, cstart, cend, strand)
            X, y_binary, record_id, cstart, cend, strand = sample
            serialized_example = serialize_example_with_metadata_no_convert(
                X, y_binary, record_id, cstart, cend, strand)
            writer.write(serialized_example)

In [None]:
# Commented out to prevent overwrites/rewrites
# Convert and write new TFRecord files for train, validation, and test splits.
# convert_and_write_tfrecord(datasets_path + "TestValTrain/train.tfrecord.gz", datasets_path + "TestValTrain/train_binary.tfrecord.gz")
# convert_and_write_tfrecord(datasets_path + "TestValTrain/val.tfrecord.gz", datasets_path + "TestValTrain/val_binary.tfrecord.gz")
# convert_and_write_tfrecord(datasets_path + "TestValTrain/test.tfrecord.gz", datasets_path + "TestValTrain/test_binary.tfrecord.gz")

All this and we haven't even started deep learning!