In [None]:
#package imports and database creation
from ftplib import FTP
import time
import os
from subprocess import Popen, PIPE
import sqlite3
import re
from ete3 import PhyloTree, Tree
from tqdm import tqdm
from collections import Counter
from ete3.coretype.tree import TreeError
import itertools
from Bio.Phylo.PAML import codeml
import requests, json
import numpy as np
import matplotlib.pyplot as plt
from math import exp, log, log10
from scipy.stats import chi2, mannwhitneyu, spearmanr, combine_pvalues
from statsmodels.stats.multitest import fdrcorrection as fdr
import statsmodels.api as sm
import statsmodels.formula.api as smf
import seaborn as sns
import pandas as pd
from collections import Counter

db = sqlite3.connect('drosophilaDatabase_diptera')
db.isolation_level = None
cursor = db.cursor()

In [None]:
#data fetching and processing
#connect to ftp server
ftp = FTP('ftp.ncbi.nlm.nih.gov')
ftp.login()
#cycle through all species
print('Connected to site successfully')
#connect to ftp site, loop over file locations for each species
with open('drosophilaProjectNcbiFtpFileLocations.txt','r') as file:
    for line in file:
        line = line.strip('\n').split('\t')
        spAbb, path = line
        #fetch raw files for CDS, translated CDS, gff
        outputFileTrans = 'dipteraTranslations_raw/'+spAbb + '_raw_trans.faa.gz'
        outputFileCDS = 'dipteraCDS_raw/'+spAbb + '_raw_cds.fna.gz'
        outputFileGFF = 'dipteraGFF_raw/'+spAbb + '_raw.gff.gz'
        print('Species:',spAbb)
        fullPath = '/genomes/all/' + path
        ftp.cwd(fullPath) #change to relevant directory for this species
        #navigate to correct files and save to correct file names
        fileList = []
        if 'GC' in fullPath: #files located in GCA or GCF rather than in annotation_releases, navigation a little different
            ftp.retrlines('NLST',callback=lambda x: fileList.append(x))
            faaFile = [file for file in fileList if 'translated_cds' in file][0]
            fnaFile = [file for file in fileList if 'cds_from_genomic' in file][0]
            gffFile = [file for file in fileList if 'genomic.gff' in file][0]
            print('Transferring files...')
            t = time.time()
            with open(outputFileTrans, 'wb') as outT:
                ftp.retrbinary('RETR '+faaFile, outT.write)
            with open(outputFileCDS, 'wb') as outC:
                ftp.retrbinary('RETR '+fnaFile, outC.write)
            with open(outputFileGFF, 'wb') as outG:
                ftp.retrbinary('RETR '+gffFile, outG.write)
            print('Done! ('+str(round(time.time()-t,1))+' seconds)')
            print()
        else:
            ftp.retrlines('LIST',callback=lambda x: fileList.append(x)) #callback arg is a function applied to each line returned
            directory = [x for x in fileList if x.startswith('d')][0].split(' ')[-1]
            ftp.cwd(directory)
            fileList2 = []
            ftp.retrlines('NLST',callback=lambda x: fileList.append(x))
            faaFile = [file for file in fileList if 'translated_cds' in file][0]
            fnaFile = [file for file in fileList if 'cds_from_genomic' in file][0]
            gffFile = [file for file in fileList if 'genomic.gff' in file][0]
            print('Transferring files...')
            t = time.time()
            with open(outputFileTrans, 'wb') as outT:
                ftp.retrbinary('RETR '+faaFile, outT.write)
            with open(outputFileCDS, 'wb') as outC:
                ftp.retrbinary('RETR '+fnaFile, outC.write)
            with open(outputFileGFF, 'wb') as outG:
                ftp.retrbinary('RETR '+gffFile, outG.write)
            print('Done! ('+str(round(time.time()-t,1))+' seconds)')
            print()
print('All done! Goodbye.')
print()
ftp.close()
print('Processing protein seq files...')
dictDict = {}
for file in os.listdir('./dipteraTranslations_raw'):
    dictDict[file] = {}
    filename = './dipteraTranslations_raw/' + file
    print(filename)
    #unzip to stdout and pipe to here, faster and more space efficient than having an uncompressed file
    unzipP = Popen(['gunzip','-c',filename],stdout=PIPE)
    text,err = unzipP.communicate()
    text = text.decode()
    splitText = text.split('\n')
    for i, line in enumerate(splitText):
        if line.startswith('>'): #new header
            seq = ''
            seqOK = True
            if 'locus_tag' in line and 'gene=' in line:
                gene = re.search('\[gene=(.*)\] \[locus_tag.*\]',line).group(1)
                prot = re.search('\[protein_id=(.*)\] \[location=',line).group(1)
            elif 'gene=' not in line and 'GeneID' in line:
                gene = re.search('GeneID:(\d*)\] \[protein.*\]',line).group(1)
                prot = re.search('\[protein_id=(.*)\] \[location=',line).group(1)
            
            elif 'gene=' in line:
                gene = re.search('\[gene=(.*)\] \[db_xref.*\]',line).group(1)
                prot = re.search('\[protein_id=(.*)\] \[location=',line).group(1)
            elif 'locus_tag=' in line: #for CCAP, and any otehers with no gene ids
                gene = re.search('\[locus_tag=(.*?)\]',line).group(1)
                try:
                    prot = re.search('\[protein_id=(.*?)\]',line).group(1)
                except AttributeError: #no protein id given
                    seqOK = False
            else:
                print('??')
                print(line)
        elif i+1 == len(splitText) and seqOK == True: #last line of file, same as last seq but I don't want to deal with index errors
            seq = seq+line #add last line
            pLen = len(seq)
               # print('last line')
            try:
                if dictDict[file][gene][1] < pLen:
                    dictDict[file][gene] = (prot,pLen,seq)
            except KeyError:
                dictDict[file][gene] = (prot,pLen,seq)
        elif splitText[i+1].startswith('>') and seqOK == True: #last sequence line for this protein
                #print('last of this prot')
            seq = seq+line #add this last line
            pLen = len(seq)
            try:
                if dictDict[file][gene][1] < pLen: #if the current protein for this gene is shorter than this one
                    dictDict[file][gene] = (prot,pLen,seq)
            except KeyError:
                dictDict[file][gene] = (prot,pLen,seq)
        elif seqOK == True:
            seq = seq+line
    print('Writing to output file')
    outpath = './dipteraTranslations_processed/'+file.strip('_raw_trans.faa.gz') + '.longest_only.faa'
    sp = file.strip('_raw_trans.faa.gz')
    print(sp)
    pLenDict = dictDict[file]
    #write longest sequence and reduced headers to file
    with open(outpath,'w') as out:
        for gene in pLenDict:
            prot, length, sequence = pLenDict[gene]
            line = '>' + gene + '|' + prot + '|' + sp + '\n'
            out.write(line)
            out.write(sequence+'\n')

In [None]:
#creating orthogroups table
#similar for all tools, only Orthofinder shown
#iterate over orthogroups making sure they're complete
#I think the latest version has fixed the issue where the groups don't necessarily reflect the orthologues but just in case
orthoGroupsReal = []
cursor.execute('CREATE TABLE groups_Orthofinder(groupID INTEGER, groupMembers TEXT)')
for d in os.listdir('dipteraTranslations_processed/Orthofinder/Results_Mar30/Orthologues/'):
    print('Current directory:',d)
    for filename in os.listdir('neopTranslations_processed/Orthofinder/Results_Mar30/Orthologues/'+d):
        print('   Doing', filename)
        path = 'dipteraTranslations_processed/Orthofinder/Results_Mar30/Orthologues/' + d + '/' + filename

        with open(path,'r') as file:
            file.readline()
            for line in file:
                group, focal, other = line.strip('\n').split('\t')
                focal = focal.split(', ')
                other = other.split(', ')
                wholeGroup = focal
                wholeGroup.extend(other) #entire group that are orthologous according to this one line
                found = False
                for oldGroup in orthoGroupsReal:
                    if not set(wholeGroup).isdisjoint(set(oldGroup)): #if has things in common
                        #add on all new genes
                        oldGroup.extend(list(set(wholeGroup)-set(oldGroup)))
                        found = True
                        break #exit once have found the existing group
                if not found: #no matching group found, add the group for this line to the search
                    orthoGroupsReal.append(wholeGroup)
    print()
    print(dCount,'directories done')
    print()
for i, group in enumerate(orthoGroupsReal):
    cursor.execute('INSERT INTO groups_Orthofinder VALUES(?,?)',(i,','.join(group)))
db.commit()

In [None]:
# create list of singletons from blast output (all v all blastp, filtered to only hits with E value under 0.1)
with open('dsuz.filtered.out', 'r') as file:
    gL = []
    for line in file:
        line = line.strip().split('\t')
        gL.append(line[0])
sing_list = []
for x in set(gL):
#     if only one hit for a given gene, it must be a singleton (only hit that meets cutoff is self)
    if gL.count(x) == 1:
        sing_list.append(x)
print('Total singletons to be considered:',len(sing_list))
with open('dsuz_singletons.txt','w') as file:
    for p in sing_list:
        file.write(p+'\n')

In [None]:
#creating sequence table for easier fetching later
c = 1
cursor.execute('CREATE TABLE IF NOT EXISTS sequenceTab(id TEXT,seq TEXT)')
for file in os.listdir('dipteraTranslations_processed/'):
    if not file.endswith('.faa'):
        continue
    print('Species',str(c),'of 37:',file)
    fullPath = 'dipteraTranslations_processed/'+file
    inputList = []
    with open(fullPath,'r') as fasta:
        seq = ''
        for line in fasta:
            line = line.strip('\n')
            if line.startswith('>') and seq == '':
                i = line.strip('>')
                print(seq)
            elif line.startswith('>'):
                inputList.append((i,seq))
                i = line.strip('>')
                seq = ''
            else:
                seq = seq + line
                
        cursor.executemany('INSERT INTO sequenceTab VALUES(?,?)',(inputList))
        inputList = []
        c += 1
db.commit()

In [None]:
cursor.execute('UPDATE sequenceTab SET id = REPLACE(id,"(","_")')
cursor.execute('UPDATE sequenceTab SET id = REPLACE(id,")","_")')
cursor.execute('UPDATE sequenceTab SET id = REPLACE(id,":","_")')
db.commit()

In [None]:
def extractSingGroups(groupTable,singList):
#     obtain the correct groups for the singleton set
    realGroupDict = {}
    cursor.execute('SELECT * FROM '+groupTable)
    for a,b in cursor.fetchall():
        realGroupDict[a] = b.split(',')
    for s in singList:
        try:
            group = [x for x in realGroupDict if s in realGroupDict[x]][0]
            yield group
        except IndexError: #no group found
            try:
                group = [x for x in realGroupDict if s in ','.join(realGroupDict[x])][0]
                yield group
            except IndexError:
                continue

In [None]:
def createGroupFastas(groupGen,outputFastaDir):
    count, ti = 0, time.time()
    #create neopTranslations_by_orthogroup fastas
    #for each group fetch the sequences and add to file
    if not os.path.exists(outputFastaDir):
        os.makedirs(outputFastaDir)
    print('Creating fastas for singleton orthogroups...')
    for group in groupGen:
        groupFasta = outputFastaDir + '/group'+str(group)+'.fa'
        if os.path.exists(groupFasta):
            continue
        with open(groupFasta,'w') as out:
#         orthoGroupList = []
#         unAssigned = []
            cursor.execute('SELECT members FROM '+groupTable+' WHERE group == ?',(group,))
            geneList = cursor.fetchall()[0][0].split(',')
            
            for g in geneList:
                if g == '*':
                    continue
                if '(' in g and ')' in g:
                    g = g.replace('(','_')
                    g = g.replace(')','_')
                cursor.execute('SELECT seq FROM sequenceTab WHERE id == ?',(g,))
                try:
                    seq = cursor.fetchall()[0][0]
                except IndexError:
                    print(g, 'seems to be missing in seq table,from group',group)
                    raise ValueError
                out.write('>'+g+'\n')
                out.write(seq+'\n')

In [None]:
def doGroupAlignment(alignDirPath,fastaDirPath,singList=None,groupTable=None):
    count, ti = 0,time.time()
    if not singList:
        for f in os.listdir(fastaDirPath):
            fullInFastaPath = fastaDirPath + '/' + f
            outFasta = alignDirPath + '/' + f[:-3] + '_alignment.fa'
            if os.path.exists(outFasta):
                continue
            cline = ['muscle','-in', fullInFastaPath, '-out', outFasta]
            p = Popen(cline, stdout=PIPE, stderr=PIPE)
            out, err = p.communicate()
            count += 1
            if count % 200 == 0:
                print(count, 'done', round(time.time()-ti, 2), 'seconds')
    else:
        #using the fastas output by OMA so will need to filter to only align relevant groups
        groupList = extractSingGroups(groupTable,singList)
        for group in groupList:
            fullInFastaPath = fastaDirPath+ '/' + 'OG'+group[3:] + '.fa'
            outFasta = alignDirPath+ '/' + group + '_alignment.fa'
            cline = ['muscle','-in', fullInFastaPath, '-out', outFasta]
            p = Popen(cline, stdout=PIPE, stderr=PIPE)
            out, err = p.communicate()
            count += 1
            if count % 200 == 0:
                print(count, 'done', round(time.time()-ti, 2), 'seconds')

In [None]:
def buildTrees(alignDirPath):
    count = 0
    t = time.time()
    allFiles = [x for x in os.listdir(alignDirPath)]
    toDo = []
    for file in allFiles:
        if file.endswith('.fa') and file +'.log' not in allFiles:
            toDo.append(file)
    print('Trees to build, total:',len(toDo))
    print('Starting ...')
    for al in toDo:
        alFile = 'neopTranslations_alignments_current/' + al
        cmd = ['../SOFTWARE/iqtree-1.6.12-MacOSX/bin/iqtree', '-s', alFile, '-bb', str(1000),'-nt','3', '-mset','WAG,LG,JTT']
        p = Popen(cmd,stdout=PIPE, stderr = PIPE)
        out,err = p.communicate()
        count += 1
        if count % 100 == 0:
            print(count, 'done:', time.time()-t, 'seconds')

In [None]:
def get_species(name):
    return(name[-4:])
def processTrees(treeDirPath,singList,outTable,groupTable): #needs review?
    cursor.execute('CREATE TABLE IF NOT EXISTS '+outTable+'(id TEXT, orthogroup TEXT, realGroup TEXT, baseTree BLOB, prunedTree BLOB, excludedReason TEXT, split TEXT)')
    cursor.execute('SELECT * FROM '+groupTable)
    realGroupDict = dict(cursor.fetchall())
    for g in singList:
        try:
            realGroup = [x for x in realGroupDict if g in realGroupDict[x]][0]
            if str(realGroup).startswith('OMA'):
                treeFile = treeDirPath + '/' + str(realGroup) + '_alignment.fa.treefile'
                
            else:
                treeFile = treeDirPath + '/group'+str(realGroup) + '_alignment.fa.treefile'
            orthoGroup = realGroup
            with open(treeFile,'r') as treefile:
                tree = treefile.readline().strip('\n')
            tree = PhyloTree(tree, format=1) #hopefully
            for n in tree.traverse():
                n.set_species_naming_function(get_species)
            prefList = [['BCOP','AALB','ASTE','AAEG','AEAL','CQUI','CPIP'],
                    ['HILL'],['CCAP','BTRY'],['SLEB'],['DBUS','DALB','DVIR','DNOV','DHYD'],
                    ['DSUB','DGUA','DPSE','DPER','DMIR'],['DANA'],['DSER','DKIK'],
                    ['DELE','DFIC'],['DSUP','DSUZ','DBIA'],['DEUG'],['DERE','DYAK','DSAN']]
            for clade in prefList:
                outNodes = []
                for sp in clade:
                    for node in tree.traverse():
                        if sp in node.name:
                            outNodes.append(node)
                if outNodes == []: #that clade doesn't exist in this tree
                    continue
                elif len(outNodes) == 1: #single species, set as outgroup
                    tree.set_outgroup(outNodes[0])
                    break
                else:
                    rootNode = tree.get_common_ancestor(outNodes)
                    try:
                        tree.set_outgroup(rootNode)
                        break
                    except TreeError: #error where it's trying to root on it's current root node and can't, 
                        #it's because the outgroup is split so the common anc isn't a separate node
                        #the workaround is to root on some not-outgroup node so they get forced into a clade, 
                        #then root again using the correct node (which should now exist)
                        for node in tree.traverse():
                            #filter out nodes that aren't tips i.e. internal nodes, may be '' or bootstrap value
                            if get_species(node.name) not in clade and not node.name.isnumeric() and not node.name =='':
                                otherNode = node
                                break
                        #runs into problems when the only species present are in the clade being used to root
                        #those trees aren't going to be useful anyways
                        else:
#                             print('Gene',g,'may have an unrootable tree.')
                            continue
                        tree.set_outgroup(otherNode)
                        rootNode = tree.get_common_ancestor(outNodes)
                        tree.set_outgroup(rootNode)
                        break
            else: #it should break out of the loop if it finds a suitable clade
                raise ValueError('Could not find suitable root')
            baseTree = str(tree.write())
            #pruning tree to relevant species
            pList = []
            nodes = ['DBIA','DEUG','DMEL','DSIM','DMAU','DERE','DSEC','DYAK','DSUZ','DSAN','DSUP']
            for n in nodes:
                for node in tree.traverse():
                    if n in node.name:
                         pList.append(node.name)
            try:
                tree.prune(pList, preserve_branch_length=True)
            except:
                print('pruning failed')
                print(tree)
            pTree = str(tree.write())
            exclude = None
        except IndexError as e:
    #         print(e, 'group not found')
            realGroup = None
            tree = None
            baseTree = None 
            pTree = None
            exclude = 'Not assigned to group'
        except FileNotFoundError as e:
    #         print(e, 'group not found')
            orthoGroup = None
            realGroup = None
            tree = None
            baseTree = None 
            pTree = None
            exclude = 'Not assigned to group/tree could not be built from group'
        except ValueError as e:
            orthoGroup = None
            realGroup = None
            tree = None
            baseTree = None 
            pTree = None
            exclude = 'Tree not rooted'
        cursor.execute('INSERT INTO '+outTable+' VALUES(?,?,?,?,?,?,?)',(g,orthoGroup,realGroup,baseTree,pTree,exclude,None))
db.commit() 

In [None]:
def checkMissingSp(table):
    somethingMissing = 0
    cursor.execute('SELECT id,prunedTree FROM '+table+' WHERE excludedReason IS NULL')
    geneTreePrunedDict = dict(cursor.fetchall())
    for g in geneTreePrunedDict:
        tree = PhyloTree(geneTreePrunedDict[g])
        for n in tree.traverse():
            n.set_species_naming_function(get_species)
        sup, bia, suz, eug = False,False,False,False
        for node in tree.traverse():
            if 'DSUP' in node.name:
                sup = True
            elif 'DBIA' in node.name:
                bia = True
            elif 'DSUZ' in node.name:
                suz = True
            elif 'DEUG' in node.name:
                eug = True
        if all([sup,bia,suz,eug]):
            singTreeListPrunedMissingChecked.append(g)
        else:
            cursor.execute('UPDATE '+table+' SET excludedReason = "Outgroup missing" WHERE id == ?', (g,))
            somethingMissing += 1
    print('Number missing an outgroup species:',somethingMissing)

In [None]:
def checkOutgroupDup(table):
    try:
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN notes TEXT')
    except:
        pass #col exists
    cursor.execute('SELECT id, prunedTree FROM '+table+' WHERE excludedReason IS NULL')
    geneTreePrunedDict = dict(cursor.fetchall())
    singTreeOneEach, singTreeMulti = [], []
    cursor.execute
    for g in geneTreePrunedDict:
        tree = PhyloTree(geneTreePrunedDict[g])
        eug, sup, bia,suz = 0,0,0,0
        for node in tree.traverse():
            if 'DEUG' in node.name:
                eug += 1
            if 'DSUP' in node.name:
                sup += 1
            if 'DSUZ' in node.name:
                suz += 1
            if 'DBIA' in node.name:
                bia += 1
        if eug == 1 and suz == 1 and sup == 1 and bia == 1:
            singTreeOneEach.append(g)
        else:
            cursor.execute('UPDATE singleton_trees SET notes = "Multiple copies in outgroup" WHERE id == ?', (g,))
            singTreeMulti.append(g)

In [None]:
def flatten(List):
    outList = []
    for subList in List:
        outList.extend(subList)
    return(outList)

def checkSplits(table):
    #tree topology checks for splitting into multiple trees in cases where multicopy in the outgroup species
    #I can check the 'check monophyly' method here when I have to update the species
    #assign topology ids to all trees, check for straighforward duplications
    outSp = ['DEUG','DBIA','DSUP','DSUZ']
    try:
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN tree_top_id INTEGER')
    except:
        pass
    cursor.execute('SELECT id, prunedTree FROM '+table+' WHERE notes == "Multiple copies in outgroup"')
    geneTreePrunedDict = dict(cursor.fetchall())
    top_dict = {}
    print('Doing auto checks...',table)
    un = 0
    for g in geneTreePrunedDict:
        unsuitable = False
        tree = PhyloTree(geneTreePrunedDict[g])
        for s in outSp: 
            #initial check, get all nodes for each outgroup species, check if they form a monophyletic group
            #this should only be the case if it's a species specific dup, it's just to filter out some before the manual check
            if unsuitable:
                break
            nodeList = []
            for node in tree.traverse():
                if s in node.name:
                    nodeList.append(node.name)

            if len(nodeList) > 1:
                if tree.check_monophyly(values=nodeList,target_attr='name')[0]:
                    #singTreeMulti.pop(g) #it might get upset about altering an object I'm iterating over, might need to keep a list
                    cursor.execute('UPDATE '+table+' SET split = "F" WHERE id == ?',(g,))
                    cursor.execute('UPDATE '+table+' SET excludedReason = "Multiple copies in outgroup-split not possible/subtrees unsuitable" WHERE id == ?',(g,))
                    unsuitable = True
        if unsuitable:
            un += 1
            continue
        #set topology ids for any tree passing this point
        top = tree.get_topology_id(attr='species')
        try:
            top_dict[top].append(g)
        except KeyError:
            top_dict[top] = [g]
    print(un, 'trees unsuitable in auto checks')  
    #for each topology, look at one tree and check if it's suitable to split
    print('Assigning top IDs...')
    for top in top_dict:
        for g in top_dict[top]:
            cursor.execute('UPDATE '+table+' SET tree_top_id = ? WHERE id == ?',(top,g))

    #determine if trees can be split or not, manual inspection
    print('Manual checks...')
    manualDict = {}
    for top in top_dict:
        print(PhyloTree(geneTreePrunedDict[top_dict[top][0]]))
        opinion = input('Opinion? Answer Y for splittable, N for not ')
        while opinion != 'Y' and opinion != 'N':
            opinion = input('Opinion? ')
        manualDict[top] = opinion #must be either Y or N
    
    print('Setting outcome...')
    #set excluded reason for trees determined to be disappointments
    exTopList = [x for x in manualDict if manualDict[x] == 'N']
    exSingMultiTreesList = [top_dict[x] for x in top_dict if x in exTopList]
    exclGenes = flatten(exSingMultiTreesList)
    for g in exclGenes:
        cursor.execute('UPDATE '+table+' SET split = "F" WHERE id == ?',(g,))
        cursor.execute('UPDATE '+table+' SET excludedReason = "Multiple copies in outgroup-split not possible/subtrees unsuitable" WHERE id == ?',(g,))

    #update for trees that can be split
    splitTopList = [x for x in manualDict if manualDict[x] == 'Y']
    splitSingMultiTreesList = [top_dict[x] for x in top_dict if x in splitTopList]
    splitTreeGenes = flatten(splitSingMultiTreesList)
    for g in splitTreeGenes:
        cursor.execute('UPDATE '+table+' SET split = "T" WHERE id == ?',(g,))
    db.commit()
    print('Done',table)

In [None]:
def fancy_node_search(node,name):
    if name in node.name:
        return True
    else:
        return False
    
def doSplits(table,outTable):
    #split any trees that can be split, write the final singleton trees to new table (does this every 10 and at the end)
    #add all the trees that are straightforward, single copy outgroup situations to the dict for inserting into table
    cursor.execute('CREATE TABLE IF NOT EXISTS '+outTable+'(id TEXT, tree BLOB, excludedReason TEXT)')
    cursor.execute('SELECT id, prunedTree FROM '+table+' WHERE excludedReason IS NULL AND (split != "T" OR split IS NULL)')
    finalTreeDict = dict(cursor.fetchall())
    exSingMultiTreesList2 = []
    finalTrees = []
    redoList = []
    nodeNonsense = []
    # cursor.execute('SELECT id, prunedTree FROM singleton_trees WHERE notes == "Multiple copies in outgroup"')
    # geneTreePrunedDict = dict(cursor.fetchall())
    count = 0
    # tree splitting
    cursor.execute('SELECT id FROM '+table+' WHERE split == "T"')
    s = cursor.fetchall()
    if s == []:
        splitSingMultiTreesList = []
        print('No trees to split, inserting suitable trees...',len([x for x in finalTreeDict.keys()]))
        for g in finalTreeDict:
            cursor.execute('INSERT INTO '+outTable+'(id,tree) VALUES (?,?)', (g,finalTreeDict[g]))
            db.commit()
        print('Done')
    else:
        splitSingMultiTreesList = [x[0] for x in s]
    cursor.execute('SELECT id, prunedTree FROM '+table)
    geneTreePrunedDict = dict(cursor.fetchall())
    for g in splitSingMultiTreesList:
        count += 1
        error = False
        currentTreeSplits = []
        currentIndex = splitSingMultiTreesList.index(g)
        tree = PhyloTree(geneTreePrunedDict[g])
        tree.show()
        multi = input('Multiple usable trees? Answer Yes or No (Or anything else if the tree is not suitable)')
        if multi == 'No':
            try:
                node1Name = input("Node 1? Answer Forgot if didn't check the node name")
                node2Name = input('Node 2? ')
                if node1Name == 'Forgot':
                    tree.show()
                    node1Name = input('Node 1? ')
                    node2Name = input('Node 2? ')
                node1 = [x for x in filter(lambda node: fancy_node_search(node, node1Name.strip()),tree.traverse())][0]
                node2 = [x for x in filter(lambda node: fancy_node_search(node, node2Name.strip()),tree.traverse())][0]
                #     get common ancester to detach at
                splitTreeNode = node1.get_common_ancestor(node2)
                # pretty sure detach will give the tree below that node
                splitTree = splitTreeNode.detach()
                splitTree.show()
                okay = input('Okay?')
                if okay == 'Y':
                    currentTreeSplits.append(splitTree)
                elif okay == 'N':
                    break
                else:
                    check = input('All good?')
            except AttributeError:
                    nodeNonsense.append(g, geneTreePrunedDict[g])
                    continue
                    error = True
        elif multi == 'Yes':
            done = 'Y'
            while done == 'Y':
                try:
                    tree.show()
                    node1Name = input('Node 1? ')
                    node2Name = input('Node 2? ')
                    node1 = [x for x in filter(lambda node: fancy_node_search(node, node1Name.strip()),tree.traverse())][0]
                    node2 = [x for x in filter(lambda node: fancy_node_search(node, node2Name.strip()),tree.traverse())][0]
                    #     get common ancester to detach at
                    splitTreeNode = node1.get_common_ancestor(node2)
                    # pretty sure detach will give the tree below that node
                    splitTree = splitTreeNode.detach()
                    splitTree.show()
                    okay = input('Is the tree okay? Answer Y or N')
                    if okay == 'Y':
                        currentTreeSplits.append(splitTree)
                    elif okay == 'N':
                        break
                    else:
                        check = input('All good?')
                        if check == 'Y':
                            currentTreeSplits.append(splitTree)
                        elif check == 'N':
                            break

                    done = input('Any more trees? Answer Y or N')
                except AttributeError:
                    print('Something went wrong with the tree: unfinished!!')
                    nodeNonsense.append((g, geneTreePrunedDict[g]))
                    done = 'N'
                    error = True
                except:
                    done = 'Y'
                    print('Try again!')


        else:
            check = input('Did you mean to do that? Answer Yes or No')
            if check == 'No':
                redoList.append((g,geneTreePrunedDict[g]))
                error = True
            elif check == 'Yes':
                exSingMultiTreesList2.append(g)
                
        if not error:
            finalTrees.extend(currentTreeSplits)
        else:
            print('Something went wrong somewhere for gene '+g+'!')
            print('Lists for accidents and node issues will be returned')
        
        if count%10 == 0 or count == len(splitSingMultiTreesList):
            print("You've done " + str(count) + ', good job!')
            for tree in finalTrees:
                for node in tree.traverse():
                    if 'Dere' in node.name:
                        geneName = re.search('.*_(FB.*)_',node.name).group(1)
                        finalTreeDict[geneName] = tree.write()

            for g in finalTreeDict:
                print('Running insert...')
                cursor.execute('INSERT INTO '+outTable+'(id,tree) VALUES (?,?)', (g,finalTreeDict[g]))
            db.commit()
            print('Done')
            finalTreeDict = {}
            finalTrees = []

    # remove all the trees that turned out to be no use on a second look
    for g in exSingMultiTreesList2:
        cursor.execute('UPDATE '+table+' SET excludedReason = "Multiple copies in outgroup-split not possible/subtrees unsuitable" WHERE id == ?',(g,))
    db.commit()
    if redoList != [] or nodeNonsense != []:
        return [redoList,nodeNonsense]

In [None]:
def assignDupStatus(table):
    #assign duplication status
    try:
        cursor.execute('ALTER TABLE '+table+' ADD COlUMN dup_status TEXT')
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN dupInSp TEXT')
    except:
        pass

    cursor.execute('SELECT id, tree FROM '+table)
    finalTreeDict = dict(cursor.fetchall())

    i = 0
    dupc, presc = 0,0
    both = 0
    for g in finalTreeDict:
        dup, pres = False, False
        tree = PhyloTree(finalTreeDict[g])
        for n in tree.traverse():
            n.set_species_naming_function(get_species)
        sim, ere, mel, yak, sec, mau, san = 0,0,0,0,0 #counts for each species

        for node in tree.traverse():

            if node.name == '' or node.name.is_numeric():
                continue #internal nodes, no label or bootstrap value

            if 'DSIM' in node.name:
                sim += 1
            elif 'DERE' in node.name:
                ere += 1
            elif 'DMEL' in node.name:
                mel += 1
            elif 'DSEC' in node.name:
                sec += 1
            elif 'DYAK' in node.name:
                yak += 1
            elif 'DMAU' in node.name:
                mau += 1
            elif 'DSAN' in node.name:
                san += 1
            dupStatusList = [sim, ere, mel, yak, sec, mau, san]
        dup_test = map(lambda x: x <=1, dupStatusList)
        presence_test = map(lambda x: x > 0, dupStatusList)
    #     all returns False if any elements are False
        if not all(dup_test): #there is a duplicated gene somewhere, x<= 1 returned false at some point
            cursor.execute('UPDATE '+table+' SET dup_status = "D" WHERE id == ?', (g,))
            dupList = []
            spDict = {'Dyak': yak, 'Dsec':sec, 'Dmel':mel, 'Dsim':sim, 'Dere':ere,'Dmau':mau,'Dsan':san}
            for s in spDict:
                if spDict[s] > 1:
                    dupList.append(s)
    #                 print(dupList)
            cursor.execute('UPDATE '+table+' SET dupInSp = ? WHERE id == ?',(','.join(dupList),g))
            dup = True
        elif all(presence_test):
            cursor.execute('UPDATE '+table+' SET dup_status = "S" WHERE id == ?', (g,))
            pres = True
            presc += 1
        else: #exclude, not duplicated and some species missing so can't guarantee either status
            cursor.execute('UPDATE '+table+' SET excludedReason = "Missing species: no duplications" WHERE id == ?', (g,))
            both += 1

In [None]:
def checkDupTiming(table):
    # checking paralogs aren't ancestral duplications 
    maybeProblemCount = 0
    cursor.execute('SELECT id, tree FROM '+table+' WHERE dup_status == "D"')
    dupTreeDict = dict(cursor.fetchall())
    i =0
    okCount = 0
    for g in dupTreeDict:
        tree = PhyloTree(dupTreeDict[g])
        dupSp = []
        outNode = None
        dupStatusList = []
        ok = True
        for node in tree.traverse():
            if node.name == '':
                continue #internal nodes

            if 'DSIM' in node.name:
                dupStatusList.append('DSIM')
            elif 'DERE' in node.name:
                dupStatusList.append('DERE')
            elif 'DMEL' in node.name:
                dupStatusList.append('DMEL')
            elif 'DYAK' in node.name:
                dupStatusList.append('DYAK')
            elif 'DSEC' in node.name:
                dupStatusList.append('DSEC')
            elif 'DMAU' in node.name:
                dupStatusList.append('DMAU')
            elif 'DSAN' in node.name:
                dupStatusList.append('DSAN')
        for sp in set(dupStatusList):
            if dupStatusList.count(sp) > 1:
                dupSp.append(sp)

        for sp in dupSp:
            inNodeList = []
            for node in tree.traverse():
                if node.name == '':
                    continue
                if sp in node.name:
                    inNodeList.append(node)

                if 'DSUZ' in node.name:
                    outNode = node
            if len(inNodeList) > 2: #need to do selection of which pair to take
                min_dist = None
                for a,b in itertools.combinations(inNodeList,2):
                    dist = tree.get_distance(a,b)
                    if min_dist == None:
                        min_dist = dist
                        min_pair = (a,b)
                    elif dist < min_dist:
                        min_dist = dist
                        min_pair = (a,b)
                inNodeList = [min_pair[0],min_pair[1]]

        #     otherwise:
            betweenParalogs = tree.get_distance(inNodeList[0],inNodeList[1])
            outDistance1 = tree.get_distance(inNodeList[0],outNode)
            outDistance2 = tree.get_distance(inNodeList[1],outNode)

            if (betweenParalogs < outDistance1) and (betweenParalogs < outDistance2):
                ok = True
            else:
                ok = False

            if ok == False:
                cursor.execute('UPDATE '+table+' SET excludedReason = "Possible ancestral duplication" WHERE id == ?',(g,))
                break
        else:
            okCount +=1

    print('This many trees pass to this point: ',okCount)

In [None]:
def checkCorrectOutgroups(table):
    #  check outgroups actually are outgroups by checking monophyly of the other species
    cursor.execute('SELECT id, tree FROM '+table+' WHERE (NOT (dup_status IS NULL)) AND (excludedReason IS NULL)')
    treeDict = dict(cursor.fetchall())
    for g in treeDict:
        spList = ['DMEL','DSEC','DSIM','DYAK','DERE','DSAN','DMAU']
        tree = PhyloTree(treeDict[g], sp_naming_function=get_species)
        try:
            check = tree.check_monophyly(values=spList,target_attr='species')

        except ValueError: # in duplicable trees, there may be species missing
            spList2 = []
            for s in spList:
                 if s in [node.species for node in tree.traverse()]:
                        spList2.append(s)
            check = tree.check_monophyly(values=spList2,target_attr='species')
        if check[0] == False:
            cursor.execute('UPDATE '+table+' SET excludedReason = "Incorrect outgroups" WHERE id == ?',(g,))
    db.commit()

In [None]:
def retrieve_prot_id(node): #from trees
    import re
    name = node.name
    split = name.split("_")
    if len(split) == 3: #no internal underscores
        prot = split[1]
    elif len(split) == 4:
        prot = '_'.join(split[1:3])
    elif len(split) == 5:
        prot = '_'.join(split[2:4])
    else:
        prot = '_'.join(split[-3:-1])
    return prot
def retrieve_prot_id_align(string):
    prot = string.split('|')[1]
    return prot
def retrieve_species(node):
    import re
    if type(node) == str:
        sp = node[-4:]
    else:
        sp = node.name[-4:]
    return sp

In [None]:
def rateCalc(table,alignmentDir,origTable):
    #rate calc and comparison
    headerEndList= ['DMEL','DSIM','DSEC','DYAK','DERE','DSAN','DSUP','DSUZ','DBIA','DMAU','DEUG']
    try:
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN confirm_rate REAL')
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN proxy_rate REAL')

        cursor.execute('ALTER TABLE '+table+' ADD COLUMN confirm_dS REAL')
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN proxy_dS REAL')

        cursor.execute('ALTER TABLE '+table+' ADD COLUMN confirm_dN REAL')
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN proxy_dN REAL')
    except:
        pass

    # select trees
    cursor.execute('SELECT id,tree FROM '+table+' WHERE excludedReason IS NULL')
    trees = cursor.fetchall()

    for base_id, tree in tqdm(trees, miniters=50, mininterval=60):
        cursor.execute('SELECT realGroup FROM '+ origTable+' WHERE id == ?',(base_id,))
        realGroup = cursor.fetchall()[0][0]
        c = False
        cdsDict = {}
        musInput = ''
        alignList = []
        DmelCount = 0

        tree2 = PhyloTree(tree)
    #     get the number of D. mel paralogs
        for node in tree2.traverse():
            if 'DMEL' in node.name:
                DmelCount += 1
        if DmelCount == 0:
            cursor.execute('UPDATE '+table+' SET excludedReason = "Missing Dmel" WHERE id == ?',(base_id,))
            print('No DMEL in tree')
            continue
    #     need to get the correct sequences from files - have muscle alignments already from tree building
#         get the cds seq for each gene in tree, store in dict
        for node in tree2.traverse():
            if node.name == '' or node.name.isnumeric():
                continue
            ID = retrieve_prot_id(node)
            sp = retrieve_species(node)

            cmdCDS = ['sed', '-n', '-e', '/'+ ID +'/,/>/ p', 'dipteraCDS_raw/'+sp+'_raw_cds.fna']
            pCDS = Popen(cmdCDS, stdout=PIPE,stderr=PIPE)
            out,err = pCDS.communicate()
            
            CDSseq = out.split(bytes('\n','utf-8'))[1:-2]
            CDSseq = ''.join([x.decode('utf-8') for x in CDSseq])

            cdsDict['>'+ID + ':' + sp] = CDSseq
        if str(realGroup).startswith('OMA'):
            alignFile = alignmentDir + '/' + str(realGroup) + '_alignment.fa'
        else:
            alignFile = alignmentDir + '/group' + str(realGroup) + '_alignment.fa'
        #TODO: what exactly is going on with this alignment file, am I putting *all* the sequences into interProt.fa??
        with open(alignFile,'r') as prot_file, open('interCDS.fa','w') as cds_file, open('interProt.fa','w') as prot_align_file:
    #         muscle outputs alignments in a different order to input sequences
    #         cds sequences have to be written in same order to work with pal2nal
            alignSeq = ''
            alignDict = {}
            #put entire protein alignment in a string
            for line in prot_file:
                alignSeq = alignSeq + line
            
            #get the order the headers occur in
            order= [line for line in alignSeq.split('\n') if line.startswith('>')]
            #make new headers based on this
#             order2 = ['>'+retrieve_prot_id_align(x)+':'+retrieve_species(x) for x in order]
            order2=order
#             create new protein alignment file with 
            incl = False #changed from True, shouldn't really matter because it'll hit a header first
            for line in alignSeq.split('\n'):
                if line.startswith('>'):
#                     newHead = '>'+retrieve_prot_id_align(line)+':'+retrieve_species(line)
                    newHead = line
                    if newHead in order2:
                        alignDict[newHead] = ''
                        current = newHead
                        incl = True
                    else:
                        incl = False
                elif incl == True:
                    alignDict[current] = alignDict[current] + line
            # get CDS in right order as well, write each header and sequence to CDS/prot file
            for x in order2:
                try:
                    i = [y for y in cdsDict if retrieve_prot_id_align(x) in y][0]
                    seq = cdsDict[i]
                except IndexError:
                    continue #species not included in pruned trees ...I guess this is where I'm filtering??? -Looks that way
                cds_file.write(x.replace('|','_') + '\n')
                cds_file.write(seq + '\n')
                prot_align_file.write(x.replace('|','_') + '\n')
                prot_align_file.write(alignDict[x] + '\n')    
        # somehow convert to codon alignments - pal2nal
        #     usage: pal2nal prot_alignment dna_seq -output paml
    #     print('at pal2nal')
        cm = '../SOFTWARE/pal2nal.pl interProt.fa interCDS.fa -output paml'
        p2 = Popen(cm.split(' '), stdout=PIPE, stderr=PIPE)
        out2, err2 = p2.communicate()

        with open('interCodonAlign.paml','w') as cod_file, open('example.tree','w') as treeFile:
            cod_file.write(out2.decode())
            treeFile.write(tree)
    #     print('Finished that, on to sequence extracting')
    # #     extract specific sequences for pairwise dn/ds calc
    # #     need to take the first line for paml format to work, also need to change the first number to no. of seqs
        with open('interCodonAlign.paml','r') as inFile:
            for line in inFile:
                alignList.append(line)   
        
        with open('paml_input1.paml','w') as outFile, open('paml_input2.paml','w') as outFile2:
            dmelFound = False
            for i,line in enumerate(alignList):
                if i == 0:
                    testLine = line.split(' ')
                    intCount = 0
                    for ob in testLine:
                        if ob == '':
                            pass
                        else:
                            intCount += 1
                            if intCount == 2:
                                seqLen = ob
                    outLine = '  ' + '2' + '   ' + seqLen + '\n'
    #                 # number of sequences in the dmel file might be greater than 2 due to dups, = dyak paralogs +1
                    outLine2 = '  ' + str(DmelCount + 1) + '   ' + seqLen + '\n'
                    outFile.write(outLine) #proxy rate
                    outFile2.write(outLine2) #confirm rate
                else:
                    if 'DEUG' in line:
                        outFile.write(line)
                        for line2 in alignList[i+1:]:
#                             if not line2.startswith('F') and not line2.startswith('X') and not line2.startswith('N'):
                            if line2.strip('\n')[-4:] not in headerEndList:
                                outFile.write(line2)
                            else:
                                break
                        outFile2.write(line)
                        for line2 in alignList[i+1:]: #write all lines following the 'DEUG' line until next id line
#                             if not line2.startswith('F') and not line2.startswith('X') and not line2.startswith('N'):
                            if line2.strip('\n')[-4:] not in headerEndList:
                                outFile2.write(line2)
                            else:
                                break
                    elif 'DMEL' in line:
                        dmelFound = True
                        outFile2.write(line)
                        for line2 in alignList[i+1:]:
#                             if not line2.startswith('F') and not line2.startswith('X') and not line2.startswith('N'):
                            if line2.strip('\n')[-4:] not in headerEndList:
                                outFile2.write(line2)
                            else:
                                break
                    elif 'DSUZ' in line:
                        outFile.write(line)
                        for line2 in alignList[i+1:]:
#                             if not line2.startswith('F') and not line2.startswith('X') and not line2.startswith('N'):
                            if line2.strip('\n')[-4:] not in headerEndList:
                                outFile.write(line2)
                            else:
                                break
#         break
    #     print('Sequences extracted')
        if not dmelFound:
            cursor.execute('UPDATE '+table+' SET excludedReason = "missing Dmel ortholog" WHERE id == ?',(base_id,))
            print('Somehow I passed the first check but without Dmel') #likely 
            continue
    #     print('Running PAML now')
    #     # run PAML on alignments - this needs worked on, need to get the mean rate across Dmel duplicates
        cmd = codeml.Codeml(alignment='paml_input1.paml', tree='example.tree', out_file='results.out',working_dir='.')
        cmd.read_ctl_file('codeml.ctl')
        cmd.alignment='paml_input1.paml'
        cmd.verbose = True
        try:
            output = cmd.run(command='../SOFTWARE/paml4.9j/bin/codeml')
        except Exception as e:
            print(e)
    #         print('abandon all hope, we have a PAML issue')
            continue
        
    #    extract dn/ds value for each pair
        id1 = [x for x in output['pairwise'].keys()][0]
        id2 = [x for x in output['pairwise'].keys()][1]

        resDict = output['pairwise'][id1]
        ds3 = resDict[id2]['dS']
        dn3 = resDict[id2]['dN']
        dnds3 = resDict[id2]['omega']

        cursor.execute('UPDATE '+table+' SET proxy_rate = ? WHERE id == ?',(dnds3,base_id))
        cursor.execute('UPDATE '+table+' SET proxy_dS = ? WHERE id == ?',(ds3,base_id))
        cursor.execute('UPDATE '+table+' SET proxy_dN = ? WHERE id == ?',(dn3,base_id))

    #      run PAML on alignment2 (confirm)
        cmd = codeml.Codeml(alignment='paml_input2.paml', tree='example.tree', out_file='results.out',working_dir='.')
        cmd.read_ctl_file('codeml.ctl')
        cmd.alignment='paml_input2.paml'
        try:
            output = cmd.run(command='../SOFTWARE/paml4.9j/bin/codeml') #might need to add parse = True
        except Exception as e:
            print('File 2')
            print(e)
    #         print('abandon all hope, we have a PAML issue')
            continue

        id1 = [x for x in output['pairwise'].keys() if 'DEUG' in x][0]
        id2List = [x for x in output['pairwise'].keys()]
        dndsList = []
        dsList, dnList = [],[]

        resDict = output['pairwise'][id1]
        for id2 in [x for x in id2List if x != id1]:
            ds = resDict[id2]['dS']
            dn = resDict[id2]['dN']
            dnds = resDict[id2]['omega']

            dndsList.append(dnds)
            dsList.append(ds)
            dnList.append(dn)


            dnds2 = np.mean(dndsList)
            dn2 = np.mean(dnList)
            ds2 = np.mean(dsList)

    #     print('all done the second alignment,',ds2,dn2,dnds2)  
        cursor.execute('UPDATE '+table+' SET confirm_rate = ? WHERE id == ?',(dnds2,base_id))
        cursor.execute('UPDATE '+table+' SET confirm_dS = ? WHERE id == ?',(ds2,base_id))
        cursor.execute('UPDATE '+table+' SET confirm_dN = ? WHERE id == ?',(dn2,base_id))

    db.commit()

In [None]:
def rateComp(table,pairing,limitDS=False):
    if pairing == 'confirm' and not limitDS:
        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL')
        dupList_dn = [x[0] for x in cursor.fetchall()]

        print('Confirm rate p val:')
        print(mannwhitneyu(singList_rate, dupList_rate,alternative='two-sided'))
        print('Medians:',np.median(singList_rate),np.median(dupList_rate))
        print('Means:',np.mean(singList_rate),np.mean(dupList_rate))
        print('')

        print('Confirm dN p val:')
        print(mannwhitneyu(singList_dn, dupList_dn,alternative='two-sided'))
        print('Medians:',np.median(singList_dn),np.median(dupList_dn))
        print('Means:',np.mean(singList_dn),np.mean(dupList_dn))

        print('Confirm dS p val:')
        print(mannwhitneyu(singList_ds, dupList_ds,alternative='two-sided'))
        print('Medians:',np.median(singList_ds),np.median(dupList_ds))
        print('Means:',np.mean(singList_ds),np.mean(dupList_ds))
        print('')
    
    elif pairing == 'confirm':
        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_dn = [x[0] for x in cursor.fetchall()]

        print('Confirm rate p val:')
        print(mannwhitneyu(singList_rate, dupList_rate,alternative='two-sided'))
        print('Medians:',np.median(singList_rate),np.median(dupList_rate))
        print('Means:',np.mean(singList_rate),np.mean(dupList_rate))
        print('')

        print('Confirm dN p val:')
        print(mannwhitneyu(singList_dn, dupList_dn,alternative='two-sided'))
        print('Medians:',np.median(singList_dn),np.median(dupList_dn))
        print('Means:',np.mean(singList_dn),np.mean(dupList_dn))

        print('Confirm dS p val:')
        print(mannwhitneyu(singList_ds, dupList_ds,alternative='two-sided'))
        print('Medians:',np.median(singList_ds),np.median(dupList_ds))
        print('Means:',np.mean(singList_ds),np.mean(dupList_ds))
        print('')
    
    elif pairing == 'proxy' and not limitDS:
        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL')
        dupList_dn = [x[0] for x in cursor.fetchall()]

        print('Proxy rate p val:')
        print(mannwhitneyu(singList_rate, dupList_rate,alternative='two-sided'))
        print('Medians:',np.median(singList_rate),np.median(dupList_rate))
        print('Means:',np.mean(singList_rate),np.mean(dupList_rate))
        print('')

        print('Proxy dN p val:')
        print(mannwhitneyu(singList_dn, dupList_dn,alternative='two-sided'))
        print('Medians:',np.median(singList_dn),np.median(dupList_dn))
        print('Means:',np.mean(singList_dn),np.mean(dupList_dn))

        print('Proxy dS p val:')
        print(mannwhitneyu(singList_ds, dupList_ds,alternative='two-sided'))
        print('Medians:',np.median(singList_ds),np.median(dupList_ds))
        print('Means:',np.mean(singList_ds),np.mean(dupList_ds))
        print('')
    
    elif pairing == 'proxy':
        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_dn = [x[0] for x in cursor.fetchall()]

        print('Proxy rate p val:')
        print(mannwhitneyu(singList_rate, dupList_rate,alternative='two-sided'))
        print('Medians:',np.median(singList_rate),np.median(dupList_rate))
        print('Means:',np.mean(singList_rate),np.mean(dupList_rate))
        print('')

        print('Proxy dN p val:')
        print(mannwhitneyu(singList_dn, dupList_dn,alternative='two-sided'))
        print('Medians:',np.median(singList_dn),np.median(dupList_dn))
        print('Means:',np.mean(singList_dn),np.mean(dupList_dn))

        print('Proxy dS p val:')
        print(mannwhitneyu(singList_ds, dupList_ds,alternative='two-sided'))
        print('Medians:',np.median(singList_ds),np.median(dupList_ds))
        print('Means:',np.mean(singList_ds),np.mean(dupList_ds))
        print('')

In [None]:
def fetchPval(list1,list2):
    p = mannwhitneyu(list1,list2,alternative='two-sided').pvalue
    outString = str(round(p,4))
    return outString
def generateRateCompFigure(table,pairing,limitDS=False): #CHECK - anything from here on might be altered on the mac version
    if pairing == 'confirm' and not limitDS:
        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL')
        dupList_dn = [x[0] for x in cursor.fetchall()]

    
    elif pairing == 'confirm':
        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_rate FROM processed_trees WHERE NOT (confirm_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dS FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "S" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT confirm_dN FROM processed_trees WHERE dup_status == "D" AND NOT (confirm_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_dn = [x[0] for x in cursor.fetchall()]
    
    elif pairing == 'proxy' and not limitDS:
        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL')
        dupList_dn = [x[0] for x in cursor.fetchall()]
    
    elif pairing == 'proxy':
        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "S" AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_rate FROM processed_trees WHERE NOT (proxy_rate IS NULL) AND dup_status == "D" AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_rate = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dS FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dS IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_ds = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "S" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        singList_dn = [x[0] for x in cursor.fetchall()]

        cursor.execute('SELECT proxy_dN FROM processed_trees WHERE dup_status == "D" AND NOT (proxy_dN IS NULL) AND excludedReason IS NULL AND confirm_dS <= 4')
        dupList_dn = [x[0] for x in cursor.fetchall()]
    
    
    fig, axes = plt.subplots(1,3,figsize=(30,10))
    ax1, ax2, ax3 = axes
    sns.set()
    
    for a in [ax1,ax2,ax3]:
        a.yaxis.set_tick_params(labelsize=15)
        a.xaxis.set_tick_params(labelsize=15)
    if pairing == 'confirm':
        ax2.text(0.5,1.3,'Confirmatory comparisons: D. eugracilis v D. melanogaster',fontdict=dict(size=16))
    elif pairing == 'proxy':
        ax2.text(0.5,1.3,'Proxy comparisons: D. eugracilis v D. suzukii',fontdict=dict(size=16))
    
    b= ax1.boxplot([singList_dn,dupList_dn], showmeans= True, meanline= True, flierprops=dict(markersize=1),labels=['Singleton','Duplicate'],patch_artist=True)

    b1['boxes'][0].set_fc('#DC3220')
    b1['boxes'][1].set_fc('#005AB5')
    for m in b1['medians']:
            m.set_color('k')
            m.set_lw(2)

    ax1.set_ylabel('dN',fontdict=dict(fontsize=15))
    


    b2= ax2.boxplot([singList_ds,dupList_ds], flierprops=dict(markersize=1),labels=['Singleton','Duplicate'],patch_artist=True)
    
    b2['boxes'][0].set_fc('#DC3220')
    b2['boxes'][1].set_fc('#005AB5')
    for m in b2['medians']:
            m.set_color('k')
            m.set_lw(2)

    ax2.set_ylabel('dS',fontdict=dict(fontsize=15))
    


    b3= ax3.boxplot([singList_rate,dupList_rate], flierprops=dict(markersize=1),labels=['Singleton','Duplicate'],patch_artist=True)

    b3['boxes'][0].set_fc('#DC3220')
    b3['boxes'][1].set_fc('#005AB5')
    for m in b3['medians']:
            m.set_color('k')
            m.set_lw(2)

    ax3.set_ylabel('dN/dS',fontdict=dict(fontsize=15))
    

    ax1.text(0.24,0.01,'p = '+fetchPval(singList_dn,dupList_dn), fontsize=16, ha='center',transform=fig.transFigure) #need to change these? Might just put p vals on plot
    ax2.text(0.515,0.01,'p = '+fetchPval(singList_ds,dupList_ds), fontsize=16, ha='center',transform=fig.transFigure)
    ax3.text(0.79,0.01,'p = '+fetchPval(singList_rate,dupList_rate), fontsize=16, ha='center',transform=fig.transFigure)
    
    if pairing == 'confirm' and not limitDS:
        plt.savefig('confirmComp_'+table+'_final.eps',bbox_inches='tight')
    elif pairing == 'confirm':
        plt.savefig('confirmComp_dsUnder4_'+table+'_final.eps',bbox_inches='tight')
    elif pairing == 'proxy' and not limitDS:
        plt.savefig('proxyComp_'+table+'_final.eps',bbox_inches='tight')
    elif pairing == 'proxy':
        plt.savefig('proxyComp_dsUnder4_'+table+'_final.eps',bbox_inches='tight')
    plt.show()

In [None]:
#confounders: pairwise comparisons for possible confounders and correlations with rate
def gc_percent(seq):
    length = len(seq)
    gc = seq.count('G') + seq.count('C')
    return gc/length
def confounderInsert(table):
    cursor.execute('SELECT id FROM ' + table +' WHERE excludedReason IS NULL')
    res = [x[0] for x in cursor.fetchall()]
    singList = [re.search('^(.*?)\|',x).group(1) for x in res] #protein ids
    idDict = dict(zip(singList,res))
    
    try:
        cursor.execute('ALTER TABLE ' + table +' ADD COLUMN cdsLen INTEGER')
        cursor.execute('ALTER TABLE ' + table +' ADD COLUMN gc REAL')
        cursor.execute('ALTER TABLE ' + table +' ADD COLUMN gc3 REAL')
        cursor.execute('ALTER TABLE ' + table +' ADD COLUMN exp REAL')
    except:
        pass
    res = [x[0] for x in cursor.fetchall()]
    

    seqDict = {}
    incSeq = False
    currentSeq = ''
    currentID = None

    #fetch cds sequence for each gene

    with open('dipteraCDS_raw/DSUZ_raw_cds.fna','r') as file:
        print('started')
        for line in file:
            line = line.strip('\n')
            if line.startswith('>') and re.search('\[protein_id=(.*?)\]',line).group(1) in singList:
                if currentID:#two inc seqs in a row
                    seqDict[currentID] = currentSeq
                    currentID = None
                    currentSeq = ''
                    currentID = re.search('\[protein_id=(.*?)\]',line).group(1)
                else:
                    currentID = re.search('\[protein_id=(.*?)\]',line).group(1)
                    incSeq = True
            elif line.startswith('>'):
                if currentID:
                    seqDict[currentID] = currentSeq
                    currentID = None
                    currentSeq = ''
                    incSeq = False
            elif incSeq:
                currentSeq = currentSeq + line

    #length and gc content , check case of sequence, check id format in cds file and exp output
    for gene in seqDict:
        l = len(seqDict[gene])
        gc = gc_percent(seqDict[gene])
        pos3 = seqDict[gene][2::3]
        gc3 = gc_percent(pos3)
        cursor.execute('UPDATE ' + table + ' SET cdsLen = ?, gc = ?, gc3 = ? WHERE id == ?',(l,gc,gc3,idDict[gene]))
    #expression
    with open('dSuz_exp.genes.results','r') as file: 
        file.readline()
        for line in file:
            line = line.strip('\n').split('\t')
            ID, tpm = re.search('gene-(.*)$',line[0]).group(1),float(line[5])
            if ID in singList:
                cursor.execute('UPDATE ' +table+ ' SET exp = ? WHERE id == ?',(tpm,idDict[ID]))
    db.commit()

In [None]:
def confounderCompCorr(table):
    #comparison of features between sing and dup genes, correlations
    cursor.execute('SELECT cdsLen, gc, gc3, exp FROM ' + table +' WHERE dup_status == "S" AND excludedReason IS NULL AND proxy_dS < 4')
    sRes = cursor.fetchall()
    cursor.execute('SELECT cdsLen, gc, gc3, exp FROM ' + table +' WHERE dup_status == "D" AND excludedReason IS NULL AND proxy_dS < 4')
    dRes = cursor.fetchall()
    cursor.execute('SELECT cdsLen, gc, gc3, exp, proxy_rate FROM ' + table +' WHERE excludedReason IS NULL AND proxy_rate IS NOT NULL AND proxy_dS < 4')
    totRes = cursor.fetchall()
    print(len(sRes),len(dRes))
    print('CDS Length:') 
    print('Sing vs Dup comp:')
    print(mannwhitneyu([log10(x[0]) if x[0] > 0 else 0 for x in sRes],[log10(x[0]) if x[0] > 0 else 0 for x in dRes],alternative='two-sided'))
    print('Medians:')
    print(np.median([x[0] for x in sRes]),np.median([x[0] for x in dRes]))
    print('Overall corr:')
    print(spearmanr([x[4] for x in totRes],[x[0] for x in totRes]))
    print()
    print('GC content:')
    print('Sing vs Dup comp:')
    print(mannwhitneyu([x[1] for x in sRes],[x[1] for x in dRes],alternative='two-sided'))
    print('Medians:')
    print(np.median([x[1] for x in sRes]),np.median([x[1] for x in dRes]))
    print('Overall corr:')
    print(spearmanr([x[4] for x in totRes],[x[1] for x in totRes]))
    print()
    print('GC3 content:')
    print('Sing vs Dup comp:')
    print(mannwhitneyu([x[2] for x in sRes],[x[2] for x in dRes],alternative='two-sided'))
    print('Medians:')
    print(np.median([x[2] for x in sRes]),np.median([x[2] for x in dRes]))
    print('Overall corr:')
    print(spearmanr([x[4] for x in totRes],[x[2] for x in totRes]))
    print()
    print('Expression:')
    sExpRes = [x for x in sRes if x[3]]
    dExpRes = [x for x in dRes if x[3]]
    totExpRes = [x for x in totRes if x[3]]
    print('Sing vs Dup comp:')
    print(mannwhitneyu([log10(x[3]) if x[3] > 0 else 0 for x in sExpRes],[log10(x[3]) if x[3] > 0 else 0 for x in dExpRes],alternative='two-sided'))
    print('Medians:')
    print(np.median([log10(x[3]) if x[3] > 0 else 0 for x in sExpRes]),np.median([log10(x[3]) if x[3] > 0 else 0 for x in dExpRes]))
    print('Overall corr:')
    print(spearmanr([x[4] for x in totExpRes],[x[3] for x in totExpRes]))

In [None]:
#regression on dN/dS
#going to do OLS and LOWESS
def doRegression(table):
    cursor.execute('SELECT proxy_rate, cdsLen, gc, gc3, exp, dup_status FROM '+table+' WHERE excludedReason IS NULL AND proxy_dS < 4') 
    data = pd.DataFrame(cursor.fetchall())
    data.columns = ['Rate','CDS','GC','GC3','Exp','DupStatus']
    data = data[data['Exp']>0]
    data['CDS'] = np.log10(data['CDS'])
    data['Exp'] = np.log10(data['Exp'])
    data['GC'] = data['GC']*100
    data['GC3'] = data['GC3']*100
    data['Rate'] = np.log10(data['Rate'])

    dataDup = data[data['DupStatus']=='D']
    dataSing = data[data['DupStatus']=='S']

    for feature in ['CDS','GC','GC3','Exp']:
        model = smf.ols('Rate ~ '+ feature, data=data)
        res = model.fit()
    #     print(res.summary())
        residuals = res.resid
        fittedVals = res.fittedvalues
        data['linear_fit'+feature] = fittedVals
        data['linear_resid'+feature] = residuals

        print(feature)
        print('Original comp:')
        dOrig = data[data['DupStatus']=='D']['Rate']
        sOrig = data[data['DupStatus']=='S']['Rate']
        print(mannwhitneyu(sOrig,dOrig,alternative='two-sided'))
        print('OLS residuals:')
        dResid = data[data['DupStatus']=='D']['linear_resid'+feature]
        sResid = data[data['DupStatus']=='S']['linear_resid'+feature]
        print(mannwhitneyu(sResid,dResid,alternative='two-sided'))
        print('LOWESS residuals:')
        smoothModel = sm.nonparametric.lowess(data['Rate'],data[feature],frac=1/3,return_sorted=False)
        data['lowess_fit'+feature] = smoothModel
        data['lowess_resid'+feature] = data['Rate']-data['lowess_fit'+feature]


        dResid = data[data['DupStatus']=='D']['lowess_resid'+feature]
        sResid = data[data['DupStatus']=='S']['lowess_resid'+feature]
        print(mannwhitneyu(sResid,dResid,alternative='two-sided'))
        
    return data

In [None]:
def generateResidPlots(data):
    import seaborn as sns
    from scipy.stats import probplot
    sns.set()
    fig, axes = plt.subplots(3,4,figsize=(30,20))
    
    scatterRow, histRow, normRow = axes
    nameDict = {'CDS':'CDS length','Exp':'Expression level','GC':'GC content','GC3':'GC3 content'}
    hPos = 0
    for feat in ['CDS', 'Exp','GC','GC3']:
        resids = data['linear_resid'+feat]
        fitted = data['linear_fit'+feat]
        #plot of resids vs rate
        scatterRow[hPos].plot(fitted,resids,'bo')
        scatterRow[hPos].set_xlabel('Fitted values',fontsize=16)
        scatterRow[hPos].set_ylabel('Residuals',fontsize=16)
        scatterRow[hPos].xaxis.set_tick_params(labelsize=16)
        scatterRow[hPos].yaxis.set_tick_params(labelsize=16)
        scatterRow[hPos].set_title(nameDict[feat],fontsize=18)
        
        #distribution of resids - possibly test of normality
        histRow[hPos].hist(resids)
        histRow[hPos].set_xlabel('Residuals',fontsize=16)
        histRow[hPos].set_ylabel('Frequency',fontsize=16)
        histRow[hPos].xaxis.set_tick_params(labelsize=16)
        histRow[hPos].yaxis.set_tick_params(labelsize=16)
        
        
        #QQ plot of resids
        probplot(resids,dist='norm',plot=normRow[hPos])
        normRow[hPos].xaxis.set_tick_params(labelsize=16)
        normRow[hPos].yaxis.set_tick_params(labelsize=16)
        normRow[hPos].set_ylabel('Ordered values',fontsize=16)
        normRow[hPos].set_xlabel('Theoretical quantiles',fontsize=16)
        normRow[hPos].set_title('')

        hPos += 1
    plt.savefig('residualDiagnostics.svg',bbox_inches='tight')
    plt.show()

In [None]:
def confounderBoxPlot(data,resids):
    dataDup = data[data['DupStatus']=='D']
    dataSing = data[data['DupStatus']=='S']
    
    fig, axes = plt.subplots(1,4,figsize=(20,5))
    sns.set()
    for feat, ax in zip(['CDS','Exp','GC','GC3'],axes):
        if resids == 'orig':
            boxes1 = ax.boxplot([dataSing[feat],dataDup[feat]],patch_artist=True,labels=['Singleton','Duplicable'],flierprops={'ms':1})
        elif resids == 'OLS':
            boxes1 = ax.boxplot([dataSing['linear_resid'+feat],dataDup['linear_resid'+feat]],patch_artist=True,labels=['Singleton','Duplicable'],flierprops={'ms':1})
        elif resids == 'lowess':
            boxes1 = ax.boxplot([dataSing['lowess_resid'+feat],dataDup['lowess_resid'+feat]],patch_artist=True,labels=['Singleton','Duplicable'],flierprops={'ms':1})
        boxes1['boxes'][0].set_fc('#DC3220')
        boxes1['boxes'][1].set_fc('#005AB5')
        for m in boxes1['medians']:
            m.set_color('k')
            m.set_lw(2)

        if feat == 'CDS':
            ax.set_title('CDS length')
            ax.set_ylabel('log(bp)')
        elif feat == 'Exp':
            ax.set_title('Expression')
            ax.set_ylabel('log(TPM)')
        elif feat == 'GC':
            ax.set_title('% GC')
            ax.set_ylabel('%')
        elif feat == 'GC3':
            ax.set_title('% GC3')
            ax.set_ylabel('%')
            
        if resids != 'orig':
            ax.set_ylabel('Residuals')
    plt.savefig('confounderComp_'+resids+'.eps',bbox_inches='tight')
    plt.show()

In [None]:
def regressionPlot(table,rType):
# Plotting regression models
    cursor.execute('SELECT proxy_rate, cdsLen, gc, gc3, exp, dup_status FROM '+table+' WHERE excludedReason IS NULL AND proxy_dS < 4') 
    data = pd.DataFrame(cursor.fetchall())
    data.columns = ['Rate','CDS','GC','GC3','Exp','DupStatus']
    data = data[data['Exp']>0]
    data['CDS'] = np.log10(data['CDS'])
    data['Exp'] = np.log10(data['Exp'])
    data['GC'] = data['GC']*100
    data['GC3'] = data['GC3']*100
    data['Rate'] = np.log10(data['Rate'])

    dataDup = data[data['DupStatus']=='D']
    dataSing = data[data['DupStatus']=='S']

    sns.set()
    fig, axes = plt.subplots(2,2,figsize=(10,10),gridspec_kw={'hspace':0.35})
    
    
    if rType == 'lowess':
        # sns.regplot(x='Exp', y='Rate', data=data, scatter_kws={'s':0.5},ax=axes[0][0])
        axes[0][0].plot(dataSing['CDS'],dataSing['Rate'],'o',color='#DC3220',ms=2)
        axes[0][0].plot(dataDup['CDS'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='CDS',y='Rate',data=data,ax=axes[0][0],scatter=False,color='black',lowess=True)
        sns.regplot(x='CDS',y='Rate',data=dataDup,ax=axes[0][0],scatter=False,color='#005AB5',truncate=False,lowess=True)
        sns.regplot(x='CDS',y='Rate',data=dataSing,ax=axes[0][0],scatter=False,color='#DC3220',lowess=True)
        axes[0][0].set_title('CDS length')
        axes[0][0].set_xlabel('log(bp)')
        
        axes[0][1].plot(dataSing['Exp'],dataSing['Rate'],'o',color='#DC3220',ms=2)
        axes[0][1].plot(dataDup['Exp'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='Exp',y='Rate',data=data,ax=axes[0][1],scatter=False,color='black',lowess=True)
        sns.regplot(x='Exp',y='Rate',data=dataDup,ax=axes[0][1],scatter=False,color='#005AB5',truncate=False,lowess=True)
        sns.regplot(x='Exp',y='Rate',data=dataSing,ax=axes[0][1],scatter=False,color='#DC3220',lowess=True)
        axes[0][1].set_title('Expression')
        axes[0][1].set_xlabel('log(TPM)')
        
        axes[1][0].plot(dataSing['GC'],dataSing['Rate'],'o',color='#DC3220',ms=2)
        axes[1][0].plot(dataDup['GC'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='GC',y='Rate',data=data,ax=axes[1][0],scatter=False,color='black',lowess=True,label='All')
        sns.regplot(x='GC',y='Rate',data=dataDup,ax=axes[1][0],scatter=False,color='#005AB5',truncate=False,lowess=True,label='Duplicable')
        sns.regplot(x='GC',y='Rate',data=dataSing,ax=axes[1][0],scatter=False,color='#DC3220',lowess=True,label='Singletons')
        axes[1][0].set_title('% GC')
        axes[1][0].set_xlabel('%')
        
        axes[1][1].plot(dataSing['GC3'],dataSing['Rate'],'o',color='#DC3220',ms=2,fillstyle='none')
        axes[1][1].plot(dataDup['GC3'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='GC3',y='Rate',data=data,ax=axes[1][1],scatter=False,color='black',lowess=True)
        sns.regplot(x='GC3',y='Rate',data=dataDup,ax=axes[1][1],scatter=False,color='#005AB5',truncate=False,lowess=True)
        sns.regplot(x='GC3',y='Rate',data=dataSing,ax=axes[1][1],scatter=False,color='#DC3220',lowess=True)
        axes[1][1].set_title('% GC3')
        axes[1][1].set_xlabel('%')
    
    elif rType == 'OLS':
        axes[0][0].plot(dataSing['CDS'],dataSing['Rate'],'o',color='#DC3220',ms=2)
        axes[0][0].plot(dataDup['CDS'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='CDS',y='Rate',data=data,ax=axes[0][0],scatter=False,color='black')
        sns.regplot(x='CDS',y='Rate',data=dataDup,ax=axes[0][0],scatter=False,color='#005AB5',truncate=False)
        sns.regplot(x='CDS',y='Rate',data=dataSing,ax=axes[0][0],scatter=False,color='#DC3220')
        axes[0][0].set_title('CDS length')
        axes[0][0].set_xlabel('log(bp)')
        
        axes[0][1].plot(dataSing['Exp'],dataSing['Rate'],'o',color='#DC3220',ms=2)
        axes[0][1].plot(dataDup['Exp'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='Exp',y='Rate',data=data,ax=axes[0][1],scatter=False,color='black')
        sns.regplot(x='Exp',y='Rate',data=dataDup,ax=axes[0][1],scatter=False,color='#005AB5',truncate=False)
        sns.regplot(x='Exp',y='Rate',data=dataSing,ax=axes[0][1],scatter=False,color='#DC3220')
        axes[0][1].set_title('Expression')
        axes[0][1].set_xlabel('log(TPM)')

        axes[1][0].plot(dataSing['GC'],dataSing['Rate'],'o',color='#DC3220',ms=2)
        axes[1][0].plot(dataDup['GC'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='GC',y='Rate',data=data,ax=axes[1][0],scatter=False,color='black',label='All')
        sns.regplot(x='GC',y='Rate',data=dataDup,ax=axes[1][0],scatter=False,color='#005AB5',truncate=False,label='Duplicable')
        sns.regplot(x='GC',y='Rate',data=dataSing,ax=axes[1][0],scatter=False,color='#DC3220',label='Singleton')
        axes[1][0].set_title('% GC')
        axes[1][0].set_xlabel('%')

        axes[1][1].plot(dataSing['GC3'],dataSing['Rate'],'o',color='#DC3220',ms=1.8)
        axes[1][1].plot(dataDup['GC3'],dataDup['Rate'],'o',color='#005AB5',ms=2)
        sns.regplot(x='GC3',y='Rate',data=data,ax=axes[1][1],scatter=False,color='black')
        sns.regplot(x='GC3',y='Rate',data=dataDup,ax=axes[1][1],scatter=False,color='#005AB5',truncate=False)
        sns.regplot(x='GC3',y='Rate',data=dataSing,ax=axes[1][1],scatter=False,color='#DC3220')
        axes[1][1].set_title('% GC3')
        axes[1][1].set_xlabel('%')
        
    axes[0][0].set_ylabel('log(dN/dS)')
    axes[0][1].set_ylabel('log(dN/dS)')
    axes[1][0].set_ylabel('log(dN/dS)')
    axes[1][1].set_ylabel('log(dN/dS)')  
    
    axes[1][0].legend()
    
    plt.savefig(table+'_'+rType+'_final.eps',bbox_inches='tight')

In [None]:
#monte carlo simulations to test sig of change in p value in rate comp when comparing residuals
def residComp(feature, method, df, actualPVal, lower=True, it=1000):
    import random,time
    t = time.time()
    count = 0
    df2 = df.copy(deep=True)
    if method == 'ols':
        for i in range(it):
            df2[feature] = np.random.permutation(df[feature].values)
            model = smf.ols('Rate ~ '+ feature, data=df2)
            res = model.fit()
        #     print(res.summary())
            residuals = res.resid
            fittedVals = res.fittedvalues
            df2['linear_fit'+feature] = fittedVals
            df2['linear_resid'+feature] = residuals

            dResid = df2[df2['Duplication_Status']=='D']['linear_resid'+feature]
            sResid = df2[df2['Duplication_Status']=='S']['linear_resid'+feature]
            p = mannwhitneyu(sResid,dResid,alternative='two-sided').pvalue
            if lower:
                if p <= actualPVal:
                    count += 1
            elif not lower:
                if p >= actualPVal:
                    count += 1
            if i%10 == 0:
                print(time.time()-t)
    
    
    elif method == 'LOWESS':
        for i in range(it):
            df2[feature] = np.random.permutation(df[feature].values)
#             display(df2[feature])
            smoothModel = sm.nonparametric.lowess(df2['Rate'],df2[feature],frac=1/3,return_sorted=False)
            df2['lowess_fit'+feature] = smoothModel
            df2['lowess_resid'+feature] = df2['Rate']-df2['lowess_fit'+feature]
            dResid = df2[df2['Duplication_Status']=='D']['lowess_resid'+feature]
            sResid = df2[df2['Duplication_Status']=='S']['lowess_resid'+feature]
            p = mannwhitneyu(sResid,dResid,alternative='two-sided').pvalue
            if lower:
                if p <= actualPVal:
                    count += 1
            elif not lower:
                if p >= actualPVal:
                    count += 1
#             if i%10 == 0:
#                 print(time.time()-t)
    return count/it

print(residComp('CDS_Length','LOWESS',df,0.05287523357330763,lower=False))
print(residComp('Expression','LOWESS',df,0.010352820114804085))

In [None]:
def LRT(lnLAlt, lnLNull, numParamAlt, numParamNull):
#     lnLNull =  -1462.62
#     lnLAlt =  -1460.62
    chi_crit = 2*(max([lnLAlt,lnLNull])-min([lnLAlt,lnLNull]))
    dof = numParamAlt - numParamNull
    p = chi2.sf(chi_crit,dof)
    return p

In [None]:
def get_species(node):
        return node.name[-4:]
def rateHypTest(table,alignmentDir,groupTable):
    cursor.execute('SELECT groupID, groupMembers FROM '+groupTable)
    groupDict = dict(cursor.fetchall())
    cursor.execute('SELECT id, tree FROM '+table+' WHERE excludedReason IS NULL AND dup_status == "D"')
    r = cursor.fetchall()
    accPvals, asymmPvals = [],[]
    accLs, asymmLs = [],[]
    accRates = []
    idList = []
    for gene, origTree in tqdm(r):
        tree = PhyloTree(origTree,sp_naming_function=get_species)
        #make condon alignment,as for rateCalc
        #for each sequence
        cdsDict = {}
        for node in tree.traverse():
                if node.name == '' or node.name.isnumeric():
                    continue
                ID = retrieve_prot_id(node)
                sp = retrieve_species(node)
                #get CDS sequence
                cmdCDS = ['sed', '-n', '-e', '/'+ ID +'/,/>/ p', 'dipteraCDS_raw/'+sp+'_raw_cds.fna']
                pCDS = Popen(cmdCDS, stdout=PIPE,stderr=PIPE)
                out,err = pCDS.communicate()

                CDSseq = out.split(bytes('\n','utf-8'))[1:-2]
                CDSseq = ''.join([x.decode('utf-8') for x in CDSseq])
                #store CDS seqs and headers in dict
                cdsDict['>'+ID + ':' + sp] = CDSseq
        alignFile = alignmentDir + '/group' + [str(x) for x in groupDict if ID in groupDict[x]][0] + '_alignment.fa'
        #write only relevant sequences to prot/CDS files, in same order with same headers
        with open(alignFile,'r') as prot_file, open('interCDS.fa','w') as cds_file, open('interProt.fa','w') as prot_align_file:
        #  muscle outputs alignments in a different order to input sequences
        #  cds sequences have to be written in same order to work with pal2nal
            alignSeq = ''
            alignDict = {}
            #put entire protein alignment in a string
            for line in prot_file:
                alignSeq = alignSeq + line
            
            #get the order the headers occur in
            order= [line for line in alignSeq.split('\n') if line.startswith('>')]
            #make new headers based on this
            #order2 = ['>'+retrieve_prot_id_align(x)+':'+retrieve_species(x) for x in order]
            order2=order
    #       create new protein alignment file with correct headers and in correct order
    #       store in dict
            incl = False 
            for line in alignSeq.split('\n'):
                if line.startswith('>'):
#                     newHead = '>'+retrieve_prot_id_align(line)+':'+retrieve_species(line)
                    newHead = line
                    if newHead in order2:
                        alignDict[newHead] = ''
                        current = newHead
                        incl = True
                    else:
                        incl = False
                elif incl == True:
                    alignDict[current] = alignDict[current] + line
            # get CDS in right order as well, write each header and sequence to CDS/prot file
            for x in order2:
                try:
                    i = [y for y in cdsDict if retrieve_prot_id_align(x) in y][0]
                    seq = cdsDict[i]
                except IndexError:
                    continue #species not included in pruned trees ...I guess this is where I'm filtering??? -Looks that way
                
                cds_file.write(x.replace('|','_') + '\n')
                cds_file.write(seq + '\n')
                prot_align_file.write(x.replace('|','_') + '\n')
                prot_align_file.write(alignDict[x] + '\n')     
        # somehow convert to codon alignments - pal2nal
        #     usage: pal2nal prot_alignment dna_seq -output paml
    #     print('at pal2nal')
        cm = '../SOFTWARE/pal2nal.pl interProt.fa interCDS.fa -output paml'
        p2 = Popen(cm.split(' '), stdout=PIPE, stderr=PIPE)
        out2, err2 = p2.communicate()

        with open('interCodonAlign.paml','w') as cod_file:
            cod_file.write(out2.decode())
     
        alignmentFile = '~/drosphilaFinal/interCodonAlign.paml'   
        # Getting to the actual meat of the thing, get dup events from the tree
        events = tree.get_descendant_evol_events()
        dupEvents = [e for e in events if e.etype == 'D']
        
        numEvents = 0
        
        for e in dupEvents:
            #I think I have to do this every round to get the unedited tree to start with again
            tree = PhyloTree(origTree,sp_naming_function=get_species)
            # each of these two will have different rates in the asymmetry alt model
            dupGroup1 = e.in_seqs 
            dupGroup2 = e.out_seqs
            #all of these will have the same rate in the null for asymmetry and the alt for acceleration
            totalPostDup = dupGroup1.union(dupGroup2) 
            #write tree to temp treefile for null acceleration model, run PAML, extract lnL and params
            tree.write(outfile='temp.treefile')

            cmd = codeml.Codeml(alignment='interCodonAlign.paml', tree='temp.treefile', out_file='results.out',working_dir='.')
            cmd.read_ctl_file('model_zero.ctl')
            output = cmd.run(command='../SOFTWARE/paml4.9j/bin/codeml')
            allOneRate_lnL = output['NSsites'][0]['lnL']
            try:
                allOneRate_numParams = len(output['NSsites'][0]['parameters']['omega'])
            except TypeError:
                allOneRate_numParams = 1

    #   edit tree to label all post dup branches with one rate, write to file, run PAML, extract lnL and params
            for node in tree.traverse():
                if not node.is_leaf():
                    node.name = ' ' #I may or may not need this, ete is labelling these all with the name NoName if name = ''
                    #I don't know if that poses an issue to PAML
            nodeToLabel = tree.get_common_ancestor(totalPostDup)
            nodeToLabel.name = ' $1'
            tree.write(format=8,outfile='temp.treefile')

            cmd = codeml.Codeml(alignment='interCodonAlign.paml', tree='temp.treefile', out_file='results.out',working_dir='.')
            cmd.read_ctl_file('model_acc.ctl')
            output = cmd.run(command='../SOFTWARE/paml4.9j/bin/codeml')
            postDupRate_lnL = output['NSsites'][0]['lnL']
            postDupRate_numParams = len(output['NSsites'][0]['parameters']['omega'])
            postDupRate_rates = output['NSsites'][0]['parameters']['omega']
            
            #reset tree
            tree = PhyloTree(origTree,sp_naming_function=get_species)
        #   edit tree to label the post dup branches with 2 different rates, write to file, run PAML, extract lnL and params
            for node in tree.traverse():
                if not node.is_leaf():
                    node.name = ' ' #I may or may not need this, ete is labelling these all with the name NoName if name = ''
                    #I don't know if that poses an issue to PAML
            if len(dupGroup1) >1:
                nodeToLabel = tree.get_common_ancestor(dupGroup1)
                nodeToLabel.name = ' $1'
            else:
                name = dupGroup1.pop()
                node = tree&name
                node.name = node.name + ' #1'
                
            if len(dupGroup2) >1:
                nodeToLabel = tree.get_common_ancestor(dupGroup2)
                nodeToLabel.name = ' $2'
            else:
                name = dupGroup2.pop()
                node = tree&name
                node.name = node.name + ' #2'
                
            tree.write(format=8,outfile='temp.treefile')

            cmd = codeml.Codeml(alignment='interCodonAlign.paml', tree='temp.treefile', out_file='results.out',working_dir='.')
            cmd.read_ctl_file('model_asym.ctl')
            output = cmd.run(command='../SOFTWARE/paml4.9j/bin/codeml')
            asymmRate_lnL = output['NSsites'][0]['lnL']
            asymmRate_numParams = len(output['NSsites'][0]['parameters']['omega'])
        
    #         #do LRT for the various combinations
            try:
                accLRT = LRT(postDupRate_lnL,allOneRate_lnL,postDupRate_numParams,allOneRate_numParams)
                asymLRT = LRT(asymmRate_lnL,postDupRate_lnL,asymmRate_numParams,postDupRate_numParams)
            except ZeroDivisionError:
                print('Zero Divide?')
                continue

            #append pvals for final FDR correction
            accPvals.append(accLRT)
            asymmPvals.append(asymLRT)
            accLs.append((allOneRate_lnL,postDupRate_lnL))
            asymmLs.append((postDupRate_lnL,asymmRate_lnL))
            accRates.append(postDupRate_rates)
            idList.append(gene)
           
    #FDR pval adjustment
    accFDR = fdr(accPvals) #default alpha is 0.05
    asymFDR = fdr(asymmPvals)
    
    accPvalsAdjusted = [x for x in accFDR[1]] 
    asymPvalsAdjusted = [x for x in asymFDR[1]]
    
    
    return {'ID list':idList,
            'Unadjusted P vals, acceleration':accPvals,
            'Adjusted P vals, acceleration':accPvalsAdjusted,
           'Unadjusted P vals, asymmetry':asymmPvals,
            'Adjusted P vals, asymmetry':asymPvalsAdjusted,
           'Null v Alt lnL, acceleration':accLs,
           'Null v alt lnL, asymmetry':asymmLs,
           'Accelerate Model rates':accRates}

In [None]:
def defFES(table,filename=None):
    # fast evolving singletons
    # define the group
    try:
        cursor.execute('ALTER TABLE '+table+' ADD COLUMN FES TEXT')
    except:
        pass
    cursor.execute('SELECT id, dup_status, proxy_rate FROM '+ table +' WHERE excludedReason IS NULL AND proxy_dS < 4')
    res = cursor.fetchall()
    rates = [x[2] for x in res]
    cutoff = np.percentile(rates,95)
    print('95th %ile:',cutoff)
    fes = [x[0] for x in res if x[2] >= cutoff and x[1]=='S']
    dups = [x[0] for x in res if x[2] >= cutoff and x[1]=='D']
    for s in fes:
        cursor.execute('UPDATE '+table+' SET FES = "T" WHERE id == ?',(s,))
    db.commit()
    if filename:
        with open(filename,'w') as out:
            for s in fes:
                out.write(s+'\n')
    print('Dups in top 5%',len(dups))
    return fes

In [None]:
def taxRestrictionLevels(fes,table):
    restrictDict = {1:['BCOP','AALB','ASTE','AAEG','AEAL','CPIP','CQUI'],2:['HILL'],3:['CCAP','BTRY'],
               4:['SLEB'],5:['DVIR','DNOV','DHYD','DALB','DBUS'],6:['DSUB','DGUA','DPER','DPSE','DMIR'],
               7:['DANA'],8:['DSER','DKIK'],9:['DFIC','DELE'],10:['DBIA','DSUP','DSUZ']}
    levelList = []
    with open('restrictionLevels_'+table+'.txt','w') as out:
        for s in fes:
            cursor.execute('SELECT baseTree FROM '+table+' WHERE id == ?',(s,))
            tree = cursor.fetchall()[0][0]
            for level in restrictDict:
                for sp in restrictDict[level]:
                    if sp in tree:
                        break
                else:
                    continue
                r = level
                levelList.append(r)
                tree = PhyloTree(tree)
                for node in tree.traverse():
                    #I need DMEL ids for abSENSE comp
                    if 'DMEL' in node.name:
                        g = node.name
                out.write(g+'\t'+str(level)+'\n')
                break
    return levelList

In [None]:
def checkMissedDups(blastDir,table,fesList):
    #checking for blast hits within the sp of interest
    #convert to Orthofinder IDs
    #for each one, check the right file, see if has a non-self hit
        #record all the species that have a dup for this group
    orthoIDDict = {}
    with open(blastDir+'/SequenceIDs.txt','r') as file:
        for line in file:
            line = line.strip('\n').split(' ')
            line[0] = line[0].strip(':')
            line[1] = line[1].replace('|','_')
            orthoIDDict[line[1]] = line[0]
#     blastDir = './dipteraTranslations_processed/OrthoFinder/Results_Mar30/WorkingDirectory/'
    for s in fesList:
        dupInSp = []
        possDup = False
        possParas = []
        cursor.execute('SELECT tree FROM '+table+' WHERE id == ?',(s,))
        #make list of the genes involved
        tree = cursor.fetchall()[0][0]
        tree = PhyloTree(tree)
        genes = []
        for node in tree.traverse():
            if node.name == '' or node.name.isnumeric():
                continue
            else:
                genes.append(node.name)
        for gene in genes:
            geneO = orthoIDDict[gene]
            spO = geneO.split('_')[0]
            blastFile = blastDir+'Blast'+spO+'_'+spO+'.txt.gz'
            p1 = Popen(['gunzip','-c',blastFile],stdout=PIPE,stderr=PIPE)
    #         p2 = Popen(['awk','{if($11<=0.0001) print $0}'],stdin=p1.stdout,stdout=PIPE,stderr=PIPE)
            p3 = Popen(['grep',geneO+'\s'],stdin=p1.stdout,stdout=PIPE,stderr=PIPE)
            out, err = p3.communicate()
            bRes = [x for x in out.decode().split('\n') if x != '']
            if len(bRes) > 1:
                possDup = True
                for x in bRes:
                    x = x.split('\t')
                    if x[0] == x[1]:
                        continue
                    else:
                        if x[0] == geneO:
                            possP = x[1]
                        elif x[1] == geneO:
                            possP = x[0]
                        possP = [y for y in orthoIDDict if orthoIDDict[y]==possP][0]
                        possParas.append(possP)
        if possDup:
            possParas=','.join(possParas)
            cursor.execute('UPDATE TABLE '+table+' SET possMissedDup="T" WHERE id == ?',(s,))
            cursor.execute('UPDATE TABLE '+table+' SET possMissedParas = ? WHERE id == ?',(possParas,s))
    db.commit()

In [None]:
def fesConfoundComp():

In [None]:
def checkDupSpDist(table)
    cursor.execute('SELECT dupInSp FROM '+table+' WHERE excludedReason IS NULL AND dup_status == "D"')
    r = [x[0] for x in cursor.fetchall()]
    return Counter(r)

In [None]:
cursor.execute('CREATE INDEX seqID ON sequenceTab(id)')

In [None]:
#create fastas for the newly created Orthofinder groups and SonicParanoid, for the singleton groups
# createGroupFastas(extractSingGroups('groups_Orthofinder',sing_list),'orthofinderGroupFastas','groups_Orthofinder')
# createGroupFastas(extractSingGroups('groups_SonicParanoid',sing_list),'sonicParaGroupFastas','groups_SonicParanoid')
# createGroupFastas(extractSingGroups('groups_Orthofinder_ultrasens',sing_list),'orthofinderUltraGroupFastas','groups_Orthofinder_ultrasens')
#createGroupFastas(extractSingGroups('groups_Orthofinder',sing_list_relaxed),'orthofinderGroupFastasRel','groups_Orthofinder')
createGroupFastas(extractSingGroups('groups_Orthofinder',sing_list_all),'orthofinderGroupFastas','groups_Orthofinder')

In [None]:
#muscle alignments for each set of groups
# doGroupAlignment('orthofinderAlignments','orthofinderGroupFastas',singList=None)
# doGroupAlignment('sonicParaAlignments','sonicParaGroupFastas',singList=None)
# doGroupAlignment('omaAlignments','dipteraOMA/Output/OrthologousGroupsFasta',singList=sing_list,groupTable='groups_OMA')
# doGroupAlignment('omaAlignments_lowerThresh','omaOutput2/OrthologousGroupsFasta',singList=sing_list,groupTable='groups_OMA_lowerThresh')
# doGroupAlignment('omaAlignments_higherThresh','omaOutput3/OrthologousGroupsFasta',singList=sing_list,groupTable='groups_OMA_higherThresh')
# doGroupAlignment('orthofinderUltraAlignments','orthofinderUltraGroupFastas',singList=None)
doGroupAlignment('orthofinderAlignments','orthofinderGroupFastas',singList=None)

In [None]:
#tree building
# buildTrees('orthofinderAlignments')
# buildTrees('omaAlignments')
# buildTrees('omaAlignments_lowerThresh')
# buildTrees('omaAlignments_higherThresh')
# buildTrees('sonicParaAlignments')
# buildTrees('orthofinderUltraAlignments')
buildTrees('orthofinderAlignments')

In [None]:
#initial data entry, trees for singleton groups
# processTrees('orthofinderAlignments',sing_list,'singTrees_Orthofinder','groups_Orthofinder')
processTrees('orthofinderUltraAlignments',sing_list, 'singTrees_Orthofinder_ultrasens', 'groups_Orthofinder_ultrasens')
# processTrees('sonicParaAlignments',sing_list,'singTrees_SonicParanoid','groups_SonicParanoid')
# processTrees('omaAlignments',sing_list,'singTrees_OMA','groups_OMA')
# processTrees('omaAlignments_lowerThresh',sing_list,'singTrees_OMA_lowerThresh','groups_OMA_lowerThresh')
# processTrees('omaAlignments_higherThresh',sing_list,'singTrees_OMA_higherThresh','groups_OMA_higherThresh')

In [None]:
#filtering and checks for Orthofinder
# checkMissingSp('singTrees_Orthofinder')
# checkOutgroupDup('singTrees_Orthofinder')

#filtering and checks for Orthofinder
checkMissingSp('singTrees_Orthofinder_ultrasens')
checkOutgroupDup('singTrees_Orthofinder_ultrasens')

# filtering and checks for SonicParanoid
# checkMissingSp('singTrees_SonicParanoid')
# checkOutgroupDup('singTrees_SonicParanoid')


#filtering and checks for OMA (defaults)
# checkMissingSp('singTrees_OMA')
# checkOutgroupDup('singTrees_OMA')


#filtering and checks for OMA (lower inParalogTol value)
# checkMissingSp('singTrees_OMA_lowerThresh')
# checkOutgroupDup('singTrees_OMA_lowerThresh')

#filtering and checks for OMA (higher inParalogTol value)
# checkMissingSp('singTrees_OMA_higherThresh')
# checkOutgroupDup('singTrees_OMA_higherThresh')

In [None]:
# checkSplits('singTrees_Orthofinder')
checkSplits('singTrees_Orthofinder_ultrasens')
# checkSplits('singTrees_SonicParanoid')
# checkSplits('singTrees_OMA')
# checkSplits('singTrees_OMA_lowerThresh')

In [None]:
#splitting trees, final table creation - in individual cells because I don't want to do all this manual checking at once
doSplits('singTrees_Orthofinder','procTrees_Orthofinder')

In [None]:
doSplits('singTrees_SonicParanoid','procTrees_SonicParanoid')

In [None]:
doSplits('singTrees_OMA','procTrees_OMA')

In [None]:
doSplits('singTrees_OMA_lowerThresh','procTrees_OMA_lowerThresh')

In [None]:
doSplits('singTrees_OMA_higherThresh','procTrees_OMA_higherThresh')

In [None]:
doSplits('singTrees_Orthofinder_relaxed','procTrees_Orthofinder_relaxed')

In [None]:
#filtering final table, rate calculation
#assignDupStatus('procTrees_Orthofinder')
#checkDupTiming('procTrees_Orthofinder')
#checkCorrectOutgroups('procTrees_Orthofinder')
# rateCalc('procTrees_Orthofinder','orthofinderAlignments','singTrees_Orthofinder')

# assignDupStatus('procTrees_Orthofinder_ultrasens')
# checkDupTiming('procTrees_Orthofinder_ultrasens')
# checkCorrectOutgroups('procTrees_Orthofinder_ultrasens')
# rateCalc('procTrees_Orthofinder_ultrasens','orthofinderUltraAlignments','singTrees_Orthofinder_ultrasens')

assignDupStatus('procTrees_Orthofinder_relaxed')
checkDupTiming('procTrees_Orthofinder_relaxed')
checkCorrectOutgroups('procTrees_Orthofinder_relaxed')
rateCalc('procTrees_Orthofinder_relaxed','orthofinderAlignments','singTrees_Orthofinder_relaxed')

# assignDupStatus('procTrees_SonicParanoid')
# checkDupTiming('procTrees_SonicParanoid')
# checkCorrectOutgroups('procTrees_SonicParanoid')
# rateCalc('procTrees_SonicParanoid','sonicParaAlignments','singTrees_SonicParanoid')

# assignDupStatus('procTrees_OMA')
# checkDupTiming('procTrees_OMA')
# checkCorrectOutgroups('procTrees_OMA')
# rateCalc('procTrees_OMA','omaAlignments','singTrees_OMA')

# assignDupStatus('procTrees_OMA_lowerThresh')
# checkDupTiming('procTrees_OMA_lowerThresh')
# checkCorrectOutgroups('procTrees_OMA_lowerThresh')
# rateCalc('procTrees_OMA_lowerThresh','omaAlignments_lowerThresh','singTrees_OMA_lowerThresh')

# assignDupStatus('procTrees_OMA_higherThresh')
# checkDupTiming('procTrees_OMA_higherThresh')
# checkCorrectOutgroups('procTrees_OMA_higherThresh')
# rateCalc('procTrees_OMA_higherThresh','omaAlignments_higherThresh','singTrees_OMA_higherThresh')

In [None]:
rateComp('procTrees_Orthofinder','confirm',limitDS=False)
rateComp('procTrees_Orthofinder','proxy',limitDS=False)

rateComp('procTrees_Orthofinder','confirm',limitDS=True)
rateComp('procTrees_Orthofinder','proxy',limitDS=True)

In [None]:
rateComp('procTrees_Orthofinder_relaxed','confirm',limitDS=False)
rateComp('procTrees_Orthofinder','proxy',limitDS=False)

rateComp('procTrees_Orthofinder_relaxed','confirm',limitDS=True)
rateComp('procTrees_Orthofinder_relaxed','proxy',limitDS=True)

In [None]:
rateComp('procTrees_Orthofinder_ultrasens','confirm',limitDS=False)
rateComp('procTrees_Orthofinder_ultrasens','proxy',limitDS=False)

rateComp('procTrees_Orthofinder_ultrasens','confirm',limitDS=True)
rateComp('procTrees_Orthofinder_ultrasens','proxy',limitDS=True)

In [None]:
rateComp('procTrees_SonicParanoid','confirm',limitDS=False)
rateComp('procTrees_SonicParanoid','proxy',limitDS=False)

rateComp('procTrees_SonicParanoid','confirm',limitDS=True)
rateComp('procTrees_SonicParanoid','proxy',limitDS=True)

In [None]:
rateComp('procTrees_OMA','confirm',limitDS=False)
rateComp('procTrees_OMA','proxy',limitDS=False)

rateComp('procTrees_OMA','confirm',limitDS=True)
rateComp('procTrees_OMA','proxy',limitDS=True)

In [None]:
generateRateCompFigure('procTrees_Orthofinder','confirm',limitDS=False)
generateRateCompFigure('procTrees_Orthofinder','proxy',limitDS=False)

generateRateCompFigure('procTrees_Orthofinder','confirm',limitDS=True)
generateRateCompFigure('procTrees_Orthofinder','proxy',limitDS=True)

In [None]:
generateRateCompFigure('procTrees_Orthofinder_relaxed','confirm',limitDS=False)
generateRateCompFigure('procTrees_Orthofinder_relaxed','proxy',limitDS=False)

generateRateCompFigure('procTrees_Orthofinder_relaxed','confirm',limitDS=True)
generateRateCompFigure('procTrees_Orthofinder_relaxed','proxy',limitDS=True)

In [None]:
generateRateCompFigure('procTrees_Orthofinder_ultrasens','confirm',limitDS=False)
generateRateCompFigure('procTrees_Orthofinder_ultrasens','proxy',limitDS=False)

generateRateCompFigure('procTrees_Orthofinder_ultrasens','confirm',limitDS=True)
generateRateCompFigure('procTrees_Orthofinder_ultrasens','proxy',limitDS=True)

In [None]:
generateRateCompFigure('procTrees_SonicParanoid','confirm',limitDS=False)
generateRateCompFigure('procTrees_SonicParanoid','proxy',limitDS=False)

generateRateCompFigure('procTrees_SonicParanoid','confirm',limitDS=True)
generateRateCompFigure('procTrees_SonicParanoid','proxy',limitDS=True)

In [None]:
# confounderInsert('procTrees_Orthofinder')
confounderCompCorr('procTrees_Orthofinder')

In [None]:
regressDF = doRegression('procTrees_Orthofinder')

In [None]:
regressionPlot('procTrees_Orthofinder','OLS')
regressionPlot('procTrees_Orthofinder','lowess')

In [None]:
confounderBoxPlot(regressDF,'orig')
confounderBoxPlot(regressDF,'OLS')
confounderBoxPlot(regressDF,'lowess')

In [None]:
#monte carlo sim to see if change in p val is significant
# print('CDS')
print(residComp('CDS', 'OLS', regressDF, 0.00046949663863196666, lower=True, it=100000))
# print('Expression')
print(residComp('Exp', 'OLS', regressDF, 0.00013165187064657547, lower=True, it=100000))
print('GC')
print(residComp('GC', 'OLS', regressDF, 0.017101905456286735, lower=False, it=100000))
print('GC3')
print(residComp('Exp', 'OLS', regressDF, 0.04546184449416842, lower=False, it=100000))

In [None]:
resDict = rateHypTest('procTrees_Orthofinder','orthofinderAlignments','groups_Orthofinder')
accRates = [x for x in zip(resDict['Adjusted P vals, acceleration'],resDict['Accelerate Model rates'])]
print('Higher post dup rate count:',len([x for x in accRates if x[1][1]>x[1][0]]))
accRatesUnadj = [x for x in zip(resDict['Unadjusted P vals, acceleration'],resDict['Accelerate Model rates'])]
print('Higher rate, sig before correction:',len([x for x in accRatesUnadj if x[0] < 0.05 and x[1][1]>x[1][0]]))
print('Higher rate, sig after correction:',len([x for x in accRates if x[0] < 0.05 and x[1][1]>x[1][0]]))

asymRates = [x for x in resDict['Adjusted P vals, asymmetry']]
asymRatesUnadj = [x for x in resDict['Unadjusted P vals, asymmetry']]
print('Asym rate, sig before correction:',len([x for x in asymRatesUnadj if x[0] < 0.05]))
print('Asym rate, sig after correction:',len([x for x in asymRates if x[0] < 0.05]))
print()
print('Fisher\'s method for combining p values')
print('Acceleration:')
print(combine_pvalues(accRatesUnadj).pval)
print('Asymmetry:')
print(combine_pvalues(asymRatesUnadj).pval)

In [None]:
checkMissedDups('dipteraTranslations_processed/OrthoFinder/Results_Mar30/WorkingDirectory','procTrees_Orthofinder',defFES('procTrees_Orthofinder'))
checkMissedDups('dipteraTranslations_processed/OrthoFinder/Results_May28/WorkingDirectory','procTrees_Orthofinder_ultrasens',defFES('procTrees_Orthofinder_ultrasens'))

In [None]:
#checking for blast hits within the sp of interest
orthoIDDict = {}
with open('dipteraTranslations_processed/OrthoFinder/Results_Mar30/WorkingDirectory/SequenceIDs.txt','r') as file:
    for line in file:
        line = line.strip('\n').split(' ')
        line[0] = line[0].strip(':')
        line[1] = line[1].replace('|','_')
        orthoIDDict[line[1]] = line[0]
blastDir = './dipteraTranslations_processed/OrthoFinder/Results_Mar30/WorkingDirectory/'
for s in fes:
    dupInSp = []
    cursor.execute('SELECT tree FROM procTrees_Orthofinder WHERE id == ?',(s,))
    #make list of the genes involved
    tree = cursor.fetchall()[0][0]
    tree = PhyloTree(tree)
    genes = []
    for node in tree.traverse():
        if node.name == '' or node.name.isnumeric():
            continue
        else:
            genes.append(node.name)
    for gene in genes:
        geneO = orthoIDDict[gene]
        spO = geneO.split('_')[0]
        blastFile = blastDir+'Blast'+spO+'_'+spO+'.txt.gz'
        p1 = Popen(['gunzip','-c',blastFile],stdout=PIPE,stderr=PIPE)
#         p2 = Popen(['awk','{if($11<=0.0001) print $0}'],stdin=p1.stdout,stdout=PIPE,stderr=PIPE)
        p3 = Popen(['grep',geneO+'\s'],stdin=p1.stdout,stdout=PIPE,stderr=PIPE)
        out, err = p3.communicate()
        
        if len([x for x in out.decode().split('\n') if x != '']) > 1:
            print('dup')
            print([x for x in out.decode().split('\n') if x != ''])
        
    #convert to Orthofinder IDs
    #for each one, check the right file, see if has a non-self hit
        #record all the species that have a dup for this group

In [None]:
#plots for comparing possible confounders within fast-evolving singletons
fes = defFES('procTrees_Orthofinder')
cursor.execute('SELECT id, cdsLen, gc3, exp FROM procTrees_Orthofinder WHERE excludedReason IS NULL AND dup_status == "S" AND proxy_dS < 4')
res = cursor.fetchall()
fesFeats = [x for x in res if x[0] in fes]
otherFeats = [x for x in res if x[0] not in fes]

cursor.execute('SELECT id, cdsLen, gc3, exp FROM procTrees_Orthofinder WHERE excludedReason IS NULL AND dup_status == "D" AND proxy_dS < 4')
dupFeats = cursor.fetchall()


fig,axes= plt.subplots(1,3,figsize=(15,5))
ax1,ax2,ax3 = axes

ax1.set_title('CDS Length')
ax1.set_ylabel('log(bp)')
cdsF = np.log10([x[1] for x in fesFeats])
cdsS = np.log10([x[1] for x in otherFeats])
cdsD = np.log10([x[1] for x in dupFeats])

boxes = ax1.boxplot([cdsS,cdsF,cdsD],labels=['Singleton','Fast-evolving singleton','Duplicable'],patch_artist=True,flierprops={'ms':1})
boxes['boxes'][0].set_fc('#DC3220')
boxes['boxes'][1].set_fc('#6c35b5')
boxes['boxes'][2].set_fc('#005AB5')
for m in boxes['medians']:
            m.set_color('k')
            m.set_lw(2)
for tick in ax1.get_xticklabels():
    tick.set_rotation(45)
    tick.set_ha('right')

ax2.set_title('Expression')
ax2.set_ylabel('log(TPM)')

expF = np.log10([x[3] for x in fesFeats if x[3] is not None])
expS = np.log10([x[3] for x in otherFeats if x[3] is not None])
expD = np.log10([x[3] for x in dupFeats if x[3] is not None])
    
boxes = ax2.boxplot([expS,expF,expD],labels=['Singleton','Fast-evolving singleton','Duplicable'],patch_artist=True,flierprops={'ms':1})
boxes['boxes'][0].set_fc('#DC3220')
boxes['boxes'][1].set_fc('#6c35b5')
boxes['boxes'][2].set_fc('#005AB5')
for m in boxes['medians']:
            m.set_color('k')
            m.set_lw(2)
for tick in ax2.get_xticklabels():
    tick.set_rotation(45)
    tick.set_ha('right')

ax3.set_title('% GC3')
ax3.set_ylabel('%')
gcF = [x[2]*100 for x in fesFeats]
gcS = [x[2]*100 for x in otherFeats]
gcD = [x[2]*100 for x in dupFeats]
    
boxes = ax3.boxplot([gcS,gcF,gcD],labels=['Singleton','Fast-evolving singleton','Duplicable'],patch_artist=True,flierprops={'ms':1})
boxes['boxes'][0].set_fc('#DC3220')
boxes['boxes'][1].set_fc('#6c35b5')
boxes['boxes'][2].set_fc('#005AB5')
for m in boxes['medians']:
            m.set_color('k')
            m.set_lw(2)
for tick in ax3.get_xticklabels():
    tick.set_rotation(45)
    tick.set_ha('right')
plt.savefig('fesConfound_final.png',bbox_inches='tight')

In [None]:
taxRestrictionLevels(defFES('procTrees_Orthofinder'),'singTrees_Orthofinder')