In [1]:
%load_ext autoreload
%autoreload 2

In [91]:
import Bio
from lib import DihedralAdherence
import os
from lib.constants import AMINO_ACID_CODES
from Bio import PDB
from Bio.PDB import Superimposer
from Bio.Align import PairwiseAligner
import numpy as np

In [3]:
INCLUSION_RADIUS = 15 # Angstroms
THRESHOLDS = [0.5, 1, 2, 4] # Angstroms

In [188]:
PDBMINE_URL = os.getenv("PDBMINE_URL")
PROJECT_DIR = 'casp_da'
proteins = [
  'T1024', 'T1030', 'T1030-D2', 'T1024-D1', 'T1032-D1', 'T1053-D1', 'T1027-D1', 'T1029-D1',
  'T1025-D1', 'T1028-D1', 'T1030-D1', 'T1053-D2', 'T1057-D1','T1058-D1', 'T1058-D2'
]
da = DihedralAdherence(proteins[0], [4,5,6,7], PDBMINE_URL, PROJECT_DIR, kdews=[1,32,64,128], 
                      mode='ml', weights_file='ml_runs/best_model-kde_16-32_383.pt', device='cpu')

Initializing T1024 ...
Results already exist
Casp ID: T1024 	PDB: 6t1z
Structure exists: 'pdb/pdb6t1z.ent' 
UniProt ID: Q48658


In [190]:
parser = PDB.PDBParser()
xray_structure = parser.get_structure(da.casp_protein_id, da.xray_fn)
pred_structure = parser.get_structure(da.alphafold_id, da.predictions_dir / da.alphafold_id)

Exception ignored.
Some atoms or residues may be missing in the data structure.


In [201]:
chainA = next(iter((xray_structure[0].get_chains())))
[(r.resname,AMINO_ACID_CODES.get(r.resname, 'X')) for r in chainA.get_residues() if not r.resname in AMINO_ACID_CODES]

[('HT1', 'X'),
 ('XP4', 'X'),
 ('XP4', 'X'),
 ('LMU', 'X'),
 ('LMU', 'X'),
 ('GOL', 'X'),
 ('HOH', 'X'),
 ('HOH', 'X'),
 ('HOH', 'X')]

In [225]:
# A is reference

chainA = next(iter((xray_structure[0].get_chains())))
chainB = next(iter((pred_structure[0].get_chains())))

# align
residuesA = ''.join([AMINO_ACID_CODES.get(r.resname, 'X') for r in chainA.get_residues()])
residuesB = ''.join([AMINO_ACID_CODES.get(r.resname, 'X') for r in chainB.get_residues()])
residuesB = residuesB[:300] + residuesB[310:]


aligner = PairwiseAligner()
aligner.mode = 'global'
alignments =  aligner.align(residuesA, residuesB)
print(alignments[0])
aligned = alignments[0].aligned

atomsA = list(chainA.get_atoms())
atomsB = list(chainB.get_atoms())

atomsA = []
atomsB = []
residuesA = list(chainA.get_residues())
residuesB = list(chainB.get_residues())
residuesB = residuesB[:300] + residuesB[310:]
for i,((t1,t2),(q1,q2)) in enumerate(zip(*alignments[0].aligned)):
    for j, (residueA, residueB) in enumerate(zip(residuesA[t1:t2], residuesB[q1:q2])):
        if residueA.resname != residueB.resname:
            print(f'{residueA.resname} != {residueB.resname}')
            continue
        try:
            atomsA.append(residueA['CA'].coord)
            atomsB.append(residueB['CA'].coord)
        except KeyError:
            print(f'No CA atom for {residueA.resname} or {residueB.resname}')
    
    if i < alignments[0].aligned.shape[1] - 1:
        print(alignments[0].aligned[0,i])
        print(alignments[0].aligned[0,i+1])
        print()
        if alignments[0].aligned[0,i][1] != alignments[0].aligned[0,i+1][0]:
            for k in range(alignments[0].aligned[0,i][1], alignments[0].aligned[0,i+1][0]+1):
                atomsA.append(residuesA[k]['CA'].coord)
                atomsB.append(np.full(3, np.nan))
    elif alignments[0].aligned[0,i][1] != len(residuesA):
        print('mismatch - end missing')


target            0 G-KEFWNLDKNLQLRLGIVFLGAFSYGTVFSSMTIYYNQYLGSAITGILLALSAVATFVA
                  0 --||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
query             0 -MKEFWNLDKNLQLRLGIVFLGAFSYGTVFSSMTIYYNQYLGSAITGILLALSAVATFVA

target           59 GILAGFFADRNGRKPVMVFGTIIQLLGAALAIASNLPGHVNPWSTFIAFLLISFGYNFVI
                 60 ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
query            59 GILAGFFADRNGRKPVMVFGTIIQLLGAALAIASNLPGHVNPWSTFIAFLLISFGYNFVI

target          119 TAGNAMIIDASNAENRKVVFMLDYWAQNLSVILGAALGAWLFRPAFEALLVILLLTVLVS
                120 ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
query           119 TAGNAMIIDASNAENRKVVFMLDYWAQNLSVILGAALGAWLFRPAFEALLVILLLTVLVS

target          179 FFLTTFVMTETFKPT---D----NIFQAYKTVLQDKTYMIFMGANIATTFIIMQFDNFLP
                180 |||||||||||||||---|----|||||||||||||||||||||||||||||||||||||
query           179 FFLTTFVMTETFKPTVKVDEKAENIFQAYKTVLQDKTYMIFMGANIATTFIIMQFDNFLP

target          232 VHLS

In [244]:
atomsB[287:]

array([[         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [ -7.03200006,   1.653     ,  20.91699982],
       [ -9.86699963,   4.2420001 ,  20.68199921],
       [         nan,          nan,          nan],
       [         nan,          nan,          nan],
       [-11.33699989,   6.29699993,  17.81599998],
       [-10.26399994,   9.86100006,  18.8220005 ],
       [ -6.50299978,   9.24100018,  19.49799919],
       [ -6.11299992,   7.09399986,  16.30999947],
       [ -8.01000023,   9.77000046,  14.30200005],
       [ -5.61299992,  12.48200035,  15.67300034],
       [ -2.63199997,  10.16899

In [241]:
atomsA = np.array(atomsA)
atomsB = np.array(atomsB)

In [192]:
chainA = next(iter((xray_structure[0].get_chains())))
chainB = next(iter((pred_structure[0].get_chains())))

# align
residuesA = ''.join([AMINO_ACID_CODES.get(r.resname, 'X') for r in chainA.get_residues()])
residuesB = ''.join([AMINO_ACID_CODES.get(r.resname, 'X') for r in chainB.get_residues()])

aligner = PairwiseAligner()
aligner.mode = 'global'
alignments =  aligner.align(residuesA, residuesB)
print(alignments[0])
aligned = alignments[0].aligned

atomsA = list(chainA.get_atoms())
atomsB = list(chainB.get_atoms())

# for atomsA, atomsB in zip(atomsA, atomsB):
    # if atomsA.element != atomsB.element:
        # print(atomsA, atomsB)
atomsA = []
atomsB = []
residuesA = list(chainA.get_residues())
residuesB = list(chainB.get_residues())
for j, (residueA, residueB) in enumerate(zip(residuesA, residuesB)):
    #TODO loop through ref and add nan for pred if no match
    # for j, (residueA, residueB) in enumerate(zip(residuesA, residuesB)):
    if residueA.resname != residueB.resname:
        print(f'{residueA.resname} != {residueB.resname}')
        continue
    try:
        atomsA.append(residueA['CA'].coord)
        atomsB.append(residueB['CA'].coord)
    except KeyError:
        print(f'No CA atom for {residueA.resname} or {residueB.resname}')


# sup = Bio.PDB.Superimposer()
# sup.set_atoms(atomsA, atomsB)
# sup.apply(atomsB)

atomsA = np.array(atomsA)
atomsB = np.array(atomsB)

# 273x1x3 - 273x3 = 273x273x3
# norm(x) = sqrt(sum(x**2, axis=-1))
pairwise_dists_A = np.linalg.norm(atomsA[:, None] - atomsA, axis=-1)
pairwise_dists_B = np.linalg.norm(atomsB[:, None] - atomsB, axis=-1)

# Compute lddt
lddt = []
local_atoms_mask = (pairwise_dists_A <= INCLUSION_RADIUS)
n_local_atoms = local_atoms_mask.sum()
dist_diff = np.abs(pairwise_dists_A - pairwise_dists_B)
dist_diff[~local_atoms_mask] = np.inf
lddt = np.mean([(dist_diff < thresh).sum() / n_local_atoms for thresh in THRESHOLDS])
print(lddt)

GLY != MET
ASP != VAL
ASN != LYS
ILE != VAL
PHE != ASP
GLN != GLU
ALA != LYS
TYR != ALA
LYS != GLU
THR != ASN
VAL != ILE
LEU != PHE
ASP != ALA
LYS != TYR
THR != LYS
TYR != THR
MET != VAL
ILE != LEU
PHE != GLN
MET != ASP
GLY != LYS
ALA != THR
ASN != TYR
ILE != MET
ALA != ILE
THR != PHE
THR != MET
PHE != GLY
ILE != ALA
ILE != ASN
MET != ILE
GLN != ALA
PHE != THR
ASP != THR
ASN != PHE
PHE != ILE
LEU != ILE
PRO != MET
VAL != GLN
HIS != PHE
LEU != ASP
SER != ASN
ASN != PHE
SER != LEU
PHE != PRO
LYS != VAL
THR != HIS
ILE != LEU
THR != SER
TYR != ASN
GLY != SER
GLN != PHE
ARG != LYS
MET != THR
LEU != ILE
ILE != PHE
TYR != TRP
LEU != GLY
ILE != PHE
LEU != GLU
ALA != ILE
CYS != TYR
VAL != GLY
LEU != GLN
VAL != ARG
VAL != MET
LEU != THR
MET != ILE
THR != TYR
THR != LEU
LEU != ILE
ASN != LEU
ARG != ALA
LEU != CYS
THR != VAL
LYS != LEU
ASP != VAL
TRP != VAL
SER != LEU
HIS != LEU
GLN != MET
LYS != THR
GLY != THR
PHE != LEU
ILE != ASN
TRP != ARG
GLY != LEU
SER != THR
LEU != LYS
PHE != ASP
MET != TRP