In [None]:
# %load ../snippets/basic_settings.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import sys
import plotly.express as px
import yaml

sns.set_context("notebook", font_scale=1.1)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)
plt.rcParams["figure.figsize"] = (16, 12)
plt.rcParams['savefig.dpi'] = 200
plt.rcParams['figure.autolayout'] = False
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['font.serif'] = "cm"
#pd.set_option('display.float_format', lambda x: '{:,.2f}'.format(x))

In [None]:
root = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/deutschbauer/fastq/test_out")

# Maps

In [None]:
map1 = pd.read_csv(root/"test_out/TnSeq_SB2B_ML5_l0.annotated.csv")
map2 = pd.read_csv(root/"test_out/TnSeq_SB2B_ML5_tn2_l0.annotated.csv")

In [None]:
map1.sample(20000).number_of_reads.hist(bins=1000)
plt.xlim(0, 100)

In [None]:
map1.shape

In [None]:
map1[map1.number_of_reads > 100].shape

In [None]:
np.log2(20)

In [None]:
map2.shape

In [None]:
map2[map2.number_of_reads > 20].shape

In [None]:
map2.sample(40000).number_of_reads.hist(bins=1000)
plt.xlim(0, 100)

In [None]:
map2.sort_values('number_of_reads').tail(20)

In [None]:
map1[map1.barcode == 'CTCAACATTTGAAGATGTTT']

In [None]:
map2[map2.barcode == "CTTATGCTTCACAAATTGAG" ]

In [None]:
np.quantile(map2.number_of_reads, 0.1)

In [None]:
map2[map2.number_of_reads > 10].shape

In [None]:
map2[map2.number_of_reads > 20].shape

In [None]:
#map2[map2.multimap == True].sample(20000).number_of_reads.hist(bins=500)
np.log2(map2.sample(20000).number_of_reads).hist(bins=500)

In [None]:
blast_file = root/"test_out/TnSeq_SB2B_ML5_l0.blastn"

In [None]:
df = pd.read_table(blast_file, header=None)

In [None]:
df.columns = "qseqid sseqid pident length qstart qend sstart send evalue bitscore qseq sstrand".split()

In [None]:
df.shape

In [None]:
df = df[(df.evalue < 0.1) & (df.length > 20)]

In [None]:
df3

In [None]:
best_hits = df.groupby('qseqid').agg({'bitscore': ['max']}).reset_index()

In [None]:
best_hits.columns = ['qseqid', 'bitscore']
best_hits['barcode'] = best_hits['qseqid'].str.split('_', expand=True)[[2]]

In [None]:
best_hits.head()

In [None]:
best_hits['cnt'] = best_hits['qseqid'].str.split('_', expand=True)[[4]].astype(int)

In [None]:
best_hits.head()

In [None]:
#total_count = best_hits.groupby('barcode').cnt.sum().reset_index()

In [None]:
#total_count.columns = ['barcode', 'total_count']

In [None]:
#total_count.head()

In [None]:
#best_hits = best_hits.merge(total_count, how='left', on='barcode')

In [None]:
query_best_hits = best_hits.merge(df, how='left', on=['qseqid', 'bitscore'])

In [None]:
query_best_hits[query_best_hits.barcode == 'CTCTTGGACGTTGGCGCGAG']

In [None]:
total_counts = query_best_hits.groupby(['barcode', 'sstart']).cnt.sum().reset_index()
total_counts.columns = ['barcode', 'sstart', 'total_cnt']
total_counts['tts'] = total_counts['total_cnt'] / total_counts.groupby('barcode')['total_cnt'].transform('sum')

In [None]:
total_counts[total_counts.barcode == "AAGACGCCCTGCAGGGATGT"]

In [None]:
total_counts[(total_counts.tts > 0.75) & (total_counts.total_cnt > 10)]

In [None]:
mp = total_counts[(total_counts.tts > 0.1) & (total_counts.tts < 0.75)].groupby('barcode').total_cnt.sum().reset_index()
mp[mp.total_cnt > 10].shape




In [None]:
total_counts[total_counts.total_cnt > 10].tts.hist(bins=500)

In [None]:
query_best_hits = query_best_hits.sort_values(['barcode', 'cnt'], ascending=False)
query_best_hits['rank'] = query_best_hits.groupby(['barcode']).cumcount()
#query_best_hits = query_best_hits[query_best_hits['rank'] == 0].copy()
#query_best_hits.drop('rank', axis=1, inplace=True)

In [None]:
query_best_hits[query_best_hits.barcode == 'CTCTTGGACGTTGGCGCGAG']

In [None]:
best_hits[best_hits.barcode == 'CTCTTGGACGTTGGCGCGAG']

In [None]:
query_best_hits[query_best_hits.barcode == 'CTCTTGGACGTTGGCGCGAG']

In [None]:


        # Note: Total counts are calculated with cnt 1 included,
        # but low counts are filtered out right after
        
        
        # Create best hits data frame by merging best_hits with other columns from blast file
        # There still could be multiple hits for each qseqid, if they have the same blast score
        
        multimap = (query_best_hits.groupby(['barcode']).sstart.std(ddof=0) > 5).reset_index().rename(
            {'sstart': 'multimap'},
            axis=1)
        query_best_hits = query_best_hits.merge(multimap, on='barcode')
        # For each barcode select the position supported by most reads
        query_best_hits = query_best_hits.sort_values(['barcode', 'cnt'], ascending=False)
        query_best_hits['rank'] = query_best_hits.groupby(['barcode']).cumcount()
        query_best_hits = query_best_hits[query_best_hits['rank'] == 0].copy()
        query_best_hits.drop('rank', axis=1, inplace=True)
        self.positions = query_best_hits

# Implementing new `_find_most_likely_positions`

In [None]:
root = "/nfs/cds-peta/exports/biol_micro_cds_gr_sunagawa/scratch/Projects_NCCR/ref/mbarq_test_data/dnaid1315/expected_outcomes"

In [None]:
test_blastn="/nfs/cds-peta/exports/biol_micro_cds_gr_sunagawa/scratch/Projects_NCCR/ref/mbarq_test_data/dnaid1315/expected_outcomes/library_11_1_FKDL202598974-1a-D701-AK1682_HHG5YDSXY_L4_subsample_1.blastn"
positions = "/nfs/cds-peta/exports/biol_micro_cds_gr_sunagawa/scratch/Projects_NCCR/ref/mbarq_test_data/dnaid1315/expected_outcomes/likely_positions.csv"

In [None]:
def _find_most_likely_positions_v2(temp_blastn_file, filter_below, perc_primary_location=0.75) -> None:
    """
     Takes in blast file, and provides most likely locations for each barcode
     :param: blast_file
     :param: filter_below
     :param: logger
     :return: pd.DataFrame
     """
    
    def merge_similar_locations(df):
        df = df.sort_values(['sstart']).reset_index()
        df['Group']=((df.end.rolling(window=2,min_periods=1).min()
                    -df.sstart.rolling(window=2,min_periods=1).max())<0).cumsum()
        cnt = df.groupby(['Group']).agg({'cnt': ['sum']}).reset_index()
        cnt.columns = ['Group', 'total_count']
        loc = df.loc[df.groupby(['barcode', 'Group'])['cnt'].idxmax()]
        loc = loc.merge(cnt, on=['Group'])
        return loc[['sstart', 'sstrand', 'total_count']]

    df = pd.read_table(temp_blastn_file, header=None)
    df.columns = "qseqid sseqid pident length qstart qend sstart send evalue bitscore qseq sstrand".split()
    # Filter out spurious hits
    df = df[(df.evalue < 0.1) & (df.length > 20)]
    # Get a best hit for each qseqID( barcode:host combo): group by qseqid, find max bitscore
    best_hits = df.groupby('qseqid').agg({'bitscore': ['max']}).reset_index()
    best_hits.columns = ['qseqid', 'bitscore']
    # Get barcode out of qseqid
    best_hits['barcode'] = best_hits['qseqid'].str.split('_', expand=True)[[2]]
    # Get count out of qseqid
    best_hits['cnt'] = best_hits['qseqid'].str.split('_', expand=True)[[4]].astype(int)
    query_best_hits = best_hits.merge(df, how='left', on=['qseqid', 'bitscore'])
    query_best_hits['end'] = query_best_hits['sstart'] + 5
    
    # for each barcode, find all positions detected, and count how many reads per position
    total_counts = query_best_hits.groupby(['barcode', 'sseqid']).apply(merge_similar_locations).reset_index()
    total_counts['prop_read_per_position'] = total_counts['total_count'] / total_counts.groupby('barcode')['total_count'].transform('sum')
    likely_positions = total_counts[total_counts['prop_read_per_position'] > perc_primary_location].reset_index()
    likely_multimappers = (total_counts[(total_counts['prop_read_per_position'] < perc_primary_location) 
                                        & (likely_positions.total_count > filter_below)]
                          .barcode.nunique())
    print(likely_multimappers)
    likely_positions = likely_positions[likely_positions.total_count > filter_below]
    
    return likely_positions[['barcode', 'sseqid', 'sstrand', 'sstart', 'total_count', 'prop_read_per_position']]

In [None]:
lp2 = _find_most_likely_positions_v2(test_blastn, 0)

In [None]:

lp2.to_csv("/nfs/cds-peta/exports/biol_micro_cds_gr_sunagawa/scratch/Projects_NCCR/ref/mbarq_test_data/dnaid1315/expected_outcomes/23-06-22-likely_positions.csv",
          index=False)



In [None]:
lp2.head()

In [None]:
cnt2 = loc[['barcode', 'sseqid', 'sstart', 'Group']].merge(cnt, on=['Group', 'sseqid'])

In [None]:
cnt2

In [None]:
df1 = df1.merge(loc, on='Group')

In [None]:
import re

In [None]:
lp3 = _merge_colliding_barcodes(lp2)

In [None]:
ex_merge = pd.read_csv(Path(root)/"merge_colliding_bcs.csv")

In [None]:
ex_merge

In [None]:
t = lp3.merge(ex_merge, how='outer', on='barcode')

In [None]:
lp2[lp2.barcode == 'ACCCCACACATAGGTGT']

In [None]:
lp2[(lp2.sstart > 1456300) & (lp2.sstart < 1456320.0)]

In [None]:
lp3.to_csv(Path(root)/"23-06-22-merge_colliding_bcs.csv", index=False)

In [None]:
def _merge_colliding_barcodes(pps):

    """
    Takes data frame of positions, and merges colliding barcodes
    """
    pps = pps[['sseqid', 'sstart', 'sstrand', 'barcode', 'total_count']].copy()

    positions_sorted = (pps.groupby('sseqid')
                        .apply(pd.DataFrame.sort_values, 'sstart')
                        .drop(['sseqid'], axis=1)
                        .reset_index()
                        .drop(['level_1'], axis=1))

    # Get indices for rows with collisions
    collision_index = list(
        positions_sorted[(positions_sorted.sstart.diff() < 5) & (positions_sorted.sstart.diff() >= 0)].index)
    collision_index.extend([i - 1 for i in collision_index if i - 1 not in collision_index])
    collision_index.sort()

    # Barcodes without collisions

    unique = positions_sorted[~positions_sorted.index.isin(collision_index)]

    collisions = positions_sorted.iloc[collision_index]
    if collisions.empty:
        return unique[['barcode', 'total_count', 'sstart', 'sseqid', 'sstrand']]
    else:
        def row_to_barcode(structure, row):
            bc = Barcode(structure)
            bc.bc_seq = row.barcode
            bc.chr = row.sseqid
            bc.start = row.sstart
            bc.strand = row.sstrand
            bc.count = row.total_count
            return bc

        bcs = []
        final_bcs = []
        for i, r in collisions.iterrows():
            bcs.append(row_to_barcode('B17N13GTGTATAAGAGACAG', r))

        bc = bcs.pop(0)
        cnt = bc.count
        while len(bcs) > 0:
            bc2 = bcs.pop(0)
            if bc.chr != bc2.chr or abs(bc.start - bc2.start) > 5:
                bc.count = cnt
                final_bcs.append(bc)
                bc = bc2
                cnt = bc2.count
            else:
                if bc.editdistance(bc2) > 3:
                    bc.count = cnt
                    final_bcs.append(bc)
                    bc = bc2
                    cnt = bc2.count
                else:
                    cnt += bc2.count
                    if bc.count < bc2.count:
                        bc = bc2
        if bc not in final_bcs:
            final_bcs.append(bc)
        resolved_collisions = pd.DataFrame([[bc.chr, bc.start, bc.strand,
                                             bc.bc_seq, bc.count] for bc in final_bcs],
                                           columns=['sseqid', 'sstart', 'sstrand', 'barcode', 'total_count'])
        return pd.concat([unique, resolved_collisions])[['barcode', 'total_count', 'sstart',
                                                                   'sseqid', 'sstrand']]
        

In [None]:

class FastA:
    '''
    Standard data container for fasta sequences
    '''
    __slots__ = ['header', 'sequence']

    def __init__(self, header: str, sequence: str) -> None:
        self.header = header
        self.sequence = sequence


class Barcode:
    def __init__(self, structure='', sequence=''):
        self.structure: str = structure
        self.bc_seq: str = sequence
        self.host: str = ''
        self.bc_len: int
        self.tn_seq: str
        self.count: int = -1
        self.bc_before_tn: bool
        self.len_spacer: int
        # In theory these are optional
        self.start: Optional[int] = None  # todo don't need these, need insertion site
        self.end: Optional[int] = None
        self.chr: Optional[str] = None
        self.strand: Optional[str] = None
        self.multimap: Optional[bool] = None
        self.identifiers: Optional[List[str]] = None
        if self.structure:
            self._parse_structure()
        if not self.structure and not self.bc_seq:
            raise ValueError("Please provide either structure or sequence")

    # def _parse_structure_old(self):
    #     self.tn_seq = self.structure.split(':')[0]
    #     self.len_spacer = int(self.structure.split(':')[2])
    #     self.bc_len = int(self.structure.split(':')[1])
    #     self.bc_before_tn = True if self.structure.split(':')[3] == 'before' else False

    def _parse_structure(self):
        try:
            self.tn_seq = re.findall('[ACGT]+', self.structure)[0]
            bc_len = re.findall('B(\\d+)', self.structure)[0]
            spacer = re.findall('N(\\d+)', self.structure)
            self.len_spacer = int(spacer[0]) if spacer else 0
            self.bc_before_tn = self.structure.index(self.tn_seq) > self.structure.index(bc_len)
            self.bc_len = int(bc_len)


        except IndexError:
            raise 'Could not transposon structure provided'



    def editdistance(self, other_barcode: "Barcode") -> int:
        '''
        Calculate the edit distance between 2 sequences with identical length.
        Will throw an error if the length of both sequences differs

        :return:
        '''
        if not self.bc_seq or not other_barcode.bc_seq:
            raise Exception("Could not find sequence for one or more barcodes")
        if self.bc_seq == other_barcode.bc_seq:
            return 0
        if len(self.bc_seq) != len(other_barcode.bc_seq):
            raise Exception(
                f'{self.bc_seq} and {other_barcode.bc_seq} have different length. Edit distance can be computed on same length sequences only.')
        dist = 0
        for letter1, letter2 in zip(self.bc_seq, other_barcode.bc_seq):
            if letter1 != letter2:
                dist += 1
        return dist

    def extract_barcode_host(self, r1: FastA) -> None:
        '''
       Extract barcode and host sequence from read with tp2.
       Return (None, None) if the barcode sequence is not complete (17bp)
        :param r1:
        :return:

        -----|BARCODE|----------|TN end sequence (tp2)|---Host------
        -----|-bc_len-|--len_spacer--|---------tn_seq---------|-------------
        -----|-17bp--|---13bp---|---------15bp--------|----?--------
         ---(-30)---(-13)-------(0)---------------------------------
        '''
        splits: List[str] = r1.sequence.split(self.tn_seq)  # check that tn in sequence?
        self.bc_seq = ''
        self.host = ''
        if self.bc_before_tn:
            bc_start = -(self.bc_len + self.len_spacer)
            bc_end = None if self.len_spacer == 0 else -self.len_spacer
            bc_seq = splits[0]
            host_seq = splits[1]
        else:
            bc_start = self.len_spacer
            bc_end = self.len_spacer + self.bc_len
            bc_seq = splits[1]
            host_seq = splits[0]

        if len(bc_seq) >= self.len_spacer + self.bc_len:
            self.bc_seq = bc_seq[bc_start:bc_end]
            self.host = host_seq

    def __repr__(self):
        if self.bc_seq:
            return f"{self.bc_seq}: {self.count}"
        else:
            return f"Barcode({self.structure})"


In [None]:
y = tc[tc.barcode == 'ACCACGCAGTATTTTGC'].sort_values(['sseqid', 'sstart']).copy()

In [None]:
y['st'] = y.sstart.diff()
y['cs'] = y.total_cnt.cumsum()

In [None]:
y

In [None]:
ex_df = pd.read_csv(positions)[['barcode', 'sseqid', 'sstart', 'total_count', 'multimap']].sort_values("sstart")

In [None]:
ex_df.multimap.sum()

In [None]:
mdf = tc.merge(ex_df, how='outer', on='barcode')

In [None]:
mdf.head()

In [None]:
qh[qh.barcode =='TCTCGGGACAGTTAGCC']

In [None]:
mdf[mdf.sstart_x != mdf.sstart_y]

In [None]:
px.scatter(mdf[mdf.sseqid_x == 'FQ312003.1'], x='total_count_x', y='total_count_y', hover_data=['barcode'])

In [None]:
positions_sorted = (tc.groupby('sseqid')
                        .apply(pd.DataFrame.sort_values, 'sstart')
                        .drop(['sseqid'], axis=1)
                        .reset_index()
                        .drop(['level_1'], axis=1))
collision_index = list(
        positions_sorted[(positions_sorted.sstart.diff() < 5) & (positions_sorted.sstart.diff() >= 0)].index)
collision_index.extend([i - 1 for i in collision_index if i - 1 not in collision_index])
collision_index.sort()
collisions = positions_sorted.iloc[collision_index]

In [None]:
collisions.head(20)

In [None]:
test = pd.DataFrame([['aaa', 'aaa', 'aaa', 'aaa', 'aaa'], 
                     ['a', 'a', 'a', 'b', 'b'], 
                     [300, 302, 500, 550, 552], [3,4,1,1,1]], index=['barcode', 'sseqid', 'left', 'count']).T
test['right'] = test['left']+5

In [None]:

total_counts = test.groupby(['barcode', 'sseqid', 'left'])['count'].sum().reset_index()
total_counts.columns = ['barcode', 'sseqid','left','total_cnt']
total_counts['right'] = total_counts['left']+5

In [None]:
def get_group(x):
    ((x.right.rolling(window=2,min_periods=1).min()
            -x.left.rolling(window=2,min_periods=1).max())<0).cumsum()

In [None]:
total_counts

In [None]:
total_counts.groupby('sseqid').apply(get_group)

In [None]:
test=test.sort_values(['left','right'])
test['Group']=((test.right.rolling(window=2,min_periods=1).min()
                -test.left.rolling(window=2,min_periods=1).max())<0).cumsum()


In [None]:
test

In [None]:
def get_location(g):
    g = g.reset_index()
    return g.iloc[g['count'].idxmax()].left

In [None]:
loc = test.groupby('Group').apply(get_location).reset_index()
cnt = test.groupby('Group').agg({'count': ['sum']})

In [None]:
x = pd.DataFrame({'left': [0,5,10,3,12,13,18,31], 'right':[4,8,13,7,19,16,23,35]})
x = x.sort_values(['left','right'])

In [None]:
x['Group']=((x.right.rolling(window=2,min_periods=1).min()
                -x.left.rolling(window=2,min_periods=1).max())<0).cumsum()


In [None]:
x

In [None]:
x.right.rolling(window=2,min_periods=1).min()