## Summary

In this notebook we load a network trained to reconstruct protein sequences and use this network to design sequences that fit the geometry of CATH domain [1n5uA03](http://www.cathdb.info/version/latest/domain/1n5uA03).

----

## Imports

In [None]:
import os
from pathlib import Path

from IPython.display import HTML
from IPython.display import display

import matplotlib.pyplot as plt
import pandas as pd
import proteinsolver
import torch
import torch_geometric
from kmbio import PDB
from kmtools import sci_tools

## Properties

In [None]:
NOTEBOOK_NAME = "protein_demo"

In [None]:
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

In [None]:
UNIQUE_ID = "191f05de"

In [None]:
BEST_STATE_FILES = {
    #
    "191f05de": "protein_train/191f05de/e53-s1952148-d93703104.state"
}

In [None]:
STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "1n5uA03.pdb")).resolve()
STRUCTURE_FILE

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

In [None]:
torch.cuda.is_available()

## Load structure

In [None]:
structure_all = PDB.load(STRUCTURE_FILE)
structure = PDB.Structure(STRUCTURE_FILE.name + "A", structure_all[0].extract('A'))
assert len(list(structure.chains)) == 1

In [None]:
PDB.view_structure(structure)

## Load model

In [None]:
%run protein_train/{UNIQUE_ID}/model.py

In [None]:
batch_size = 1
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

In [None]:
state_file = BEST_STATE_FILES[UNIQUE_ID]
state_file

In [None]:
net = Net(
    x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
)
net.load_state_dict(torch.load(state_file, map_location=device))
net.eval()
net = net.to(device)

## Helper functions

In [None]:
import io
import logging
from typing import List
from unittest.mock import patch

import weblogo._cli
from PIL import Image

logger = logging.getLogger(__name__)


class _BytesIO(io.BytesIO):
    def __init__(self):
        super().__init__()
        self.buffer = self


def make_weblogo(
    seqs: List[str],
    units: str = "bits",
    color_scheme: str = "charge",
    stacks_per_line: int = 60,
    format_="png",
    output_file=None,
):
    assert format_ in ["eps", "png", "png_print", "pdf", "jpeg", "svg", "logodata"]

    weblogo_args = [
        "weblogo",
        f"--format={format_}",
        f"--units={units}",
        "--sequence-type=protein",
        f"--stacks-per-line={stacks_per_line}",
        f"--color-scheme={color_scheme}",
        "--scale-width=no",
        '--fineprint=""',
        "--composition=none",
    ]
    fin = io.StringIO()
    _write_sequences(seqs, fin)
    fin.seek(0)

    with patch("sys.stdin", fin), patch("weblogo._cli.sys.argv", weblogo_args), patch(
        "sys.stdout", new_callable=_BytesIO
    ) as patch_out:
        try:
            weblogo._cli.main()
        except RuntimeError as e:
            logger.error("Failed to create WebLogo image because of error: '%s'.", str(e))
            return None
        finally:
            patch_out.seek(0)
            img_data = patch_out.read()

    if output_file:
        with output_file.open("wb") as fout:
            fout.write(img_data)

    if format_ in ["eps", "png", "png_print", "jpeg"]:
        img = Image.open(io.BytesIO(img_data))
    else:
        img = None

    return img


def _write_sequences(seqs, fh):
    for i in range(len(seqs)):
        fh.write(f"> seq_{i}\n")
        fh.write(seqs[i] + "\n")

## Design pipeline

### Load protein sequence and geometry

In [None]:
structure = PDB.load(STRUCTURE_FILE)
pdata = proteinsolver.utils.extract_seq_and_adj(structure, 'A')
print(pdata)
print(f"Protein sequence: '{pdata.sequence}'")
print(f"Number of amino acids: {len(pdata.sequence)}")

### Convert data to suitable format

In [None]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)
data.to(device)

proteinsolver.utils.get_node_outputs(net, data.x, data.edge_index, data.edge_attr).sum().item()

### Run protein design using A* search

<span style="color: red">Feel free to interrupt after several minutes by pressing <code style="color: red">I, I</code> or going to <code style="color: red">Kernel -> Interrupt Kernel</code>.</span>

In [None]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)
data.to(device)

data.y = data.x
x_in = torch.ones_like(data.x) * 20
results = []
try:
    proteinsolver.utils.design_protein(net, x_in, data.edge_index, data.edge_attr, results=results, cutoff=np.log(0.15))
except KeyboardInterrupt:
    pass

### Convert designs into a DataFrame

In [None]:
designs = [(r.x.data.cpu().numpy(), r.x_proba.data.cpu().numpy()) for r in results]

In [None]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
# data = proteinsolver.datasets.protein.transform_edge_attr(data)

sequence_ref = "".join(proteinsolver.utils.AMINO_ACIDS[i] for i in data.x)

In [None]:
df = pd.DataFrame(designs, columns=["seq_array", "seq_proba_array"])

In [None]:
df["sequence"] = ["".join(proteinsolver.utils.AMINO_ACIDS[i] for i in seq_array) for seq_array in df["seq_array"]]

In [None]:
df["sum_log_prob"] = [seq_proba_array.sum() for seq_proba_array in df["seq_proba_array"]]

In [None]:
df["avg_log_prob"] = df["sum_log_prob"] / df["sequence"].str.len()

In [None]:
df["seq_identity"] = [
    float((seq_array == data.x.data.numpy()).sum().item()) / data.x.size(0)
    for seq_array in df["seq_array"]
]

In [None]:
df["index"] = df.index

In [None]:
df = df.sort_values("avg_log_prob", ascending=False).iloc[:200_000]

In [None]:
df.head(2)

In [None]:
columns_to_keep = ["index", "sequence", "sum_log_prob", "avg_log_prob", "seq_identity"]
df[columns_to_keep].to_csv(NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + ".csv"), sep="\t", index=False)

## Make plots

### Adjacency matrix

In [None]:
adj = torch_geometric.utils.to_dense_adj(edge_index=data.edge_index, edge_attr=1 / data.edge_attr[:, 0]).squeeze()

with plt.rc_context(rc={"font.size": 18}):
    fg, ax = plt.subplots(figsize=(8 * 0.9, 6 * 0.9))
    out = ax.imshow(adj, cmap="Greys")
    ax.set_ylabel("Amino acid position")
    ax.set_xlabel("Amino acid position")
    ax.tick_params("both", labelsize=16)
    cb = fg.colorbar(out, ax=ax)
    cb.set_label("1 / distance (Å$^{-1}$)")

fg.tight_layout()
fg.savefig(NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-adjacency.svg"))

### Score distributions

In [None]:
with plt.rc_context(rc={"font.size": 18}):
    fg, ax = plt.subplots(figsize=(8 * 0.9, 6 * 0.9))
    x = np.exp(df["avg_log_prob"])
    y = df["seq_identity"].values + 0.005 * np.random.randn(len(df))
    out = ax.hexbin(x, y, gridsize=50, bins="log", cmap="Greys")
    ax.set_xlim(x.min(), x.max())
    ax.set_ylim(y.min(), y.max())
    ax.set_ylabel("Sequence identity to reference")
    ax.set_xlabel("Normalized sequence probability")
    ax.tick_params("both", labelsize=16)
    cb = fg.colorbar(out, ax=ax)
    cb.set_label("Number of sequences")

fg.tight_layout()
fg.savefig(NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-hexbin.svg"))

### Sequence logo

In [None]:
make_weblogo(
    df["sequence"].values,
    units="probability",
    color_scheme="hydrophobicity",
    stacks_per_line=110,
    format_="svg",
    output_file=NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-weblogo-gen.svg"),
)

In [None]:
make_weblogo(
    [sequence_ref],
    units="probability",
    color_scheme="hydrophobicity",
    stacks_per_line=110,
    format_="svg",
    output_file=NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-weblogo-wt.svg"),
)

In [None]:
NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-weblogo-wt.svg")

In [None]:
display(HTML("""\
<div style="overflow: scroll; width: 100%">
<p style="text-align: center; margin-bottom: 0px"><b>Wild-type</b></p>
<img src="./protein_demo/1n5uA03-weblogo-wt.svg" style="width: 100%"/>

<p style="text-align: center; margin-bottom: 0px"><b>Designs</b></p>
<img src="./protein_demo/1n5uA03-weblogo-gen.svg" style="width: 100%"/>
</div>
"""))