In [78]:
# start coding here
all_seqs = snakemake.input.all_seqs
output_dir = snakemake.output[0]
log = snakemake.log[0]

In [79]:
import os
import time


from Bio import SeqIO

def _save_cluster(output_dir, clusterID, records):
    """
    Save cluster sequence records as follows:
        if the ClusterID is 1mai_N, the records will be saved to output_dir/1m/1mai_N.fasta
    """
    filename = os.path.join(output_dir, clusterID[0:2], f"{clusterID}.fasta")
    
    # ensure this directory exists
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    # save records
    with open(filename, "w") as f:
        SeqIO.write(sequences=records, handle=f, format="fasta")
    

def main(all_seqs, output_dir, log):
    """
    Based on MMseqs2 format, clusters look like this:
    
    >1mai_N
    >1mai_N
    XXXXXXX
    >1nnz_A
    XXXXXXX
    
        Saving to indivudual fasta files inclusive of each cluster and without the repeated header.
        Each cluster and file is named after mmseqs2 cluster ID (the repeated ID)
    """
    cur_clusterID = ""
    cluster_records = list()
    cluster_count = 0
    
    with open(all_seqs, "r") as f:
        for i,record in  enumerate(SeqIO.parse(f, "fasta")):
            if len(record) == 0: # the cluster head will be record of zero length
                
                # save previous cluster
                # this does not trigger on the first cluster only
                if cur_clusterID:
                    _save_cluster(output_dir, cur_clusterID, cluster_records)
                
                # reset for next cluster
                cur_clusterID = record.description
                cluster_records = list()
                
                # keep count of clusters
                cluster_count += 1
                
                continue
            else:
                cluster_records.append(record)
                
    with open(log, "w") as f:
        f.write(f"Extracted a total of {cluster_count} clusters")
                
    

In [80]:
%%timeit -r 1  -n 1
main(all_seqs=all_seqs,
     output_dir=output_dir,
     log=log)

# CLEANUP # 
# delete unneeded all_seqs fasta file to save space
os.remove(all_seqs)