# SCA
### KO K00370 (nar)


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
from mpl_toolkits.axes_grid1 import make_axes_locatable
import tqdm.notebook as tqdm
from importlib import reload

from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import pdist

import mysca.core as mysca_core
import mysca.helpers as mysca_helpers
from mysca.mappings import DEFAULT_MAP
from mysca.io import load_msa
from mysca.preprocess import preprocess_msa
from mysca.core import run_sca, run_ica
from mysca.helpers import get_top_k_conserved_retained_positions
from mysca.helpers import get_conserved_rawseq_positions
from mysca.helpers import get_rawseq_positions_in_groups
from mysca.helpers import get_group_rawseq_positions_by_entry

In [None]:
SEED = None
rng = np.random.default_rng(seed=SEED)

In [None]:
DATDIRBASE = "../data"
OUTDIRBASE = "../out"

KOKEY = "K00370"
# KOKEY = "K00362"

SEQDIR = f"{DATDIRBASE}/{KOKEY}/seqs"
MSADIR = f"{DATDIRBASE}/{KOKEY}/msas"
STRUCTDIR = f"{DATDIRBASE}/{KOKEY}/structures"

SEQ_FPATH = f"{SEQDIR}/input.fasta"
MSA_FPATH = f"{MSADIR}/MSA_800_with_reference.aln-fasta"

# REFSEQ_NAME = "reference"
REFSEQ_NAME = None

In [None]:
OUTDIR = f"{OUTDIRBASE}/sca_by_ko/{KOKEY}"
os.makedirs(OUTDIR, exist_ok=True)

In [None]:
# # Convert letters to numbers (0-20 for gap + AAs)
# AA_TO_INT = {aa: i for i, aa in enumerate("ACDEFGHIKLMNPQRSTVWY-")}
# GAP = AA_TO_INT["-"]

# AA_LIST = np.sort([k for k in AA_TO_INT.keys() if k != "-"])

# NUM_AAS = len(AA_TO_INT)  # 20 amino acids + gap

# assert NUM_AAS == 21, "Should have 20 amino acids and 1 gap marker"
# assert GAP == 20, f"GAP index should equal 20. Got {GAP}"

In [None]:
# Load MSA

from mysca.mappings import SymMap
DEFAULT_MAP = SymMap(
    "ACDEFGHIKLMNPQRSTVWY", "-", ["X"]
)

msa_obj_orig, msa_orig, seqids_orig = load_msa(
    MSA_FPATH, format="fasta", 
    mapping=DEFAULT_MAP,
    verbosity=1
)

NUM_SEQS, NUM_POS = msa_orig.shape
print(f"Loaded MSA shape: {msa_orig.shape} (sequences x positions)")

In [None]:
GAP_TRUNCATION_THRESH = 0.5
SEQUENCE_GAP_THRESH = 0.5
REFERENCE_ID = None
REFERENCE_SIMILARITY_THRESH = 0.0
SEQUENCE_SIMILARITY_THRESH = 1.0
POSITION_GAP_THRESH = 0.25

msa, xmsa, seqids, weights, fi0_pretrunc, retained_sequences, retained_positions, ref_results = preprocess_msa(
    msa_orig, seqids_orig, 
    mapping=DEFAULT_MAP,
    gap_truncation_thresh=GAP_TRUNCATION_THRESH,
    sequence_gap_thresh=SEQUENCE_GAP_THRESH,
    reference_id=REFERENCE_ID,
    reference_similarity_thresh=REFERENCE_SIMILARITY_THRESH,
    sequence_similarity_thresh=SEQUENCE_SIMILARITY_THRESH,
    position_gap_thresh=POSITION_GAP_THRESH,
    verbosity=1,
)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.plot(fi0_pretrunc, ".")
ax.hlines(POSITION_GAP_THRESH, *ax.get_xlim(), linestyle='--', color="r", label="cutoff")

ax.legend()
ax.set_xlim(0, 10 + msa.shape[1])

ax.set_xlabel(f"position")
ax.set_ylabel(f"gap frequency")
ax.set_title(f"Gap frequency by position")

plt.show()

In [None]:
BACKGROUND_FREQ = {
    'A': 0.078, 'C': 0.020, 'D': 0.053, 'E': 0.063,
    'F': 0.039, 'G': 0.072, 'H': 0.023, 'I': 0.053,
    'K': 0.059, 'L': 0.091, 'M': 0.022, 'N': 0.043,
    'P': 0.052, 'Q': 0.042, 'R': 0.051, 'S': 0.071,
    'T': 0.058, 'V': 0.066, 'W': 0.014, 'Y': 0.033
}

BACKGROUND_FREQ_ARRAY = np.zeros(20)
for a in BACKGROUND_FREQ:
    BACKGROUND_FREQ_ARRAY[DEFAULT_MAP[a]] = BACKGROUND_FREQ[a]    
BACKGROUND_FREQ_ARRAY = BACKGROUND_FREQ_ARRAY / BACKGROUND_FREQ_ARRAY.sum()

In [None]:
LAM_REGULARIZATION = 0.03

sca_results = run_sca(
    xmsa, weights,
    background_map=BACKGROUND_FREQ,
    mapping=DEFAULT_MAP,
    background_arr=BACKGROUND_FREQ_ARRAY,
    regularization=LAM_REGULARIZATION,
    return_keys="all",
    pbar=True,
    leave_pbar=True,
)

In [None]:
fi0 = sca_results["fi0"]
fia = sca_results["fia"]
fijab = sca_results["fijab"]
Dia = sca_results["Dia"]
Di = sca_results["Di"]
Cijab_raw = sca_results["Cijab_raw"]
Cij_raw = sca_results["Cij_raw"]
phi_ia = sca_results["phi_ia"]
Cijab_corr = sca_results["Cijab_corr"]
Cij = sca_results["Cij_corr"]

In [None]:
N_TOP_CONSERVED = {
    "K00370": 10,
    "K00362": 5,
}[KOKEY]
topk_conserved_msa_pos, top_conserved_Di = get_top_k_conserved_retained_positions(
    retained_positions, Di, N_TOP_CONSERVED
)

topk_conserved_msa_pos

In [None]:
# Plot conservation
fig, ax = plt.subplots(1, 1, figsize=(10,4))

ax.plot(
    retained_positions, Di, "o",
    color="Blue",
    alpha=0.2
)

ax.plot(
    topk_conserved_msa_pos, top_conserved_Di, "o",
    color="Green",
    alpha=0.5
)

ax.set_xlim(0, NUM_POS)
ax.set_xlabel(f"Position")
plt.ylabel("Relative Entropy (KL Divergence, $D_i$)")
ax.set_title(f"Position-wise Conservation")

plt.show()

In [None]:
# Map MSA positions to raw sequence positions

from mysca.helpers import get_rawseq_indices_of_msa

rawseq_idxs = get_rawseq_indices_of_msa(msa_obj_orig)
rawseq_idxs = rawseq_idxs[retained_sequences,:]
rawseq_idxs = rawseq_idxs[:,retained_positions]

In [None]:
# Eigendecomposition of C_ij (raw and corrected)

evals_sca_raw, evecs_sca_raw = np.linalg.eigh(Cij_raw)
evals_sca_raw = np.flip(evals_sca_raw)
evecs_sca_raw = np.flip(evecs_sca_raw, axis=1)

evals_sca, evecs_sca = np.linalg.eigh(Cij)
evals_sca = np.flip(evals_sca)
evecs_sca = np.flip(evecs_sca, axis=1)

print(f"      Eigenvalue spectrum of Cij (raw): " + 
      f"{evals_sca_raw.min():.3g}, {evals_sca_raw.max():.3f}")
print(f"Eigenvalue spectrum of Cij (corrected): " + 
      f"{evals_sca.min():.3g}, {evals_sca.max():.3f}")

In [None]:
fig, ax = plt.subplots(1, 1)

sc = ax.imshow(
    Cij_raw, 
    cmap="Blues", 
    origin="lower",
    vmax=None,
)

fig.colorbar(sc, label="Covariation")
ax.set_xlabel("(Retained) Position i")
ax.set_ylabel("(Retained) Position j")
ax.set_title("Covariance Matrix for K00370")


fig, ax = plt.subplots(1, 1)

sc = ax.imshow(
    Cij, 
    cmap="Blues", 
    origin="lower",
    vmax=None,
)

fig.colorbar(sc, label="Covariation")
ax.set_xlabel("(Retained) Position i")
ax.set_ylabel("(Retained) Position j")
ax.set_title("SCA Matrix for K00370")

plt.show()

In [None]:

Z = linkage(pdist(Cij, metric='euclidean'), method='ward')

n_clusters = 10
clusters = fcluster(Z, t=n_clusters, criterion='maxclust')


dendro = dendrogram(Z, no_plot=True)
leaf_indices = dendro['leaves']


cmap = plt.cm.turbo
cluster_colors = [to_hex(cmap(i)) for i in np.linspace(0, 1, n_clusters)]

def color_func(link_idx):
    if link_idx < len(clusters):  # Only color leaf nodes
        return cluster_colors[clusters[link_idx] - 1]
    return "#000000"


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 6), 
                               gridspec_kw={'width_ratios': [0.2, 1]})


dendrogram(
    Z,
    orientation='left',
    ax=ax1,
#    color_threshold=max(Z[-n_clusters+1, 2], 0.1),
    link_color_func=color_func,
    above_threshold_color='k'
)

ax1.set_ylabel('Position', fontsize='x-large')
ax1.set_xticks([])
ax1.set_yticks([])


rearranged_data = Cij[leaf_indices][:, leaf_indices]
im = ax2.imshow(
    rearranged_data, 
    aspect='auto', 
    cmap='Blues',
    interpolation='nearest', 
    origin='lower', 
    # vmin=0, vmax=1,
)


boundaries = np.where(np.diff(clusters[leaf_indices]))[0]
for b in boundaries:
    ax2.axhline(b + 0.5, color='black', linestyle='--')
    ax2.axvline(b + 0.5, color='black', linestyle='--')

ax2.set_title('Clustering of Positions', fontsize='x-large')
ax2.set_xlabel('Position', fontsize='x-large')
ax2.set_xticks([])
ax2.set_yticks([])

plt.tight_layout()
plt.show()

In [None]:
from Bio import SeqIO
fasta_fpath = SEQ_FPATH
fastaseqs = SeqIO.parse(fasta_fpath, "fasta")
fastaseqs = {str(e.id): e  for e in fastaseqs}
fastaseqs

In [None]:
conserved_aa_idxs = get_conserved_rawseq_positions(
    msa_obj_orig, retained_sequences, topk_conserved_msa_pos
)
conserved_aa_idxs

In [None]:
# Load PDB structures if available
from mysca.io import get_residue_sequence_from_pdb_structure
from mysca.io import load_pdb_structure

pdb_mappings = {}
missing_pdb_entries = []
nan_filler = np.array([np.nan, np.nan, np.nan])
for i, seqidx in enumerate(retained_sequences):
    entry = msa_obj_orig[int(seqidx)]
    id = entry.id
    conserved_positions = conserved_aa_idxs[i]
    pdbfpath = f"{STRUCTDIR}/{id}.pdb"
    if not os.path.isfile(pdbfpath):
        missing_pdb_entries.append(id)
        continue
    if -1 in conserved_positions:
        print(f"Entry {id} does not contain all conserved positions.")
        continue
    structure = load_pdb_structure(pdbfpath, id=id, quiet=True)
    residues = get_residue_sequence_from_pdb_structure(structure)
    conserved_residues = [
        residues[i] if i >= 0 else None for i in conserved_positions
    ]
    conserved_residue_positions = np.array(
        [nan_filler if r is None else r['CA'].coord for r in conserved_residues]
    )
    pdb_mappings[id] = conserved_residue_positions


In [None]:
pdb_mappings

In [None]:
# Compute pairwise distance matrix for conserved positions.

ncombs = N_TOP_CONSERVED * (N_TOP_CONSERVED - 1) // 2
all_pdists = np.nan * np.ones([len(pdb_mappings), ncombs])

for i, id in enumerate(sorted(list(pdb_mappings.keys()))):
    x = pdb_mappings[id]
    dists = pdist(x, metric="euclidean")
    all_pdists[i] = dists

In [None]:
fig, ax = plt.subplots(1, 1)

sc = ax.imshow(all_pdists, cmap = 'plasma')
plt.ylabel('narG variant')
plt.xlabel('pairwise distance')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig = ax.figure
cbar = fig.colorbar(sc, cax=cax)
cbar.ax.set_ylabel("Distance (Angstroms)")

plt.show()

In [None]:
from sklearn.decomposition import PCA

all_pdists_centered = (all_pdists - all_pdists.mean(0)) / all_pdists.std(0)

fig, ax = plt.subplots(1, 1)

sc = ax.imshow(all_pdists_centered, cmap = 'plasma')
plt.ylabel('narG variant')
plt.xlabel('pairwise distance')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig = ax.figure
cbar = fig.colorbar(sc, cax=cax)
cbar.ax.set_ylabel("Distance (Normalized)")

pca = PCA(n_components=min(20, ncombs))
pca.fit(all_pdists_centered)
print(pca.explained_variance_ratio_)
print(pca.singular_values_)
data_pca = pca.transform(all_pdists_centered)

fig, ax = plt.subplots(1, 1)
ax.plot(
    data_pca[:,0], data_pca[:,1], "."

)

fig, ax = plt.subplots(1, 1)
ax.plot(
    1 + np.arange(len(pca.explained_variance_ratio_)), 
    np.cumsum(pca.explained_variance_ratio_)
)

plt.show()

## Bootstrapping

In [None]:
niters = 10

def shuffle_columns(m, rng=None):
    rng = np.random.default_rng(rng)
    r, c = m.shape
    idx = np.argsort(rng.random((r, c)), axis=0)
    return m[idx, np.arange(c)]


DO_SHUFFLING = True
SHUFFLE_SEED = 13243
shuffling_saveas = f"{OUTDIR}/shuffled_cijs_corrected.npy"

if DO_SHUFFLING:
    rng_shuffler = np.random.default_rng(SHUFFLE_SEED)
    cijs_shuffled = np.full([niters, *Cij.shape], np.nan)
    for iteridx in tqdm.trange(niters):
        msa_shuff = shuffle_columns(msa, rng=rng_shuffler)
        xmsa_shuff = np.eye(21, dtype=bool)[msa_shuff][:,:,:-1]
        res = run_sca(
            xmsa_shuff, weights,
            background_map=BACKGROUND_FREQ,
            mapping=DEFAULT_MAP,
            background_arr=BACKGROUND_FREQ_ARRAY,
            regularization=LAM_REGULARIZATION,
            return_keys=["Cij_corr"],
            pbar=True,
            leave_pbar=False,
        )
        cijs_shuffled[iteridx] = res["Cij_corr"]

    np.save(shuffling_saveas, cijs_shuffled)
else:
    cijs_shuffled = np.load(shuffling_saveas)

In [None]:
evals_shuff = np.full([len(cijs_shuffled), *evals_sca.shape], np.nan)
for i, cij_shuff in enumerate(cijs_shuffled):
    evals = np.linalg.eigvalsh(cij_shuff)
    evals_shuff[i] = np.flip(evals)

In [None]:
fig, ax = plt.subplots(1, 1)

for e in evals_shuff:
    ax.plot(
        1 + np.arange(len(e)), e, ".",
        markersize=3
    )


ax.plot(
    1 + np.arange(len(evals_sca)), evals_sca,
    "k.",
    markersize=2
)

plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)

# Histogram of data eigenvalues
counts, bins, patches = ax.hist(
    evals_sca, bins=100, color="black", alpha=0.8, log=True, label="Data"
)

cutoff = np.mean(evals_shuff[:,1]) + 2 * np.std(evals_shuff[:,1])  # See SI G of [1]
# cutoff = 100
print("significant eigenvalue cutoff:", cutoff)

kstar = np.sum(evals_sca > cutoff)
sig_evals_sca = evals_sca[:kstar]
sig_evecs_sca = evecs_sca[:,:kstar]
print(f"{kstar} significant eigenvalues:", sig_evals_sca)


# Null distribution
bin_centers = 0.5 * (bins[1:] + bins[:-1])
h, bin_edges = np.histogram(evals_shuff.flatten(), bins=bins)
ax.axvline(cutoff, 0, 1, linestyle="--", color="grey")
ax.plot(bin_centers, h / niters, color="red", lw=1.5, label="Null")

ax.legend()
ax.set_xlabel(f"$\\lambda$")
ax.set_ylabel(f"Count")
ax.set_title(f"Spectral decomposition")

plt.show()

In [None]:
sig_evecs_sca.shape

In [None]:
w_ica = run_ica(sig_evecs_sca.T)
v_ica = sig_evecs_sca @ w_ica.T
v_ica.shape

In [None]:
v_ica_normalized = v_ica / np.sqrt(np.sum(np.square(v_ica), axis=0))
for i in range(v_ica.shape[1]):
    maxpos = np.argmax(np.abs(v_ica_normalized[:,i]))
    if v_ica_normalized[maxpos,i] < 0:
        v_ica_normalized[:,i] *= -1

In [None]:
# Get groups
groups = []
for i in range(v_ica_normalized.shape[1]):
    top95 = np.where(v_ica_normalized[:,i] >= np.percentile(v_ica_normalized[:,i], 95))[0]
    groups.append(top95)


In [None]:
# Save groups in MSA coordinates

subdir = f"{OUTDIR}/groups"
os.makedirs(subdir, exist_ok=True)

for i in range(len(groups)):
    np.save(f"{subdir}/group_{i}_msapos.npy", groups[i])


In [None]:
fig, ax = plt.subplots(1, 1)

sc = ax.scatter(
    v_ica_normalized[:,0], v_ica_normalized[:,1],
    c='k', 
    # s=sizes, 
    # cmap='viridis', 
    alpha=0.2, 
    edgecolor='k'
)

for i, gidx in enumerate([0, 1]):
    g = groups[gidx]
    ax.scatter(
        v_ica_normalized[g,0], v_ica_normalized[g,1],
        # c=colors, 
        # s=sizes, 
        # cmap='viridis', 
        alpha=1, 
        edgecolor='k',
        label=f"Group {gidx}"
    )
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
fig, ax = plt.subplots(1, 1)

sc = ax.scatter(
    v_ica_normalized[:,0], v_ica_normalized[:,2],
    c='k', 
    # s=sizes, 
    # cmap='viridis', 
    alpha=0.2, 
    edgecolor='k'
)

for i, gidx in enumerate([0, 1, 2, 3]):
    g = groups[gidx]
    ax.scatter(
        v_ica_normalized[g,0], v_ica_normalized[g,2],
        # c=colors, 
        # s=sizes, 
        # cmap='viridis', 
        alpha=1, 
        edgecolor='k',
        label=f"Group {gidx}"
    )
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(12,5))
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(
    v_ica_normalized[:,0], v_ica_normalized[:,1], v_ica_normalized[:,2], 
    c="k", 
    # s=sizes, 
    # cmap='viridis', 
    alpha=0.2, 
    edgecolor='k'
)

for i, g in enumerate(groups):
    ax.scatter(
        v_ica_normalized[g,0], v_ica_normalized[g,1], v_ica_normalized[g,2], 
        # c=colors, 
        # s=sizes, 
        # cmap='viridis', 
        alpha=1, 
        edgecolor='k'
    )

# ax.scatter(0, 0, 0, "r")

ax.set_xlabel(f"IC 1")
ax.set_ylabel(f"IC 2")
ax.set_zlabel(f"IC 3")
ax.set_title(f"ICA")

ax.view_init(elev=30, azim=40)   # elev ~ tilt, azim ~ around z; tweak to taste
# ax.set_proj_type('persp')

# plt.tight_layout()
plt.show()


In [None]:

fig = plt.figure(figsize=(12,5))
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(
    v_ica_normalized[:,0], v_ica_normalized[:,1], v_ica_normalized[:,3], 
    c="k", 
    # s=sizes, 
    # cmap='viridis', 
    alpha=0.2, 
    edgecolor='k'
)

for i, g in enumerate(groups):
    ax.scatter(
        v_ica_normalized[g,0], v_ica_normalized[g,1], v_ica_normalized[g,3], 
        # c=colors, 
        # s=sizes, 
        # cmap='viridis', 
        alpha=1, 
        edgecolor='k'
    )

# ax.scatter(0, 0, 0, "r")

ax.set_xlabel(f"IC 1")
ax.set_ylabel(f"IC 2")
ax.set_zlabel(f"IC 3")
ax.set_title(f"ICA")

ax.view_init(elev=30, azim=40)   # elev ~ tilt, azim ~ around z; tweak to taste
# ax.set_proj_type('persp')

# plt.tight_layout()
plt.show()


In [None]:
# Save residue groups by sequence

group_rawseq_positions = get_rawseq_positions_in_groups(
    rawseq_idxs, groups
)

group_rawseq_positions_by_entry = get_group_rawseq_positions_by_entry(
    msa_obj_orig, retained_sequences, groups, group_rawseq_positions
)


for groupidx in range(len(groups)):
    subdir = f"{OUTDIR}/sca_groups/group_{groupidx}"
    os.makedirs(subdir, exist_ok=True)
    for i, seqidx in enumerate(retained_sequences):
        entry = msa_obj_orig[int(seqidx)]
        id = entry.id
        pdbfpath = f"{STRUCTDIR}/{id}.pdb"
        group_arr = group_rawseq_positions_by_entry[id][groupidx]
        if os.path.isfile(pdbfpath):
            np.save(f"{subdir}/group_{groupidx}_{id}.npy", group_arr)


In [None]:
group_rawseq_positions_by_entry[msa_obj_orig[int(retained_sequences[0])].id]

In [None]:
# !for f in data/K00370/structures/Soil*.pdb; do s=$(basename $f); s=${s/.pdb/}; echo $s; sh scripts/run_pymol_sca.sh $s; done