# Imports

In [None]:
!nvidia-smi

In [None]:
import atexit
import csv
import os
import tempfile
import time
import warnings
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow
import pyarrow.parquet as pq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from scipy import stats
from torch import optim
from torch.utils.data import DataLoader
from torch_geometric.data import DataLoader
from torch_geometric.nn import ChebConv, EdgeConv, GATConv, GCNConv
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter_

# from google.colab import files

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import kmbio
import kmtools.sci_tools
from kmbio import PDB
from kmtools import structure_tools

import proteinsolver
import proteinsolver.datasets
from proteinsolver.utils.protein import AMINO_ACIDS

# Properties

In [None]:
torch.cuda.is_available()

In [None]:
NOTEBOOK_NAME = "generate_protein_sequences"

In [None]:
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

In [None]:
# STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "1n5uA03.pdb")).resolve()
# STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "4z8jA00.pdb")).resolve()
STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "4unuA00.pdb")).resolve()
# STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "4beuA02.pdb")).resolve()
STRUCTURE_FILE

In [None]:
structure_all = PDB.load(STRUCTURE_FILE)
structure = PDB.Structure(STRUCTURE_FILE.name + "A", structure_all[0].extract('A'))
assert len(list(structure.chains)) == 1

In [None]:
PDB.view_structure(structure)

# Pipeline

## Load protein data

In [1]:
class ProteinData(NamedTuple):
    sequence: str
    row_index: torch.LongTensor
    col_index: torch.LongTensor
    distances: torch.FloatTensor

NameError: name 'NamedTuple' is not defined

In [None]:
def get_interaction_dataset_wdistances(structure_file, model_id, chain_id, r_cutoff=12):
    structure = PDB.load(structure_file)
    chain = structure[0][chain_id]
    num_residues = len(list(chain.residues))
    dd = structure_tools.DomainDef(model_id, chain_id, 1, num_residues)
    domain = structure_tools.extract_domain(structure, [dd])
    distances_core = structure_tools.get_distances(domain, r_cutoff, 0, groupby="residue")
    assert (distances_core["residue_idx_1"] <= distances_core["residue_idx_2"]).all()
    return domain, distances_core

In [None]:
def extract_seq_and_adj(structure_file, chain_id):
    domain, result_df = get_interaction_dataset_wdistances(structure_file, 0, chain_id, r_cutoff=12)
    domain_sequence = structure_tools.get_chain_sequence(domain)
    assert max(result_df["residue_idx_1"].values) < len(domain_sequence)
    assert max(result_df["residue_idx_2"].values) < len(domain_sequence)
    data = ProteinData(
        domain_sequence,
        result_df["residue_idx_1"].values,
        result_df["residue_idx_2"].values,
        result_df["distance"].values,
    )
    return data

In [None]:
pdata = extract_seq_and_adj(STRUCTURE_FILE, 'A')
print(pdata)
print(len(pdata.sequence))

## Load predicted sequences

In [None]:
from dataclasses import dataclass, field
from typing import Any


@dataclass(order=True)
class PrioritizedItem:
    p: float
    x: Any = field(compare=False)
    x_proba: Any = field(compare=False)

In [None]:
results = torch.load(NOTEBOOK_PATH / (STRUCTURE_FILE.stem + ".torch"))

In [None]:
results = [(r.x.data.cpu().numpy(), r.x_proba.data.cpu().numpy()) for r in results]

## Create dataframe

In [None]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
# data = proteinsolver.datasets.protein.transform_edge_attr(data)

sequence_ref = "".join(AMINO_ACIDS[i] for i in data.x)

In [None]:
df = pd.DataFrame(results, columns=["seq_array", "seq_proba_array"])

In [None]:
df["sequence"] = ["".join(AMINO_ACIDS[i] for i in seq_array) for seq_array in df["seq_array"]]

In [None]:
df["sum_log_prob"] = [seq_proba_array.sum() for seq_proba_array in df["seq_proba_array"]]

In [None]:
df["avg_log_prob"] = df["sum_log_prob"] / df["sequence"].str.len()

In [None]:
df["seq_identity"] = [
    float((seq_array == data.x.data.numpy()).sum().item()) / data.x.size(0)
    for seq_array in df["seq_array"]
]

In [None]:
df["index"] = df.index

In [None]:
df = df.sort_values("avg_log_prob", ascending=False).iloc[:200_000]

In [None]:
df.head(2)

In [None]:
columns_to_keep = ["index", "sequence", "sum_log_prob", "avg_log_prob", "seq_identity"]
df[columns_to_keep].to_csv(NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + ".csv"), sep="\t", index=False)

## Make plots

In [None]:
adj = torch_geometric.utils.sparse_to_dense(data.edge_index, 1 / data.edge_attr[:, 0])

with plt.rc_context(rc={"font.size": 18}):
    fg, ax = plt.subplots(figsize=(8 * 0.9, 6 * 0.9))
    out = ax.imshow(adj, cmap="Greys")
    ax.set_ylabel("Amino acid position")
    ax.set_xlabel("Amino acid position")
    ax.tick_params("both", labelsize=16)
    cb = fg.colorbar(out, ax=ax)
    cb.set_label("1 / distance (Å$^{-1}$)")
    
fg.tight_layout()
fg.savefig(NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-adjacency.svg"))

In [None]:
with plt.rc_context(rc={"font.size": 18}):
    fg, ax = plt.subplots(figsize=(8 * 0.9, 6 * 0.9))
    x = np.exp(df["avg_log_prob"])
    y = df["seq_identity"].values + 0.005 * np.random.randn(len(df))
    out = ax.hexbin(x, y, gridsize=50, bins="log", cmap="Greys")
    ax.set_xlim(x.min(), x.max())
    ax.set_ylim(y.min(), y.max())
    ax.set_ylabel("Sequence identity to reference")
    ax.set_xlabel("Normalized sequence probability")
    ax.tick_params("both", labelsize=16)
    cb = fg.colorbar(out, ax=ax)
    cb.set_label("Number of sequences")
fg.tight_layout()
fg.savefig(NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-hexbin.svg"))

In [None]:
kmtools.sci_tools.make_weblogo(
    df["sequence"].values,
    units="probability",
    color_scheme="hydrophobicity",
    stacks_per_line=110,
    format_="svg",
    output_file=NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-weblogo-gen.svg"),
)

In [None]:
kmtools.sci_tools.make_weblogo(
    [sequence_ref],
    units="probability",
    color_scheme="hydrophobicity",
    stacks_per_line=110,
    format_="svg",
    output_file=NOTEBOOK_PATH.joinpath(STRUCTURE_FILE.stem + "-weblogo-wt.svg"),
)