### How to train and use model for contact map predictions

Install requirements from requirements.txt before running notebook and install BLAST by running this command

`apt install ncbi-blast+`

To create `fasta` file from which you can create local BLAST database you can run this code 

In [3]:
import os
import re
import tqdm
import glob

from Bio.PDB import PDBParser
from dataset.utils import residues_id_map

In [None]:
def build_fastas(id, seq, fasta_file):
    name = re.sub('[^a-zA-Z0-9]', '_', id)

    with open(fasta_file, 'a') as f:
        f.write(f"> {name} \n")
        f.write(seq)
        f.write('\n\n')

parser = PDBParser()
pdb_files_folder = "path"
fasta_file = "database.fasta"
all_pdbs = glob.glob(pdb_files_folder)
for i, pdb in tqdm.tqdm(enumerate(all_pdbs)):
    structure = parser.get_structure(os.path.basename(os.path.splitext(pdb)[0]), pdb)
    for model in structure:
        for chain in model:
            protein_sequence = ""
            residues = []
            for residue in chain:
                if residue.id[0] == ' ':
                    protein_sequence += residues_id_map.get(residue.get_resname(), "X")
                    residues.append(residue)
            build_fastas(structure.id + "_" + str(model.id) + "_" + chain.id, protein_sequence, fasta_file)

To create database just run this code

In [None]:
from Bio.Blast.Applications import NcbimakeblastdbCommandline

fasta_file = "database.fasta"
db_name = "database_name"

makeblastdb_cline = NcbimakeblastdbCommandline(input_file=fasta_file, dbtype="prot", out=db_name)
stdout, stderr = makeblastdb_cline()

print("Database created successfully.")

# Training

In [None]:
# Currently we are just removing this warnings to not mess up outputs by them, but in future they should be fixed
import warnings
from Bio import BiopythonParserWarning
from Bio.PDB.PDBExceptions import PDBConstructionWarning

warnings.filterwarnings("ignore", category=BiopythonParserWarning)
warnings.filterwarnings("ignore", category=PDBConstructionWarning)

In [None]:
from model import ContactMapPredictor
from dataset import PDBDataModule
from lightning_module import ContactMapLightningModule

import esm
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger


train_dir = "../train"  # Directory with training PDB files
test_dir = "../test"    # Directory with test PDB files
batch_size = 1
num_workers = 0

esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm_model = esm_model.to(device)

data_module = PDBDataModule(train_dir, test_dir, esm_model, batch_converter, device,
                            database_folder=train_dir, db_name="train_db", e_value_threshold=1e-3, k=5,
                            val_split=0.2, batch_size=batch_size, num_workers=num_workers, max_sequence_length=100)
data_module.setup()

model = ContactMapPredictor(embeddings_size=320, fusion_num_blocks=10)
lightning_module = ContactMapLightningModule(model=model, learning_rate=1e-3)
lightning_module.to(device)

logger = TensorBoardLogger("logs", name="contact_map_experiment")
trainer = pl.Trainer(max_epochs=10, logger=logger)

trainer.fit(lightning_module, data_module.train_dataloader(), data_module.val_dataloader())

trainer.test(lightning_module, dataloaders=data_module.test_dataloader())