In [1]:
from __future__ import print_function
import os
import re
import gzip
import itertools
import argparse
import time
import logging

In [2]:
__author__ = 'Maggie Ruimin Sun'

In [5]:
logger = logging.getLogger('root')

In [3]:
def read_fq(filename):
    if re.search('.gz$',filename):
        fastq = gzip.open(filename, 'rb')
    else:
        fastq = open(filename)
    with fastq as f:
        while True:
            l1 = f.readline()
            if not l1:
                break
            l2 = f.readline()
            l3 = f.readline
            l4 = f.readline()
            yield [l1, l2, l3, l4]

In [4]:
def get_sample_id(i1, i2, sample_names):
    seq1 = i1[1]
    seq2 = i2[1]
    sample_barcode = seq1 + seq2
    if sample_barcode in sample_names:
        return sample_names[sample_barcode]
    else:
        return sample_barcode

In [6]:
def demultiplex(read1, read2, index1, index2, sample_barcodes, out_dir, min_reads):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    
    if type(sample_barcodes) != dict:
        sample_names = {}
        if not sample_barcodes == None:
            for line in open(sample_barcodes,'r'):
                fields = line.strip().split('\t')
                if len(fields)==2:
                    sampleid, barcode = fields
                    sample_names[barcode] = sampleid
    else:
        sample_names = sample_barcodes
        
    outfiles_r1 = {}
    outfiles_r2 = {}
    outfiles_i1 = {}
    outfiles_i2 = {}
    
    total_count = 0
    count = {}
    buffer_r1 = {}
    buffer_r2 = {}
    buffer_i1 = {}
    buffer_i2 = {}
    
    start = time.time()
    for r1, r2, i1, i2 in itertools.izip(read_fq(read1), read_fq(read2),read_fq(index1),read_fq(index2)):
        total_count += 1
        if total_count % 1000000 == 0:
            logger.info("Procesed %d reads in %.1f minutes.", total_count, (time.time()-start)/60)
        sample_id = get_sample_id(i1, i2, sample_names)
        
        if not count.has_key(sample_id):
            count[sample_id] = []
            buffer_r1[sample_id] = []
            buffer_r2[sample_id] = []
            buffer_i1[sample_id] = []
            buffer_i2[sample_id] = []
        count[sample_id] += 1
        
        if count[smaple_id] < min_reads:
            buffer_r1[sample_id].append(r1)
            buffer_r2[sample_id].append(r2)
            buffer_i1[sample_id].append(i1)
            buffer_i2[sample_id].append(i2)
        elif count[sample_id] == min_reads:
            outfiles_r1[sample_id] = open(os.path.join(out_dir, '%s.r1.fq' % sample_id), 'w')
            outfiles_r2[sample_id] = open(os.path.join(out_dir, '%s.r2.fq' % sample_id), 'w')
            outfiles_i1[sample_id] = open(os.path.join(out_dir, '%s.i1.fq' % sample_id), 'w')
            outfiles_i2[sample_id] = open(os.path.join(out_dir, '%s.i2.fq' % sample_id), 'w')
            for record in buffer_r1[sample_id] + r1:
                outfiles_r1[sample_id].write(''.join(record))
            for record in buffer_r2[sample_id] + r2:
                outfiles_r2[sample_id].write(''.join(record))
            for record in buffer_i1[sample_id] + i1:
                outfiles_i1[sample_id].write(''.join(record))
            for record in buffer_i2[sample_id] + i2:
                outfiles_i2[sample_id].write(''.join(record))
            del buffer_r1[sample_id]
            del buffer_r2[sample_id]
            del buffer_i1[sample_id]
            del buffer_i2[sample_id]
        else:
            for line in r1:
                print (line, file=outfiles_r1[sample_id], end="")
            for line in r2:
                print (line, file=outfiles_r2[sample_id], end="")
            for line in i1:
                print (line, file=outfiles_i1[sample_id], end="")
            for line in i2:
                print (line, file=outfiles_i2[sample_id], end="")
        
    undetermined_r1 = open(os.path.join(out_dir, 'undetermined_r1.fq'), 'w')
    undetermined_r2 = open(os.path.join(out_dir, 'undetermined_r2.fq'), 'w')
    undetermined_i1 = open(os.path.join(out_dir, 'undetermined_i1.fq'), 'w')
    undetermined_i2 = open(os.path.join(out_dir, 'undetermined_i2.fq'), 'w')
    for sample_id in buffer_r1.keys():
        for record in buffer_r1[sample_id]:
            undetermined_r1.write(''.join(record))
        for record in buffer_r2[sample_id]:
            undetermined_r2.write(''.join(record))
        for record in buffer_i1[sample_id]:
            undetermined_i1.write(''.join(record))
        for record in buffer_i2[sample_id]:
            undetermined_i2.write(''.join(record))
    
    for sample_id in outfiles_r1:
        outfiles_r1[sample_id].close()
        outfiles_r2[sample_id].close()
        outfiles_i1[sample_id].close()
        outfiles_i2[sample_id].close()
    undetermined_r1.close()
    undetermined_r2.close()
    undetermined_i1.close()
    undetermined_i2.close()
    
    num_fastqs = len([v for k,v in count.iteritems() if v>=min_reads])
    logger.info('Wrote FASTQs for %d sample barcodes out of %d with at least %d reads.', num_fastqs, len(count), min_reads)

In [8]:
def main():
    source = '/home/yaneng/RSun/Data/qiagen-colon/'
    out_dir = source + 'demultiplexed/'
    read1 = source + 'QIAGEN-2959YJ_S2_L001_R1_001_undetermined.fq'
    read2 = source + 'QIAGEN-2959YJ_S2_L001_R2_001_undetermined.fq'
    index1 = source + 'QIAGEN-2959YJ_S2_L001_I1_001_undetermined.fq'
    index2 = source + 'QIAGEN-2959YJ_S2_L001_I2_001_undetermined.fq'
    min_reads = 10000
    sample_barcodes = {} # key=barcode, value=sample_id
    if len(sample_barcodes) > 1:
        demultiplex(read1, read2, index1, index2, sample_barcodes, out_dir, min_reads=min_reads)


In [9]:
if __name__ == '__main__':
    main()