Adapted from: 
- https://github.com/jingraham/neurips19-graph-protein-design/blob/master/data/build_chain_dataset.py
- https://github.com/jingraham/neurips19-graph-protein-design/blob/master/data/mmtf_util.py

In [None]:
!pip install mmtf-python

Collecting mmtf-python
  Downloading https://files.pythonhosted.org/packages/af/d2/aabefd607ff89e35704574205b382a351f531885cc98bfd91baaf192a9fe/mmtf_python-1.1.2-py2.py3-none-any.whl
Installing collected packages: mmtf-python
Successfully installed mmtf-python-1.1.2


In [None]:
import os, time, gzip, urllib, json
import mmtf
from collections import defaultdict

def download_cached(url, target_location):
    """ Download with caching """
    target_dir = os.path.dirname(target_location)
    if not os.path.isfile(target_location):
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)

        # Use MMTF for speed
        response = urllib.request.urlopen(url)
        size = int(float(response.headers['Content-Length']) / 1e3)
        print('Downloading {}, {} KB'.format(target_location, size))
        with open(target_location, 'wb') as f:
            f.write(response.read())
    return target_location


def mmtf_fetch(pdb, cache_dir='cath/mmtf/'):
    """ Retrieve mmtf record from PDB with local caching """
    mmtf_file = cache_dir + pdb + '.mmtf.gz'
    url = 'http://mmtf.rcsb.org/v1.0/full/' + pdb + '.mmtf.gz'
    mmtf_file = download_cached(url, mmtf_file)
    mmtf_record = mmtf.parse_gzip(mmtf_file)
    return mmtf_record


def mmtf_parse(pdb_id, chain, target_atoms = ['N', 'CA', 'C', 'O']):
    """ Parse mmtf file to extract C-alpha coordinates """
    # MMTF traversal derived from the specification 
    # https://github.com/rcsb/mmtf/blob/master/spec.md
    A = mmtf_fetch(pdb_id)

    # Build a dictionary
    mmtf_dict = {}
    mmtf_dict['seq'] = []
    mmtf_dict['coords'] = {code:[] for code in target_atoms}

    # Get chain of interest from Model 0
    model_ix, chain_ix, group_ix, atom_ix = 0, 0, 0, 0
    target_chain_ix, target_entity = next(
        (i, entity) for entity in A.entity_list for i in entity['chainIndexList'] 
        if entity['type'] == 'polymer' and A.chain_name_list[i] == chain
    )

    # Traverse chains
    num_chains = A.chains_per_model[model_ix]
    mmtf_dict['num_chains'] = num_chains
    for ii in range(num_chains):
        chain_name = A.chain_name_list[chain_ix]

        # Chain of interest?
        if chain_ix == target_chain_ix:
            mmtf_dict['seq'] = target_entity['sequence']
            coords_null = [[float('nan')] * 3] * len(mmtf_dict['seq'])
            mmtf_dict['coords'] = {code : list(coords_null) for code in target_atoms}

            # Traverse groups, storing data
            chain_group_count = A.groups_per_chain[chain_ix]
            for jj in range(chain_group_count):
                group = A.group_list[A.group_type_list[group_ix]]

                # Extend coordinate data
                seq_ix = A.sequence_index_list[group_ix]
                for code in target_atoms:
                    if code in group['atomNameList']:
                        A_ix = atom_ix + group['atomNameList'].index(code)
                        xyz = [A.x_coord_list[A_ix], A.y_coord_list[A_ix], A.z_coord_list[A_ix]]
                        mmtf_dict['coords'][code][seq_ix] = xyz

                group_atom_count = len(group['atomNameList'])
                atom_ix += group_atom_count
                group_ix += 1
            chain_ix += 1

        else:
            # Traverse groups
            chain_group_count = A.groups_per_chain[chain_ix]
            for jj in range(chain_group_count):
                group = A.group_list[A.group_type_list[group_ix]]
                atom_ix += len(group['atomNameList'])
                group_ix += 1
            chain_ix += 1

    return mmtf_dict

In [None]:
cath_base_url = 'http://download.cathdb.info/cath/releases/all-releases/v4_3_0/'
cath_domain_fn = 'cath-domain-list-v4_3_0.txt'
cath_domain_url = cath_base_url + 'cath-classification-data/' + cath_domain_fn
cath_domain_file = 'cath/cath-domain-listv-4_3_0.txt'
download_cached(cath_domain_url, cath_domain_file)

Downloading cath/cath-domain-listv-4_3_0.txt, 37121 KB


'cath/cath-domain-listv-4_3_0.txt'

In [None]:
cath_base_url = 'http://download.cathdb.info/cath/releases/all-releases/v4_2_0/'
cath_domain_fn = 'cath-domain-list-v4_2_0.txt'
cath_domain_url = cath_base_url + 'cath-classification-data/' + cath_domain_fn
cath_domain_file = 'cath/cath-domain-listv-4_2_0.txt'
download_cached(cath_domain_url, cath_domain_file)

Downloading cath/cath-domain-listv-4_2_0.txt, 32275 KB


'cath/cath-domain-listv-4_2_0.txt'

In [None]:
import pandas as pd

df4_3 = pd.read_fwf('/content/cath/cath-domain-listv-4_3_0.txt', header=None)
df4_3.drop(labels=[i for i in range(16)], axis=0, inplace=True)
df4_3.reset_index(drop=True, inplace=True)
df4_3[0] = df4_3[0].apply(lambda x: x.split('    ')[0])
df4_3[1] = [ 'v4_3' for i in range(len(df4_3))]
df4_3

Unnamed: 0,0,1
0,1oaiA00,v4_3
1,1go5A00,v4_3
2,3frhA01,v4_3
3,3friA01,v4_3
4,3b89A01,v4_3
...,...,...
500233,4aybQ00,v4_3
500234,3hkzY00,v4_3
500235,3hkzZ00,v4_3
500236,3zbeA00,v4_3


In [None]:
df4_2 = pd.read_fwf('/content/cath/cath-domain-listv-4_2_0.txt', header=None)
df4_2.drop(labels=[i for i in range(16)], axis=0, inplace=True)
df4_2.reset_index(drop=True, inplace=True)
df4_2[0] = df4_2[0].apply(lambda x: x.split('    ')[0])
df4_2[1] = [ 'v4_2' for i in range(len(df4_2))]
df4_2

Unnamed: 0,0,1
0,1oaiA00,v4_2
1,1go5A00,v4_2
2,3frhA01,v4_2
3,3friA01,v4_2
4,3b89A01,v4_2
...,...,...
434852,2kn1A00,v4_2
434853,1vprA01,v4_2
434854,1vprA02,v4_2
434855,1jyoE00,v4_2


In [None]:
df_diff = pd.concat([df4_3,df4_2]).drop_duplicates(subset=0, keep=False)
df_diff = df_diff[df_diff[1] == 'v4_3']
df_diff

Unnamed: 0,0,1
64,5lxjA00,v4_3
79,6jb7A02,v4_3
82,6if1A02,v4_3
83,6if1B02,v4_3
86,6jb6A02,v4_3
...,...,...
500233,4aybQ00,v4_3
500234,3hkzY00,v4_3
500235,3hkzZ00,v4_3
500236,3zbeA00,v4_3


In [None]:
diff_list = df_diff[0].to_list()
# diff_list

In [None]:
from collections import defaultdict

MAX_LENGTH = 500

# CATH base URL
cath_base_url = 'http://download.cathdb.info/cath/releases/all-releases/v4_3_0/'
# cath_base_url = 'http://download.cathdb.info/cath/releases/latest-release/'

# CATH non-redundant set at 40% identity
cath_nr40_fn = 'cath-dataset-nonredundant-S40-v4_3_0.list'
cath_nr40_url = cath_base_url + 'non-redundant-data-sets/' + cath_nr40_fn
cath_nr40_file = 'cath/' + cath_nr40_fn
download_cached(cath_nr40_url, cath_nr40_file)

with open(cath_nr40_file) as f:
    cath_nr40_ids = f.read().split('\n')[:-1]
cath_nr40_chains = list(set(cath_id[:5] for cath_id in cath_nr40_ids))
chain_set = sorted([(name[:4], name[4]) for name in  cath_nr40_chains])
print('chain set: ', len(chain_set))

# CATH hierarchical classification
cath_domain_fn = 'cath-domain-list-v4_3_0.txt'
cath_domain_url = cath_base_url + 'cath-classification-data/' + cath_domain_fn
cath_domain_file = 'cath/cath-domain-listv-4_3_0.txt'
# download_cached(cath_domain_url, cath_domain_file)

# CATH topologies
cath_nodes = defaultdict(list)
with open(cath_domain_file,'r') as f:
    lines = [line.strip() for line in f if not line.startswith('#')]
    for line in lines:
        entries = line.split()
        cath_id, cath_node = entries[0], '.'.join(entries[1:4])
        if cath_id in diff_list:
          chain_name = cath_id[:4] + '.' + cath_id[4]
          cath_nodes[chain_name].append(cath_node)
cath_nodes = {key:list(set(val)) for key,val in cath_nodes.items()}

# Build dataset
dataset = []
for chain_ix, (pdb, chain) in enumerate(chain_set):
    try:

        if pdb in diff_chain:
          # Load and parse coordinates
          print(chain_ix, pdb, chain)
          start = time.time()
          chain_dict = mmtf_parse(pdb, chain)
          stop = time.time() - start

          if len(chain_dict['seq']) <= MAX_LENGTH:
              chain_name = pdb + '.' + chain
              chain_dict['name'] = chain_name
              chain_dict['CATH'] = cath_nodes[chain_name]
              print('@check seq')
              print(pdb, chain, chain_dict['num_chains'], chain_dict['seq'])
              dataset.append(chain_dict)
          else:
              print('Too long')
    except Exception as e:
        print(e)

outfile = 'chain_set-v4_3.jsonl'
with open(outfile, 'w') as f:
    for entry in dataset:
        f.write(json.dumps(entry) + '\n')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_chain' is not defined
name 'diff_

In [None]:
import json


seq_list = []
num_chains_list = []
name_list = []
cath_node_dict = {'test': [], 'cath_nodes':{}, 'train':[], 'val':[]}
cath_dict = {}


c = 0
for line in open('/content/chain_set-v4_3.jsonl', 'r'):
    data = json.loads(line)

    # Sequence data
    seq_list.append(data.get('seq'))

    print('name:', data.get('name'))
    cath_node_dict['test'].append(data.get('name'))

    print('CATH:', data.get('CATH'))
    cath_dict[data.get('name')] = data.get('CATH')

    print('')


cath_node_dict['cath_nodes'] = cath_dict

name: 1c0w.A
CATH: ['6.10.140']

name: 1e2t.A
CATH: ['6.10.140']

name: 1eg3.A
CATH: ['6.10.140']

name: 1ej6.D
CATH: ['1.10.287']

name: 1esx.A
CATH: ['1.20.5']

name: 1g5g.A
CATH: ['1.10.287']

name: 1go4.H
CATH: ['6.10.250']

name: 1gxs.B
CATH: ['6.10.250']

name: 1hkg.A
CATH: ['1.10.287']

name: 1hn3.A
CATH: ['6.10.250']

name: 1htr.P
CATH: ['6.10.140']

name: 1ig8.A
CATH: ['1.10.287']

name: 1k75.A
CATH: ['1.20.5']

name: 1kae.A
CATH: ['1.20.5']

name: 1kg1.A
CATH: ['6.20.370']

name: 1krl.B
CATH: ['6.10.140']

name: 1l5a.A
CATH: ['6.10.250']

name: 1na6.A
CATH: ['3.40.91']

name: 1na6.B
CATH: ['3.40.91']

name: 1p4d.A
CATH: ['6.10.250']

name: 1p4q.A
CATH: ['6.10.140']

name: 1p65.A
CATH: ['6.10.140']

name: 1p9c.A
CATH: ['6.10.250']

name: 1pmm.C
CATH: ['4.10.280']

name: 1prz.A
CATH: ['6.10.140']

name: 1q0v.A
CATH: ['6.10.140']

name: 1q2h.A
CATH: ['6.10.140']

name: 1qv9.A
CATH: ['6.10.140']

name: 1qzz.A
CATH: ['1.10.287']

name: 1r71.A
CATH: ['6.10.250']

name: 1rgx.A
CATH:

In [None]:
cath_node_dict

{'cath_nodes': {'1c0w.A': ['6.10.140'],
  '1e2t.A': ['6.10.140'],
  '1eg3.A': ['6.10.140'],
  '1ej6.D': ['1.10.287'],
  '1esx.A': ['1.20.5'],
  '1g5g.A': ['1.10.287'],
  '1go4.H': ['6.10.250'],
  '1gxs.B': ['6.10.250'],
  '1hkg.A': ['1.10.287'],
  '1hn3.A': ['6.10.250'],
  '1htr.P': ['6.10.140'],
  '1ig8.A': ['1.10.287'],
  '1k75.A': ['1.20.5'],
  '1kae.A': ['1.20.5'],
  '1kg1.A': ['6.20.370'],
  '1krl.B': ['6.10.140'],
  '1l5a.A': ['6.10.250'],
  '1na6.A': ['3.40.91'],
  '1na6.B': ['3.40.91'],
  '1p4d.A': ['6.10.250'],
  '1p4q.A': ['6.10.140'],
  '1p65.A': ['6.10.140'],
  '1p9c.A': ['6.10.250'],
  '1pmm.C': ['4.10.280'],
  '1prz.A': ['6.10.140'],
  '1q0v.A': ['6.10.140'],
  '1q2h.A': ['6.10.140'],
  '1qv9.A': ['6.10.140'],
  '1qzz.A': ['1.10.287'],
  '1r71.A': ['6.10.250'],
  '1rgx.A': ['6.10.250'],
  '1rh7.B': ['6.10.250'],
  '1rp0.A': ['6.10.250'],
  '1rp3.B': ['6.10.140'],
  '1rp3.H': ['6.10.140'],
  '1rq0.A': ['6.10.140'],
  '1sg4.A': ['6.10.250'],
  '1sg4.C': ['6.10.250'],
  '1sr

In [None]:
# chain_set_splits = {'test': name_list, 'train':[], 'val':[]}

with open('chain_set_splits-v4_3.json', 'w') as outfile:
    json.dump(cath_node_dict, outfile)

In [None]:
len(name_list)

1489