In [None]:
import time
import sys
import os
import glob
import math
import threading
import concurrent.futures as cf

import numpy as np
import pandas as pd
import tensorflow as tf
from keras import Input, Model, layers, metrics, losses, callbacks, optimizers, models, utils
from keras import backend as K
import gc
import keras_tuner as kt
from pyfaidx import Fasta

K.clear_session()
gc.collect()

datasets_path = "../../Datasets/"
models_path = "../../Models/"

This notebook was used to adjust the fasta to match my purposes.  The first cell below is the only current and useful one.

In [1]:
valid_chroms = set([f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"])

input_fasta = "chr_genome.fa"  # Removed chrM because there aren't introns on it and I didn't want to confuse the training data
output_fasta = "trim_chr_genome.fa"

keep = False
'''
Removes non-chromosome entries in the human genome data
If line is a record key (starswith(">")), it compares first 'word' in the record key to valid_chroms
If comparison is true, it keeps writing lines until it finds another record key line to make the comparison
'''
with open(input_fasta, "r") as fin, open(output_fasta, "w") as fout:
    for line in fin:
        if line.startswith(">"):
            chrom_name = line.strip()[1:].split()[0]
            keep = (chrom_name in valid_chroms)

            if keep:
                fout.write(line)
                print(chrom_name)
        else:
            if keep:
                fout.write(line)


chr1
chr2
chr3
chr4
chr5
chr6
chr7
chr8
chr9
chr10
chr11
chr12
chr13
chr14
chr15
chr16
chr17
chr18
chr19
chr20
chr21
chr22
chrX
chrY


In [11]:
valid_chroms = set(["chrX", "chrY", "chrM"])

input_fasta = "genome.fa"
output_fasta = "test_genome.fa"

keep = False
'''
Makes a smaller test set of chromosomes
If line is a record key (starswith(">")), it compares first 'word' in the record key to valid_chroms
If comparison is true, it keeps writing lines until it finds another record key line to make the comparison
'''
with open(input_fasta, "r") as fin, open(output_fasta, "w") as fout:
    for line in fin:
        if line.startswith(">"):
            chrom_name = line.strip()[1:].split()[0]
            keep = (chrom_name in valid_chroms)

            if keep:
                fout.write(line)
                print(chrom_name)
        else:
            if keep:
                fout.write(line)

chrX
chrY
chrM


In [2]:
valid_chroms = set([f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY", "chrM"])

input_gtf = "annotations.gtf"
output_gtf = "chr_annotations.gtf"

'''
Removes annotations for non-chromosome entries in the human genome data
If a line is a comment, it keeps it
Otherwise, it checks if the line is in record format (tab separated) and checks field 0 for chromosome name of that record
If record is a member of valid_chroms, it writes the line
'''
with open(input_gtf, "r") as f_in, open(output_gtf, "w") as f_out:
    for line in f_in:
        if line.startswith("#"):
            f_out.write(line)
            continue

        fields = line.strip().split("\t")
        if len(fields) > 0:
            chrom = fields[0]
            if chrom in valid_chroms:
                f_out.write(line)


In [12]:
valid_chroms = set(["chrX", "chrY", "chrM"])

input_gtf = "annotations.gtf"
output_gtf = "test_annotations.gtf"

'''
Makes a smaller test set of annotations
If a line is a comment, it keeps it
Otherwise, it checks if the line is in record format (tab separated) and checks field 0 for chromosome name of that record
If record is a member of valid_chroms, it writes the line
'''
with open(input_gtf, "r") as f_in, open(output_gtf, "w") as f_out:
    for line in f_in:
        if line.startswith("#"):
            f_out.write(line)
            continue

        fields = line.strip().split("\t")
        if len(fields) > 0:
            chrom = fields[0]
            if chrom in valid_chroms:
                f_out.write(line)


In [15]:
import pandas as pd

def load_gtf_annotations(gtf_file):
    """
    Loads GTF into a pandas DataFrame and converts start and end to zero-based indexing because python.
    """
    gtf_data = pd.read_csv(
        gtf_file, sep='\t', comment='#', header=None,
        names=['seqname', 'source', 'feature', 'start', 'end', 
               'score', 'strand', 'frame', 'attribute']
    )
    # Convert to zero-based indexing at start.  Note that end is not here because 1 indexing is end-inclusive and 0 indexing is end-exclusive
    gtf_data['start'] = gtf_data['start'] - 1
    return gtf_data

In [3]:
from pyfaidx import Fasta

def compute_chunk_indices(fasta_file, chunk_size):
    """
    Creates a list of (record_id, start, end) for each chunk in the FASTA.
    """
    fa = Fasta(fasta_file)  # for indexed random access
    chunk_indices = []
    for record_id in fa.keys():
        seq_len = len(fa[record_id])
        for start in range(0, seq_len, chunk_size):
            end = min(start + chunk_size, seq_len)
            chunk_indices.append((record_id, start, end))
    return chunk_indices

    


In [5]:
'''
Reads DNA to confirm there are no special characters that require masking besides N.  Skips header lines
'''

unique_chars = set()

with open("chr_genome.fa", "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith(">"):
            continue
        
        for char in line:
            unique_chars.add(char)


print("Unique characters in the FASTA file:")
for char in sorted(unique_chars):
    print(char)


Unique characters in the FASTA file:
A
C
G
N
T


Output to the above is this:

Unique characters in the FASTA file:

A

C

G

N

T

In [2]:
import pandas as pd
import os
os.environ['BEDTOOLS_PATH'] = '/usr/bin/bedtools'
import pybedtools

# Load your GTF into a DataFrame
gtf = pd.read_csv(
    "annotations.gtf", sep='\t', comment='#', header=None,
    names=['seqname','source','feature','start','end','score','strand','frame','attribute']
)

# If your GTF is truly 1-based, convert to 0-based
# (uncomment if needed)
gtf['start'] = gtf['start'] - 1

# Subset exons & introns
exons_df = gtf[gtf['feature'] == 'exon'].copy()
introns_df = gtf[gtf['feature'] == 'intron'].copy()

# pybedtools expects at least 3 columns: chrom, start, end
# If you want to keep strand, you can include that as well
exons_bed = pybedtools.BedTool.from_dataframe(
    exons_df[['seqname', 'start', 'end', 'feature', 'score', 'strand']]
)
introns_bed = pybedtools.BedTool.from_dataframe(
    introns_df[['seqname', 'start', 'end', 'feature', 'score', 'strand']]
)

# Intersect
overlaps = exons_bed.intersect(introns_bed, wa=True, wb=True)

# 'overlaps' is a BedTool object. You can convert to a DataFrame:
overlaps_df = overlaps.to_dataframe(
    names=[
        'exon_chrom','exon_start','exon_end','exon_feature','exon_score','exon_strand',
        'intron_chrom','intron_start','intron_end','intron_feature','intron_score','intron_strand'
    ]
)

print("Number of exon-intron overlaps:", len(overlaps_df))
print(overlaps_df.head())

Number of exon-intron overlaps: 0
Empty DataFrame
Columns: []
Index: []


MU273365.1	5969	8223	exon	.	+

MU273365.1	5969	8223	exon	.	+



In [3]:
import pandas as pd
import pybedtools

# Load GTF into a DataFrame
gtf = pd.read_csv(
    "annotations.gtf", sep='\t', comment='#', header=None,
    names=["seqname","source","feature","start","end","score","strand","frame","attribute"]
)

# Filter for genes only
genes_df = gtf[gtf["feature"] == "gene"].copy()

# If needed: convert to 0-based. (Uncomment if your GTF is standard 1-based.)
# genes_df["start"] = genes_df["start"] - 1

# Now we have a table of genes, each with [seqname, start, end, strand, etc.].
# We'll pick columns for a BED-like structure:
genes_df["score"] = 0  # If you don't have a score, just set 0

# columns = [chrom, start, end, name, score, strand]
genes_bedtool = pybedtools.BedTool.from_dataframe(
    genes_df[["seqname", "start", "end", "feature", "score", "strand"]]
)


Number of overlaps between exons: 0
Empty DataFrame
Columns: []
Index: []


MU273365.1	HAVANA	exon	5969	8223	.	+	.	"gene_id ""ENSG00000291343.1""; transcript_id ""ENST00000707199.1""; gene_type ""processed_pseudogene""; gene_name ""BEND3P2""; transcript_type ""processed_pseudogene""; transcript_name ""BEND3P2-202""; exon_number 1; exon_id ""ENSE00004001511.1""; level 2; hgnc_id ""HGNC:45015""; ont ""PGO:0000004""; tag ""basic""; tag ""Ensembl_canonical"";"

Error: Invalid record in file /tmp/pybedtools.ute8oiog.tmp. Record is 
KI270861.1	HAVANA	exon	0	5793	.	-	.	"gene_id ""ENSG00000278550.4""; transcript_id ""ENST00000634102.1""; gene_type ""protein_coding""; gene_name ""SLC43A2""; transcript_type ""protein_coding""; transcript_name ""SLC43A2-224""; exon_number 14; exon_id ""ENSE00003783572.1""; level 2; protein_id ""ENSP00000488355.1""; transcript_support_level ""1""; hgnc_id ""HGNC:23087""; tag ""basic""; havana_gene ""OTTHUMG00000191160.1""; havana_transcript ""OTTHUMT00000486890.1"";"


In [4]:
# import pandas as pd
# import pybedtools

# Load GTF into a DataFrame
gtf = pd.read_csv(
    "annotations.gtf", sep='\t', comment='#', header=None,
    names=["seqname","source","feature","start","end","score","strand","frame","attribute"]
)

# Filter for genes only
genes_df = gtf[gtf["feature"] == "gene"].copy()

# If needed: convert to 0-based. (Uncomment if your GTF is standard 1-based.)
genes_df["start"] = genes_df["start"] - 1

# Now we have a table of genes, each with [seqname, start, end, strand, etc.].
# We'll pick columns for a BED-like structure:
genes_df["score"] = 0  # If you don't have a score, just set 0

# columns = [chrom, start, end, name, score, strand]
genes_bedtool = pybedtools.BedTool.from_dataframe(
    genes_df[["seqname", "start", "end", "feature", "score", "strand"]]
)

overlaps = genes_bedtool.intersect(genes_bedtool, wa=True, wb=True)

overlaps_df = overlaps.to_dataframe(
    names=[
        "chromA","startA","endA","nameA","scoreA","strandA",
        "chromB","startB","endB","nameB","scoreB","strandB"
    ]
)
print(len(overlaps_df))


148583


MU273365.1	5969	8223	gene	0	+

MU273365.1	5969	8223	gene	0	+



In [7]:
overlaps_df.head()
filtered_overlap = overlaps_df[
    ~(
        (overlaps_df["startA"] == overlaps_df["startB"]) &
        (overlaps_df["endA"]   == overlaps_df["endB"])
    )
]
print(len(filtered_overlap))

77926


In [8]:
filtered_overlap.head()

Unnamed: 0,chromA,startA,endA,nameA,scoreA,strandA,chromB,startB,endB,nameB,scoreB,strandB
1,chr1,11868,14409,gene,0,+,chr1,12009,13670,gene,0,+
2,chr1,12009,13670,gene,0,+,chr1,11868,14409,gene,0,+
4,chr1,14695,24886,gene,0,-,chr1,17368,17436,gene,0,-
7,chr1,17368,17436,gene,0,-,chr1,14695,24886,gene,0,-
9,chr1,29553,31109,gene,0,+,chr1,30365,30503,gene,0,+


In [9]:
def canonical_pair(row):
    # Build tuples for A and B
    A = (row["chromA"], row["startA"], row["endA"], row["strandA"])
    B = (row["chromB"], row["startB"], row["endB"], row["strandB"])
    # Sort them so the 'lesser' always comes first
    # (Python compares tuples lexicographically)
    pair = tuple(sorted([A, B]))
    return pair

# Apply and store the result in a new column
filtered_overlap["pair_key"] = filtered_overlap.apply(canonical_pair, axis=1)

# Drop duplicates so that each pair_key appears only once
filtered_overlap = filtered_overlap.drop_duplicates(subset="pair_key")

# Optionally remove the pair_key column afterwards
# filtered_overlap = filtered_overlap.drop(columns=["pair_key"])
print(len(filtered_overlap))
filtered_overlap.head()

38958


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_overlap["pair_key"] = filtered_overlap.apply(canonical_pair, axis=1)


Unnamed: 0,chromA,startA,endA,nameA,scoreA,strandA,chromB,startB,endB,nameB,scoreB,strandB,pair_key
1,chr1,11868,14409,gene,0,+,chr1,12009,13670,gene,0,+,"((chr1, 11868, 14409, +), (chr1, 12009, 13670,..."
4,chr1,14695,24886,gene,0,-,chr1,17368,17436,gene,0,-,"((chr1, 14695, 24886, -), (chr1, 17368, 17436,..."
9,chr1,29553,31109,gene,0,+,chr1,30365,30503,gene,0,+,"((chr1, 29553, 31109, +), (chr1, 30365, 30503,..."
15,chr1,57597,64116,gene,0,+,chr1,62948,63887,gene,0,+,"((chr1, 57597, 64116, +), (chr1, 62948, 63887,..."
19,chr1,89294,133566,gene,0,-,chr1,89550,91105,gene,0,-,"((chr1, 89294, 133566, -), (chr1, 89550, 91105..."


In [10]:
filtered_overlap = filtered_overlap.drop(columns=["pair_key"])
print(len(filtered_overlap))
filtered_overlap.head()

38958


Unnamed: 0,chromA,startA,endA,nameA,scoreA,strandA,chromB,startB,endB,nameB,scoreB,strandB
1,chr1,11868,14409,gene,0,+,chr1,12009,13670,gene,0,+
4,chr1,14695,24886,gene,0,-,chr1,17368,17436,gene,0,-
9,chr1,29553,31109,gene,0,+,chr1,30365,30503,gene,0,+
15,chr1,57597,64116,gene,0,+,chr1,62948,63887,gene,0,+
19,chr1,89294,133566,gene,0,-,chr1,89550,91105,gene,0,-


In [8]:
import os
import pybedtools

print("PATH =", os.environ["PATH"])
!which bedtools
!bedtools --version

PATH = /home/virtuousrogue/Deep Learning Projects/venv/bin:/home/virtuousrogue/.vscode-server/bin/91fbdddc47bc9c09064bf7acf133d22631cbf083/bin/remote-cli:/home/virtuousrogue/.local/bin:/home/virtuousrogue/.local/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/usr/lib/wsl/lib:/mnt/c/Program Files/Common Files/Oracle/Java/javapath:/mnt/c/Program Files/Python310/Scripts/:/mnt/c/Program Files/Python310/:/mnt/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.6/bin:/mnt/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.6/libnvvp:/mnt/c/Program Files/Python312/Scripts/:/mnt/c/Program Files/Python312/:/mnt/c/WINDOWS/system32:/mnt/c/WINDOWS:/mnt/c/WINDOWS/System32/Wbem:/mnt/c/WINDOWS/System32/WindowsPowerShell/v1.0/:/mnt/c/WINDOWS/System32/OpenSSH/:/mnt/c/Program Files/Microsoft VS Code/bin:/mnt/c/Program Files (x86)/NVIDIA Corporation/PhysX/Common:/mnt/c/Program Files/dotnet/:/mnt/c/Program Files/Microsoft SQL Server/150

In [17]:
import numpy as np

'''
Uses a dictionary to encode bases. N's are masked as all zeroes
'''
n_base_encoder = {
    'A': [1, 0, 0, 0],
    'C': [0, 1, 0, 0],
    'G': [0, 0, 1, 0],
    'T': [0, 0, 0, 1],
    'N': [0, 0, 0, 0]  
}


def one_hot_encode_with_N(sequence):
    return [n_base_encoder.get(nuc, [0, 0, 0, 0]) for nuc in sequence]

def pad_one_hot_sequence(encoded_seq, target_length):
    """
    Pads sequence chunks so all are the same size
    """
    seq_len = len(encoded_seq)
    padding_length = target_length - seq_len
    if padding_length > 0:
        encoded_seq += [[0, 0, 0, 0]] * padding_length
    return encoded_seq

def label_sequence(sequence_length, annotations):
    """
    Assigns categorical numeric labels to save memory:
      0 = non-coding
      1 = intron start
      2 = intron end
      3 = exon start
      4 = exon end
    """
    labels = [0] * sequence_length  # All non-coding by default
    for _, row in annotations.iterrows():
        start, end, feature = int(row['start']), int(row['end']), row['feature']
        # Clip to chunk boundaries
        start = max(0, start)
        end   = min(sequence_length, end)

        if feature == 'exon':
            if start < sequence_length:
                labels[start] = 3  # Exon start
            if end - 1 < sequence_length and end - 1 >= 0:
                labels[end - 1] = 4  # Exon end
            
        elif feature == 'intron':
            if start < sequence_length:
                labels[start] = 1  # Intron start
            if end - 1 < sequence_length and end - 1 >= 0:
                labels[end - 1] = 2  # Intron end

    return labels

def pad_one_hot_sequence(encoded_seq, target_len):
    pad_size = target_len - len(encoded_seq)
    if pad_size > 0:
        encoded_seq += [[0, 0, 0, 0]] * pad_size
    return encoded_seq

def pad_labels(labels, target_length):
    """
    Pads label array up to target_length with 0 (non-coding).
    """
    padding_length = target_length - len(labels)
    if padding_length > 0:
        labels += [0] * padding_length
    return labels


In [18]:
import tensorflow as tf


def serialize_example(X, y):
    """
    Convert (X, y) to a tf.train.Example for TFRecord.
    X: shape (chunk_size, 4), float32
    y: shape (chunk_size,), int
    """
    # Flatten X to store as bytes or as a list of floats
    X_flat = X.reshape(-1).tolist()
    y_flat = y.tolist()

    feature = {
        'X': tf.train.Feature(float_list=tf.train.FloatList(value=X_flat)),
        'y': tf.train.Feature(int64_list=tf.train.Int64List(value=y_flat)),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def _write_shard(
    shard_id,
    shard_indices,
    fasta_file,
    gtf_data,
    chunk_size,
    out_dir,
    compression="GZIP"
):
    """
    Worker function that writes TFRecord for the assigned shard_indices.
    Each shard goes into its own file: e.g. out_dir/shard_{shard_id:04d}.tfrecord.gz
    """
    # We'll open the FASTA in each process so we can do random-access.
    fa = Fasta(fasta_file)
    
    # Create TFRecord writer with compression
    shard_filename = os.path.join(out_dir, f"shard_{shard_id:04d}.tfrecord")
    options = tf.io.TFRecordOptions(compression_type=compression)
    with tf.io.TFRecordWriter(shard_filename, options=options) as writer:
        for (record_id, start, end) in shard_indices:
            seq = fa[record_id][start:end].seq
            
            # Filter relevant annotations
            sub_anno = gtf_data[
                (gtf_data['seqname'] == record_id) &
                (gtf_data['start'] < end) &
                (gtf_data['end'] > start)
            ].copy()
            # Shift coordinates relative to chunk
            sub_anno['start'] -= start
            sub_anno['end']   -= start

            # One-hot encode
            encoded = one_hot_encode_with_N(seq)
            labels  = label_sequence(len(seq), sub_anno)

            # Pad to chunk_size if needed
            encoded = pad_one_hot_sequence(encoded, chunk_size)
            labels  = pad_labels(labels, chunk_size)

            X = np.array(encoded, dtype=np.float32)
            y = np.array(labels,  dtype=np.int32)

            example = serialize_example(X, y)
            writer.write(example)

    return shard_filename


def parallel_write_tfrecords(
    fasta_file,
    gtf_file,
    chunk_size,
    out_dir,
    num_shards=8,
    compression="GZIP"
):
    """
    Splits chunk_indices into N shards and processes them in parallel.
    Each shard is written to out_dir/shard_xxxx.tfrecord (compressed).
    """
    os.makedirs(out_dir, exist_ok=True)

    print("Loading GTF...")
    gtf_data = load_gtf_annotations(gtf_file)

    print("Computing chunk indices...")
    chunk_indices = compute_chunk_indices(fasta_file, chunk_size)
    
    # Shuffle if you want random distribution among shards
    np.random.shuffle(chunk_indices)

    # Split chunk_indices into num_shards subsets
    shard_size = math.ceil(len(chunk_indices) / num_shards)
    shard_splits = [
        chunk_indices[i * shard_size : (i + 1) * shard_size]
        for i in range(num_shards)
    ]

    # Launch parallel processes
    futures = []
    print(f"Writing {len(chunk_indices)} chunks into {num_shards} shards...")
    with concurrent.futures.ProcessPoolExecutor(max_workers=num_shards) as executor:
        for shard_id, shard_indices in enumerate(shard_splits):
            futures.append(
                executor.submit(
                    _write_shard,
                    shard_id,
                    shard_indices,
                    fasta_file,
                    gtf_data,      # pass the entire DataFrame to each process
                    chunk_size,
                    out_dir,
                    compression
                )
            )

    results = []
    for f in concurrent.futures.as_completed(futures):
        results.append(f.result())

    print("All shards written:")
    for r in results:
        print("  ", r)




In [19]:
fasta_file = "genome.fa"
gtf_file   = "annotations.gtf"
chunk_size = 5000


# We'll produce 8 shards in parallel using GZIP compression
output_dir = "tfrecord_shards"
parallel_write_tfrecords(
    fasta_file=fasta_file,
    gtf_file=gtf_file,
    chunk_size=chunk_size,
    out_dir=output_dir,
    num_shards=8,          # or however many parallel processes / shards you want
    compression="GZIP"     # or "ZLIB" or None for no compression
    )

print("Done writing TFRecord!")

Loading GTF...
Computing chunk indices...
Writing 658691 chunks into 8 shards...


KeyboardInterrupt: 

In [None]:
def parse_tfrecord(example_proto, chunk_size=5000):
    # Define the features
    feature_spec = {
        'X': tf.io.FixedLenSequenceFeature([], dtype=tf.float32, allow_missing=True),
        'y': tf.io.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=True)
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_spec)
    # 'X' is shape [chunk_size*4], 'y' is shape [chunk_size]
    X = parsed_features['X']
    y = parsed_features['y']

    # Reshape X back to [chunk_size, 4]
    X = tf.reshape(X, [chunk_size, 4])
    y = tf.reshape(y, [chunk_size])
    
    return X, tf.cast(y, tf.int32)

def build_dataset_from_tfrecords(file_pattern, chunk_size=5000, batch_size=8):
    files = tf.data.Dataset.list_files(file_pattern)
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=tf.data.AUTOTUNE)
    ds = ds.map(lambda x: parse_tfrecord(x, chunk_size), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

# Example usage:
write_tfrecords("genome.fa", "annotations.gtf", 5000, "my_dataset.tfrecord")
dataset = build_dataset_from_tfrecords("my_dataset.tfrecord", chunk_size=5000, batch_size=8)

# Now this dataset is purely Tensor-based and can be saved or re-loaded if needed.


In [None]:
def parse_tfrecord(example_proto, chunk_size=5000):
    """
    Parses a single TFRecord example into (X, y) Tensors.
    """
    feature_spec = {
        'X': tf.io.FixedLenSequenceFeature([], dtype=tf.float32, allow_missing=True),
        'y': tf.io.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=True),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_spec)
    
    # 'X' is a 1D float array of length (chunk_size * 4)
    X = parsed_features['X']
    y = parsed_features['y']
    
    # Reshape X to [chunk_size, 4]
    X = tf.reshape(X, [chunk_size, 4])
    y = tf.reshape(y, [chunk_size])
    
    # Cast y to int32 if needed
    y = tf.cast(y, tf.int32)
    return X, y

def build_dataset_from_tfrecords(file_pattern, chunk_size=5000, batch_size=8):
    """
    Reads TFRecords, parses them, and returns a tf.data.Dataset of (X, y).
    """
    # file_pattern can be a single file or a wildcard to multiple files
    dataset = tf.data.Dataset.list_files(file_pattern)
    dataset = dataset.interleave(
        lambda fp: tf.data.TFRecordDataset(fp, compression_type=None),
        cycle_length=4,       # how many files to read in parallel
        num_parallel_calls=tf.data.AUTOTUNE
    )
    # Parse each record
    dataset = dataset.map(
        lambda x: parse_tfrecord(x, chunk_size),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    # Shuffle (buffer size is a hyperparameter) + batch + prefetch
    dataset = dataset.shuffle(1024).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

# Usage:
# after writing TFRecords (my_dataset.tfrecord),
# read them back in a TF data pipeline:
ds = build_dataset_from_tfrecords("my_dataset.tfrecord", 
                                  chunk_size=5000,
                                  batch_size=8)
for X_batch, y_batch in ds.take(1):
    print(X_batch.shape, y_batch.shape)  # (8, 5000, 4), (8, 5000)


Hybrid CNN LSTM Serial Model

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, LSTM, Bidirectional, Dense, Dropout, Flatten
from tensorflow.keras.models import Model

# Define input shape
input_shape = (sequence_length, 4)  # One-hot encoded sequence

# Input layer
inputs = Input(shape=input_shape)

# 1D-CNN block
cnn = Conv1D(filters=64, kernel_size=5, activation='relu', padding='same')(inputs)
cnn = MaxPooling1D(pool_size=2)(cnn)

cnn = Conv1D(filters=128, kernel_size=5, activation='relu', padding='same')(cnn)
cnn = MaxPooling1D(pool_size=2)(cnn)

# LSTM block
lstm = Bidirectional(LSTM(128, return_sequences=True))(cnn)
lstm = Bidirectional(LSTM(128))(lstm)

# Fully connected layers
dense = Dense(128, activation='relu')(lstm)
dense = Dropout(0.5)(dense)

# Output layer
output = Dense(num_classes, activation='softmax')(dense)  # For multi-class classification

# Create model
model = Model(inputs=inputs, outputs=output)

# Compile model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['f1_score'])  # might try adagrad if adam is bad

# Model summary
model.summary()


CNN only with similar calculation number to hybrid 

In [None]:
# Adjusted 1D-CNN Model
inputs = Input(shape=(sequence_length, 4))

# Add more filters and layers to match the hybrid's complexity
cnn = Conv1D(filters=96, kernel_size=5, activation='relu', padding='same')(inputs)
cnn = MaxPooling1D(pool_size=2)(cnn)

cnn = Conv1D(filters=192, kernel_size=5, activation='relu', padding='same')(cnn)
cnn = MaxPooling1D(pool_size=2)(cnn)

cnn = Conv1D(filters=128, kernel_size=5, activation='relu', padding='same')(cnn)

# Fully connected layers
flatten = Flatten()(cnn)
dense = Dense(128, activation='relu')(flatten)
dense = Dropout(0.5)(dense)

# Output layer
output = Dense(num_classes, activation='softmax')(dense)

cnn_model = Model(inputs=inputs, outputs=output)

# Compile and summarize
cnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['f1_score'])  # might try adagrad if adam is bad
cnn_model.summary()


LSTM only with similar calculation number to hybrid

In [None]:
# Adjusted LSTM Model
inputs = Input(shape=(sequence_length, 4))

# Increase LSTM hidden size to match hybrid complexity
lstm = Bidirectional(LSTM(192, return_sequences=True))(inputs)  # Larger hidden size
lstm = Bidirectional(LSTM(192))(lstm)

# Fully connected layers
dense = Dense(128, activation='relu')(lstm)
dense = Dropout(0.5)(dense)

# Output layer
output = Dense(num_classes, activation='softmax')(dense)

lstm_model = Model(inputs=inputs, outputs=output)

# Compile and summarize
lstm_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['f1_score'])  # might try adagrad if adam is bad
lstm_model.summary()


Hybrid parallel model

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, LSTM, Bidirectional, Dense, Dropout, Flatten, Concatenate
from tensorflow.keras.models import Model

# Define input shape
input_shape = (sequence_length, 4)  # One-hot encoded sequence

# Input layer
inputs = Input(shape=input_shape)

# CNN Branch
cnn = Conv1D(filters=64, kernel_size=5, activation='relu', padding='same')(inputs)
cnn = MaxPooling1D(pool_size=2)(cnn)
cnn = Conv1D(filters=128, kernel_size=5, activation='relu', padding='same')(cnn)
cnn = MaxPooling1D(pool_size=2)(cnn)
cnn = Flatten()(cnn)  # Flatten for concatenation

# LSTM Branch
lstm = Bidirectional(LSTM(128, return_sequences=True))(inputs)
lstm = Bidirectional(LSTM(128))(lstm)  # Return a single vector

# Concatenate the branches
combined = Concatenate()([cnn, lstm])

# Fully connected layers
dense = Dense(128, activation='relu')(combined)
dense = Dropout(0.5)(dense)

# Output layer
output = Dense(num_classes, activation='softmax')(dense)  # For multi-class classification

# Create model
parallel_hybrid_model = Model(inputs=inputs, outputs=output)

# Compile model
parallel_hybrid_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['f1_score'])  # might try adagrad if adam is bad

# Model summary
parallel_hybrid_model.summary()


In [None]:
# Training parameters
max_time_seconds = 3600  # 1 hour
batch_size = 32
epochs = 100  # Set high enough to allow stopping by time

time_limit_callback = TimeLimit(max_time_seconds=max_time_seconds)

# Train 1D-CNN Model
history_cnn = cnn_model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    batch_size=batch_size,
    epochs=epochs,
    callbacks=[time_limit_callback]
)

# Train LSTM Model
history_lstm = lstm_model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    batch_size=batch_size,
    epochs=epochs,
    callbacks=[time_limit_callback]
)

# Train Hybrid Model
history_hybrid = hybrid_model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    batch_size=batch_size,
    epochs=epochs,
    callbacks=[time_limit_callback]
)


In [None]:
test_loss_cnn, test_acc_cnn = cnn_model.evaluate(x_test, y_test)
test_loss_lstm, test_acc_lstm = lstm_model.evaluate(x_test, y_test)
test_loss_hybrid, test_acc_hybrid = hybrid_model.evaluate(x_test, y_test)


In [5]:
def write_tfrecords_for_test(fasta_file, gtf_file, chunk_size, output_path):
    """Writes chunks in the EXACT order they appear in chunk_indices (no shuffle)."""
    import math
    import os
    import numpy as np
    import tensorflow as tf
    from pyfaidx import Fasta
    
    gtf_data = load_gtf_annotations(gtf_file)
    chunk_indices = compute_chunk_indices(fasta_file, chunk_size)  # <--- no shuffle
    fa = Fasta(fasta_file)
    
    # For test, let's do no compression to keep it simple (you can add if you want)
    with tf.io.TFRecordWriter(output_path) as writer:
        for (record_id, start, end) in chunk_indices:
            seq = fa[record_id][start:end].seq
            
            # Filter relevant annotations
            sub_anno = gtf_data[
                (gtf_data['seqname'] == record_id) &
                (gtf_data['start'] < end) &
                (gtf_data['end'] > start)
            ].copy()
            sub_anno['start'] -= start
            sub_anno['end']   -= start

            # One-hot + label
            encoded = one_hot_encode_with_N(seq)
            labels  = label_sequence(len(seq), sub_anno)
            
            encoded = pad_one_hot_sequence(encoded, chunk_size)
            labels  = pad_labels(labels, chunk_size)

            X = np.array(encoded, dtype=np.float32)
            y = np.array(labels, dtype=np.int32)

            example = serialize_example(X, y)
            writer.write(example)

    print("Wrote TFRecord for test in same order as chunk_indices:", output_path)


In [6]:
fasta_file = "genome.fa"
gtf_file   = "annotations.gtf"
chunk_size = 5000
test_tfrecord_path = "test_ordered.tfrecord"

write_tfrecords_for_test(fasta_file, gtf_file, chunk_size, test_tfrecord_path)


KeyboardInterrupt: 

In [7]:
chunk_indices = compute_chunk_indices(fasta_file, chunk_size)  # same function
# The "30th" chunk in zero-based indexing is chunk_indices[29],
# or if you literally mean the 30th chunk in one-based counting, chunk_indices[30 - 1].
# Let's assume zero-based for programming clarity:
chunk_idx = 29  # zero-based
record_id_30, start_30, end_30 = chunk_indices[chunk_idx]

print("Chunk #30 details (zero-based index = 29):")
print("  record_id:", record_id_30)
print("  start:", start_30)
print("  end:  ", end_30)


Chunk #30 details (zero-based index = 29):
  record_id: chr1
  start: 145000
  end:   150000


In [8]:
def parse_tfrecord(example_proto, chunk_size=5000):
    feature_spec = {
        'X': tf.io.FixedLenSequenceFeature([], dtype=tf.float32, allow_missing=True),
        'y': tf.io.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=True),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_spec)
    X = parsed_features['X']
    y = parsed_features['y']
    
    X = tf.reshape(X, [chunk_size, 4])
    y = tf.reshape(y, [chunk_size])
    y = tf.cast(y, tf.int32)
    return X, y


dataset = tf.data.TFRecordDataset(test_tfrecord_path)  # no compression
dataset = dataset.map(lambda x: parse_tfrecord(x, chunk_size=chunk_size))

# Skip 29 examples, then take 1
dataset_30th = dataset.skip(chunk_idx).take(1)

# Pull it into Python
for X_30th, y_30th in dataset_30th:
    X_30th_np = X_30th.numpy()
    y_30th_np = y_30th.numpy()

print("X_30th_np shape:", X_30th_np.shape)
print("y_30th_np shape:", y_30th_np.shape)


I0000 00:00:1737078009.848599     745 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1737078010.001964     745 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1737078010.002015     745 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1737078010.005459     745 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1737078010.005509     745 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:0

X_30th_np shape: (5000, 4)
y_30th_np shape: (5000,)


2025-01-16 18:40:10.809666: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [9]:
from pyfaidx import Fasta

fasta = Fasta(fasta_file)
seq_30 = fasta[record_id_30][start_30:end_30].seq

gtf_data = load_gtf_annotations(gtf_file)
sub_anno_30 = gtf_data[
    (gtf_data['seqname'] == record_id_30) &
    (gtf_data['start'] < end_30) &
    (gtf_data['end'] > start_30)
].copy()

sub_anno_30['start'] -= start_30
sub_anno_30['end']   -= start_30

encoded_30 = one_hot_encode_with_N(seq_30)
labels_30  = label_sequence(len(seq_30), sub_anno_30)

encoded_30 = pad_one_hot_sequence(encoded_30, chunk_size)
labels_30  = pad_labels(labels_30, chunk_size)

X_30th_manual = np.array(encoded_30, dtype=np.float32)
y_30th_manual = np.array(labels_30,  dtype=np.int32)

print("Manual shapes:", X_30th_manual.shape, y_30th_manual.shape)


Manual shapes: (5000, 4) (5000,)


In [10]:
import numpy as np

same_X = np.array_equal(X_30th_np, X_30th_manual)
same_y = np.array_equal(y_30th_np, y_30th_manual)

print(f"X array match? {same_X}")
print(f"y array match? {same_y}")

# If you want a more flexible comparison (allow floating rounding errors):
# np.allclose(X_30th_np, X_30th_manual, atol=1e-7)


X array match? True
y array match? True


In [11]:
import os
import math
import numpy as np
import pandas as pd
import tensorflow as tf
import concurrent.futures
from pyfaidx import Fasta

def load_gtf_annotations(gtf_file):
    gtf_data = pd.read_csv(
        gtf_file, sep='\t', comment='#', header=None,
        names=['seqname', 'source', 'feature', 'start', 'end', 
               'score', 'strand', 'frame', 'attribute']
    )
    gtf_data['start'] = gtf_data['start'] - 1  # zero-based
    return gtf_data

def compute_chunk_indices(fasta_file, chunk_size):
    fa = Fasta(fasta_file)
    chunk_indices = []
    idx = 0
    for record_id in fa.keys():
        seq_len = len(fa[record_id])
        for start in range(0, seq_len, chunk_size):
            end = min(start + chunk_size, seq_len)
            chunk_indices.append((idx, record_id, start, end))
            idx += 1
    return chunk_indices

NUC_ENCODING = {
    'A': [1, 0, 0, 0],
    'C': [0, 1, 0, 0],
    'G': [0, 0, 1, 0],
    'T': [0, 0, 0, 1],
    'N': [0, 0, 0, 0]
}

def one_hot_encode_with_N(sequence):
    return [NUC_ENCODING.get(n, [0,0,0,0]) for n in sequence]

def label_sequence(seq_len, annotations):
    labels = [0]*seq_len
    for _, row in annotations.iterrows():
        start = max(0, int(row['start']))
        end   = min(seq_len, int(row['end']))
        feat  = row['feature']
        if feat == 'exon':
            if start < seq_len:
                labels[start] = 3
            if (end - 1) < seq_len and (end-1) >= 0:
                labels[end-1] = 4
            for i in range(start+1, end-1):
                if 0 <= i < seq_len:
                    labels[i] = 6
        elif feat == 'intron':
            if start < seq_len:
                labels[start] = 1
            if (end - 1) < seq_len and (end-1) >= 0:
                labels[end-1] = 2
            for i in range(start+1, end-1):
                if 0 <= i < seq_len:
                    labels[i] = 5
    return labels

def pad_one_hot_sequence(encoded_seq, target_len):
    pad_size = target_len - len(encoded_seq)
    if pad_size > 0:
        encoded_seq += [[0,0,0,0]]*pad_size
    return encoded_seq

def pad_labels(labels, target_len):
    pad_size = target_len - len(labels)
    if pad_size > 0:
        labels += [0]*pad_size
    return labels

def serialize_example(X, y, chunk_idx, record_id, start, end):
    """
    Store:
      - X, y
      - chunk_idx (int)
      - record_id (string)
      - start, end (int)
    """
    # Flatten X for storage
    X_flat = X.reshape(-1).tolist()
    y_flat = y.tolist()

    # Build feature dict
    feature = {
        'X': tf.train.Feature(float_list=tf.train.FloatList(value=X_flat)),
        'y': tf.train.Feature(int64_list=tf.train.Int64List(value=y_flat)),
        'chunk_idx': tf.train.Feature(int64_list=tf.train.Int64List(value=[chunk_idx])),
        'record_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[record_id.encode()])),
        'start': tf.train.Feature(int64_list=tf.train.Int64List(value=[start])),
        'end':   tf.train.Feature(int64_list=tf.train.Int64List(value=[end])),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def _write_shard(
    shard_id,
    shard_indices,
    fasta_file,
    gtf_data,
    chunk_size,
    out_dir,
    compression="GZIP"
):
    fa = Fasta(fasta_file)
    shard_fname = os.path.join(out_dir, f"shard_{shard_id:04d}.tfrecord")
    options = tf.io.TFRecordOptions(compression_type=compression)
    with tf.io.TFRecordWriter(shard_fname, options=options) as writer:
        for (chunk_idx, record_id, start, end) in shard_indices:
            seq = fa[record_id][start:end].seq
            sub_anno = gtf_data[
                (gtf_data['seqname'] == record_id) &
                (gtf_data['start'] < end) &
                (gtf_data['end'] > start)
            ].copy()
            sub_anno['start'] -= start
            sub_anno['end']   -= start

            encoded = one_hot_encode_with_N(seq)
            labels  = label_sequence(len(seq), sub_anno)

            # pad
            encoded = pad_one_hot_sequence(encoded, chunk_size)
            labels  = pad_labels(labels, chunk_size)

            X = np.array(encoded, dtype=np.float32)
            y = np.array(labels,  dtype=np.int32)

            example = serialize_example(X, y, chunk_idx, record_id, start, end)
            writer.write(example)

    return shard_fname


def parallel_write_tfrecords(
    fasta_file,
    gtf_file,
    chunk_size,
    out_dir,
    num_shards=4,
    compression="GZIP"
):
    os.makedirs(out_dir, exist_ok=True)
    gtf_data = load_gtf_annotations(gtf_file)
    chunk_indices = compute_chunk_indices(fasta_file, chunk_size)
    np.random.shuffle(chunk_indices)  # optional

    shard_size = math.ceil(len(chunk_indices) / num_shards)
    shard_splits = [
        chunk_indices[i * shard_size : (i+1)*shard_size]
        for i in range(num_shards)
    ]

    futures = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=num_shards) as executor:
        for shard_id, shard_idxs in enumerate(shard_splits):
            futures.append(
                executor.submit(
                    _write_shard,
                    shard_id,
                    shard_idxs,
                    fasta_file,
                    gtf_data,
                    chunk_size,
                    out_dir,
                    compression
                )
            )
    
    results = []
    for f in concurrent.futures.as_completed(futures):
        results.append(f.result())

    print("Shards written:")
    for r in results:
        print(r)


In [12]:
def parse_tfrecord(example_proto, chunk_size=5000):
    feature_spec = {
        'X':         tf.io.FixedLenSequenceFeature([], dtype=tf.float32, allow_missing=True),
        'y':         tf.io.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=True),
        'chunk_idx': tf.io.FixedLenFeature([], tf.int64),
        'record_id': tf.io.FixedLenFeature([], tf.string),
        'start':     tf.io.FixedLenFeature([], tf.int64),
        'end':       tf.io.FixedLenFeature([], tf.int64),
    }
    parsed = tf.io.parse_single_example(example_proto, feature_spec)
    
    X = parsed['X']
    y = parsed['y']
    chunk_idx = parsed['chunk_idx']
    record_id = parsed['record_id']
    start     = parsed['start']
    end       = parsed['end']

    # Reshape X, y
    X = tf.reshape(X, [chunk_size, 4])
    y = tf.reshape(y, [chunk_size])

    return (X, y, chunk_idx, record_id, start, end)


def build_dataset_from_tfrecord_shards(
    shard_pattern, 
    chunk_size=5000, 
    compression="GZIP"
):
    dataset = tf.data.Dataset.list_files(shard_pattern, shuffle=True)
    dataset = dataset.interleave(
        lambda fp: tf.data.TFRecordDataset(fp, compression_type=compression),
        cycle_length=4,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.map(
        lambda x: parse_tfrecord(x, chunk_size),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    return dataset


In [13]:
def find_chunk(dataset, target_idx):
    """
    Scans the dataset until it finds the record where chunk_idx == target_idx.
    Returns (X, y, record_id, start, end).
    """
    for X, y, chunk_idx, record_id, start, end in dataset:
        # chunk_idx, record_id, etc., are tf.Tensor objects.
        if int(chunk_idx.numpy()) == target_idx:
            return (X.numpy(), y.numpy(),
                    record_id.numpy().decode('utf-8'),
                    int(start.numpy()),
                    int(end.numpy()))
    return None

# Example usage:
sharded_ds = build_dataset_from_tfrecord_shards(
    shard_pattern="tfrecord_shards/shard_*.tfrecord",
    chunk_size=5000,
    compression="GZIP"
)

target_chunk_idx = 30
result = find_chunk(sharded_ds, target_chunk_idx)
if result is None:
    print(f"Could not find chunk_idx={target_chunk_idx}")
else:
    X_array, y_array, rec_id_str, start_val, end_val = result
    print(f"Found chunk_idx={target_chunk_idx}, record_id={rec_id_str}, start={start_val}, end={end_val}")
    print("X_array shape:", X_array.shape)  # (5000, 4)
    print("y_array shape:", y_array.shape)  # (5000,)

    # For the "first 1000 bases" — that corresponds to X_array[:1000, :]
    # It's a one-hot. Let's decode or compare to the original FASTA to confirm.


InvalidArgumentError: Expected 'tf.Tensor(False, shape=(), dtype=bool)' to be true. Summarized data: b'No files matched pattern: tfrecord_shards/shard_*.tfrecord'

In [None]:
from pyfaidx import Fasta

def compare_chunk(
    X_array, y_array,   # from TFRecord
    record_id, start, end,
    fasta_file, gtf_data,
    chunk_size=5000,
    n_bases=1000
):
    """
    Compare the first n_bases in X_array to the actual FASTA substring.
    Also check the labeling vs. GTF if desired.
    """
    fa = Fasta(fasta_file)
    seq = fa[record_id][start:end].seq  # original substring
    # If we want only the first n_bases, slice:
    seq_substring = seq[:n_bases]  
    print("FASTA substring (first 1000 if that many):", seq_substring)

    # Convert X_array[:n_bases] from one-hot back to letters to see if they match
    # your original substring.
    one_hot_slice = X_array[:n_bases]  # shape (1000, 4)
    decoded = []
    for row in one_hot_slice:
        # row is [A, C, G, T], find which index = 1
        # but here we can handle partial or zeros
        a, c, g, t = row
        if a == 1.0 and c==0.0 and g==0.0 and t==0.0:
            decoded.append('A')
        elif a==0.0 and c==1.0 and g==0.0 and t==0.0:
            decoded.append('C')
        elif a==0.0 and c==0.0 and g==1.0 and t==0.0:
            decoded.append('G')
        elif a==0.0 and c==0.0 and g==0.0 and t==1.0:
            decoded.append('T')
        else:
            decoded.append('N')  # or 'N' if all zeros

    decoded_str = ''.join(decoded)
    print("Decoded from one-hot (first 1000 bases):", decoded_str)

    # Check if it matches:
    if decoded_str == seq_substring.upper():
        print("First 1000 bases match perfectly!")
    else:
        print("Mismatch found between TFRecord data and FASTA substring.")


# Suppose you found chunk_idx=30 in the TFRecord dataset, and got:
# X_array, y_array, rec_id_str, start_val, end_val = ...
gtf_data = load_gtf_annotations("annotations.gtf")  # or reuse if you have it
compare_chunk(
    X_array, y_array,
    rec_id_str, start_val, end_val,
    fasta_file="genome.fa", 
    gtf_data=gtf_data,
    chunk_size=5000,
    n_bases=1000
)
