In [1]:
from Bio import SeqIO
import csv
import numpy as np
import functools
import operator
import gzip

In [6]:
overlaps_fname = '/Users/esaliya/sali/git/github/esaliya/cpp/lbl.pastis/overlap.txt'
seqs_fname = '/Users/esaliya/sali/data/scope/uniqs/100/100_of_77040_unique_of_243813_astral-scopedom-seqres-gd-all-2.07-stable.fa' 


In [7]:
# All super-families dictionary. Each super family entry will have list,
# where the first element is the number of families in it and the second
# is a dictionary for each of its families (family name -> sequence count in family).
all_sfs = {}
fam_names = []
sf_names = []
limit = 'all'
with open(seqs_fname, "r") as seqf:
    count = 0
    for record in SeqIO.parse(seqf, "fasta"):
        if count != 'all' and count == limit:
            break
        l_idx = record.description.index(" ")
        r_idx = record.description.index(" ", l_idx+1)
        cls, fold, sf, fam = record.description[l_idx: r_idx].split('.')
        fam_names.append(fam)
        sf_names.append(sf)
        if sf in all_sfs:
            sf_fams = all_sfs[sf][1]
            if fam in sf_fams:
                sf_fams[fam] += 1
            else:
                sf_fams[fam] = 1
            all_sfs[sf][0] += 1
        else:
            all_sfs[sf] = [1, {fam: 1}]

        count += 1
print("Read ", count, " sequences")

Read  100  sequences


In [8]:
num_sf = len(all_sfs.keys())
num_fam = sum([len(all_sfs[k][1].keys()) for k in all_sfs])
print("Num super families: ", num_sf)               
print("Num families: ", num_fam)

all_sf_fams_seq_counts = [list(all_sfs[k][1].values()) for k in all_sfs]
fam_seq_counts = np.array(functools.reduce(operator.iconcat, all_sf_fams_seq_counts, []))
sf_seq_counts = np.array([sum(sf_fams_seq_counts) for sf_fams_seq_counts in all_sf_fams_seq_counts])
print(fam_seq_counts)
print(sf_seq_counts)

Num super families:  1
Num families:  3
[51 30 19]
[100]


In [9]:
# Number of family pairs (top triangle only, excludes diagonal as well)
num_fam_pairs = np.sum(fam_seq_counts * (fam_seq_counts - 1) / 2)
# Number of super-family pairs (top triangle only, excludes diagonal as well). Includes family pairs too.
num_sf_pairs = np.sum(sf_seq_counts * (sf_seq_counts - 1) / 2)
num_sf_only_pairs = num_sf_pairs - num_fam_pairs
print(num_fam_pairs, num_sf_pairs, num_sf_only_pairs)


1881.0 4950.0 3069.0


In [14]:
all_pair_count = count*(count - 1)/2
print("all pairs: ", all_pair_count)
print("fam_pair ratio: ", num_fam_pairs / all_pair_count)
print("sf_pair ratio: ", num_sf_pairs / all_pair_count)

all pairs:  4950.0
fam_pair ratio:  0.38
sf_pair ratio:  1.0


In [15]:
import time
t = time.process_time()
num_A, num_B, num_C = 0, 0, 0
with open(overlaps_fname, 'rt') as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    print(next(csv_reader))  # ignore header
    line_count = 0
    for g_col, g_row, shared_kmer_count in csv_reader:
        
        g_col = int(g_col)
        g_row = int(g_row)
        if sf_names[g_col] == sf_names[g_row]:
            if fam_names[g_col] == fam_names[g_row]:
                num_A += 1
            else: 
                num_B += 1
        else:
            num_C += 1
        line_count += 1
        if line_count % 1000000 == 0:
            elapsed = time.process_time() - t
            print("lines ", line_count, " of ", all_pair_count, " (", round((line_count*100.0/2967542280), 2), "%) took ", elapsed, "s")

print("Line count: ", line_count)
print("Total time: ", (time.process_time() - t), "s")      
        
            

['g_col_idx', 'g_row_idx', 'common_kmer_count']
Line count:  311
Total time:  0.003462000000000298 s


In [16]:
print(num_A, num_B, num_C)
print(num_fam_pairs)

248 63 0
1881.0


## Overall accuracy

In [17]:
recall = num_A / num_fam_pairs
precision = (num_A + num_B) / (num_A+num_B+num_C)
print(recall, precision)

0.1318447634237108 1.0


## Family only accuracy

In [18]:
fam_recall = recall
fam_precision = num_A/ (num_A+num_B+num_C)
print(fam_recall, fam_precision)

0.1318447634237108 0.797427652733119


## Super family accuracy

In [19]:
sf_recall = (num_A + num_B) / (num_fam_pairs + num_sf_only_pairs)
sf_precision = precision
print(sf_recall, sf_precision)

0.06282828282828283 1.0


In [20]:
Abc = int(num_C)
aBc = int(num_sf_only_pairs - num_B)
ABc = int(num_B)
abC = 0
AbC = 0
aBC = int(num_fam_pairs - num_A)
ABC = int(num_A)
print(Abc, aBc, ABc, abC, AbC, aBC, ABC)

0 3006 63 0 0 1633 248


In [21]:
#%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
from matplotlib_venn import venn3, venn3_circles
# venn3(subsets = (10, 8, 22, 0,0,4,2), set_labels=("6-mer Overlap Pairs", "Super-family Pairs", "Family Pairs" ))
mul = int(1e5)
mul_str='e5'
plt.figure(figsize=(8,8))

v = venn3(subsets = (Abc//mul, aBc//mul, ABc//mul, abC//mul, AbC//mul, aBC//mul, ABC//mul), set_labels=("PISA output:\n6-mer Overlap\nPairs", "Super-family\nPairs", "Family\nPairs" ))
plt.gca().set_facecolor('#616161')
plt.gca().set_axis_on()


lbl = v.get_label_by_id('A')
x, y = lbl.get_position()
lbl.set_position((x+0.12, y-0.3))
lbl.set_multialignment('center')

lbl = v.get_label_by_id('B')
x, y = lbl.get_position()
lbl.set_position((x-0.3, y-0.1))
lbl.set_multialignment('center')

lbl = v.get_label_by_id('C')
x, y = lbl.get_position()
lbl.set_position((x, y+0.12))
lbl.set_multialignment('center')

v.get_patch_by_id('100').set_color('#f7b154')
v.get_patch_by_id('100').set_alpha(0.9)
lbl = v.get_label_by_id('100')
lbl.set_size('large')
lbl.set_text(f'~{Abc//mul}{mul_str}')

v.get_patch_by_id('110').set_color('#f2f2c9')
v.get_patch_by_id('110').set_alpha(0.9)
lbl = v.get_label_by_id('110')
lbl.set_size('large')
lbl.set_text(f'~{ABc//mul}{mul_str}')

v.get_patch_by_id('111').set_color('#54b3f7')
v.get_patch_by_id('111').set_alpha(0.9)
lbl = v.get_label_by_id('111')
x, y = lbl.get_position()
lbl.set_position((x+0.1, y))
lbl.set_size('large')
lbl.set_text(f'~{ABC//mul}{mul_str}')

plt.savefig('k5.venn_subs_5percent.venn.png', dpi=100)




AttributeError: 'NoneType' object has no attribute 'set_color'

In [None]:
#%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
from matplotlib_venn import venn3, venn3_circles
# venn3(subsets = (10, 8, 22, 0,0,4,2), set_labels=("6-mer Overlap Pairs", "Super-family Pairs", "Family Pairs" ))
mul = int(1e7)
v = venn3(subsets = (Abc//mul, aBc//mul, ABc//mul, abC//mul, AbC//mul, aBC//mul, ABC//mul), set_labels=("PISA output: 6-mer Overlap Pairs", "Super-family Pairs", "Family Pairs" ))
v.get_patch_by_id('100').set_color
plt.show()
