In [10]:
import h5py
import numpy as np
from tqdm import tqdm
import random

import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme()

In [4]:
input_file = '/mnt/f/hprc/hprc-v1.1-mc-chm13_segments.gfa'
output_file = '/mnt/f/hprc/segments_b.hdf5'

In [12]:
nucleotide_to_index = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
vocab_size = 4  # Number of unique nucleotides

dt = h5py.vlen_dtype(np.uint8)
max_seq_count = 100000

In [13]:
train_prob = 0.8
val_prob = 0.1
test_prob = 0.1
assert train_prob + val_prob + test_prob == 1.0, "Split probabilities must sum to 1.0"

train_count = 0
val_count = 0
test_count = 0

In [14]:
max_seq_count = 100000

In [15]:
added_seq_count = 0
with h5py.File(output_file, 'w') as hdf5_file:
    train_dataset = hdf5_file.create_dataset('train_sequences', shape=(0,), maxshape=(None,), dtype=dt)
    val_dataset = hdf5_file.create_dataset('val_sequences', shape=(0,), maxshape=(None,), dtype=dt)
    test_dataset = hdf5_file.create_dataset('test_sequences', shape=(0,), maxshape=(None,), dtype=dt)
    
    with open(input_file, 'r') as f:
        for line in tqdm(f, desc="Processing sequences"):
            seq = line.strip()
            if not seq:
                continue  # Skip empty lines

            # Encode the sequence
            try:
                encoded_seq = np.array([nucleotide_to_index[nuc] for nuc in seq if nuc in nucleotide_to_index], dtype=np.uint8)
                if len(encoded_seq) == 0:
                    continue  # Skip sequences with no valid nucleotides
            except KeyError:
                # Skip sequences with invalid nucleotides
                continue

            # Randomly assign the sequence to a split
            rand_num = random.random()
            if rand_num < train_prob:
                dataset = train_dataset
                count = train_count
                train_count += 1
            elif rand_num < train_prob + val_prob:
                dataset = val_dataset
                count = val_count
                val_count += 1
            else:
                dataset = test_dataset
                count = test_count
                test_count += 1

            # Resize the dataset and store the sequence
            dataset.resize((dataset.shape[0] + 1,))
            dataset[-1] = encoded_seq  # Store the encoded sequence

            added_seq_count += 1
            if added_seq_count == max_seq_count:
                break

print(f"Finished processing. Train sequences: {train_count}, Validation sequences: {val_count}, Test sequences: {test_count}")

Processing sequences: 99999it [00:04, 23249.62it/s]


Finished processing. Train sequences: 80072, Validation sequences: 10053, Test sequences: 9875
