In [9]:
#@title Enter the construct sequences in FASTA format and hit `Runtime` -> `Run all`
batch_size = 4 # @param {"type":"integer"}
input_fasta = """
>CONSTR_000001
MTVFFVTRLVKKHDKLSKQQIEDFAEKLMTILFETYRSHWHSDCPSKGQAFRCIRINNNQ
NKDPILERACVESNVDFSHLGLPKEMTIWVDPFEVCCRYGEKNHPFTVASFKGRWEEWEL
YQQISYAVSRASSDVSSGTSCDEESCGSHHHHHH
>CONSTR_000002
MDYTKPLEHPPVKRNEEAQVHDKLNSGMVSNMEGTAGGERPSVVNGDSGKSGGVGDPREP
LGCLQEGSGCHPTTESFEKSVREDASPLPHVCCCKQDALILQRGLHHEDGSQHIGLLHPG
DRGPDHEYVLVEEAECGSHHHHHH
>CONSTR_000003
MHHHHHHENLYFQGSLEVRGQLQSALLILGEPKEGGMPMNISIMPSSLQMKTPEGCTEIQ
LPAEVRLVPSSCRGLQFVVGDGLHLRLQTQAKLGTKLISMFNQSSQTQE
>CONSTR_000004
MECPEGQLPISSENDSTPTVSTSEVTSQQEPQILVDRGSETTYESSADIAGDEGTQIPAD
EDTQTDADSSAQAAAQAPENFQEGKDMSESQDEVPDEVENGSHHHHHH
>CONSTR_000005
MSTAPSEDIWKKFELVPSPPTSPPWGLGPGAGDPAPGIGPPEPWPGGCTGDEAESRGHSK
GWGRNYASIIRRDCMWSGFSARERLERAVSDRLAPGAPRGNPPKASAAPDCTPSLEAGNP
APAAPCPLGEPKTQACSGSESPSDSENEEIDVVTVEKRQSLGIRKPVTITVRADPLDPCM
KHFHGSHHHHHH
>CONSTR_000006
MEKARHETFAAEMRQNDKIMCILENRKKRDRKNLCRAINDFQQSFQKPETRREFDLSDPL
ALKKDLPARQSDNDVRNTISGMQGSHHHHHH
>CONSTR_000007
MLMKKAYELSVLCDCEIALIIFNSANRLFQYASTDMDRVLLKYTEYSEPHESRTNTDILE
TLKRRGIGLDGPELEPDEGPEEPGEKFRRLAGEGGDPGSHHHHHH
>CONSTR_000008
MPTESASCSTARQTKQKRKSHSLSIRRTNSSEQERTGLPRDMLEGQDSKLPSSVRSTLLE
LFGQIEREFENLYIENLELRREIDTLNERLAAEGQAIDGAELSKGQLKTKASHSTSQLSQ
KLKTTYKASTSKIVSSFKTTTSRAACQLVKEYIGHRDGIWDVSVAKTQPVVLGTASADHT
ALLWSIETGKCLVKYAGHVGSVNSIKFHPSEQLALTASGDQTAHIWRYAVQLPTPQPVAD
TSISGEDEVECSDKDEPDLDGDVSSDCPTIRVPLTSLKSHQGVVIASDWLVGGKQAVTAS
WDRTANLYDVETSELVHSLTGHDQELTHCCTHPTQRLVVTSSRDTTFRLWDFRDPSIHSV
NVFQGHTDTVTSAVFTVGDNVVSGSDDRTVKVWDLKNMRSPIATIRTDSAINRINVCVGQ
KIIALPHDNRQVRLFDMSGVRLARLPRSSRQGHRRMVCCSAWSEDHPVCNLFTCGFDRQA
IGWNINIPALLQEKGSHHHHHH
>CONSTR_000009
MHHHHHHENLYFQGSPTESASCSTARQTKQKRKSHSLSIRRTNSSEQERTGLPRDMLEGQ
DSKLPSSVRSTLLELFGQIEREFENLYIENLELRREIDTLNERLAAEGQAIDGAELSKGQ
LKTKASHSTSQLSQKLKTTYKASTSKIVSSFKTTTSRAACQLVKEYIGHRDGIWDVSVAK
TQPVVLGTASADHTALLWSIETGKCLVKYAGHVGSVNSIKFHPSEQLALTASGDQTAHIW
RYAVQLPTPQPVADTSISGEDEVECSDKDEPDLDGDVSSDCPTIRVPLTSLKSHQGVVIA
SDWLVGGKQAVTASWDRTANLYDVETSELVHSLTGHDQELTHCCTHPTQRLVVTSSRDTT
FRLWDFRDPSIHSVNVFQGHTDTVTSAVFTVGDNVVSGSDDRTVKVWDLKNMRSPIATIR
TDSAINRINVCVGQKIIALPHDNRQVRLFDMSGVRLARLPRSSRQGHRRMVCCSAWSEDH
PVCNLFTCGFDRQAIGWNINIPALLQEK
>CONSTR_000010
MRDEIATTVFFVTRLVKKHDKLSKQQIEDFAEKLMTILFETYRSHWHSDCPSKGQAFRCI
RINNNQNKDPILERACVESNVDFSHLGLPKEMTIWVDPFEVCCRYGEKNHPFTVASFKGR
WEEWELYQQISYAVSRASSDVSSGTSCDEESCSKEPRVIPKVSNPKSIYQVENLKQPFQS
WLQIPRKKNVVDGRVGLLGNTYHGSQKHPKCYRPAMHRLDRILGSHHHHHH
"""

In [None]:
#@title Install the dependencies and download the checkpoint
%%bash

set -e

pip install RP3Net 'torchvision==0.20.1'
wget -nv -nc https://ftp.ebi.ac.uk/pub/software/RP3Net/v0.1/checkpoints/rp3net_v0.1_d.ckpt

In [5]:
#@title Imports
import re
import io
import pandas as pd
import RP3Net as rp3
from tqdm.notebook import tqdm
RE_FASTA_HEADER = re.compile(r'^>([\w\-.:#*]+)') # https://www.ncbi.nlm.nih.gov/genbank/fastaformat/

In [6]:
#@title Helper functions
def iter_fasta(io):
    fasta_id, sequence = None, None
    for line in io:
        line = line.strip()
        if len(line) == 0:
            continue
        m = RE_FASTA_HEADER.match(line)
        if m:
            if fasta_id is not None:
                yield fasta_id, ''.join(sequence)
            sequence = []
            fasta_id = m.group(1)
        else:
            sequence.append(line)
    if fasta_id is not None:
        yield fasta_id, ''.join(sequence)

def parse_fasta(s):
    return {id: seq for id, seq in iter_fasta(io.StringIO(s))}

def batches():
    fasta_map = parse_fasta(input_fasta)
    fasta_keys = list(fasta_map.keys())
    r = tqdm(range(0, len(fasta_map), batch_size), desc='RP3Net Inference')
    for i in r:
        yield {k: fasta_map[k] for k in fasta_keys[i:i + batch_size]}

In [7]:
#@title Load the model
m = rp3.load_model(rp3.RP3_DEFAULT_CONFIG, 'rp3net_v0.1_d.ckpt')


In [None]:
#@title Run the prediction on GPU
m = m.to(device='cuda')
scores_map = dict()
for b in batches():
    scores_map |= m.predict(b, device='cuda')


In [11]:
#@title Print and save the results
df = pd.DataFrame([[id, score] for (id, score) in scores_map.items()], columns=['id', 'score'])
print(df)
df.to_csv("rp3_scores.csv", index=False)

              id     score
0  CONSTR_000001  0.691543
1  CONSTR_000002  0.971137
2  CONSTR_000003  0.931065
3  CONSTR_000004  0.972745
4  CONSTR_000005  0.928140
5  CONSTR_000006  0.977404
6  CONSTR_000007  0.744749
7  CONSTR_000008  0.009805
8  CONSTR_000009  0.009679
9  CONSTR_000010  0.433345
