## Process embeddings


In [2]:
import pandas as pd
import numpy as np
import h5py
from glob import glob

In [3]:
embeddings = glob("../data/ucyn-a_enriched/embeddings/*.csv")
dfs = [(e, pd.read_csv(e)) for e in embeddings]

In [4]:
dfs[0][0].split("/")[-1].split(".")[0]

In [5]:
# Write embeddings to H5 file
with h5py.File("../data/ucyn-a_enriched/embeddings.h5", "w") as f:
    for i, (name, df) in enumerate(dfs):
        name = name.split("/")[-1].split(".")[0] + " Cytoplasm-U new_test_set"
        f.create_dataset(name, data=df.values)

## Process Predictions


In [11]:
import pandas as pd
from Bio import SeqIO
from matplotlib import pyplot as plt

In [12]:
ids = [seq.id for seq in SeqIO.parse("../data/ucyn-a_enriched/ucyn-a_enriched.fasta", "fasta")]
df = pd.read_csv(
    "../data/ucyn-a_enriched/localizations.txt", header=None, names=["localization"]
)
df["id"] = ids

# Plot localization distribution
df["localization"].value_counts().plot(kind="bar")
plt.show()

## MuLocDeep


In [2]:
import pandas as pd
import numpy as np
from glob import glob
from Bio import SeqIO, AlignIO
from matplotlib import pyplot as plt
from collections import defaultdict

In [3]:
alignment = AlignIO.read("../data/ucyn-a_enriched/ucyn-a_enriched_cobalt.fa", "fasta")
for align in alignment:
    new_id = " ".join(align.description.split(" ")[1:])
    align.id = new_id
    align.description = ""

alignment_dict = {record.id: record.seq for record in alignment}

In [4]:
# Parse predictions
attention_weights_str = "\n".join(
    open(f).read()
    for f in glob("../data/ucyn-a_enriched/mulocdeep-localization/*/attention_weights.txt")
)
sub_cellular_prediction_str = "\n".join(
    open(f).read()
    for f in glob(
        "../data/ucyn-a_enriched/mulocdeep-localization/*/sub_cellular_prediction.txt"
    )
)
sub_organellar_prediction_str = "\n".join(
    open(f).read()
    for f in glob(
        "../data/ucyn-a_enriched/mulocdeep-localization/*/sub_organellar_prediction.txt"
    )
)

attention_weights = {}
attention_weights_lines = attention_weights_str.split("\n")
while len(attention_weights_lines) > 0:
    line = attention_weights_lines.pop(0).strip()
    if line == "":
        continue
    if line.startswith(">"):
        prot_id = line[1:]
        line = attention_weights_lines.pop(0).strip()
        attention_weights[prot_id] = np.array([float(x) for x in line.split(" ")])

sub_cellular_predictions = {}
for line in sub_cellular_prediction_str.split("\n"):
    line = line.strip()
    if line == "":
        continue
    prot_id, pred = line.split(":", maxsplit=1)
    prot_id = prot_id.strip()[1:]
    sub_cellular_predictions[prot_id] = {}
    for p in pred.split("\t"):
        p = p.strip()
        if p == "":
            continue
        target, prob = p.split(":")
        if target == "prediction":
            continue
        prob = float(prob)
        sub_cellular_predictions[prot_id][target] = prob
    
sub_organellar_predictions = {}
for line in sub_organellar_prediction_str.split("\n"):
    line = line.strip()
    if line == "":
        continue
    prot_id, pred = line.split(":", maxsplit=1)
    prot_id = prot_id.strip()[1:]
    sub_organellar_predictions[prot_id] = {}
    for p in pred.split("\t"):
        p = p.strip()
        if p == "":
            continue
        target, prob = p.split(":")
        if target == "prediction":
            continue
        prob = float(prob)
        sub_organellar_predictions[prot_id][target] = prob

In [9]:
subcell_loc_categories = set(list(sub_cellular_predictions.values())[0].keys())
total_probs = defaultdict(lambda: 0)
for _, targets in sub_cellular_predictions.items():
    for t, p in targets.items():
        total_probs[t] += p

# Normalize
total_p = sum(total_probs.values())
for i in total_probs:
    total_probs[i] /= total_p

plt.bar(total_probs.keys(), total_probs.values())
plt.xticks(rotation=30, ha='right')
plt.show()

In [63]:
# Plot the average attention over all aligned sequences

align_len = len(alignment_dict[list(attention_weights.keys())[0]])
num_seqs = len(alignment_dict)
total_attn = np.zeros(align_len)

for seq_id, req in alignment_dict.items():
    attn_vec = attention_weights[seq_id]
    
    # Create alignment map
    alignment_map = {}
    unaligned_idx = 0
    for idx, res in enumerate(req):
        if res != "-":
            alignment_map[unaligned_idx] = idx
            unaligned_idx += 1
            
    # Align attention vector
    for idx, w in enumerate(attn_vec):
        total_attn[alignment_map[idx]] += w
        
avg_attn = total_attn / num_seqs

plt.plot(avg_attn)
plt.title("Average attention over all aligned sequences")
plt.show()

In [66]:
def compute_attention_per_normalized_position(attention_weights):
    attn = defaultdict(lambda: 0)
    for _, attn_vec in attention_weights.items():
        for idx, w in enumerate(attn_vec):
            normalized_pos = idx / len(attn_vec)
            attn[normalized_pos] += w
            
    attn = sorted(attn.items(), key=lambda x: x[0])
    attn_x = [x[0] for x in attn]
    attn_y = [x[1] for x in attn]
    return attn_x, attn_y

In [85]:
attn_x, attn_y = compute_attention_per_normalized_position(attention_weights)

plt.plot(attn_x[1:], attn_y[1:])
plt.title("Total attention vs. normalized residue position")
plt.show()

In [82]:
n_term_frac = 0.2   
subcell_loc_categories = set(list(sub_cellular_predictions.values())[0].keys())

n_term_attn_per_cat = {}
for cat in subcell_loc_categories:
    seq_ids = [k for k, v in sub_cellular_predictions.items() if cat == max(v.keys(), key=lambda x: v[x])]
    attn = {k: attention_weights[k] for k in seq_ids}
    attn_x, attn_y = compute_attention_per_normalized_position(attn)
    if sum(attn_y) == 0:
        continue
    n_terminal_attention = sum(attn_y[:int(len(attn_y) * n_term_frac)])
    non_n_terminal_attention = sum(attn_y[int(len(attn_y) * n_term_frac):])
    n_term_attn_per_cat[cat] = n_terminal_attention / non_n_terminal_attention
    
plt.bar(n_term_attn_per_cat.keys(), n_term_attn_per_cat.values())
plt.xticks(rotation=30, ha='right')
plt.title("N-terminal attention per subcellular localization category")
plt.show()

In [83]:
c_term_frac = 0.2

c_term_attn_per_cat = {}
for cat in subcell_loc_categories:
    seq_ids = [k for k, v in sub_cellular_predictions.items() if cat == max(v.keys(), key=lambda x: v[x])]
    attn = {k: attention_weights[k] for k in seq_ids}
    attn_x, attn_y = compute_attention_per_normalized_position(attn)
    if sum(attn_y) == 0:
        continue
    c_terminal_attention = sum(attn_y[-int(len(attn_y) * c_term_frac):])
    non_c_terminal_attention = sum(attn_y[:-int(len(attn_y) * c_term_frac)])
    c_term_attn_per_cat[cat] = c_terminal_attention / non_c_terminal_attention
    
plt.bar(c_term_attn_per_cat.keys(), c_term_attn_per_cat.values())
plt.xticks(rotation=30, ha='right')
plt.title("C-terminal attention per subcellular localization category")
plt.show()