In [1]:
from collections import defaultdict
import os
import numpy as np
import pandas as pd

In [7]:
def getCombinedVariants(
    seq, 
    distance,
    post_context="",
    correct_snps=True,
    correct_indels=True,
    bases={'A', 'C', 'G', 'T', 'N'}
):
    """
    Generates all unique, fixed-length variants of a sequence within a given
    combined correction distance.

    This version has been corrected to remove the flawed pre_context logic.
    All deletions are now correctly compensated by the post_context.
    
    Returns a map of {variant: (post_context, distance)}.
    """
    # The state now tracks variant, post_context, and distance
    all_variants_map = {seq: (post_context, 0)}
    last_generation = [(seq, post_context)]

    for d in range(distance):
        current_generation = []
        
        for var, post in last_generation:
            
            # 1. Generate SNPs (these don't change post_context)
            if correct_snps:
                for i in range(len(var)):
                    original_base = var[i]
                    for new_base in bases:
                        if new_base != original_base:
                            new_variant = var[:i] + new_base + var[i+1:]
                            if new_variant not in all_variants_map:
                                all_variants_map[new_variant] = (post, d + 1)
                                current_generation.append((new_variant, post))

            # 2. Generate Indels
            if correct_indels:
                # Deletions within the slice (compensated by post_context)
                for i in range(len(var)):
                    deleted_in_slice = var[:i] + var[i+1:]
                    if post:
                        new_variant = deleted_in_slice + post[0]
                        if new_variant not in all_variants_map:
                            # New post_context is one char shorter
                            new_post = post[1:]
                            all_variants_map[new_variant] = (new_post, d + 1)
                            current_generation.append((new_variant, new_post))
                    else: # Fallback: No post_context, pad with all bases
                        for b in bases:
                            new_variant = deleted_in_slice + b
                            if new_variant not in all_variants_map:
                                all_variants_map[new_variant] = ("", d + 1)
                                current_generation.append((new_variant, ""))
                
                # Insertions anywhere (compensated by truncation)
                for i in range(len(var) + 1):
                    for b in bases:
                        inserted = var[:i] + b + var[i:]
                        new_variant = inserted[:len(seq)]
                        if new_variant not in all_variants_map:
                            # New post_context is the truncated char + old post
                            new_post = inserted[-1] + post
                            all_variants_map[new_variant] = (new_post, d + 1)
                            current_generation.append((new_variant, new_post))
        
        last_generation = current_generation
        
    return all_variants_map

In [5]:
getCombinedVariants("ACTGAGATAG",distance=1,post_context="T")

{'ACTGAGATAG': ('T', 0),
 'GCTGAGATAG': ('T', 1),
 'NCTGAGATAG': ('T', 1),
 'CCTGAGATAG': ('T', 1),
 'TCTGAGATAG': ('T', 1),
 'AGTGAGATAG': ('T', 1),
 'AATGAGATAG': ('T', 1),
 'ANTGAGATAG': ('T', 1),
 'ATTGAGATAG': ('T', 1),
 'ACGGAGATAG': ('T', 1),
 'ACAGAGATAG': ('T', 1),
 'ACNGAGATAG': ('T', 1),
 'ACCGAGATAG': ('T', 1),
 'ACTAAGATAG': ('T', 1),
 'ACTNAGATAG': ('T', 1),
 'ACTCAGATAG': ('T', 1),
 'ACTTAGATAG': ('T', 1),
 'ACTGGGATAG': ('T', 1),
 'ACTGNGATAG': ('T', 1),
 'ACTGCGATAG': ('T', 1),
 'ACTGTGATAG': ('T', 1),
 'ACTGAAATAG': ('T', 1),
 'ACTGANATAG': ('T', 1),
 'ACTGACATAG': ('T', 1),
 'ACTGATATAG': ('T', 1),
 'ACTGAGGTAG': ('T', 1),
 'ACTGAGNTAG': ('T', 1),
 'ACTGAGCTAG': ('T', 1),
 'ACTGAGTTAG': ('T', 1),
 'ACTGAGAGAG': ('T', 1),
 'ACTGAGAAAG': ('T', 1),
 'ACTGAGANAG': ('T', 1),
 'ACTGAGACAG': ('T', 1),
 'ACTGAGATGG': ('T', 1),
 'ACTGAGATNG': ('T', 1),
 'ACTGAGATCG': ('T', 1),
 'ACTGAGATTG': ('T', 1),
 'ACTGAGATAA': ('T', 1),
 'ACTGAGATAN': ('T', 1),
 'ACTGAGATAC': ('T', 1),


In [13]:
# For a list of correct barcodes, creates a dictionary mapping uncorrected barcodes to correct barcodes
def getCorrDict(correct_bars, default=None, ambiguous=None, **kwargs):
    out_dict = defaultdict(lambda: default) # returns default if a non-matching barcode is inputted
    ambiguous_variants = set()
    for correct_bar in correct_bars:
        for bar in getCombinedVariants(correct_bar, **kwargs):
            if bar not in ambiguous_variants:
                if bar in out_dict:
                    if out_dict[bar] != correct_bar:
                        ambiguous_variants.add(bar)
                        if ambiguous:
                            out_dict[bar] = ambiguous
                        else:
                            del out_dict[bar]
                else:
                    out_dict[bar] = correct_bar

    return out_dict
class Barcode:
    def __init__(self, seq='', plate='', well='', condition='', read_type='', gene='', next_commands=[], output_files=''):
        self.seq = seq # full sequence, including part that is not used
        self.plate = plate
        self.well = well
        self.condition = condition
        self.read_type = read_type
        self.gene = gene
        self.next_commands = next_commands
        self.length = len(seq)
        
        # Parse output files as semicolon-separated set
        if output_files:
            self.output_targets = {target.strip() for target in output_files.split(';') if target.strip()}
        else:
            self.output_targets = set()
        
class BarcodePos:
    def __init__(self, bar_list, length_to_use, plate_list, corr_dist, orientation, correct_snps, correct_indels):
        
        # Get length of shortest barcode, this is what will be used for disambiguation
        # Also see if there is a default for "No match"
        self.no_match = None
        self.ambiguous = None
        min_length = length_to_use
        true_bars = [] # doesn't include no match
        for bar in bar_list:
            if bar.seq == 'No match':
                self.no_match = bar
                continue
            if bar.seq == 'Ambiguous':
                self.ambiguous = bar
                continue
                
            if bar.length < min_length:
                min_length = bar.length
                
            true_bars.append(bar)
        
        self.length = min_length
        self.plate_list = plate_list
        self.orientation = orientation
                
        # Make barcode correction dictionary
        potential_corrections = defaultdict(list)
        for bar in true_bars:
            if plate_list:
                if bar.plate not in plate_list:
                    continue
                    
            if orientation == 'reverse':
                seq = revComp(bar.seq)
            else:
                seq = bar.seq
            
            post_context = seq[min_length:]
            bar_seq = seq[:min_length]
            
            variants_map = getCombinedVariants(bar_seq, corr_dist, post_context, correct_snps, correct_indels)
            for var_seq, (_, dist) in variants_map.items():
                potential_corrections[var_seq].append((bar, dist))

        self.corr_dict = defaultdict(lambda: self.no_match)
        for var_seq, candidates in potential_corrections.items():
            if len(candidates) == 1:
                self.corr_dict[var_seq] = candidates[0][0]
                continue

            min_dist = min(c[1] for c in candidates)
            best_matches = [c[0] for c in candidates if c[1] == min_dist]

            if len(best_matches) == 1:
                self.corr_dict[var_seq] = best_matches[0]
            elif self.ambiguous:
                self.corr_dict[var_seq] = self.ambiguous
                        
    
    # Returns barcode information
    def getBar(self, bar_seq):
        return self.corr_dict[bar_seq]

# Returns reverse complement of sequence
# Not as flexible as Bio.Seq, but probably faster and easier to use
def revComp(seq, comp_dict={'A':'T','G':'C','C':'G','T':'A','N':'N'}):
    return ''.join([comp_dict[base] for base in seq[::-1]])
    
def makeBarDict(barcode_folder, debug_mode=False):
    bar_dict = {}
    for bar_file in os.listdir(barcode_folder):
        if bar_file.endswith('.csv') and bar_file.startswith('Barcode-'):
            name = bar_file.split('-')[1].split('.')[0]
            with open(os.path.join(barcode_folder, bar_file), 'rt') as f:
                length_to_use = f.readline().split(',')[1].strip()
                if length_to_use == 'all':
                    length_to_use = float('inf')
                else:
                    length_to_use = int(length_to_use)
                    
                plates_to_use = f.readline().split(',')[1].strip()
                if plates_to_use == 'all':
                    plate_list = None
                else:
                    plate_list = plates_to_use.split(';')
                    
                corr_dist = f.readline().split(',')[1].strip()
                if corr_dist: # deals with 0 becoming blank in certain cases
                    corr_dist = int(corr_dist)
                else:
                    corr_dist = 0
                
                correct_snps = f.readline().split(',')[1].strip()
                if correct_snps.lower() == 'true':
                    correct_snps = True
                else:
                    correct_snps = False 
                    
                correct_indels = f.readline().split(',')[1].strip()
                if correct_indels.lower() == 'true':
                    correct_indels = True
                else:
                    correct_indels = False
                    
                if debug_mode:
                    print("Corr dist: %d" % corr_dist)
                orientation = f.readline().split(',')[1].strip()
                f.readline()
                f.readline()
                
                bar_list = []
                for line in f:
                    parts = line.strip().split(',')
                    if len(parts) < 7:
                        continue
                    
                    seq = parts[0] if len(parts) > 0 else ''
                    plate = parts[1] if len(parts) > 1 else ''
                    well = parts[2] if len(parts) > 2 else ''
                    condition = parts[3] if len(parts) > 3 else ''
                    read_type = parts[4] if len(parts) > 4 else ''
                    gene = parts[5] if len(parts) > 5 else ''
                    next_command_string = parts[6] if len(parts) > 6 else ''
                    output_file_number = parts[7] if len(parts) > 7 else ''
                    
                    if not seq: # blank line at end of file
                        break
                    
                    if next_command_string:
                        next_commands = next_command_string.split(';')
                    else:
                        next_commands = None
                    bar_list.append(Barcode(seq, plate, well, condition, read_type, gene, next_commands, output_file_number))                                                                                                        
            
            if debug_mode:
                print("%d barcodes found for %s" % (len(bar_list), name))
            bar_dict[name] = BarcodePos(bar_list, length_to_use, plate_list, corr_dist, orientation, correct_snps, correct_indels)
    return bar_dict

In [14]:
barcode_folder = "/ru-auth/local/home/aepstein/cursor_projects/jointsci-pipeline/sample_sheets/barcodes_targ_new_pipeline"
bar_dict = makeBarDict(barcode_folder, debug_mode=False)

In [19]:
bar_dict['SSSprimer'].bar_dict

AttributeError: 'BarcodePos' object has no attribute 'bar_dict'

In [24]:
bar_dict['SSSprimer'].corr_dict['CTATAGGATC'].seq

'CTAAAGGATCTTCGAAGACAAC'

In [31]:
bar_dict['SSSprimer'].corr_dict['GTAAAGGATC'].seq

'CTAAAGGATCTTCGAAGACAAC'

In [28]:
bar_dict['SSSprimer'].corr_dict['CTATAGGATC'].read_type

'ATM_SSS'

In [33]:
seq = 'CCGCTTTTTCTAT'[:10]
bar_dict['SSSprimer'].corr_dict[seq]

In [36]:
test_file = "/ru-auth/local/home/aepstein/cursor_projects/jointsci-pipeline/intermediate_files/targ_merged/new_pipeline/barcoded_fastqs_targ/Targeted_F9_output.csv"
with open(test_file, 'rt') as f:
    header = f.readline()
    fields = header.strip().split(',')
    lines = []
    for i in range(100):
        line = f.readline()
        lines.append(line.strip().split(','))
        
df = pd.DataFrame(lines, columns=fields)

In [44]:
bar_dict['SSSprimer'].corr_dict['CAGCCTGGGC']

In [49]:
len(list(bar_dict['SSSprimer'].corr_dict.keys()))

322

In [53]:
'CAGCCTGGGC' in bar_dict['SSSprimer'].corr_dict

True

In [55]:
bar_dict['SSSprimer'].corr_dict

defaultdict(<function __main__.BarcodePos.__init__.<locals>.<lambda>()>,
            {'CTAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'GTAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'ATAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'NTAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'TTAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CGAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CAAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CNAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CCAAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CTGAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CTNAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CTCAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CTTAAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CTAGAGGATC': <__main__.Barcode at 0x7f7c26346fa0>,
             'CTA

In [59]:
list(bar_dict['SSSprimer'].corr_dict.values())
bar_dict['SSSprimer'].corr_dict['CAGCCTGGGC']
bar_dict['SSSprimer'].corr_dict['CAGCCTGGGC']
bar_dict['SSSprimer'].corr_dict['CAGCCTGGGC']
bar_dict['SSSprimer'].corr_dict['CAGCCTGGGC']
bar_dict['SSSprimer'].corr_dict['CAGCCTGGGC']


In [65]:
def getReadType(bar_seq):
    if bar_dict['SSSprimer'].corr_dict[bar_seq[:10]]:
        return bar_dict['SSSprimer'].corr_dict[bar_seq[:10]].read_type
    else:
        return 'Unknown'

df['SSSprimer_redone'] = df['R2'].apply(getReadType)


In [67]:
df[df['SSSprimer_redone'] != 'Unknown']

Unnamed: 0,wells_conditions,UMI,read_types,R1,R2,I5,I7,SSSprimer_redone
27,1-D4;-C4;1-F9;-;-;-&;edge;HGG5;;;,ATGGGCGA,;TargRT;;ATM_SSS;ATM_WT;ATM_RT,ATGGGCGATTCTTATACTGAGAACGACGAGCAATGCGAGTAAGGTC...,CTAAAGGATCTTCGAAGACAACTGGAACTACATAAAGATCAGATGG...,CAATCTCAATGTGTAGATCTCGGT,AGACGTTAGA,ATM_SSS
46,1-D4;-D3;1-F9;-;-;-&;edge;HGG5;;;,CTCCATGC,;TargRT;;ATM_SSS;ATM_WT;ATM_RT,CTCCATGCCGAACCAGCAGAGAACGACGAGCAATGCGAGTAAGGTC...,CTAAAGGATCTTCGAAGACAACTGGAACTACATAAAGATCAGATGG...,CAATCTCAATGTGTAGATCTCGGT,AGACGTTAGA,ATM_SSS
67,1-H2;-D11;1-F9;-;-;-&;edge;HGG5;;;,ATGTCTTG,;TargRT;;ATM_SSS;ATM_WT;ATM_RT,ATGTCTTGGAAGTTGACGTAGAGAACGACGAGCAATGCGAGTAAGG...,CTAAAGGATCTTCGAAGACAACTGGAACTACATAAAGATCAGATGG...,TCCATACCTGGTGTAGATCTCGGT,AGACGTTAGA,ATM_SSS
77,2-H4;-D6;1-F9;-;-;-&;edge;HGG5;;;,GTCGCATT,;TargRT;;ATM_SSS;ATM_WT;ATM_RT,GTCGCATTCATACGACCTCGAGAACGACGAGCAATGCGAGTAAGGT...,CTAAAGGATCTTCGAAGACAACTGGAACTACATAAAGATCAGATGG...,GAATAGACGCGTGTAGATCTCGGT,AGACGTTAGA,ATM_SSS
96,1-D1;-D6;1-F9;-;-;-&;edge;HGG5;;;,TGGGGGTA,;TargRT;;IL6_SSS;IL6_WT;IL6_RT,TGGGGGTACATACGACCTCGAGAACGACGAGCAATGCGAGTAAGGT...,GATTTGAGAGTAGTGAGGAACAAGCCAGAGCTGTGCAGATGAGTAC...,CGGAGCGTCAGTGTAGATCTCGGT,AGACGTTAGA,IL6_SSS


In [61]:
x = defaultdict(lambda: 'Unknown')

In [63]:
x['la']

'Unknown'

In [64]:
'la' in x

True