In [1]:
import os
import pandas as pd
import json
from Bio.PDB import *
from Bio import SeqIO
import nglview as nv
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from biofunctions.biofunctions import *
pd.set_option('display.max_columns', 100)
%matplotlib inline

**The objective of this notebook is to obtain the inputs and outputs to the Language Model from the interactions found in Notebook 1., and create the input-output dataset**

# 1. Create input and output per pdb

1. expand_ag_chain_seq must have a copy where each entry ist ag_letter-seqid -> Do this in all the the get_full_seq functions
2. the same for expand_cdr_seq
3. create a dataframe where columns are ag items and row are cdr items
4. For every item in ab_letters and ag_letters in interactions dict, add 1 in the dataframe
5. Fill the rest with null values
6. Save the matrix as an array

For each ab_chain I we'll get:

* ab: cdr1_start - 6 to cdr1_end + 6
* ag: min_seqid - 6 to max_seqid + 6
* out: contact matrix


* ab: cdr2_start - 6 to cdr2_end + 6
* ag: min_seqid - 6 to max_seqid + 6
* out: contact matrix


* ab: cdr3_start - 6 to cdr3_end + 6
* ag: min_seqid - 6 to max_seqid + 6
* out: contact matrix


And it will look like this:

````bash
<CDR2> P K T L I Y R A N R L M I G V <ag>  D A T P E D L N <out> . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ . | | | | | | | | _ | | | | | | | | . _ . | | | . | | . . _ . . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ . . . . . . . . . _ .
````

**Some terms:** 
* interactions_dict: a dictionary that contains all the contact information about a single pdb
* amino_acids_dict: a dictionary that contains a mapping between the 3-letter code to 1-letter code
* chains_dict: a dictionary mapping the chain_label to the chain structure of a given pdb.
* chain: a chain from the PDB.Bio.Structure.Structure
* seqid: the sequence id in the PDB according to the IMGT residue numbering
* AB: antibody
* AG: antigen
* cdr_dict: a dictionary that contains sequence information about the CDR and AG chains that are in contact. A single cdr_dict belongs to a single AB chain

````python
cdr1_start = 27
cdr1_end = 38
cdr2_start = 56
cdr2_end = 65
cdr3_start = 105
cdr3_end = 117
````

In [2]:
with open('pdb_dict.json','r') as f:
    pdb_dict = json.load(f)

In [3]:
with open('word_vocab.txt','r') as f:
    vocab = [line.strip() for line in f]

In [4]:
total_pdbs = len(pdb_dict['pdbs'])
n=6
count = 1
all_in_out_list = []

for pdb in pdb_dict['pdbs']:  

    if not pdb_dict['pdbs'][pdb]:
        print(f'PDB {pdb} is pending')

        count += 1
        continue

    interactions_dict = pdb_dict['pdbs'][pdb]

    in_out_list = create_in_out_str(interactions_dict,n)

    if in_out_list:
        all_in_out_list = all_in_out_list + in_out_list

    if count%10 == 0:
        print(f'{count} analyzed pdbs out of {total_pdbs}')


    count += 1

data_parser = DataParser(all_in_out_list, vocab)
data = data_parser.encode_data()

with open('input_output.json','w') as f:
    json.dump(data, f)
print('Finished...')

10 analyzed pdbs out of 586
20 analyzed pdbs out of 586
30 analyzed pdbs out of 586
40 analyzed pdbs out of 586
50 analyzed pdbs out of 586
60 analyzed pdbs out of 586
70 analyzed pdbs out of 586
80 analyzed pdbs out of 586
90 analyzed pdbs out of 586
100 analyzed pdbs out of 586
110 analyzed pdbs out of 586
120 analyzed pdbs out of 586
130 analyzed pdbs out of 586
140 analyzed pdbs out of 586
150 analyzed pdbs out of 586
160 analyzed pdbs out of 586
170 analyzed pdbs out of 586
180 analyzed pdbs out of 586
190 analyzed pdbs out of 586
200 analyzed pdbs out of 586
210 analyzed pdbs out of 586
220 analyzed pdbs out of 586
230 analyzed pdbs out of 586
240 analyzed pdbs out of 586
250 analyzed pdbs out of 586
260 analyzed pdbs out of 586
270 analyzed pdbs out of 586
280 analyzed pdbs out of 586
290 analyzed pdbs out of 586
300 analyzed pdbs out of 586
310 analyzed pdbs out of 586
320 analyzed pdbs out of 586
330 analyzed pdbs out of 586
340 analyzed pdbs out of 586
350 analyzed pdbs out o

In [5]:
data[0]

{'input_text': '<CDR1>LSCKALGYIFTDYEIHWVKQ<ag>DATPEDLNAK',
 'output_text': '<out>.........._..........._..........._..........._..........._..........._.||||......_.||||||.|.._.||||||.|.|_.||||.|.|.|_|.|||||||||_|||||||||||_|||||||||||_|.|||||||||_|.........._..........._..........._..........._..........._..........._.',
 'input_ids': [5,
  22,
  28,
  14,
  21,
  13,
  22,
  18,
  32,
  20,
  17,
  29,
  15,
  32,
  16,
  20,
  19,
  31,
  30,
  21,
  26,
  8,
  15,
  13,
  29,
  25,
  16,
  15,
  22,
  24,
  13,
  21],
 'labels': [9,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  12,
  12,
  12,
  12,
 

# 2. Parse data

Example of how these data are parsed for training

In [2]:
with open('input_output.json','r') as f:
    data = json.load(f)

In [3]:
with open('word_vocab.txt','r') as f:
    vocab = [line.strip() for line in f]

In [4]:
data[0].keys()

dict_keys(['input_text', 'output_text', 'input_ids', 'labels'])

In [5]:
data[0]

{'input_text': '<CDR1>LSCKALGYIFTDYEIHWVKQ<ag>DATPEDLNAK',
 'output_text': '<out>.........._..........._..........._..........._..........._..........._.||||......_.||||||.|.._.||||||.|.|_.||||.|.|.|_|.|||||||||_|||||||||||_|||||||||||_|.|||||||||_|.........._..........._..........._..........._..........._..........._.',
 'input_ids': [5,
  22,
  28,
  14,
  21,
  13,
  22,
  18,
  32,
  20,
  17,
  29,
  15,
  32,
  16,
  20,
  19,
  31,
  30,
  21,
  26,
  8,
  15,
  13,
  29,
  25,
  16,
  15,
  22,
  24,
  13,
  21],
 'labels': [9,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  12,
  12,
  12,
  12,
 

In [6]:
data_parser = DataParser(data,vocab)

In [None]:
''''''

input_indices -> input_ids in dataset

label_indices -> labels in dataset

collator: just padding and attention masking


In [7]:
data_parser[0].keys()

dict_keys(['input_text', 'output_text', 'input_ids', 'labels'])

In [8]:
print(data_parser[0]['labels'])

[9, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 12, 12, 12, 12, 11, 11, 11, 11, 11, 11, 10, 11, 12, 12, 12, 12, 12, 12, 11, 12, 11, 11, 10, 11, 12, 12, 12, 12, 12, 12, 11, 12, 11, 12, 10, 11, 12, 12, 12, 12, 11, 12, 11, 12, 11, 12, 10, 12, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 10, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 10, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 10, 12, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 10, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 11]


In [9]:
vocab_map = {item:idx for (idx,item) in enumerate(vocab)}

In [10]:
n_samples = 3
data_examples_list = [data_parser[i] for i in range(n_samples)]
data_examples_list[0]

{'input_text': '<CDR1>LSCKALGYIFTDYEIHWVKQ<ag>DATPEDLNAK',
 'output_text': '<out>.........._..........._..........._..........._..........._..........._.||||......_.||||||.|.._.||||||.|.|_.||||.|.|.|_|.|||||||||_|||||||||||_|||||||||||_|.|||||||||_|.........._..........._..........._..........._..........._..........._.',
 'input_ids': [5,
  22,
  28,
  14,
  21,
  13,
  22,
  18,
  32,
  20,
  17,
  29,
  15,
  32,
  16,
  20,
  19,
  31,
  30,
  21,
  26,
  8,
  15,
  13,
  29,
  25,
  16,
  15,
  22,
  24,
  13,
  21],
 'labels': [9,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  11,
  10,
  11,
  12,
  12,
  12,
  12,
 

In [11]:
data_collator = DataCollator(vocab)

In [13]:
batch = data_collator(data_examples_list)
batch.keys()

dict_keys(['input_ids', 'labels', 'attention_masking'])