<a href="https://colab.research.google.com/github/anihab/dnaTokenization/blob/main/segment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Takes as input a list or directory of genomic sequences in fasta format and cuts it into input segments of length *n* with length *l*.

Output should be a CSV file with the following fields: sample_ID, start_position, end_position, length, sequence, label (The model will produce the label so the label should be blank).

In [None]:
import multiprocessing
processes=multiprocessing.cpu_count()
print("The number of processes:")
print(processes)


import argparse
import gzip
import os
import math
import re
import random
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
import numpy as np

!pip install Bio
from Bio import SeqIO

Collecting Bio
  Downloading bio-1.6.2-py3-none-any.whl (278 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.6/278.6 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting biopython>=1.80 (from Bio)
  Downloading biopython-1.82-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
Collecting mygene (from Bio)
  Downloading mygene-3.2.2-py2.py3-none-any.whl (5.4 kB)
Collecting gprofiler-official (from Bio)
  Downloading gprofiler_official-1.0.0-py3-none-any.whl (9.3 kB)
Collecting biothings-client>=0.2.6 (from mygene->Bio)
  Downloading biothings_client-0.3.1-py2.py3-none-any.whl (29 kB)
Installing collected packages: biopython, gprofiler-official, biothings-client, mygene, Bio
Successfully installed Bio-1.6.2 biopython-1.82 biothings-client-0.3.1 gprofiler-official-1.0.0 mygene-3.2.2


In [None]:
def process_file(args):
    f, max_length, shift_amount, output_path = args
    filename = os.path.basename(f)
    if not is_file_read(output_path, filename):
        segment(f, max_length, shift_amount, output_path)


def read_input(input_path, max_length, output_path, **kwargs):
    shift_amount = kwargs.get('shift_amount', None)


    files_to_process = []


    if input_path.endswith('.txt'):  # if the input path is a list
        if os.path.isfile(input_path):
            with open(input_path, 'r') as list_file:
                for f in list_file:
                    f = f.strip()
                    if os.path.isfile(f):
                        files_to_process.append((f, max_length, shift_amount, output_path))
    else:                            # if the input path is a directory
        for filename in os.listdir(input_path):
            f = os.path.join(input_path, filename)
            if os.path.isfile(f):
                files_to_process.append((f, max_length, shift_amount, output_path))


    # process files in parallel using multiprocessing
    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
    pool.map(process_file, files_to_process)
    pool.close()
    pool.join()

In [None]:
def is_file_read(directory, filename):
  '''\
  Determines whether or not a file has already been processed by checking if the output
  filename exists in the output directory and has a size greater than 0.
  '''
  file_path = os.path.join(directory, filename.split('.')[0]  + '_segmented.csv')
  if os.path.isfile(file_path) and os.path.getsize(file_path) > 0:
    return True
  else:
    return False

In [None]:
def segment(filepath, max_length, shift_amount, output_path):
  filename = os.path.basename(filepath)
  filename = filename.split('.')[0]

  # process data to get sequences of appropriate length
  if shift_amount is None:
    df = preprocess_data(filepath, max_length)
  else:
    df = preprocess_shift(filepath, max_length, shift_amount)

  # save output to csv
  df.to_csv(output_path + '/' + filename + '_segmented.csv', encoding='utf-8', index=False, header=False)

In [None]:
def preprocess_data(filepath, max_length):
  records = []

  f = filepath
  if filepath.endswith('.gz'):
    f = gzip.open(filepath, 'rt', encoding='utf-8')

  try:
    for record in SeqIO.parse(f, 'fasta'):
      filename = os.path.basename(filepath)
      name = filename.split('.')[0]
      sample_ID = str(record.name)
      seq = str(record.seq).upper()
      pos = 0

      # truncate sequences if longer than max_length
      while len(seq) >= max_length:
        records.append(
          {
            'name': name,
            'sample_ID': sample_ID,
            'start_position': pos,
            'end_position': pos + max_length,
            'length': max_length,
            'sequence': seq[:max_length], # add subsequence up to max_length
            'label': ''
          }
        )
        seq = seq[max_length:] # sequence continuing from max_length
        pos += max_length

      # last case, for when max_length is None or len(seq) < max_length
      records.append(
          {
            'name': name,
            'sample_ID': sample_ID,
            'start_position': pos,
            'end_position': pos + len(seq),
            'length': len(seq),
            'sequence': seq,
            'label': ''
          }
      )
  finally:
    df = pd.DataFrame(data=records)
    return df

In [None]:
def preprocess_shift(filepath, max_length, shift_amount):
    records = []

    f = filepath
    if filepath.endswith('.gz'):
        f = gzip.open(filepath, 'rt', encoding='utf-8')

    try:
        for record in SeqIO.parse(f, 'fasta'):
          filename = os.path.basename(filepath)
          name = filename.split('.')[0]
          sample_ID = str(record.name)
          seq = str(record.seq).upper()
          pos = 0

          while len(seq) >= max_length:
              records.append(
                  {
                      'name': name,
                      'sample_ID': sample_ID,
                      'start_position': pos,
                      'end_position': pos + max_length,
                      'length': max_length,
                      'sequence': seq[:max_length],
                      'label': ''
                  }
              )
              seq = seq[shift_amount:]  # shift the sequence by shift_amount
              pos += shift_amount

          records.append(
              {
                  'name': name,
                  'sample_ID': sample_ID,
                  'start_position': pos,
                  'end_position': pos + len(seq),
                  'length': len(seq),
                  'sequence': seq,
                  'label': ''
              }
          )
          seq = seq[shift_amount:]  # shift the sequence by shift_amount
    finally:
      df = pd.DataFrame(data=records)
      return df

In [None]:
def main():
  parser = argparse.ArgumentParser()
  # Parameters
  parser.add_argument(
        "--input", default=None, type=str, required=True, help="The input directory or txt file list."
  )
  parser.add_argument(
        "--max_length", default=None, type=int, required=True, help="The max sequence length for parsing"
  )
  parser.add_argument(
        "--output", default=None, type=str, required=False, help="The output directory."
  )
  parser.add_argument(
        "--shift_amount", default=None, type=int, required=False, help="The amount of nucleotides to shift by when parsing"
  )
  args = parser.parse_args()

  # read and format files
  read_input(input_path=args.input,
             max_length=args.max_length,
             output_path=args.output,
             shift_amount=args.shift_amount)

if __name__ == "__main__":
    main()