In [1]:
import os
import os.path as osp
from io import StringIO

import py3Dmol
import torch
from Bio.PDB import Atom, Chain
from Bio.PDB import Model as PDBModel
from Bio.PDB import Residue, Structure
from Bio.PDB.mmcifio import MMCIFIO

import config
from model import Model

In [2]:
# Get paths from the latest run
log_dir = osp.join('..', 'logs')
latest_run_filename = sorted(os.listdir(log_dir))[-1]
latest_run_dir = osp.join(log_dir, latest_run_filename)
model_path = osp.join(latest_run_dir, 'model.pth')
test_dataset_path = osp.join(latest_run_dir, 'test_dataset.pt')

try:
    test_dataset = torch.load(test_dataset_path, weights_only=False)
except FileNotFoundError:
    print(f"Ошибка: Файл '{test_dataset_path}' не найден. Убедитесь, что вы запустили train.py.")

node_dim = test_dataset[0].num_node_features + 3
model = Model(node_dim=node_dim, edge_dim=1, hidden_dim=config.HIDDEN_DIM, num_timesteps=config.NUM_TIMESTEPS)

try:
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))
except FileNotFoundError:
    print(f'Ошибка: Файл `{model_path}` не найден. Убедитесь, что вы запустили train.py и тренировка успешно завершилась.')

model = model.eval()

In [7]:
original_full_graph = test_dataset[0]

# Get lateral graphs out of data window
lateral_mask = ~original_full_graph.central_mask
lateral_graphs = original_full_graph.subgraph(lateral_mask)

# Count number of nodes to generate
central_backbone_mask = original_full_graph.backbone_mask & original_full_graph.central_mask
num_nodes_to_generate = central_backbone_mask.sum().item()

with torch.no_grad():
    generated_nodes, generated_pos, generated_edges = model.sample(
        lateral_graphs, num_nodes=num_nodes_to_generate
    )

original_centroid = original_full_graph.centroid
generated_pos = generated_pos + original_centroid

In [8]:
idx_to_atom = {0: 'C', 1: 'N', 2: 'O', 3: 'P'}


def graph_to_cif_string(atom_types, pos):
    structure = Structure.Structure('generated_structure')
    model = PDBModel.Model(0)
    chain = Chain.Chain('A')
    residue = Residue.Residue((' ', 1, ' '), 'UNK', ' ')

    for i, (atom_type_idx, coord) in enumerate(zip(atom_types, pos)):
        atom_name = idx_to_atom.get(atom_type_idx, 'X')
        unique_atom_name = f'{atom_name}{i+1}'
        atom = Atom.Atom(
            name=unique_atom_name, coord=coord, bfactor=0, occupancy=1.0, altloc=' ',
            fullname=unique_atom_name, serial_number=i+1, element=atom_name.strip()
        )
        residue.add(atom)

    chain.add(residue)
    model.add(chain)
    structure.add(model)

    # Сохраняем в буфер в памяти вместо файла
    cif_io = StringIO()
    io = MMCIFIO()
    io.set_structure(structure)
    io.save(cif_io)
    return cif_io.getvalue()


generated_atom_types = torch.argmax(generated_nodes, dim=1).cpu().numpy()
generated_cif_data = graph_to_cif_string(generated_atom_types, generated_pos)

original_central_graph = original_full_graph.subgraph(central_backbone_mask)
orig_atom_types = torch.argmax(original_central_graph.x, dim=1).numpy()
orig_pos = (original_central_graph.pos + original_full_graph.centroid).numpy()
original_cif_data = graph_to_cif_string(orig_atom_types, orig_pos)

In [9]:
view = py3Dmol.view(width=800, height=400, linked=False, viewergrid=(1, 2))

view.addModel(generated_cif_data, 'cif', viewer=(0, 0))
view.setStyle({'stick': {}}, viewer=(0, 0))
view.addLabel('Generated Structure', {'fontColor': 'black', 'backgroundColor': 'lightgray', 'backgroundOpacity': 0.8}, viewer=(0, 0))

view.addModel(original_cif_data, 'cif', viewer=(0, 1))
view.setStyle({'stick': {}}, viewer=(0, 1))
view.addLabel('Original Structure', {'fontColor': 'black', 'backgroundColor': 'lightgray', 'backgroundOpacity': 0.8}, viewer=(0, 1))

view.zoomTo()
view.show()

In [6]:
print(generated_cif_data)

data_generated_structure
#
loop_
_atom_site.group_PDB
_atom_site.id
_atom_site.type_symbol
_atom_site.label_atom_id
_atom_site.label_alt_id
_atom_site.label_comp_id
_atom_site.label_asym_id
_atom_site.label_entity_id
_atom_site.label_seq_id
_atom_site.pdbx_PDB_ins_code
_atom_site.Cartn_x
_atom_site.Cartn_y
_atom_site.Cartn_z
_atom_site.occupancy
_atom_site.B_iso_or_equiv
_atom_site.auth_seq_id
_atom_site.auth_asym_id
_atom_site.pdbx_PDB_model_num
ATOM 1  N N1  . UNK A ? 1 ? -1392.608  -9560.975 8126.702  1.0 0 1 A 1 
ATOM 2  N N2  . UNK A ? 1 ? -11702.601 5111.436  199.711   1.0 0 1 A 1 
ATOM 3  P P3  . UNK A ? 1 ? -1512.822  -4345.569 459.713   1.0 0 1 A 1 
ATOM 4  P P4  . UNK A ? 1 ? 3543.524   -2984.120 -3453.609 1.0 0 1 A 1 
ATOM 5  N N5  . UNK A ? 1 ? 4564.437   -317.562  -1169.019 1.0 0 1 A 1 
ATOM 6  O O6  . UNK A ? 1 ? 2471.104   6240.861  2551.557  1.0 0 1 A 1 
ATOM 7  O O7  . UNK A ? 1 ? 219.967    2695.073  1370.475  1.0 0 1 A 1 
ATOM 8  C C8  . UNK A ? 1 ? 8377.970   172.77

In [10]:
print(generated_cif_data)

data_generated_structure
#
loop_
_atom_site.group_PDB
_atom_site.id
_atom_site.type_symbol
_atom_site.label_atom_id
_atom_site.label_alt_id
_atom_site.label_comp_id
_atom_site.label_asym_id
_atom_site.label_entity_id
_atom_site.label_seq_id
_atom_site.pdbx_PDB_ins_code
_atom_site.Cartn_x
_atom_site.Cartn_y
_atom_site.Cartn_z
_atom_site.occupancy
_atom_site.B_iso_or_equiv
_atom_site.auth_seq_id
_atom_site.auth_asym_id
_atom_site.pdbx_PDB_model_num
ATOM 1  N N1  . UNK A ? 1 ? -4310.748 4374.514  5394.291  1.0 0 1 A 1 
ATOM 2  O O2  . UNK A ? 1 ? -3147.721 -3769.107 -1257.342 1.0 0 1 A 1 
ATOM 3  C C3  . UNK A ? 1 ? 8539.244  -3552.348 7377.309  1.0 0 1 A 1 
ATOM 4  P P4  . UNK A ? 1 ? 2613.244  -7986.999 708.703   1.0 0 1 A 1 
ATOM 5  N N5  . UNK A ? 1 ? 4803.815  -910.471  -104.078  1.0 0 1 A 1 
ATOM 6  P P6  . UNK A ? 1 ? -2455.930 6546.849  931.705   1.0 0 1 A 1 
ATOM 7  O O7  . UNK A ? 1 ? -3509.413 1304.592  1081.833  1.0 0 1 A 1 
ATOM 8  P P8  . UNK A ? 1 ? -3919.169 5357.927  -410