In [409]:
import pandas as pd
import os
import glob

from Bio.Seq import Seq
from Bio.Align import MultipleSeqAlignment
from Bio import AlignIO, SeqIO
import panel as pn
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, Plot, Grid, Range1d
from bokeh.models.glyphs import Text, Rect
from bokeh.layouts import gridplot
from bokeh.io import export_svg, output_notebook, install_notebook_hook, notebook
import numpy as np
import re

import matplotlib.pyplot as plt

In [410]:
def clean_column_ids(df, col):
    return df[col].map(lambda x: x.split(":")[1])

def parse_fasta(protein_fasta):
    id = protein_fasta.split("\n")[0]
    sequence = ''.join(protein_fasta.split("\n")[1:])
    return (id, sequence)

def merge_windows(wins):
    a = []
    for alignment in wins:
        k = [alignment[0]]
        for win in alignment[1:]:
            if k[len(k)-1][1] >= (win[0]-1):
                k[len(k)-1] = (k[len(k)-1][0],win[1])
            else:
                k.append(win)
        a.append(k)
    return a

In [411]:
def merge_windows(wins):
    a = [wins[0]]
    for i in range(len(wins)-1):
        prev = a[len(a)-1]
        next = wins[i+1]
        if prev[1] + 1 >= next[0]:
            a[len(a)-1] = (a[len(a)-1][0], next[1])
        else:
            a.append(next)
    return a

def read_stride(path):
    data = []
    with open(path, "r") as f:
        for line in f.readlines():
            if line[:3] == "ASG":
                data.append(re.split("\s+", line)[6])
    return data

In [412]:
clrs = ["steelblue"]
flex = ["Coil", "Strand", "Turn"]
acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "U", "O"]

class SafeSequence():
    def __init__(self, accession, sequence, ref_windows, seq_windows, stride_path):
        self.accession = accession
        self.sequence = sequence
        self.ref_windows = ref_windows
        self.seq_windows = seq_windows
        self.strides = self.read_strides(stride_path)
        self.TP, self.TN, self.FP, self.FN = 0,0,0,0
        self.calculate_hits()
        self.precision = self.TP / (self.TP + self.FP)
        self.recall = self.TP / (self.TP + self.FN)

        # base distribution in the sequence
        self.safe_distribution = {}
        self.non_safe_distribution = {}
        self.distribution = {}
        for acid in acids:
            self.distribution[acid] = 0
            self.safe_distribution[acid] = 0
            self.non_safe_distribution[acid] = 0
        self.calculate_distributions()

    def calculate_hits(self):
        for i in range(len(self.sequence)):
            safe = self.is_safe(i)
            stable = self.is_stable(i)
            self.TP += safe and stable
            self.TN += not safe and not stable
            self.FP += safe and not stable
            self.FN += not safe and stable
        assert self.TP + self.TN + self.FP + self.FN == len(self.sequence), f"Hits do not match"

    def calculate_distributions(self):
        for i in range(len(self.sequence)):
            safe = self.is_safe(i)
            aa = self.sequence[i]
            self.distribution[aa] += 1
            self.safe_distribution[aa] += safe
            self.non_safe_distribution[aa] += not safe

    def is_safe(self, i):
        for window in self.seq_windows:
            start = window[0]
            end = window[1]
            if i >= start and i <= end:
                return True
        return False

    def is_stable(self, i):
        return self.strides[i] not in flex

    def get_safety_window_colors(self, longest):
        colors = ["white"] * longest
        clr = 0
        for window in self.seq_windows:
            start = window[0]
            end = window[1]
            for i in range(start, end):
                colors[i] = clrs[clr % len(clrs)]
            
            clr += 1
        return colors

    def get_stride_colors(self, longest):
        colors = []
        for stride in self.strides:
            if stride not in flex:
                colors.append("palegreen")
            else:
                colors.append("white")
        
        return colors + (["white"] * (longest - len(self.sequence)))

    def get_strides(self):
        return "".join([stride[0] for stride in self.strides])

    def read_strides(self, path):
        return read_stride(f"{path}/{self.accession}.out")

In [413]:
def get_colors(seqs, longest):
    """make colors for bases in sequence"""
    colors = []
    for ss in seqs:
        colors += ss.get_stride_colors(longest)
        colors += ss.get_safety_window_colors(longest)
    colors += ["white" for i in range(longest)]
    return colors

In [433]:
def read_safe_sequence(safety_path, fasta_path, stride_path):
    sequences = {}
    with open(fasta_path, "r") as f:
        db_fasta = ("\n" + f.read()).split("\n>")[1:]
        for protein_fasta in db_fasta:
            id, sequence = parse_fasta(protein_fasta)
            assert sequence not in sequences.keys(), f"Duplicate sequence: {sequence} in {fasta_path}"
            sequences[sequence] = id.split(" ")[0].split(":")[1]

    safe_sequences = []
    
    with open(safety_path, "r") as f:
        ref = f.readline().split(" ")[1].strip()
        ref_accession = sequences[ref]
        line = f.readline()
        while line:
            data = line.split(" ")
            if len(data) == 3:
                seq = data[1]
                rwindows = []
                swindows = []
                for j in range(int(data[2])):
                    line = f.readline()
                    data = line.split(" ")
                    rwindows.append((int(data[0]), int(data[1])))
                    swindows.append((int(data[2]), int(data[3])))

                accession = sequences[seq]
                safe_sequences.append(SafeSequence(accession, seq, rwindows, swindows, stride_path))
            line = f.readline()
        
    return (ref, ref_accession, safe_sequences)

In [434]:
def get_columns(ref, seqs, longest):
    text = ""
    for ss in seqs:
        s = ss.sequence + (" " * (longest - len(ss.sequence)))
        text += ss.get_strides() + (" " * (longest - len(ss.sequence)))
        text += s
    text += ref + (" " * (longest - len(ref)))
    return list(text)

In [435]:
def plot_windows(ref, ref_accession, safe_sequences):
    longest = len(ref)
    for ss in safe_sequences:
        longest = max(len(ss.sequence), longest)

    text = get_columns(ref, safe_sequences, longest)
    colors = get_colors(safe_sequences, longest)
    cols = longest
    rows = len(safe_sequences)*2+1
    x = np.arange(1, cols + 1)
    y = np.arange(0, rows, 1)
    xx, yy = np.meshgrid(x, y)
    gx = xx.ravel()
    gy = yy.flatten()
    recty = gy + 0.5
    d = dict(x=gx, y=gy, recty=recty, text=text, colors=colors)
    source = ColumnDataSource(d)
    plot_height = rows * 15
    plot_width = longest * 15
    x_range = Range1d(0, cols + 1, bounds='auto')

    rects = Rect(x="x", y="recty",  width=1, height=1, fill_color="colors", line_color=None, fill_alpha=0.8)
    glyph = Text(x="x", y="y", text="text", text_align='center',text_color="black", text_font_size="4pt")
    
    p = figure(title=None, plot_width=plot_width, plot_height=plot_height,
               x_range=x_range, y_range=(0,rows),
               min_border=0, toolbar_location='below')
    p.add_glyph(source, rects)
    p.add_glyph(source, glyph)    
    p = gridplot([[p]])

    return p

In [436]:
def work(path):
    fasta_files = sorted(glob.glob(f"{path}/fasta/*.fasta"))
    stride_files = sorted(glob.glob(f"{path}/stride/*"))
    a99_files = sorted(glob.glob(f"{path}/safety.a99/*.out"))
    a75_files = sorted(glob.glob(f"{path}/safety.a75/*.out"))
    # a50_files = sorted(glob.glob(f"{path}/safety.a50/*.out"))

    alphas = {}
    alphas["safety.a50"] = ([], [])
    alphas["safety.a75"] = ([], [])
    alphas["safety.a99"] = ([], [])
    for files in list(zip(fasta_files, stride_files, a99_files, a75_files)):
        fasta_file = files[0]
        stride_file = files[1]
        sf = files[2:5]

        for safety_file in sf:
            alpha = safety_file.split("/")[-2]
            ref, ref_accession, safe_sequences = read_safe_sequence(
                safety_file, fasta_file, stride_file
            )
            tp = 0
            fp = 0
            tn = 0
            fn = 0
            for ss in safe_sequences:
                tp += ss.TP
                fp += ss.FP
                tn += ss.TN
                fn += ss.FN
            
            alphas[alpha][0].append(tp/(tp+fp))
            alphas[alpha][1].append(tp/(tp+fn))
            
    print(f"Alpha 0.50 - avg. precision: {np.mean(alphas['safety.a50'][0]):.2f} - avg. recall: {np.mean(alphas['safety.a50'][1]):.2f}")
    print(f"Alpha 0.75 - avg. precision: {np.mean(alphas['safety.a75'][0]):.2f} - avg. recall: {np.mean(alphas['safety.a75'][1]):.2f}")
    print(f"Alpha 0.99 - avg. precision: {np.mean(alphas['safety.a99'][0]):.2f} - avg. recall: {np.mean(alphas['safety.a99'][1]):.2f}")

In [437]:
work("./../out/alphafold.30.multi-step")

AssertionError: Duplicate sequence: MPTIKLQSSDGEIFEVDVEIAKQSVTIKTMLEDLGMDDEGDDDPVPLPNVNAAILKKVIQWCTHHKDDPPPPEDDENKEKRTDDIPVWDQEFLKVDQGTLFELILAANYLDIKGLLDVTCKTVANMIKGKTPEEIRKTFNIKNDFTEEEEAQVRKENQWCEEK in ./../out/alphafold.30.multi-step/fasta/cluster_104.fasta