## Summary

**Parameters**

- `SEQUENCE_GENERATION_METHOD`
- `STRUCTURE_ID`
- `SLURM_ARRAY_TASK_ID`


**Notes:**

- `astar` method should be given >= 64G memory in order to generate 200k sequences.
- `astar` cannot be ran in parallel.

**SLURM scripts**

```bash
STRUCTURE_ID="4unuA00" SEQUENCE_GENERATION_METHOD="astar" sbatch --mem 64G --time 72:00:00 ./scripts/run_notebook_gpu.sh \
    $(realpath ./notebooks/10_generate_protein_sequence.ipynb)
```


```bash
STRUCTURE_ID="4unuA00" SEQUENCE_GENERATION_METHOD="expectimax" sbatch --mem 32G --time 24:00:00 --array=1-10 ./scripts/run_notebook_gpu.sh \
    $(realpath ./notebooks/10_generate_protein_sequence.ipynb)
```

----

## Imports

In [1]:
import gzip
import heapq
import io
import os
import pickle
import shutil
import time
from pathlib import Path

from IPython.display import HTML
from IPython.display import display
from tqdm.notebook import tqdm

import kmtools.sci_tools
import matplotlib.pyplot as plt
import pandas as pd
import proteinsolver
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch_geometric
from kmbio import PDB
from torch_geometric.data import Batch



## Properties

In [2]:
NOTEBOOK_NAME = "generate_protein_sequences"

In [3]:
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

PosixPath('/home/kimlab1/strokach/workspace/proteinsolver/notebooks/generate_protein_sequences')

In [4]:
UNIQUE_ID = "191f05de"

In [5]:
BEST_STATE_FILES = {
    #
    "191f05de": "protein_train/191f05de/e53-s1952148-d93703104.state"
}

In [6]:
# structure_id = os.getenv("STRUCTURE_ID", "1n5uA03")
# structure_id = os.getenv("STRUCTURE_ID", "4z8jA00")
structure_id = os.getenv("STRUCTURE_ID", "4unuA00")
# structure_id = os.getenv("STRUCTURE_ID", "4beuA02")

In [7]:
STRUCTURE_FILE = Path(
    os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH.parent.parent / "data" / "inputs" / f"{structure_id}.pdb")
).resolve()
STRUCTURE_FILE

PosixPath('/home/kimlab1/strokach/workspace/proteinsolver/data/inputs/4unuA00.pdb')

In [8]:
min_expected_proba_preset = {"1n5uA03": 0.15, "4unuA00": 0.25, "4beuA02": 0.25}

MIN_EXPECTED_PROBA = min_expected_proba_preset.get(structure_id, 0.15)
MIN_EXPECTED_PROBA

0.25

In [9]:
SEQUENCE_GENERATION_METHOD = os.getenv("SEQUENCE_GENERATION_METHOD", "expectimax")

assert SEQUENCE_GENERATION_METHOD in ("astar", "expectimax", "root2expectimax", "root10expectimax")

In [10]:
START_FILE_INDEX = int(os.getenv("SLURM_ARRAY_TASK_ID", 0)) * 1000
START_FILE_INDEX

0

In [11]:
print(os.environ["CUDA_VISIBLE_DEVICES"])

2


In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Load structure

In [13]:
structure_all = PDB.load(STRUCTURE_FILE)
structure = PDB.Structure(STRUCTURE_FILE.name + "A", structure_all[0].extract('A'))
assert len(list(structure.chains)) == 1

In [14]:
view = PDB.view_structure(structure)

view

_ColormakerRegistry()

NGLWidget()

## Load model

In [15]:
%run protein_train/{UNIQUE_ID}/model.py

In [16]:
batch_size = 1
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

In [17]:
state_file = BEST_STATE_FILES[UNIQUE_ID]
state_file

'protein_train/191f05de/e53-s1952148-d93703104.state'

In [18]:
net = Net(
    x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
)
net.load_state_dict(torch.load(state_file, map_location=device))
net.eval()
net = net.to(device)

## Design pipeline

### Load protein sequence and geometry

In [19]:
pdata = proteinsolver.utils.extract_seq_and_adj(STRUCTURE_FILE, 'A')
print(pdata)

ProteinData(sequence='SALTQPPSASGSLGQSVTISCTGTSSDVGGYNYVSWYQQHAGKAPKVIIYEVNKRPSGVPDRFSGSKSGNTASLTVSGLQAEDEADYYCSSYEGSDNFVFGTGTKVTVL', row_index=array([  0,   0,   0, ..., 106, 106, 107]), col_index=array([  1,   2,   3, ..., 107, 108, 108]), distances=array([1.29767875, 3.76060342, 6.58874989, ..., 1.32708647, 4.12791238,
       1.32601478]))


In [20]:
sequence_ref = pdata.sequence
print(len(sequence_ref), sequence_ref)

109 SALTQPPSASGSLGQSVTISCTGTSSDVGGYNYVSWYQQHAGKAPKVIIYEVNKRPSGVPDRFSGSKSGNTASLTVSGLQAEDEADYYCSSYEGSDNFVFGTGTKVTVL


### Convert data to suitable format

In [21]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)
data.to(device)

proteinsolver.utils.get_node_proba(net, data.x, data.edge_index, data.edge_attr).sum().item()

-187.21701049804688

### Helper functions

In [22]:
@torch.no_grad()
def design_sequence(net, data, normalize_fn=None, num_categories=None):
    if num_categories is None:
        num_categories = data.x.max().item()
    batch_size = data_batch.batch.max().item() + 1

    x = data.x.clone()
    x_proba = torch.zeros_like(x).to(torch.float)
    index_array_ref = torch.arange(x.size(0))
    mask_ref = x == num_categories
    while mask_ref.any():
        output = net(x, data.edge_index, data.edge_attr)
        output_proba_ref = torch.softmax(output, dim=1)
        output_proba_max_ref, _ = output_proba_ref.max(dim=1)

        for i in range(batch_size):
            mask = mask_ref & (data.batch == i)

            index_array = index_array_ref[mask]

            max_probas = output_proba_max_ref[mask]
            max_proba_index = index_array[max_probas.argmax().item()]

            assert x[max_proba_index] == num_categories, x[max_proba_index]
            assert x_proba[max_proba_index] == 0, x_proba[max_proba_index]
            category_probas = output_proba_ref[max_proba_index]
            if normalize_fn is not None:
                category_probas_norm = normalize_fn(category_probas)
            else:
                category_probas_norm = category_probas
            chosen_category = torch.multinomial(category_probas_norm, 1).item()
            chosen_category_proba = category_probas[chosen_category]

            assert chosen_category != num_categories
            x[max_proba_index] = chosen_category
            x_proba[max_proba_index] = chosen_category_proba
        mask_ref = x == num_categories
        del output, output_proba_ref, output_proba_max_ref
    return x, x_proba

In [23]:
from dataclasses import dataclass
from dataclasses import field
from typing import Any


def load_heap_dump(heap_file):
    if heap_file.is_file():
        try:
            return torch.load(heap_file)
        except Exception as e:
            print(f"Encountered error loading heap file '{heap_file}': '{e}'.")

    heap_file_bak = heap_file.with_suffix(".pickle.bak")
    if heap_file_bak.is_file():
        try:
            return torch.load(heap_file_bak)
        except Exception as e:
            print(f"Encountered error loading heap file '{heap_file_bak}': '{e}'.")

    return None


def update_heap_dump(heap_file, heap):
    try:
        shutil.copy2(heap_file, heap_file.with_suffix(".pickle.bak"))
    except FileNotFoundError:
        pass
    torch.save(heap, heap_file)


def get_descendents(net, x, x_proba, edge_index, edge_attr, cutoff):
    index_array = torch.arange(x.size(0))
    mask = x == 20

    with torch.no_grad():
        output = net(x.to(device), edge_index.to(device), edge_attr.to(device)).cpu()
    output = torch.softmax(output, dim=1)
    output = output[mask]
    index_array = index_array[mask]

    max_proba, max_index = output.max(dim=1)[0].max(dim=0)
    row_with_max_proba = output[max_index]

    sum_log_prob = x_proba.sum()
    assert sum_log_prob.item() <= 0, x_proba
    #     p_cutoff = min(torch.exp(sum_log_prob), row_with_max_proba.max()).item()

    children = []
    for i, p in enumerate(row_with_max_proba):
        #         if p < p_cutoff:
        #             continue
        x_clone = x.clone()
        x_proba_clone = x_proba.clone()
        assert x_clone[index_array[max_index]] == 20
        assert x_proba_clone[index_array[max_index]] == cutoff
        x_clone[index_array[max_index]] = i
        x_proba_clone[index_array[max_index]] = torch.log(p)
        children.append((x_clone, x_proba_clone))
    return children


def design_sequence_astar(
    net, data, cutoff, num_categories=20, max_results=5_000, max_heap_size=10_000_000, heap=None
):
    x_proba = torch.ones_like(data.x).to(torch.float) * cutoff
    if heap is None:
        heap = [proteinsolver.utils.PrioritizedItem(0, data.x.cpu(), x_proba.cpu())]
    results = []
    
    pbar = tqdm(total=max_results)
    while len(results) < max_results:
        try:
            item = heapq.heappop(heap)
        except IndexError:
            break
        if not (item.x == num_categories).any():
            results.append((item.x.data, item.x_proba.exp().sum().item(), item.x_proba.sum().item()))
            pbar.update(1)
        else:
            children = get_descendents(net, item.x, item.x_proba, data.edge_index, data.edge_attr, cutoff)
            for x, x_proba in children:
                heapq.heappush(
                    heap, proteinsolver.utils.PrioritizedItem(-x_proba.sum().item(), x.to("cpu"), x_proba.to("cpu"))
                )
        if len(heap) > max_heap_size:
            heap = heap[: len(heap) // 2]
            heapq.heapify(heap)
    return results, heap

### Run protein design using expectimax search

In [24]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)
data.to(device)

Data(edge_attr=[4684, 2], edge_index=[2, 4684], x=[109])

In [25]:
amino_acids = proteinsolver.utils.AMINO_ACIDS

In [26]:
def get_output_file(file_index):
    return NOTEBOOK_PATH.joinpath(f"designs-{SEQUENCE_GENERATION_METHOD}-{STRUCTURE_FILE.stem}-{file_index}.parquet")

In [27]:
file_index = START_FILE_INDEX

while get_output_file(file_index).is_file():
    file_index += 1
    
file_index

0

In [28]:
if SEQUENCE_GENERATION_METHOD == "astar":
    data.y = data.x
    data.x = torch.ones_like(data.x) * 20

    heap_file = NOTEBOOK_PATH.joinpath(f"heap-{STRUCTURE_FILE.stem}.pickle")
    #     heap = load_heap_dump(heap_file)
    heap = None

    while True:
        results, heap = design_sequence_astar(
            net, data, cutoff=np.log(MIN_EXPECTED_PROBA), num_categories=20, max_results=20_000, heap=heap
        )
        print(len(heap))
        #         update_heap_dump(heap_file, heap)
        results = [("".join(amino_acids[i] for i in r[0]),) + r[1:] for r in results]
        df = pd.DataFrame(results, columns=["sequence", "probas_sum", "probas_log_sum"])
        table = pa.Table.from_pandas(df, preserve_index=False)
        pq.write_table(table, get_output_file(file_index))
        file_index += 1

In [29]:
if SEQUENCE_GENERATION_METHOD.endswith("expectimax"):

    if SEQUENCE_GENERATION_METHOD == "root2expectimax":
        root = 2
    elif SEQUENCE_GENERATION_METHOD == "root10expectimax":
        root = 10
    else:
        root = None

    print(SEQUENCE_GENERATION_METHOD, root)
    normalize_fn = (lambda proba: proba ** (1 / root)) if root is not None else None

    batch_size = int(512 * (92 / len(sequence_ref)) ** 1.5)
    print(f"batch_size: {batch_size}")

    data_batch = Batch.from_data_list([data.clone() for _ in range(batch_size)]).to(device)
    data_batch.x = torch.ones_like(data_batch.x) * 20

    while True:
        results = []
        for _ in tqdm(range(int(np.ceil(5_000 / batch_size))), desc=str(file_index)):
            batch_values, batch_probas = design_sequence(net, data_batch, normalize_fn=normalize_fn)
            for i in range(batch_size):
                values = batch_values[data_batch.batch == i]
                probas = batch_probas[data_batch.batch == i]
                sequence = "".join(amino_acids[i] for i in values)
                probas_sum = probas.sum().item()
                probas_log_sum = probas.log().sum().item()
                results.append((sequence, probas_sum, probas_log_sum))
        df = pd.DataFrame(results, columns=["sequence", "probas_sum", "probas_log_sum"])
        table = pa.Table.from_pandas(df, preserve_index=False)
        pq.write_table(table, get_output_file(file_index))
        file_index += 1

expectimax None
batch_size: 397


HBox(children=(IntProgress(value=0, description='0', max=13, style=ProgressStyle(description_width='initial'))…

KeyboardInterrupt: 