In [7]:
import asyncio
from typing import io
import os
import esm

In [8]:
async def async_read_file(file_name: os.PathLike, queue: asyncio.Queue, event_finishing: asyncio.Event, return_idx: bool = True):
    with open(file_name, 'r') as f:

        sequence_idx = 0
        header = None
        sequence = ''

        for line_num, line in enumerate(f):
            if line.startswith('>'):
                if header is None:
                    header = line[1:].strip() 
                    
                else:
                    if return_idx: 
                        to_return =  (sequence_idx, header, sequence)
                    else:
                        to_return = (header, sequence)
                    
                    await queue.put(to_return)
                    
                    sequence_idx += 1
                    header = line[1:].strip()
                    sequence = '' 
                
            else:
                sequence += line.strip()
                
        event_finishing.set()
                
async def async_tokenize_fasta(batch_converter: esm.data.BatchConverter, in_queue: asyncio.Queue, out_queue: asyncio.Queue):
    while True:
        item = await in_queue.get()
        (sequence_no, header, sequence) = item
        data = [(header, sequence)]

        batch_labels, batch_strs, batch_tokens = batch_converter(data)

        # get list of length 1 from each item
        labels = batch_labels[0] # header
        strs = batch_strs[0] # actual aa sequence
        tokens = batch_tokens[0].tolist() # was torch.tensor previously
        tokens = ' '.join(map(str, tokens)) # convert all ints to str and then map to string

        await out_queue.put((labels, strs, tokens))
        in_queue.task_done()
        
async def async_write_to_file(queue, f_labels: io.TextIO, f_strs: io.TextIO, f_tokens: io.TextIO):
    while True:
        data = await queue.get()
        labels, strs, tokens = data
        
        f_labels.write(labels + '\n')
        f_strs.write(strs + '\n')
        f_tokens.write(tokens + '\n')
        queue.task_done()

In [None]:
fasta_path = '../../../data/raw/2022-07-25_uniref/uniref50.fasta'
first_10000 = '../../../data/raw/2022-07-25_uniref/uniref50-first10000.fasta'

In [None]:
q_in = asyncio.Queue(maxsize=100)
q_out = asyncio.Queue(maxsize=100)
finish_signal = asyncio.Event() # signals to the script that we consumed the generator fully

n_workers = 24

creator = asyncio.create_task(read_file(fasta_path, q_in, dr)) # create a task to yield sequences from the FASTA file

processor = [asyncio.create_task(async_tokenize_fasta(batch_converter, q_in, q_out)) for _ in range(n_workers)]  
# create tasks for `n_workers` workers to tokenize the sequences from `q_in` and then put them into `q_out` for usage

# file handles for writing outputs
f_tokens = open('/work/ucsf/ntranos/variant-gsp1/tokens.txt', 'w') # tokens
f_labels = open('/work/ucsf/ntranos/variant-gsp1/labels.txt', 'w') # headers
f_strs = open('/work/ucsf/ntranos/variant-gsp1/headers.txt', 'w') # actual string AA sequences

async_writer = asyncio.create_task(async_write_to_file(qo, f_labels, f_strs, f_tokens)) # one task to async write to file

try:
    await finish_signal.wait()
    
finally:
    # when we hit here we want to clean everything up so we can 
    creator.cancel()
    for proc in processor:
        proc.cancel()
    async_writer.cancel()
    
    for f in (f_tokens, f_labels, f_strs):
        f.cancel()
